diff --git a/src/simple_websocket/ws.py b/src/simple_websocket/ws.py index 871bee5..68d60b9 100644 --- a/src/simple_websocket/ws.py +++ b/src/simple_websocket/ws.py @@ -41,6 +41,7 @@ def __init__(self, sock=None, connection_type=None, receive_bytes=4096, ping_interval=None, max_message_size=None, thread_class=None, event_class=None, selector_class=None): self.sock = sock + self.subprotocol = None self.receive_bytes = receive_bytes self.ping_interval = ping_interval self.max_message_size = max_message_size @@ -129,6 +130,19 @@ def close(self, reason=None, message=None): pass self.connected = False + def choose_subprotocol(self, request): # pragma: no cover + """Choose a subprotocol to use for the WebSocket connection. + + The default implementation does not accept any subprotocols. Subclasses + can override this method to implement subprotocol negotiation. + + :param request: A ``Request`` object. + + The method should return the subprotocol to use, or ``None`` if no + subprotocol is chosen. + """ + return None + def _thread(self): sel = None if self.ping_interval: @@ -166,7 +180,9 @@ def _handle_events(self): for event in self.ws.events(): try: if isinstance(event, Request): + self.subprotocol = self.choose_subprotocol(event) out_data += self.ws.send(AcceptConnection( + subprotocol=self.subprotocol, extensions=[PerMessageDeflate()])) elif isinstance(event, CloseConnection): if self.is_server: @@ -248,6 +264,8 @@ class Server(Base): does this in its own different way. Werkzeug, Gunicorn, Eventlet and Gevent are the only web servers that are currently supported. + :param subprotocols: A list of supported subprotocols, or ``None`` (the + default) to disable subprotocol negotiation. :param receive_bytes: The size of the receive buffer, in bytes. The default is 4096. :param ping_interval: Send ping packets to clients at the requested @@ -270,10 +288,11 @@ class from the Python standard library. ``selectors.DefaultSelector`` class from the Python standard library. """ - def __init__(self, environ, receive_bytes=4096, ping_interval=None, - max_message_size=None, thread_class=None, event_class=None, - selector_class=None): + def __init__(self, environ, subprotocols=None, receive_bytes=4096, + ping_interval=None, max_message_size=None, thread_class=None, + event_class=None, selector_class=None): self.environ = environ + self.subprotocols = subprotocols self.mode = 'unknown' sock = None if 'werkzeug.socket' in environ: @@ -316,12 +335,23 @@ def handshake(self): self.ws.receive_data(in_data) self.connected = self._handle_events() + def choose_subprotocol(self, request): + print(request.subprotocols) + print(self.subprotocols) + for subprotocol in request.subprotocols: + if subprotocol in self.subprotocols: + return subprotocol + return None + class Client(Base): """This class implements a WebSocket client. :param url: The connection URL. Both ``ws://`` and ``wss://`` URLs are accepted. + :param subprotocols: The name of the subprotocol to use, or a list of + subprotocol names in order of preference. Set to + ``None`` (the default) to not use a subprotocol. :param receive_bytes: The size of the receive buffer, in bytes. The default is 4096. :param ping_interval: Send ping packets to the server at the requested @@ -344,9 +374,9 @@ class from the Python standard library. objects. The default is the `threading.Event`` class from the Python standard library. """ - def __init__(self, url, receive_bytes=4096, ping_interval=None, - max_message_size=None, ssl_context=None, thread_class=None, - event_class=None, selector_class=None): + def __init__(self, url, subprotocols=None, receive_bytes=4096, + ping_interval=None, max_message_size=None, ssl_context=None, + thread_class=None, event_class=None, selector_class=None): parsed_url = urlsplit(url) is_secure = parsed_url.scheme in ['https', 'wss'] self.host = parsed_url.hostname @@ -354,6 +384,9 @@ def __init__(self, url, receive_bytes=4096, ping_interval=None, self.path = parsed_url.path if parsed_url.query: self.path += '?' + parsed_url.query + self.subprotocols = subprotocols or [] + if isinstance(self.subprotocols, str): + self.subprotocols = [self.subprotocols] sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) if is_secure: # pragma: no cover @@ -369,7 +402,8 @@ def __init__(self, url, receive_bytes=4096, ping_interval=None, thread_class=thread_class, event_class=event_class) def handshake(self): - out_data = self.ws.send(Request(host=self.host, target=self.path)) + out_data = self.ws.send(Request(host=self.host, target=self.path, + subprotocols=self.subprotocols or [])) self.sock.send(out_data) in_data = self.sock.recv(self.receive_bytes) @@ -379,6 +413,7 @@ def handshake(self): raise ConnectionError(event.status_code) elif not isinstance(event, AcceptConnection): # pragma: no cover raise ConnectionError(400) + self.subprotocol = event.subprotocol self.connected = True def close(self, reason=None, message=None): diff --git a/tests/test_simple_websocket_client.py b/tests/test_simple_websocket_client.py index b08bbcd..851776b 100644 --- a/tests/test_simple_websocket_client.py +++ b/tests/test_simple_websocket_client.py @@ -9,12 +9,12 @@ class SimpleWebSocketClientTestCase(unittest.TestCase): - def get_client(self, mock_wsconn, url, events=[]): + def get_client(self, mock_wsconn, url, events=[], subprotocols=None): mock_wsconn().events.side_effect = \ [iter(ev) for ev in [[AcceptConnection()]] + events + [[CloseConnection(1000)]]] mock_wsconn().send = lambda x: str(x).encode('utf-8') - return simple_websocket.Client(url) + return simple_websocket.Client(url, subprotocols=subprotocols) @mock.patch('simple_websocket.ws.socket.socket') @mock.patch('simple_websocket.ws.WSConnection') @@ -33,6 +33,28 @@ def test_make_client(self, mock_wsconn, mock_socket): assert client.port == 80 assert client.path == '/ws?a=1' + @mock.patch('simple_websocket.ws.socket.socket') + @mock.patch('simple_websocket.ws.WSConnection') + def test_make_client_subprotocol(self, mock_wsconn, mock_socket): + mock_socket.return_value.recv.return_value = b'x' + client = self.get_client(mock_wsconn, 'ws://example.com/ws?a=1', + subprotocols='foo') + assert client.subprotocols == ['foo'] + client.sock.send.assert_called_with( + b"Request(host='example.com', target='/ws?a=1', extensions=[], " + b"extra_headers=[], subprotocols=['foo'])") + + @mock.patch('simple_websocket.ws.socket.socket') + @mock.patch('simple_websocket.ws.WSConnection') + def test_make_client_subprotocols(self, mock_wsconn, mock_socket): + mock_socket.return_value.recv.return_value = b'x' + client = self.get_client(mock_wsconn, 'ws://example.com/ws?a=1', + subprotocols=['foo', 'bar']) + assert client.subprotocols == ['foo', 'bar'] + client.sock.send.assert_called_with( + b"Request(host='example.com', target='/ws?a=1', extensions=[], " + b"extra_headers=[], subprotocols=['foo', 'bar'])") + @mock.patch('simple_websocket.ws.socket.socket') @mock.patch('simple_websocket.ws.WSConnection') def test_send(self, mock_wsconn, mock_socket): diff --git a/tests/test_simple_websocket_server.py b/tests/test_simple_websocket_server.py index 2c1c80d..3d10e03 100644 --- a/tests/test_simple_websocket_server.py +++ b/tests/test_simple_websocket_server.py @@ -9,9 +9,13 @@ class SimpleWebSocketServerTestCase(unittest.TestCase): - def get_server(self, mock_wsconn, environ, events=[], **kwargs): + def get_server(self, mock_wsconn, environ, events=[], + client_subprotocols=None, server_subprotocols=None, + **kwargs): mock_wsconn().events.side_effect = \ - [iter(ev) for ev in [[Request(host='example.com', target='/ws')]] + + [iter(ev) for ev in [[ + Request(host='example.com', target='/ws', + subprotocols=client_subprotocols or [])]] + events + [[CloseConnection(1000, 'bye')]]] mock_wsconn().send = lambda x: str(x).encode('utf-8') environ.update({ @@ -21,7 +25,8 @@ def get_server(self, mock_wsconn, environ, events=[], **kwargs): 'HTTP_SEC_WEBSOCKET_KEY': 'Iv8io/9s+lYFgZWcXczP8Q==', 'HTTP_SEC_WEBSOCKET_VERSION': '13', }) - return simple_websocket.Server(environ, **kwargs) + return simple_websocket.Server( + environ, subprotocols=server_subprotocols, **kwargs) @mock.patch('simple_websocket.ws.WSConnection') def test_werkzeug(self, mock_wsconn): @@ -217,3 +222,29 @@ def test_ping_pong(self, mock_time, mock_wsconn): assert mock_socket.send.call_args_list[1][0][0].startswith(b'Ping') assert mock_socket.send.call_args_list[2][0][0].startswith(b'Ping') assert mock_socket.send.call_args_list[3][0][0].startswith(b'Close') + + @mock.patch('simple_websocket.ws.WSConnection') + def test_subprotocols(self, mock_wsconn): + mock_socket = mock.MagicMock() + mock_socket.recv.return_value = b'x' + + server = self.get_server(mock_wsconn, { + 'werkzeug.socket': mock_socket, + }, client_subprotocols=['foo', 'bar'], server_subprotocols=['bar']) + while server.connected: + time.sleep(0.01) + assert server.subprotocol == 'bar' + + server = self.get_server(mock_wsconn, { + 'werkzeug.socket': mock_socket, + }, client_subprotocols=['foo'], server_subprotocols=['foo', 'bar']) + while server.connected: + time.sleep(0.01) + assert server.subprotocol == 'foo' + + server = self.get_server(mock_wsconn, { + 'werkzeug.socket': mock_socket, + }, client_subprotocols=['foo'], server_subprotocols=['bar', 'baz']) + while server.connected: + time.sleep(0.01) + assert server.subprotocol is None