Skip to content

Commit

Permalink
Support for subprotocol negotiation (Fixes #17)
Browse files Browse the repository at this point in the history
  • Loading branch information
miguelgrinberg committed Aug 7, 2022
1 parent 9ce12fb commit 04baf87
Show file tree
Hide file tree
Showing 3 changed files with 100 additions and 12 deletions.
49 changes: 42 additions & 7 deletions src/simple_websocket/ws.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def __init__(self, sock=None, connection_type=None, receive_bytes=4096,
ping_interval=None, max_message_size=None,
thread_class=None, event_class=None, selector_class=None):
self.sock = sock
self.subprotocol = None
self.receive_bytes = receive_bytes
self.ping_interval = ping_interval
self.max_message_size = max_message_size
Expand Down Expand Up @@ -129,6 +130,19 @@ def close(self, reason=None, message=None):
pass
self.connected = False

def choose_subprotocol(self, request): # pragma: no cover
"""Choose a subprotocol to use for the WebSocket connection.
The default implementation does not accept any subprotocols. Subclasses
can override this method to implement subprotocol negotiation.
:param request: A ``Request`` object.
The method should return the subprotocol to use, or ``None`` if no
subprotocol is chosen.
"""
return None

def _thread(self):
sel = None
if self.ping_interval:
Expand Down Expand Up @@ -166,7 +180,9 @@ def _handle_events(self):
for event in self.ws.events():
try:
if isinstance(event, Request):
self.subprotocol = self.choose_subprotocol(event)
out_data += self.ws.send(AcceptConnection(
subprotocol=self.subprotocol,
extensions=[PerMessageDeflate()]))
elif isinstance(event, CloseConnection):
if self.is_server:
Expand Down Expand Up @@ -248,6 +264,8 @@ class Server(Base):
does this in its own different way. Werkzeug, Gunicorn,
Eventlet and Gevent are the only web servers that are
currently supported.
:param subprotocols: A list of supported subprotocols, or ``None`` (the
default) to disable subprotocol negotiation.
:param receive_bytes: The size of the receive buffer, in bytes. The
default is 4096.
:param ping_interval: Send ping packets to clients at the requested
Expand All @@ -270,10 +288,11 @@ class from the Python standard library.
``selectors.DefaultSelector`` class from the Python
standard library.
"""
def __init__(self, environ, receive_bytes=4096, ping_interval=None,
max_message_size=None, thread_class=None, event_class=None,
selector_class=None):
def __init__(self, environ, subprotocols=None, receive_bytes=4096,
ping_interval=None, max_message_size=None, thread_class=None,
event_class=None, selector_class=None):
self.environ = environ
self.subprotocols = subprotocols
self.mode = 'unknown'
sock = None
if 'werkzeug.socket' in environ:
Expand Down Expand Up @@ -316,12 +335,23 @@ def handshake(self):
self.ws.receive_data(in_data)
self.connected = self._handle_events()

def choose_subprotocol(self, request):
print(request.subprotocols)
print(self.subprotocols)
for subprotocol in request.subprotocols:
if subprotocol in self.subprotocols:
return subprotocol
return None


class Client(Base):
"""This class implements a WebSocket client.
:param url: The connection URL. Both ``ws://`` and ``wss://`` URLs are
accepted.
:param subprotocols: The name of the subprotocol to use, or a list of
subprotocol names in order of preference. Set to
``None`` (the default) to not use a subprotocol.
:param receive_bytes: The size of the receive buffer, in bytes. The
default is 4096.
:param ping_interval: Send ping packets to the server at the requested
Expand All @@ -344,16 +374,19 @@ class from the Python standard library.
objects. The default is the `threading.Event`` class
from the Python standard library.
"""
def __init__(self, url, receive_bytes=4096, ping_interval=None,
max_message_size=None, ssl_context=None, thread_class=None,
event_class=None, selector_class=None):
def __init__(self, url, subprotocols=None, receive_bytes=4096,
ping_interval=None, max_message_size=None, ssl_context=None,
thread_class=None, event_class=None, selector_class=None):
parsed_url = urlsplit(url)
is_secure = parsed_url.scheme in ['https', 'wss']
self.host = parsed_url.hostname
self.port = parsed_url.port or (443 if is_secure else 80)
self.path = parsed_url.path
if parsed_url.query:
self.path += '?' + parsed_url.query
self.subprotocols = subprotocols or []
if isinstance(self.subprotocols, str):
self.subprotocols = [self.subprotocols]

sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
if is_secure: # pragma: no cover
Expand All @@ -369,7 +402,8 @@ def __init__(self, url, receive_bytes=4096, ping_interval=None,
thread_class=thread_class, event_class=event_class)

def handshake(self):
out_data = self.ws.send(Request(host=self.host, target=self.path))
out_data = self.ws.send(Request(host=self.host, target=self.path,
subprotocols=self.subprotocols or []))
self.sock.send(out_data)

in_data = self.sock.recv(self.receive_bytes)
Expand All @@ -379,6 +413,7 @@ def handshake(self):
raise ConnectionError(event.status_code)
elif not isinstance(event, AcceptConnection): # pragma: no cover
raise ConnectionError(400)
self.subprotocol = event.subprotocol
self.connected = True

def close(self, reason=None, message=None):
Expand Down
26 changes: 24 additions & 2 deletions tests/test_simple_websocket_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,12 @@


class SimpleWebSocketClientTestCase(unittest.TestCase):
def get_client(self, mock_wsconn, url, events=[]):
def get_client(self, mock_wsconn, url, events=[], subprotocols=None):
mock_wsconn().events.side_effect = \
[iter(ev) for ev in
[[AcceptConnection()]] + events + [[CloseConnection(1000)]]]
mock_wsconn().send = lambda x: str(x).encode('utf-8')
return simple_websocket.Client(url)
return simple_websocket.Client(url, subprotocols=subprotocols)

@mock.patch('simple_websocket.ws.socket.socket')
@mock.patch('simple_websocket.ws.WSConnection')
Expand All @@ -33,6 +33,28 @@ def test_make_client(self, mock_wsconn, mock_socket):
assert client.port == 80
assert client.path == '/ws?a=1'

@mock.patch('simple_websocket.ws.socket.socket')
@mock.patch('simple_websocket.ws.WSConnection')
def test_make_client_subprotocol(self, mock_wsconn, mock_socket):
mock_socket.return_value.recv.return_value = b'x'
client = self.get_client(mock_wsconn, 'ws://example.com/ws?a=1',
subprotocols='foo')
assert client.subprotocols == ['foo']
client.sock.send.assert_called_with(
b"Request(host='example.com', target='/ws?a=1', extensions=[], "
b"extra_headers=[], subprotocols=['foo'])")

@mock.patch('simple_websocket.ws.socket.socket')
@mock.patch('simple_websocket.ws.WSConnection')
def test_make_client_subprotocols(self, mock_wsconn, mock_socket):
mock_socket.return_value.recv.return_value = b'x'
client = self.get_client(mock_wsconn, 'ws://example.com/ws?a=1',
subprotocols=['foo', 'bar'])
assert client.subprotocols == ['foo', 'bar']
client.sock.send.assert_called_with(
b"Request(host='example.com', target='/ws?a=1', extensions=[], "
b"extra_headers=[], subprotocols=['foo', 'bar'])")

@mock.patch('simple_websocket.ws.socket.socket')
@mock.patch('simple_websocket.ws.WSConnection')
def test_send(self, mock_wsconn, mock_socket):
Expand Down
37 changes: 34 additions & 3 deletions tests/test_simple_websocket_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,13 @@


class SimpleWebSocketServerTestCase(unittest.TestCase):
def get_server(self, mock_wsconn, environ, events=[], **kwargs):
def get_server(self, mock_wsconn, environ, events=[],
client_subprotocols=None, server_subprotocols=None,
**kwargs):
mock_wsconn().events.side_effect = \
[iter(ev) for ev in [[Request(host='example.com', target='/ws')]] +
[iter(ev) for ev in [[
Request(host='example.com', target='/ws',
subprotocols=client_subprotocols or [])]] +
events + [[CloseConnection(1000, 'bye')]]]
mock_wsconn().send = lambda x: str(x).encode('utf-8')
environ.update({
Expand All @@ -21,7 +25,8 @@ def get_server(self, mock_wsconn, environ, events=[], **kwargs):
'HTTP_SEC_WEBSOCKET_KEY': 'Iv8io/9s+lYFgZWcXczP8Q==',
'HTTP_SEC_WEBSOCKET_VERSION': '13',
})
return simple_websocket.Server(environ, **kwargs)
return simple_websocket.Server(
environ, subprotocols=server_subprotocols, **kwargs)

@mock.patch('simple_websocket.ws.WSConnection')
def test_werkzeug(self, mock_wsconn):
Expand Down Expand Up @@ -217,3 +222,29 @@ def test_ping_pong(self, mock_time, mock_wsconn):
assert mock_socket.send.call_args_list[1][0][0].startswith(b'Ping')
assert mock_socket.send.call_args_list[2][0][0].startswith(b'Ping')
assert mock_socket.send.call_args_list[3][0][0].startswith(b'Close')

@mock.patch('simple_websocket.ws.WSConnection')
def test_subprotocols(self, mock_wsconn):
mock_socket = mock.MagicMock()
mock_socket.recv.return_value = b'x'

server = self.get_server(mock_wsconn, {
'werkzeug.socket': mock_socket,
}, client_subprotocols=['foo', 'bar'], server_subprotocols=['bar'])
while server.connected:
time.sleep(0.01)
assert server.subprotocol == 'bar'

server = self.get_server(mock_wsconn, {
'werkzeug.socket': mock_socket,
}, client_subprotocols=['foo'], server_subprotocols=['foo', 'bar'])
while server.connected:
time.sleep(0.01)
assert server.subprotocol == 'foo'

server = self.get_server(mock_wsconn, {
'werkzeug.socket': mock_socket,
}, client_subprotocols=['foo'], server_subprotocols=['bar', 'baz'])
while server.connected:
time.sleep(0.01)
assert server.subprotocol is None

0 comments on commit 04baf87

Please sign in to comment.