Skip to content

Commit

Permalink
Only set close code when receiving close frame.
Browse files Browse the repository at this point in the history
The RFC is clear about this:

> _The WebSocket Connection Close Code_ is defined as the status code
> (Section 7.4) contained in the first Close control frame received by
> the application implementing this protocol.

Also:

* Differentiate between closing and failing the connection in tests.
* Remove a test for setting the close code based on a local close code.
  • Loading branch information
aaugustin committed Sep 9, 2017
1 parent 3ecd547 commit 17ccbdd
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 33 deletions.
6 changes: 3 additions & 3 deletions websockets/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -544,10 +544,10 @@ def read_data_frame(self, max_size):
if frame.opcode == OP_CLOSE:
# Make sure the close frame is valid before echoing it.
code, reason = parse_close(frame.data)
self.close_code, self.close_reason = code, reason
if self.state == OPEN:
# 7.1.3. The WebSocket Closing Handshake is Started
yield from self.write_frame(OP_CLOSE, frame.data)
self.close_code, self.close_reason = code, reason
self.closing_handshake.set_result(True)
return

Expand Down Expand Up @@ -686,7 +686,6 @@ def fail_connection(self, code=1011, reason=''):
frame_data = serialize_close(code, reason)
yield from self.write_frame(OP_CLOSE, frame_data)
if not self.closing_handshake.done():
self.close_code, self.close_reason = code, reason
self.closing_handshake.set_result(False)
yield from self.close_connection()

Expand Down Expand Up @@ -738,8 +737,9 @@ def connection_lost(self, exc):
self.state = CLOSED
if not self.opening_handshake.done():
self.opening_handshake.set_result(False)
if self.close_code is None:
self.close_code = 1006
if not self.closing_handshake.done():
self.close_code, self.close_reason = 1006, ''
self.closing_handshake.set_result(False)
if not self.connection_closed.done():
self.connection_closed.set_result(None)
Expand Down
62 changes: 32 additions & 30 deletions websockets/test_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,9 +234,23 @@ def assertNoFrameSent(self):
def assertConnectionClosed(self, code, message):
# The following line guarantees that connection_lost was called.
self.assertEqual(self.protocol.state, CLOSED)
# A close frame was received.
self.assertEqual(self.protocol.close_code, code)
self.assertEqual(self.protocol.close_reason, message)

def assertConnectionFailed(self, code, message):
# The following line guarantees that connection_lost was called.
self.assertEqual(self.protocol.state, CLOSED)
# No close frame was received.
self.assertEqual(self.protocol.close_code, 1006)
self.assertEqual(self.protocol.close_reason, '')
# A close frame was sent -- unless the connection was already lost.
if code == 1006:
self.assertNoFrameSent()
else:
self.assertOneFrameSent(
True, OP_CLOSE, serialize_close(code, message))

@contextlib.contextmanager
def assertCompletesWithin(self, min_time, max_time):
t0 = self.loop.time()
Expand Down Expand Up @@ -307,24 +321,24 @@ def test_recv_on_closed_connection(self):
def test_recv_protocol_error(self):
self.receive_frame(Frame(True, OP_CONT, 'café'.encode('utf-8')))
self.process_invalid_frames()
self.assertConnectionClosed(1002, '')
self.assertConnectionFailed(1002, '')

def test_recv_unicode_error(self):
self.receive_frame(Frame(True, OP_TEXT, 'café'.encode('latin-1')))
self.process_invalid_frames()
self.assertConnectionClosed(1007, '')
self.assertConnectionFailed(1007, '')

def test_recv_text_payload_too_big(self):
self.protocol.max_size = 1024
self.receive_frame(Frame(True, OP_TEXT, 'café'.encode('utf-8') * 205))
self.process_invalid_frames()
self.assertConnectionClosed(1009, '')
self.assertConnectionFailed(1009, '')

def test_recv_binary_payload_too_big(self):
self.protocol.max_size = 1024
self.receive_frame(Frame(True, OP_BINARY, b'tea' * 342))
self.process_invalid_frames()
self.assertConnectionClosed(1009, '')
self.assertConnectionFailed(1009, '')

def test_recv_text_no_max_size(self):
self.protocol.max_size = None # for test coverage
Expand All @@ -346,7 +360,7 @@ def read_message():
self.process_invalid_frames()
with self.assertRaises(Exception):
self.loop.run_until_complete(self.protocol.worker_task)
self.assertConnectionClosed(1011, '')
self.assertConnectionFailed(1011, '')

def test_recv_cancelled(self):
recv = self.ensure_future(self.protocol.recv())
Expand Down Expand Up @@ -534,14 +548,14 @@ def test_fragmented_text_payload_too_big(self):
self.receive_frame(Frame(False, OP_TEXT, 'café'.encode('utf-8') * 100))
self.receive_frame(Frame(True, OP_CONT, 'café'.encode('utf-8') * 105))
self.process_invalid_frames()
self.assertConnectionClosed(1009, '')
self.assertConnectionFailed(1009, '')

def test_fragmented_binary_payload_too_big(self):
self.protocol.max_size = 1024
self.receive_frame(Frame(False, OP_BINARY, b'tea' * 171))
self.receive_frame(Frame(True, OP_CONT, b'tea' * 171))
self.process_invalid_frames()
self.assertConnectionClosed(1009, '')
self.assertConnectionFailed(1009, '')

def test_fragmented_text_no_max_size(self):
self.protocol.max_size = None # for test coverage
Expand Down Expand Up @@ -570,26 +584,30 @@ def test_unterminated_fragmented_text(self):
# Missing the second part of the fragmented frame.
self.receive_frame(Frame(True, OP_BINARY, b'tea'))
self.process_invalid_frames()
self.assertConnectionClosed(1002, '')
self.assertConnectionFailed(1002, '')

def test_close_handshake_in_fragmented_text(self):
self.receive_frame(Frame(False, OP_TEXT, 'ca'.encode('utf-8')))
self.receive_frame(Frame(True, OP_CLOSE, b''))
self.process_invalid_frames()
# The RFC may have overlooked this case: it says that control frames
# can be interjected in the middle of a fragmented message and that a
# close frame must be echoed. Even though there's an unterminated
# message, technically, the closing handshake was successful.
self.assertConnectionClosed(1005, '')

def test_connection_close_in_fragmented_text(self):
self.receive_frame(Frame(False, OP_TEXT, 'ca'.encode('utf-8')))
self.process_invalid_frames()
self.assertConnectionClosed(1006, '')
self.assertConnectionFailed(1006, '')

# Test miscellaneous code paths to ensure full coverage.

def test_connection_lost(self):
# Test calling connection_lost without going through close_connection.
self.protocol.connection_lost(None)

self.assertConnectionClosed(1006, '')
self.assertConnectionFailed(1006, '')

def test_ensure_connection_before_opening_handshake(self):
self.protocol.state = CONNECTING
Expand Down Expand Up @@ -683,33 +701,17 @@ def test_close_protocol_error(self):
invalid_close_frame = Frame(True, OP_CLOSE, b'\x00')
self.receive_frame(invalid_close_frame)
self.receive_eof_if_client()
self.run_loop_once()
self.loop.run_until_complete(self.protocol.close(reason='close'))

self.assertConnectionClosed(1002, '')
self.assertConnectionFailed(1002, '')

def test_close_connection_lost(self):
self.receive_eof()
self.run_loop_once()
self.loop.run_until_complete(self.protocol.close(reason='close'))

self.assertConnectionClosed(1006, '')

def test_remote_close_race_with_failing_connection(self):
self.make_drain_slow()

# Fail the connection while answering a close frame from the client.
self.loop.call_soon(self.receive_frame, self.remote_close)
self.loop.call_later(
MS, self.ensure_future, self.protocol.fail_connection())
# The client expects the server to close the connection.
# Simulate it instead of waiting for the connection timeout.
self.loop.call_later(MS, self.receive_eof_if_client)

with self.assertRaises(ConnectionClosed):
self.loop.run_until_complete(self.protocol.recv())

# The closing handshake was completed by fail_connection.
self.assertConnectionClosed(1011, '')
self.assertOneFrameSent(*self.remote_close)
self.assertConnectionFailed(1006, '')

def test_local_close_during_recv(self):
recv = self.ensure_future(self.protocol.recv())
Expand Down

0 comments on commit 17ccbdd

Please sign in to comment.