Skip to content

Commit

Permalink
Address issue #216: rename the klass argument to create_protocol.
Browse files Browse the repository at this point in the history
  • Loading branch information
cjerdonek authored and aaugustin committed Jul 29, 2017
1 parent f859e2f commit 725675e
Show file tree
Hide file tree
Showing 6 changed files with 100 additions and 15 deletions.
4 changes: 2 additions & 2 deletions docs/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ Server

.. automodule:: websockets.server

.. autofunction:: serve(ws_handler, host=None, port=None, *, klass=WebSocketServerProtocol, timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, loop=None, origins=None, subprotocols=None, extra_headers=None, **kwds)
.. autofunction:: serve(ws_handler, host=None, port=None, *, create_protocol=None, timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, loop=None, origins=None, subprotocols=None, extra_headers=None, **kwds)

.. autoclass:: WebSocketServer

Expand All @@ -50,7 +50,7 @@ Client

.. automodule:: websockets.client

.. autofunction:: connect(uri, *, klass=WebSocketClientProtocol, timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, loop=None, origin=None, subprotocols=None, extra_headers=None, **kwds)
.. autofunction:: connect(uri, *, create_protocol=None, timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, loop=None, origin=None, subprotocols=None, extra_headers=None, **kwds)

.. autoclass:: WebSocketClientProtocol(*, host=None, port=None, secure=None, timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, loop=None)

Expand Down
5 changes: 5 additions & 0 deletions docs/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,11 @@ Changelog

*In development*

* Renamed :func:`~websockets.server.serve()` and
:func:`~websockets.client.connect()`'s ``klass`` argument to
``create_protocol`` to reflect that it can also be a callable.
For backwards compatibility, ``klass`` is still supported.

* :func:`~websockets.server.serve` can be used as an asynchronous context
manager on Python ≥ 3.5.

Expand Down
6 changes: 4 additions & 2 deletions docs/cheatsheet.rst
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ Server
the handler exits normally or with an exception.

* You may subclass :class:`~websockets.server.WebSocketServerProtocol` and
pass it in the ``klass`` keyword argument for advanced customization.
pass it or a factory function as the ``create_protocol`` argument for
advanced customization.

Client
------
Expand All @@ -34,7 +35,8 @@ Client
* On Python ≥ 3.5, you can also use it as an asynchronous context manager.

* You may subclass :class:`~websockets.server.WebSocketClientProtocol` and
pass it in the ``klass`` keyword argument for advanced customization.
pass it or a factory function as the ``create_protocol`` argument for
advanced customization.

* Call :meth:`~websockets.protocol.WebSocketCommonProtocol.recv` and
:meth:`~websockets.protocol.WebSocketCommonProtocol.send` to receive and
Expand Down
14 changes: 11 additions & 3 deletions websockets/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,8 +130,7 @@ def handshake(self, wsuri,


@asyncio.coroutine
def connect(uri, *,
klass=WebSocketClientProtocol,
def connect(uri, *, create_protocol=None, klass=None,
timeout=10, max_size=2 ** 20, max_queue=2 ** 5,
read_limit=2 ** 16, write_limit=2 ** 16,
loop=None, legacy_recv=False,
Expand All @@ -156,6 +155,13 @@ def connect(uri, *,
``read_limit``, and ``write_limit`` optional arguments is described in the
documentation of :class:`~websockets.protocol.WebSocketCommonProtocol`.
The ``create_protocol`` parameter allows customizing the
:class:`WebSocketClientProtocol` class used. The argument should be a
callable or class accepting the same arguments as
:class:`WebSocketClientProtocol` and that returns a
:class:`WebSocketClientProtocol` instance. It defaults to
:class:`WebSocketClientProtocol`.
:func:`connect` also accepts the following optional arguments:
* ``origin`` sets the Origin HTTP header
Expand All @@ -175,13 +181,15 @@ def connect(uri, *,
if loop is None:
loop = asyncio.get_event_loop()

create_protocol = create_protocol or klass or WebSocketClientProtocol

wsuri = parse_uri(uri)
if wsuri.secure:
kwds.setdefault('ssl', True)
elif kwds.get('ssl') is not None:
raise ValueError("connect() received a SSL context for a ws:// URI. "
"Use a wss:// URI to enable TLS.")
factory = lambda: klass(
factory = lambda: create_protocol(
host=wsuri.host, port=wsuri.port, secure=wsuri.secure,
timeout=timeout, max_size=max_size, max_queue=max_queue,
read_limit=read_limit, write_limit=write_limit,
Expand Down
13 changes: 11 additions & 2 deletions websockets/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,7 +407,7 @@ def wait_closed(self):

@asyncio.coroutine
def serve(ws_handler, host=None, port=None, *,
klass=WebSocketServerProtocol,
create_protocol=None, klass=None,
timeout=10, max_size=2 ** 20, max_queue=2 ** 5,
read_limit=2 ** 16, write_limit=2 ** 16,
loop=None, legacy_recv=False,
Expand Down Expand Up @@ -440,6 +440,13 @@ def serve(ws_handler, host=None, port=None, *,
set the ``ssl`` keyword argument to a :class:`~ssl.SSLContext` to enable
TLS.
The ``create_protocol`` parameter allows customizing the
:class:`WebSocketServerProtocol` class used. The argument should be a
callable or class accepting the same arguments as
:class:`WebSocketServerProtocol` and that returns a
:class:`WebSocketServerProtocol` instance. It defaults to
:class:`WebSocketServerProtocol`.
The behavior of the ``timeout``, ``max_size``, and ``max_queue``,
``read_limit``, and ``write_limit`` optional arguments is described in the
documentation of :class:`~websockets.protocol.WebSocketCommonProtocol`.
Expand Down Expand Up @@ -472,10 +479,12 @@ def serve(ws_handler, host=None, port=None, *,
if loop is None:
loop = asyncio.get_event_loop()

create_protocol = create_protocol or klass or WebSocketServerProtocol

ws_server = WebSocketServer(loop)

secure = kwds.get('ssl') is not None
factory = lambda: klass(
factory = lambda: create_protocol(
ws_handler, ws_server,
host=host, port=port, secure=secure,
timeout=timeout, max_size=max_size, max_queue=max_queue,
Expand Down
73 changes: 67 additions & 6 deletions websockets/test_client_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,14 @@ def handler(ws, path):


try:
# Order by status code.
UNAUTHORIZED = http.HTTPStatus.UNAUTHORIZED
FORBIDDEN = http.HTTPStatus.FORBIDDEN
except AttributeError: # pragma: no cover
class UNAUTHORIZED:
value = 401
phrase = 'Unauthorized'

class FORBIDDEN:
value = 403
phrase = 'Forbidden'
Expand Down Expand Up @@ -94,13 +100,28 @@ def with_client(*args, **kwds):
return with_manager(temp_test_client, *args, **kwds)


class ForbiddenWebSocketServerProtocol(WebSocketServerProtocol):
class UnauthorizedServerProtocol(WebSocketServerProtocol):

@asyncio.coroutine
def get_response_status(self, set_header):
return UNAUTHORIZED


class ForbiddenServerProtocol(WebSocketServerProtocol):

@asyncio.coroutine
def get_response_status(self, set_header):
return FORBIDDEN


class FooClientProtocol(WebSocketClientProtocol):
pass


class BarClientProtocol(WebSocketClientProtocol):
pass


class ClientServerTests(unittest.TestCase):

secure = False
Expand Down Expand Up @@ -268,7 +289,7 @@ def get_response_status(self, set_header):
status = yield from super().get_response_status(set_header)
return status

with self.temp_server(klass=SaveAttributesProtocol):
with self.temp_server(create_protocol=SaveAttributesProtocol):
self.start_client(path='foo/bar', origin='http://otherhost')
self.assertEqual(attrs['origin'], 'http://otherhost')
self.assertEqual(attrs['path'], '/foo/bar')
Expand All @@ -280,10 +301,50 @@ def get_response_status(self, set_header):
self.assertIsInstance(request_headers, http.client.HTTPMessage)
self.assertEqual(request_headers.get('origin'), 'http://otherhost')

@with_server(klass=ForbiddenWebSocketServerProtocol)
def test_authentication(self):
with self.assertRaises(InvalidStatus):
def assert_client_raises_code(self, code):
with self.assertRaises(InvalidStatus) as raised:
self.start_client()
self.assertEqual(raised.exception.code, code)

@with_server(create_protocol=UnauthorizedServerProtocol)
def test_server_create_protocol(self):
self.assert_client_raises_code(401)

@with_server(create_protocol=(lambda *args, **kwargs:
UnauthorizedServerProtocol(*args, **kwargs)))
def test_server_create_protocol_function(self):
self.assert_client_raises_code(401)

@with_server(klass=UnauthorizedServerProtocol)
def test_server_klass(self):
self.assert_client_raises_code(401)

@with_server(create_protocol=ForbiddenServerProtocol,
klass=UnauthorizedServerProtocol)
def test_server_create_protocol_over_klass(self):
self.assert_client_raises_code(403)

@with_server()
@with_client('path', create_protocol=FooClientProtocol)
def test_client_create_protocol(self):
self.assertIsInstance(self.client, FooClientProtocol)

@with_server()
@with_client('path', create_protocol=(
lambda *args, **kwargs: FooClientProtocol(*args, **kwargs)))
def test_client_create_protocol_function(self):
self.assertIsInstance(self.client, FooClientProtocol)

@with_server()
@with_client('path', klass=FooClientProtocol)
def test_client_klass(self):
self.assertIsInstance(self.client, FooClientProtocol)

@with_server()
@with_client('path', create_protocol=BarClientProtocol,
klass=FooClientProtocol)
def test_client_create_protocol_over_klass(self):
self.assertIsInstance(self.client, BarClientProtocol)

@with_server()
@with_client('subprotocol')
Expand Down Expand Up @@ -437,7 +498,7 @@ def test_server_shuts_down_during_connection_handling(self):
# Websocket connection terminates with 1001 Going Away.
self.assertEqual(self.client.close_code, 1001)

@with_server(klass=ForbiddenWebSocketServerProtocol)
@with_server(create_protocol=ForbiddenServerProtocol)
def test_invalid_status_error_during_client_connect(self):
with self.assertRaises(InvalidStatus) as raised:
self.start_client()
Expand Down

0 comments on commit 725675e

Please sign in to comment.