diff --git a/aiohttp/websocket.py b/aiohttp/websocket.py index 9996fe56d60..b1018bb18cd 100644 --- a/aiohttp/websocket.py +++ b/aiohttp/websocket.py @@ -218,7 +218,7 @@ def do_handshake(method, headers, transport, protocols=()): break else: raise errors.HttpBadRequest( - 'Client protocols {} don’t overlap server-known ones {}' + 'Client protocols {!r} don’t overlap server-known ones {!r}' .format(protocols, req_protocols)) # check supported version @@ -246,8 +246,9 @@ def do_handshake(method, headers, transport, protocols=()): if protocol: response_headers.append(('SEC-WEBSOCKET-PROTOCOL', protocol)) - # response code, headers, parser, writer + # response code, headers, parser, writer, protocol return (101, response_headers, WebSocketParser, - WebSocketWriter(transport)) + WebSocketWriter(transport), + protocol) diff --git a/tests/test_websocket.py b/tests/test_websocket.py index bb101456e08..08d4ee80d5e 100644 --- a/tests/test_websocket.py +++ b/tests/test_websocket.py @@ -419,18 +419,54 @@ def test_protocol_key(self): websocket.do_handshake, self.message.method, self.message.headers, self.transport) + def gen_ws_headers(self, protocols=''): + key = base64.b64encode(os.urandom(16)).decode() + hdrs = [('UPGRADE', 'websocket'), + ('CONNECTION', 'upgrade'), + ('SEC-WEBSOCKET-VERSION', '13'), + ('SEC-WEBSOCKET-KEY', key)] + if protocols: + hdrs += [('SEC-WEBSOCKET-PROTOCOL', protocols)] + return hdrs, key + def test_handshake(self): - sec_key = base64.b64encode(os.urandom(16)).decode() + hdrs, sec_key = self.gen_ws_headers() - self.headers.extend([('UPGRADE', 'websocket'), - ('CONNECTION', 'upgrade'), - ('SEC-WEBSOCKET-VERSION', '13'), - ('SEC-WEBSOCKET-KEY', sec_key)]) - status, headers, parser, writer = websocket.do_handshake( + self.headers.extend(hdrs) + status, headers, parser, writer, protocol = websocket.do_handshake( self.message.method, self.message.headers, self.transport) self.assertEqual(status, 101) + self.assertIsNone(protocol) key = base64.b64encode( hashlib.sha1(sec_key.encode() + websocket.WS_KEY).digest()) headers = dict(headers) self.assertEqual(headers['SEC-WEBSOCKET-ACCEPT'], key.decode()) + + def test_handshake_protocol(self): + '''Tests if one protocol is returned by do_handshake''' + proto = 'chat' + + self.headers.extend(self.gen_ws_headers(proto)[0]) + _, resp_headers, _, _, protocol = websocket.do_handshake( + self.message.method, self.message.headers, self.transport, + protocols=[proto]) + + self.assertEqual(protocol, proto) + + #also test if we reply with the protocol + resp_headers = dict(resp_headers) + self.assertEqual(resp_headers['SEC-WEBSOCKET-PROTOCOL'], proto) + + def test_handshake_protocol_agreement(self): + '''Tests if the right protocol is selected given multiple''' + best_proto = 'chat' + wanted_protos = ['best', 'chat', 'worse_proto'] + server_protos = 'worse_proto,chat' + + self.headers.extend(self.gen_ws_headers(server_protos)[0]) + _, resp_headers, _, _, protocol = websocket.do_handshake( + self.message.method, self.message.headers, self.transport, + protocols=wanted_protos) + + self.assertEqual(protocol, best_proto)