Skip to content

Commit

Permalink
broke websocket API and added tests
Browse files Browse the repository at this point in the history
  • Loading branch information
flying-sheep committed Sep 12, 2014
1 parent a5b200e commit 92ce18b
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 9 deletions.
7 changes: 4 additions & 3 deletions aiohttp/websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ def do_handshake(method, headers, transport, protocols=()):
break
else:
raise errors.HttpBadRequest(
'Client protocols {} don’t overlap server-known ones {}'
'Client protocols {!r} don’t overlap server-known ones {!r}'
.format(protocols, req_protocols))

# check supported version
Expand Down Expand Up @@ -246,8 +246,9 @@ def do_handshake(method, headers, transport, protocols=()):
if protocol:
response_headers.append(('SEC-WEBSOCKET-PROTOCOL', protocol))

# response code, headers, parser, writer
# response code, headers, parser, writer, protocol
return (101,
response_headers,
WebSocketParser,
WebSocketWriter(transport))
WebSocketWriter(transport),
protocol)
48 changes: 42 additions & 6 deletions tests/test_websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,18 +419,54 @@ def test_protocol_key(self):
websocket.do_handshake,
self.message.method, self.message.headers, self.transport)

def gen_ws_headers(self, protocols=''):
key = base64.b64encode(os.urandom(16)).decode()
hdrs = [('UPGRADE', 'websocket'),
('CONNECTION', 'upgrade'),
('SEC-WEBSOCKET-VERSION', '13'),
('SEC-WEBSOCKET-KEY', key)]
if protocols:
hdrs += [('SEC-WEBSOCKET-PROTOCOL', protocols)]
return hdrs, key

def test_handshake(self):
sec_key = base64.b64encode(os.urandom(16)).decode()
hdrs, sec_key = self.gen_ws_headers()

self.headers.extend([('UPGRADE', 'websocket'),
('CONNECTION', 'upgrade'),
('SEC-WEBSOCKET-VERSION', '13'),
('SEC-WEBSOCKET-KEY', sec_key)])
status, headers, parser, writer = websocket.do_handshake(
self.headers.extend(hdrs)
status, headers, parser, writer, protocol = websocket.do_handshake(
self.message.method, self.message.headers, self.transport)
self.assertEqual(status, 101)
self.assertIsNone(protocol)

key = base64.b64encode(
hashlib.sha1(sec_key.encode() + websocket.WS_KEY).digest())
headers = dict(headers)
self.assertEqual(headers['SEC-WEBSOCKET-ACCEPT'], key.decode())

def test_handshake_protocol(self):
'''Tests if one protocol is returned by do_handshake'''
proto = 'chat'

self.headers.extend(self.gen_ws_headers(proto)[0])
_, resp_headers, _, _, protocol = websocket.do_handshake(
self.message.method, self.message.headers, self.transport,
protocols=[proto])

self.assertEqual(protocol, proto)

#also test if we reply with the protocol
resp_headers = dict(resp_headers)
self.assertEqual(resp_headers['SEC-WEBSOCKET-PROTOCOL'], proto)

def test_handshake_protocol_agreement(self):
'''Tests if the right protocol is selected given multiple'''
best_proto = 'chat'
wanted_protos = ['best', 'chat', 'worse_proto']
server_protos = 'worse_proto,chat'

self.headers.extend(self.gen_ws_headers(server_protos)[0])
_, resp_headers, _, _, protocol = websocket.do_handshake(
self.message.method, self.message.headers, self.transport,
protocols=wanted_protos)

self.assertEqual(protocol, best_proto)

0 comments on commit 92ce18b

Please sign in to comment.