diff --git a/docs/api.rst b/docs/api.rst index 9dd5f8f88..26fdc25bc 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -32,7 +32,7 @@ Server .. automodule:: websockets.server - .. autofunction:: serve(ws_handler, host=None, port=None, *, klass=WebSocketServerProtocol, timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, loop=None, origins=None, subprotocols=None, extra_headers=None, **kwds) + .. autofunction:: serve(ws_handler, host=None, port=None, *, create_protocol=None, timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, loop=None, origins=None, subprotocols=None, extra_headers=None, **kwds) .. autoclass:: WebSocketServer @@ -50,7 +50,7 @@ Client .. automodule:: websockets.client - .. autofunction:: connect(uri, *, klass=WebSocketClientProtocol, timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, loop=None, origin=None, subprotocols=None, extra_headers=None, **kwds) + .. autofunction:: connect(uri, *, create_protocol=None, timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, loop=None, origin=None, subprotocols=None, extra_headers=None, **kwds) .. autoclass:: WebSocketClientProtocol(*, host=None, port=None, secure=None, timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, loop=None) diff --git a/docs/changelog.rst b/docs/changelog.rst index 6b0241239..747948ee8 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -6,6 +6,11 @@ Changelog *In development* +* Renamed :func:`~websockets.server.serve()` and + :func:`~websockets.client.connect()`'s ``klass`` argument to + ``create_protocol`` to reflect that it can also be a callable. + For backwards compatibility, ``klass`` is still supported. + * :func:`~websockets.server.serve` can be used as an asynchronous context manager on Python ≥ 3.5. diff --git a/docs/cheatsheet.rst b/docs/cheatsheet.rst index cf6897257..5ee2c221f 100644 --- a/docs/cheatsheet.rst +++ b/docs/cheatsheet.rst @@ -23,7 +23,8 @@ Server the handler exits normally or with an exception. * You may subclass :class:`~websockets.server.WebSocketServerProtocol` and - pass it in the ``klass`` keyword argument for advanced customization. + pass it or a factory function as the ``create_protocol`` argument for + advanced customization. Client ------ @@ -34,7 +35,8 @@ Client * On Python ≥ 3.5, you can also use it as an asynchronous context manager. * You may subclass :class:`~websockets.server.WebSocketClientProtocol` and - pass it in the ``klass`` keyword argument for advanced customization. + pass it or a factory function as the ``create_protocol`` argument for + advanced customization. * Call :meth:`~websockets.protocol.WebSocketCommonProtocol.recv` and :meth:`~websockets.protocol.WebSocketCommonProtocol.send` to receive and diff --git a/websockets/client.py b/websockets/client.py index 143ec37a0..4053c2863 100644 --- a/websockets/client.py +++ b/websockets/client.py @@ -130,8 +130,7 @@ def handshake(self, wsuri, @asyncio.coroutine -def connect(uri, *, - klass=WebSocketClientProtocol, +def connect(uri, *, create_protocol=None, klass=None, timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, loop=None, legacy_recv=False, @@ -156,6 +155,13 @@ def connect(uri, *, ``read_limit``, and ``write_limit`` optional arguments is described in the documentation of :class:`~websockets.protocol.WebSocketCommonProtocol`. + The ``create_protocol`` parameter allows customizing the + :class:`WebSocketClientProtocol` class used. The argument should be a + callable or class accepting the same arguments as + :class:`WebSocketClientProtocol` and that returns a + :class:`WebSocketClientProtocol` instance. It defaults to + :class:`WebSocketClientProtocol`. + :func:`connect` also accepts the following optional arguments: * ``origin`` sets the Origin HTTP header @@ -175,13 +181,15 @@ def connect(uri, *, if loop is None: loop = asyncio.get_event_loop() + create_protocol = create_protocol or klass or WebSocketClientProtocol + wsuri = parse_uri(uri) if wsuri.secure: kwds.setdefault('ssl', True) elif kwds.get('ssl') is not None: raise ValueError("connect() received a SSL context for a ws:// URI. " "Use a wss:// URI to enable TLS.") - factory = lambda: klass( + factory = lambda: create_protocol( host=wsuri.host, port=wsuri.port, secure=wsuri.secure, timeout=timeout, max_size=max_size, max_queue=max_queue, read_limit=read_limit, write_limit=write_limit, diff --git a/websockets/server.py b/websockets/server.py index 279d814df..5c938aa25 100644 --- a/websockets/server.py +++ b/websockets/server.py @@ -407,7 +407,7 @@ def wait_closed(self): @asyncio.coroutine def serve(ws_handler, host=None, port=None, *, - klass=WebSocketServerProtocol, + create_protocol=None, klass=None, timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, loop=None, legacy_recv=False, @@ -440,6 +440,13 @@ def serve(ws_handler, host=None, port=None, *, set the ``ssl`` keyword argument to a :class:`~ssl.SSLContext` to enable TLS. + The ``create_protocol`` parameter allows customizing the + :class:`WebSocketServerProtocol` class used. The argument should be a + callable or class accepting the same arguments as + :class:`WebSocketServerProtocol` and that returns a + :class:`WebSocketServerProtocol` instance. It defaults to + :class:`WebSocketServerProtocol`. + The behavior of the ``timeout``, ``max_size``, and ``max_queue``, ``read_limit``, and ``write_limit`` optional arguments is described in the documentation of :class:`~websockets.protocol.WebSocketCommonProtocol`. @@ -472,10 +479,12 @@ def serve(ws_handler, host=None, port=None, *, if loop is None: loop = asyncio.get_event_loop() + create_protocol = create_protocol or klass or WebSocketServerProtocol + ws_server = WebSocketServer(loop) secure = kwds.get('ssl') is not None - factory = lambda: klass( + factory = lambda: create_protocol( ws_handler, ws_server, host=host, port=port, secure=secure, timeout=timeout, max_size=max_size, max_queue=max_queue, diff --git a/websockets/test_client_server.py b/websockets/test_client_server.py index 980356ee9..0edc1408c 100644 --- a/websockets/test_client_server.py +++ b/websockets/test_client_server.py @@ -40,8 +40,14 @@ def handler(ws, path): try: + # Order by status code. + UNAUTHORIZED = http.HTTPStatus.UNAUTHORIZED FORBIDDEN = http.HTTPStatus.FORBIDDEN except AttributeError: # pragma: no cover + class UNAUTHORIZED: + value = 401 + phrase = 'Unauthorized' + class FORBIDDEN: value = 403 phrase = 'Forbidden' @@ -94,13 +100,28 @@ def with_client(*args, **kwds): return with_manager(temp_test_client, *args, **kwds) -class ForbiddenWebSocketServerProtocol(WebSocketServerProtocol): +class UnauthorizedServerProtocol(WebSocketServerProtocol): + + @asyncio.coroutine + def get_response_status(self, set_header): + return UNAUTHORIZED + + +class ForbiddenServerProtocol(WebSocketServerProtocol): @asyncio.coroutine def get_response_status(self, set_header): return FORBIDDEN +class FooClientProtocol(WebSocketClientProtocol): + pass + + +class BarClientProtocol(WebSocketClientProtocol): + pass + + class ClientServerTests(unittest.TestCase): secure = False @@ -268,7 +289,7 @@ def get_response_status(self, set_header): status = yield from super().get_response_status(set_header) return status - with self.temp_server(klass=SaveAttributesProtocol): + with self.temp_server(create_protocol=SaveAttributesProtocol): self.start_client(path='foo/bar', origin='http://otherhost') self.assertEqual(attrs['origin'], 'http://otherhost') self.assertEqual(attrs['path'], '/foo/bar') @@ -280,10 +301,50 @@ def get_response_status(self, set_header): self.assertIsInstance(request_headers, http.client.HTTPMessage) self.assertEqual(request_headers.get('origin'), 'http://otherhost') - @with_server(klass=ForbiddenWebSocketServerProtocol) - def test_authentication(self): - with self.assertRaises(InvalidStatus): + def assert_client_raises_code(self, code): + with self.assertRaises(InvalidStatus) as raised: self.start_client() + self.assertEqual(raised.exception.code, code) + + @with_server(create_protocol=UnauthorizedServerProtocol) + def test_server_create_protocol(self): + self.assert_client_raises_code(401) + + @with_server(create_protocol=(lambda *args, **kwargs: + UnauthorizedServerProtocol(*args, **kwargs))) + def test_server_create_protocol_function(self): + self.assert_client_raises_code(401) + + @with_server(klass=UnauthorizedServerProtocol) + def test_server_klass(self): + self.assert_client_raises_code(401) + + @with_server(create_protocol=ForbiddenServerProtocol, + klass=UnauthorizedServerProtocol) + def test_server_create_protocol_over_klass(self): + self.assert_client_raises_code(403) + + @with_server() + @with_client('path', create_protocol=FooClientProtocol) + def test_client_create_protocol(self): + self.assertIsInstance(self.client, FooClientProtocol) + + @with_server() + @with_client('path', create_protocol=( + lambda *args, **kwargs: FooClientProtocol(*args, **kwargs))) + def test_client_create_protocol_function(self): + self.assertIsInstance(self.client, FooClientProtocol) + + @with_server() + @with_client('path', klass=FooClientProtocol) + def test_client_klass(self): + self.assertIsInstance(self.client, FooClientProtocol) + + @with_server() + @with_client('path', create_protocol=BarClientProtocol, + klass=FooClientProtocol) + def test_client_create_protocol_over_klass(self): + self.assertIsInstance(self.client, BarClientProtocol) @with_server() @with_client('subprotocol') @@ -437,7 +498,7 @@ def test_server_shuts_down_during_connection_handling(self): # Websocket connection terminates with 1001 Going Away. self.assertEqual(self.client.close_code, 1001) - @with_server(klass=ForbiddenWebSocketServerProtocol) + @with_server(create_protocol=ForbiddenServerProtocol) def test_invalid_status_error_during_client_connect(self): with self.assertRaises(InvalidStatus) as raised: self.start_client()