From 17ccbddd116c581887b4612947a2d6c57192e2f2 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 9 Sep 2017 17:45:43 +0200 Subject: [PATCH] Only set close code when receiving close frame. The RFC is clear about this: > _The WebSocket Connection Close Code_ is defined as the status code > (Section 7.4) contained in the first Close control frame received by > the application implementing this protocol. Also: * Differentiate between closing and failing the connection in tests. * Remove a test for setting the close code based on a local close code. --- websockets/protocol.py | 6 ++-- websockets/test_protocol.py | 62 +++++++++++++++++++------------------ 2 files changed, 35 insertions(+), 33 deletions(-) diff --git a/websockets/protocol.py b/websockets/protocol.py index d92f3f795..078d9b994 100644 --- a/websockets/protocol.py +++ b/websockets/protocol.py @@ -544,10 +544,10 @@ def read_data_frame(self, max_size): if frame.opcode == OP_CLOSE: # Make sure the close frame is valid before echoing it. code, reason = parse_close(frame.data) + self.close_code, self.close_reason = code, reason if self.state == OPEN: # 7.1.3. The WebSocket Closing Handshake is Started yield from self.write_frame(OP_CLOSE, frame.data) - self.close_code, self.close_reason = code, reason self.closing_handshake.set_result(True) return @@ -686,7 +686,6 @@ def fail_connection(self, code=1011, reason=''): frame_data = serialize_close(code, reason) yield from self.write_frame(OP_CLOSE, frame_data) if not self.closing_handshake.done(): - self.close_code, self.close_reason = code, reason self.closing_handshake.set_result(False) yield from self.close_connection() @@ -738,8 +737,9 @@ def connection_lost(self, exc): self.state = CLOSED if not self.opening_handshake.done(): self.opening_handshake.set_result(False) + if self.close_code is None: + self.close_code = 1006 if not self.closing_handshake.done(): - self.close_code, self.close_reason = 1006, '' self.closing_handshake.set_result(False) if not self.connection_closed.done(): self.connection_closed.set_result(None) diff --git a/websockets/test_protocol.py b/websockets/test_protocol.py index 7faaddc37..96aaf48e9 100644 --- a/websockets/test_protocol.py +++ b/websockets/test_protocol.py @@ -234,9 +234,23 @@ def assertNoFrameSent(self): def assertConnectionClosed(self, code, message): # The following line guarantees that connection_lost was called. self.assertEqual(self.protocol.state, CLOSED) + # A close frame was received. self.assertEqual(self.protocol.close_code, code) self.assertEqual(self.protocol.close_reason, message) + def assertConnectionFailed(self, code, message): + # The following line guarantees that connection_lost was called. + self.assertEqual(self.protocol.state, CLOSED) + # No close frame was received. + self.assertEqual(self.protocol.close_code, 1006) + self.assertEqual(self.protocol.close_reason, '') + # A close frame was sent -- unless the connection was already lost. + if code == 1006: + self.assertNoFrameSent() + else: + self.assertOneFrameSent( + True, OP_CLOSE, serialize_close(code, message)) + @contextlib.contextmanager def assertCompletesWithin(self, min_time, max_time): t0 = self.loop.time() @@ -307,24 +321,24 @@ def test_recv_on_closed_connection(self): def test_recv_protocol_error(self): self.receive_frame(Frame(True, OP_CONT, 'café'.encode('utf-8'))) self.process_invalid_frames() - self.assertConnectionClosed(1002, '') + self.assertConnectionFailed(1002, '') def test_recv_unicode_error(self): self.receive_frame(Frame(True, OP_TEXT, 'café'.encode('latin-1'))) self.process_invalid_frames() - self.assertConnectionClosed(1007, '') + self.assertConnectionFailed(1007, '') def test_recv_text_payload_too_big(self): self.protocol.max_size = 1024 self.receive_frame(Frame(True, OP_TEXT, 'café'.encode('utf-8') * 205)) self.process_invalid_frames() - self.assertConnectionClosed(1009, '') + self.assertConnectionFailed(1009, '') def test_recv_binary_payload_too_big(self): self.protocol.max_size = 1024 self.receive_frame(Frame(True, OP_BINARY, b'tea' * 342)) self.process_invalid_frames() - self.assertConnectionClosed(1009, '') + self.assertConnectionFailed(1009, '') def test_recv_text_no_max_size(self): self.protocol.max_size = None # for test coverage @@ -346,7 +360,7 @@ def read_message(): self.process_invalid_frames() with self.assertRaises(Exception): self.loop.run_until_complete(self.protocol.worker_task) - self.assertConnectionClosed(1011, '') + self.assertConnectionFailed(1011, '') def test_recv_cancelled(self): recv = self.ensure_future(self.protocol.recv()) @@ -534,14 +548,14 @@ def test_fragmented_text_payload_too_big(self): self.receive_frame(Frame(False, OP_TEXT, 'café'.encode('utf-8') * 100)) self.receive_frame(Frame(True, OP_CONT, 'café'.encode('utf-8') * 105)) self.process_invalid_frames() - self.assertConnectionClosed(1009, '') + self.assertConnectionFailed(1009, '') def test_fragmented_binary_payload_too_big(self): self.protocol.max_size = 1024 self.receive_frame(Frame(False, OP_BINARY, b'tea' * 171)) self.receive_frame(Frame(True, OP_CONT, b'tea' * 171)) self.process_invalid_frames() - self.assertConnectionClosed(1009, '') + self.assertConnectionFailed(1009, '') def test_fragmented_text_no_max_size(self): self.protocol.max_size = None # for test coverage @@ -570,18 +584,22 @@ def test_unterminated_fragmented_text(self): # Missing the second part of the fragmented frame. self.receive_frame(Frame(True, OP_BINARY, b'tea')) self.process_invalid_frames() - self.assertConnectionClosed(1002, '') + self.assertConnectionFailed(1002, '') def test_close_handshake_in_fragmented_text(self): self.receive_frame(Frame(False, OP_TEXT, 'ca'.encode('utf-8'))) self.receive_frame(Frame(True, OP_CLOSE, b'')) self.process_invalid_frames() + # The RFC may have overlooked this case: it says that control frames + # can be interjected in the middle of a fragmented message and that a + # close frame must be echoed. Even though there's an unterminated + # message, technically, the closing handshake was successful. self.assertConnectionClosed(1005, '') def test_connection_close_in_fragmented_text(self): self.receive_frame(Frame(False, OP_TEXT, 'ca'.encode('utf-8'))) self.process_invalid_frames() - self.assertConnectionClosed(1006, '') + self.assertConnectionFailed(1006, '') # Test miscellaneous code paths to ensure full coverage. @@ -589,7 +607,7 @@ def test_connection_lost(self): # Test calling connection_lost without going through close_connection. self.protocol.connection_lost(None) - self.assertConnectionClosed(1006, '') + self.assertConnectionFailed(1006, '') def test_ensure_connection_before_opening_handshake(self): self.protocol.state = CONNECTING @@ -683,33 +701,17 @@ def test_close_protocol_error(self): invalid_close_frame = Frame(True, OP_CLOSE, b'\x00') self.receive_frame(invalid_close_frame) self.receive_eof_if_client() + self.run_loop_once() self.loop.run_until_complete(self.protocol.close(reason='close')) - self.assertConnectionClosed(1002, '') + self.assertConnectionFailed(1002, '') def test_close_connection_lost(self): self.receive_eof() + self.run_loop_once() self.loop.run_until_complete(self.protocol.close(reason='close')) - self.assertConnectionClosed(1006, '') - - def test_remote_close_race_with_failing_connection(self): - self.make_drain_slow() - - # Fail the connection while answering a close frame from the client. - self.loop.call_soon(self.receive_frame, self.remote_close) - self.loop.call_later( - MS, self.ensure_future, self.protocol.fail_connection()) - # The client expects the server to close the connection. - # Simulate it instead of waiting for the connection timeout. - self.loop.call_later(MS, self.receive_eof_if_client) - - with self.assertRaises(ConnectionClosed): - self.loop.run_until_complete(self.protocol.recv()) - - # The closing handshake was completed by fail_connection. - self.assertConnectionClosed(1011, '') - self.assertOneFrameSent(*self.remote_close) + self.assertConnectionFailed(1006, '') def test_local_close_during_recv(self): recv = self.ensure_future(self.protocol.recv())