Source code for molten.contrib.websockets

# This file is a part of molten.
# Copyright (C) 2018 CLEARTYPE SRL <[email protected]>
# molten is free software; you can redistribute it and/or modify it
# under the terms of the GNU Lesser General Public License as published by
# the Free Software Foundation, either version 3 of the License, or (at
# your option) any later version.
# molten is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
# License for more details.
# You should have received a copy of the GNU Lesser General Public License
# along with this program.  If not, see <>.

# This module is tested using the autobahn testsuite:

import io
import logging
import selectors
import socket
import struct
import time
from base64 import b64encode
from concurrent.futures import Future, ThreadPoolExecutor
from hashlib import sha1
from inspect import Parameter
from typing import Any, Callable, Optional, Pattern, Union

from molten import (
    HTTP_400, HTTP_426, BaseApp, DependencyResolver, Environ, HeaderMissing, HTTPError, MoltenError,
    Request, RequestHandled, Response, Route, TestClient
from molten.http.headers import Headers, HeadersDict
from molten.http.query_params import ParamsDict, QueryParams

    import gevent
except ImportError:  # pragma: no cover
    raise ImportError("'gevent' package missing. Run 'pip install gevent'.")

LOGGER = logging.getLogger(__name__)

#: The pre-shared key defined in the Websocket spec.
PSK = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"

#: The amount of bytes to request per recv call.
CHUNKSIZE = 16 * 1024

#: The maximum number of bytes text and binary frames can contain.
MAX_MESSAGE_SIZE = 16 * 1024 * 1024

#: The maximum number of bytes data frames can contain.

#: The maximum number of bytes control frames can contain.

#: Continuation frame.  Only valid if received after non-final text or binary frames.

#: A frame with a utf-8 encoded payload.  May or may not be final.
OP_TEXT = 0x1

#: A frame containing binary data.  May or may not be final.

#: A frame signaling that the connection should be closed.  Always final.
OP_CLOSE = 0x8

#: A frame signaling that a PONG frame should be sent to the client.  Always final.
OP_PING = 0x9

#: A heartbeat frame.  Always final.

#: The set of data frame opcodes.

#: The set of control frame opcodes.

#: The set of all valid opcodes.

#: The set of valid close message status codes.
VALID_STATUS_CODES = {1000, 1001, 1002, 1003, 1007, 1008, 1009, 1010, 1011}

#: The set of reserved close message status codes.
RESERVED_STATUS_CODES = {1004, 1005, 1006, 1015}

#: The set of supported versions.
SUPPORTED_VERSIONS = {"7", "8", "13"}

#: The set of supported versions as a string.

#: The payload that is returned as part of the connection upgrade process.
    b"HTTP/1.1 101 Switching Protocols",
    b"connection: upgrade",
    b"upgrade: websocket",
    b"server: molten",
    b"sec-websocket-accept: %(websocket_accept)s",

[docs]class WebsocketError(MoltenError): """Base class for errors related to websockets. """
[docs]class WebsocketProtocolError(WebsocketError): """Raised whenever the protocol is violated. """
[docs]class WebsocketMessageTooLargeError(WebsocketProtocolError): """Raised when an incoming message contains too much data. """
[docs]class WebsocketFrameTooLargeError(WebsocketError): """Raised when a frame's payload is too large to be sent in a single frame. """
[docs]class WebsocketClosedError(WebsocketError): """Raised when a message is sent to a closed socket. """
class _BufferedStream: """A buffered IO stream backed by a socket. This makes data frame parsing simple and efficient because the data frame reader can request data in small byte chunks, whereas this will read the data in large chunks from the socket under the hood. """ __slots__ = ["buf", "closed", "socket"] def __init__(self, socket: socket.socket) -> None: self.buf = b"" self.closed = False self.socket = socket def read(self, n: int) -> bytes: while not self.closed and len(self.buf) < n: data = self.socket.recv(CHUNKSIZE) if not data: self.closed = True return self.buf self.buf += data data, self.buf = self.buf[:n], self.buf[n:] return data def expect(self, n: int) -> bytes: data = if len(data) != n: raise WebsocketProtocolError("Unexpected EOF while reading from socket.") return data def write(self, data: bytes) -> None: self.socket.sendall(data) def close(self) -> None: try: self.socket.shutdown(True) self.socket.close() except Exception as e: LOGGER.warning("Failed to close underlying socket: %s", e) class _DataFrameHeader: __slots__ = ["fin", "flags", "opcode", "length", "mask"] RSV1_MASK = 0x40 RSV2_MASK = 0x20 RSV3_MASK = 0x10 FIN_MASK = MASK_MASK = 0x80 FLAGS_MASK = RSV1_MASK | RSV2_MASK | RSV3_MASK OPCODE_MASK = 0x0F LENGTH_MASK = 0x7F def __init__(self, fin: bool = False, flags: int = 0, opcode: int = 0, length: int = 0, mask: Optional[bytearray] = None) -> None: # noqa self.fin = fin self.flags = flags self.opcode = opcode self.length = length self.mask = mask or bytearray() def mask_data(self, data: bytes) -> bytes: data_array = bytearray(data) for i in range(self.length): data_array[i] ^= self.mask[i % 4] return bytes(data_array) @classmethod def from_stream(cls, stream: _BufferedStream) -> "_DataFrameHeader": """Read a data frame header from the input stream. """ read = stream.expect data = read(2) fb, sb = struct.unpack("!BB", data) header = cls( fb & cls.FIN_MASK == cls.FIN_MASK, fb & cls.FLAGS_MASK, fb & cls.OPCODE_MASK, ) length = sb & cls.LENGTH_MASK if length == 126: header.length = struct.unpack("!H", read(2))[0] elif length == 127: header.length = struct.unpack("!Q", read(8))[0] else: header.length = length if sb & cls.MASK_MASK == cls.MASK_MASK: header.mask = bytearray(read(4)) return header def to_stream(self, stream: _BufferedStream) -> None: """Write this header to the output stream. """ output = bytearray() fb = self.opcode if self.fin: fb |= self.FIN_MASK if self.flags & self.RSV1_MASK == self.RSV1_MASK: fb |= self.RSV1_MASK if self.flags & self.RSV2_MASK == self.RSV2_MASK: fb |= self.RSV2_MASK if self.flags & self.RSV3_MASK == self.RSV3_MASK: fb |= self.RSV3_MASK output.append(fb) sb = self.MASK_MASK if self.mask else 0 if self.length < 126: sb |= self.length output.append(sb) elif self.length <= 0xFFFF: sb |= 126 output.append(sb) output.extend(struct.pack("!H", self.length)) elif self.length <= 0xFFFFFFFFFFFFFFF: sb |= 127 output.append(sb) output.extend(struct.pack("!Q", self.length)) else: raise WebsocketFrameTooLargeError(f"{self.length} bytes cannot fit in a single frame.") if self.mask: output.extend(self.mask) stream.write(output) class _DataFrame: __slots__ = ["header", "data"] def __init__(self, header: _DataFrameHeader, data: bytes = b"") -> None: self.header = header = data @classmethod def from_stream(cls, stream: _BufferedStream) -> "_DataFrame": """Read a data frame from an input stream. """ header = _DataFrameHeader.from_stream(stream) if header.opcode not in OPCODES: raise WebsocketProtocolError(f"Invalid opcode 0x{header.opcode:x}.") if header.flags != 0: raise WebsocketProtocolError("Reserved flags must not be set.") if header.opcode in CONTROL_FRAME_OPCODES: max_size = MAX_CONTROL_FRAME_PAYLOAD_SIZE else: max_size = MAX_DATA_FRAME_PAYLOAD_SIZE if header.length > max_size: raise WebsocketMessageTooLargeError(f"Payload exceeds {max_size} bytes.") data = stream.expect(header.length) if header.mask: data = header.mask_data(data) return cls(header, data) def to_stream(self, stream: _BufferedStream) -> None: """Write this data frame to the output stream. """ self.header.to_stream(stream) if self.header.mask: stream.write(self.header.mask_data( else: stream.write(
[docs]class Message: """A websocket message, composed of one or more data frames. """ __slots__ = ["buf"] def __init__(self, message: bytes = b"") -> None: self.buf = io.BytesIO(message) @classmethod def from_frame(cls, frame: _DataFrame) -> "Message": message = cls() message.buf.write( return message def add_frame(self, frame: _DataFrame) -> None: # pragma: no cover raise NotImplementedError(f"{type(self).__name__} does not implement add_frame()")
[docs] def to_stream(self, stream: _BufferedStream) -> None: """Write this message to the output stream. """ output = self.get_output() header = _DataFrameHeader(fin=True, opcode=OPCODES_BY_MESSAGE[type(self)], length=len(output)) frame = _DataFrame(header, output) # type: ignore frame.to_stream(stream)
[docs] def get_data(self) -> bytes: """Get this message's data as a bytestring. """ return self.buf.getvalue()
[docs] def get_text(self) -> str: """Get this message's contents as text. """ try: return self.buf.getvalue().decode("utf-8") except UnicodeDecodeError: raise WebsocketProtocolError("Invalid UTF-8 payload.") from None
[docs] def get_output(self) -> Union[bytes, bytearray, memoryview]: """Get this message's output payload. CloseMessage hooks into this to prepend the status code to the payload. """ return self.buf.getbuffer()
[docs]class CloseMessage(Message): """Received (or sent) when the connection should be closed. Close messages sent by the client are automatically handled by receive(). Attributes: code(int): The close status code. """ __slots__ = ["buf", "code"] def __init__(self, code: int = 1000, reason: str = "") -> None: self.buf = io.BytesIO(reason.encode("utf-8")) self.code = code @classmethod def from_frame(cls, frame: _DataFrame) -> "Message": code = 1000 if code_data, =[:2],[2:] if len(code_data) < 2: raise WebsocketProtocolError("Expected status code in close message payload.") code = struct.unpack("!H", code_data)[0] if code < 1000 or code > 4999: raise WebsocketProtocolError(f"Invalid status code {code}.") elif code in RESERVED_STATUS_CODES: raise WebsocketProtocolError(f"Status code {code} is reserved.") elif code < 3000 and code not in VALID_STATUS_CODES: raise WebsocketProtocolError(f"Invalid status code {code}.") message = cls() message.code = code message.buf.write( return message def get_output(self) -> Union[bytes, bytearray, memoryview]: return struct.pack("!H", self.code) + self.buf.getvalue()
[docs]class BinaryMessage(Message): """A message containing binary data. """ def add_frame(self, frame: _DataFrame) -> None: if len(self.buf.getbuffer()) + len( > MAX_MESSAGE_SIZE: raise WebsocketProtocolError(f"Message exceeds {MAX_MESSAGE_SIZE} bytes.") self.buf.write(
[docs]class TextMessage(BinaryMessage): """A message containing text data. """ def __init__(self, message: str = "") -> None: super().__init__(message.encode("utf-8"))
[docs]class PingMessage(Message): """A PING message. These are automatically handled by receive(). """
[docs]class PongMessage(Message): """A PONG message. These are automatically handled by receive(). """
#: A mapping from message classes to opcodes. OPCODES_BY_MESSAGE = { CloseMessage: OP_CLOSE, BinaryMessage: OP_BINARY, TextMessage: OP_TEXT, PingMessage: OP_PING, PongMessage: OP_PONG, }
[docs]class Websocket: """Represents a single websocket connection. These are used for bi-directional communication with a websocket client. Websockets are *not* thread-safe. Example: >>> from molten import annotate >>> from molten.contrib.websockets import CloseMessage, Websocket >>> @annotate(supports_ws=True) ... def echo(sock: Websocket): ... while not sock.closed: ... message = sock.receive() ... if isinstance(message, CloseMessage): ... break ... ... sock.send(message) Attributes: closed(bool): Whether or not this socket has been closed. """ __slots__ = ["closed", "stream"] def __init__(self, stream: _BufferedStream) -> None: self.closed = False = stream
[docs] def receive(self, *, timeout: Optional[float] = None) -> Optional[Message]: """Waits for a message from the client for up to *timeout* seconds. """ if self.closed: return None with gevent.Timeout(timeout): message = None while True: frame = _DataFrame.from_stream( if frame.header.opcode == OP_TEXT: if message is not None: raise WebsocketProtocolError("Unexpected text frame.") message = TextMessage.from_frame(frame) elif frame.header.opcode == OP_BINARY: if message is not None: raise WebsocketProtocolError("Unexpected binary frame.") message = BinaryMessage.from_frame(frame) elif frame.header.opcode == OP_CONTINUATION: if message is None: raise WebsocketProtocolError("Unexpected continuation frame.") message.add_frame(frame) elif frame.header.opcode == OP_CLOSE: if not frame.header.fin: raise WebsocketProtocolError("Close frame is not final.") message = CloseMessage.from_frame(frame) self.close(CloseMessage(reason=message.get_text())) return message elif frame.header.opcode == OP_PING: if not frame.header.fin: raise WebsocketProtocolError("Ping frame is not final.") self.send(PongMessage( continue elif frame.header.opcode == OP_PONG: if not frame.header.fin: raise WebsocketProtocolError("Pong frame is not final.") continue else: raise WebsocketProtocolError(f"Unexpected frame with opcode 0x{frame.header.opcode:x}.") if frame.header.fin: return message
[docs] def send(self, message: Message) -> None: """Send a message to the client. """ if self.closed: raise WebsocketClosedError("Websocket already closed.") message.to_stream(
[docs] def close(self, message: Optional[Message] = None) -> None: """Close this websocket and send a close message to the client. Note: This does not close the underlying websocket as it's better to let gunicorn handle that by itself. """ try: self.send(message or CloseMessage()) except WebsocketClosedError: pass finally: self.closed = True
class _WebsocketComponent: """Resolves websocket objects. Users of this module don't need to worry about providing this to the App object as the middleware does it automatically. """ __slots__ = ["websocket"] is_cacheable = True is_singleton = False def __init__(self, websocket: Websocket) -> None: self.websocket = websocket def can_handle_parameter(self, parameter: Parameter) -> bool: return parameter.annotation is Websocket def resolve(self) -> Websocket: return self.websocket
[docs]class WebsocketsMiddleware: """A middleware that handles websocket upgrades. Warning: Please note that this functionality is currently gunicorn-specific and it requires the use of async workers in order to function correctly. Parameters: origin_re: An optional regular expression that can be used to validate the origin of incoming browser requests. """ __slots__ = ["origin_re"] def __init__(self, origin_re: Optional[Pattern[str]] = None) -> None: self.origin_re = origin_re
[docs] def handle_exception(self, exception: BaseException, websocket: Websocket) -> None: """Called whenever an unhandled exception occurs in middleware or a handler. Overwrite this in a subclass to implement custom error handling for websocket handlers. If you do overwrite this, don't forget to close the websocket connection when necessary. """ LOGGER.exception("Unhandled error from websocket handler.") if issubclass(type(exception), WebsocketProtocolError): websocket.close(CloseMessage(1002, str(exception))) elif issubclass(type(exception), WebsocketFrameTooLargeError): websocket.close(CloseMessage(1009, str(exception))) else: websocket.close(CloseMessage(1011, "Internal server error."))
def __call__(self, handler: Callable[..., Any]) -> Callable[..., Response]: def handle( resolver: DependencyResolver, request: Request, environ: Environ, route: Optional[Route], ) -> Response: if route is None or not getattr(route.handler, "supports_ws", False): return handler() try: connection = request.headers["connection"] upgrade = request.headers["upgrade"] websocket_key = request.headers["sec-websocket-key"] websocket_version = request.headers["sec-websocket-version"] except HeaderMissing as e: raise HTTPError(HTTP_400, {"errors": {str(e): "this header is required"}}) try: origin = request.headers["origin"] except HeaderMissing: origin = "" if self.origin_re and not self.origin_re.match(origin): raise HTTPError(HTTP_400, {"error": "invalid origin"}) if "upgrade" not in connection.lower() or "websocket" not in upgrade.lower(): raise HTTPError(HTTP_400, {"error": "invalid upgrade request"}) if websocket_version not in SUPPORTED_VERSIONS: return Response(HTTP_426, headers={"sec-websocket-version": SUPPORTED_VERSIONS_STR}) # TODO: Implement extension handling. # TODO: Implement subprotocol handling. stream = _BufferedStream(environ["gunicorn.socket"]) stream.write(UPGRADE_RESPONSE_TEMPLATE % { b"websocket_accept": b64encode(sha1(f"{websocket_key}{PSK}".encode()).digest()), }) websocket = Websocket(stream) resolver.add_component(_WebsocketComponent(websocket)) try: handler() except Exception as e: handle_exception = resolver.resolve(self.handle_exception) handle_exception(exception=e) finally: websocket.close(CloseMessage()) raise RequestHandled("websocket request was upgraded") return handle
class _WebsocketsTestConnection: """A proxy context manager for websocket objects. """ __slots__ = ["__future", "__socket"] def __init__(self, future: Future, socket: Websocket) -> None: # type: ignore self.__future = future self.__socket = socket def close(self) -> None: try: self.__socket.send(CloseMessage()) except WebsocketClosedError: pass finally: self.__future.result() def __enter__(self) -> "_WebsocketsTestConnection": return self def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: self.close() def __getattr__(self, name: str) -> Any: return getattr(self.__socket, name)
[docs]class WebsocketsTestClient(TestClient): """This is a subclass of the standard test client that adds an additional method called :meth:`.connect` that may be used to connect to websocket endpoints. Example: >>> client = WebsocketsTestClient(app) >>> with client.connect("/echo") as sock: ... sock.send(TextMessage("hi!")) ... assert sock.receive(timeout=1).get_text() == "hi!" Note: In order for :meth:`receive's<Websocket.receive>` "timeout" parameter to work, you need use gevent to monkeypatch sockets before running your tests. """ def __init__(self, app: BaseApp) -> None: = app self.executor = ThreadPoolExecutor(max_workers=8)
[docs] def connect( self, path: str, headers: Optional[Union[HeadersDict, Headers]] = None, params: Optional[Union[ParamsDict, QueryParams]] = None, auth: Optional[Callable[[Request], Request]] = None, ) -> _WebsocketsTestConnection: """Initiate a websocket connection against the application. Parameters: path: The request path. headers: Optional request headers. params: Optional query params. auth: An optional function that can be used to add auth headers to the request. """ headers = headers or Headers() headers["connection"] = "upgrade" headers["upgrade"] = "websocket" headers["sec-websocket-key"] = b64encode(b"a" * 16).decode() headers["sec-websocket-version"] = "13" client_sock, server_sock = socket.socketpair() def get(*args, **kwargs): try: return self.get(*args, **kwargs) finally: client_sock.close() def prepare_environ(environ: Environ) -> Environ: nonlocal client_sock environ["gunicorn.socket"] = client_sock return environ # Execute the websocket handler in a background thread because # it may block while waiting on the socket. Keep a reference # to it so we can keep track of exceptions that occur in the # handler. future = self.executor.submit( get, path, headers, params, auth=auth, prepare_environ=prepare_environ, ) expected_response = UPGRADE_RESPONSE_TEMPLATE % { b"websocket_accept": b"3SC6TZx4582OZaOogPVxMx5CGS0=", } # Consume the upgrade response and make sure it looks right. response_data = _chomp(server_sock, expected_response, timeout=5) if response_data is None: raise RuntimeError("timed out while waiting for upgrade response") elif not response_data == expected_response: raise ValueError(f"Invalid upgrade response: {response_data!r}. Did you connect() to a standard endpoint?") websocket = Websocket(_BufferedStream(server_sock)) return _WebsocketsTestConnection(future, websocket)
def _chomp(sock: socket.socket, expected: bytes, timeout: float) -> Union[bytes, None]: try: res = b'' sock.setblocking(False) with selectors.DefaultSelector() as sel: sel.register(sock, selectors.EVENT_READ) deadline = time.monotonic() + timeout while True: events =, deadline - time.monotonic())) if not events: return None for _, _ in events: data = sock.recv(CHUNKSIZE) if not data: return res res += data if res == expected: return res finally: sock.setblocking(True)