From 5276a41702594cb58785352a444a2ab7b710cf74 Mon Sep 17 00:00:00 2001 From: Chris Jerdonek Date: Sun, 9 Jul 2017 15:26:11 -0700 Subject: [PATCH 1/8] Add pre_handshake() hook. --- websockets/server.py | 91 +++++++++++++++++++++++++------------------- 1 file changed, 51 insertions(+), 40 deletions(-) diff --git a/websockets/server.py b/websockets/server.py index cf04d6eaf..dcbfda8e8 100644 --- a/websockets/server.py +++ b/websockets/server.py @@ -65,15 +65,20 @@ def handler(self): # Since this method doesn't have a caller able to handle exceptions, # it attemps to log relevant ones and close the connection properly. try: - try: - path = yield from self.handshake( - origins=self.origins, subprotocols=self.subprotocols, - extra_headers=self.extra_headers) - except ConnectionError as exc: - logger.debug( - "Connection error in opening handshake", exc_info=True) - raise + response_headers = yield from self.pre_handshake(origins=self.origins) + if response_headers is None: + # Then the response was already written. + return + + try: + yield from self.handshake( + response_headers, subprotocols=self.subprotocols, + extra_headers=self.extra_headers) + except ConnectionError as exc: + logger.debug( + "Connection error in opening handshake", exc_info=True) + raise except Exception as exc: if self._is_server_shutting_down(exc): response = ('HTTP/1.1 503 Service Unavailable\r\n\r\n' @@ -89,13 +94,8 @@ def handler(self): self.writer.write(response.encode()) raise - # Subclasses can customize get_response_status() or handshake() to - # reject the handshake, typically after checking authentication. - if path is None: - return - try: - yield from self.ws_handler(self, path) + yield from self.ws_handler(self, self.path) except Exception as exc: if self._is_server_shutting_down(exc): yield from self.fail_connection(1001) @@ -234,22 +234,50 @@ def get_response_status(self, set_header): It is declared as a coroutine because such authentication checks are likely to require network requests. - The connection is closed immediately after sending the response when - the status code is not ``HTTPStatus.SWITCHING_PROTOCOLS``. + A return value of ``None`` means to continue to the opening + handshake, which would result in ``HTTPStatus.SWITCHING_PROTOCOLS`` + on success. If the return value is not ``None``, the connection is + closed immediately after sending the response. Call ``set_header(key, value)`` to set additional response headers. """ - return SWITCHING_PROTOCOLS + return None @asyncio.coroutine - def handshake(self, origins=None, subprotocols=None, extra_headers=None): + def pre_handshake(self, origins=None): """ - Perform the server side of the opening handshake. + Handle the response, if possible, before the opening handshake. If provided, ``origins`` is a list of acceptable HTTP Origin values. Include ``''`` if the lack of an origin is acceptable. + Return the response headers so far, or None if the response has + already been handled. + + """ + yield from self.read_http_request() + request_headers = self.request_headers + get_header = lambda k: request_headers.get(k, '') + self.origin = self.process_origin(get_header, origins) + + response_headers = [] + set_header = lambda k, v: response_headers.append((k, v)) + set_header('Server', USER_AGENT) + + status = yield from self.get_response_status(set_header) + if status is None: + return response_headers + + yield from self.write_http_response(status, headers) + self.opening_handshake.set_result(False) + yield from self.close_connection(force=True) + + @asyncio.coroutine + def handshake(self, response_headers, subprotocols=None, extra_headers=None): + """ + Perform the server side of the opening handshake. + If provided, ``subprotocols`` is a list of supported subprotocols in order of decreasing preference. @@ -263,48 +291,31 @@ def handshake(self, origins=None, subprotocols=None, extra_headers=None): Return the URI of the request. """ - path, headers = yield from self.read_http_request() + headers = self.request_headers get_header = lambda k: headers.get(k, '') - key = check_request(get_header) - self.origin = self.process_origin(get_header, origins) self.subprotocol = self.process_subprotocol(get_header, subprotocols) - - headers = [] - set_header = lambda k, v: headers.append((k, v)) - - set_header('Server', USER_AGENT) - - status = yield from self.get_response_status(set_header) - - # Abort the connection if the status code isn't 101. - if status.value != SWITCHING_PROTOCOLS.value: - yield from self.write_http_response(status, headers) - self.opening_handshake.set_result(False) - yield from self.close_connection(force=True) - return + set_header = lambda k, v: response_headers.append((k, v)) # Status code is 101, establish the connection. if self.subprotocol: set_header('Sec-WebSocket-Protocol', self.subprotocol) if extra_headers is not None: if callable(extra_headers): - extra_headers = extra_headers(path, self.raw_request_headers) + extra_headers = extra_headers(self.path, self.raw_request_headers) if isinstance(extra_headers, collections.abc.Mapping): extra_headers = extra_headers.items() for name, value in extra_headers: set_header(name, value) build_response(set_header, key) - yield from self.write_http_response(status, headers) + yield from self.write_http_response(SWITCHING_PROTOCOLS, response_headers) assert self.state == CONNECTING self.state = OPEN self.opening_handshake.set_result(True) - return path - class WebSocketServer: """ From 4a5df1bf5301a9b8ada3fc0ea85fa3c6d22d6d3f Mon Sep 17 00:00:00 2001 From: Chris Jerdonek Date: Sun, 9 Jul 2017 15:51:03 -0700 Subject: [PATCH 2/8] Fix variable name. --- websockets/server.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/websockets/server.py b/websockets/server.py index dcbfda8e8..59dd0b1ab 100644 --- a/websockets/server.py +++ b/websockets/server.py @@ -257,6 +257,7 @@ def pre_handshake(self, origins=None): """ yield from self.read_http_request() + request_headers = self.request_headers get_header = lambda k: request_headers.get(k, '') self.origin = self.process_origin(get_header, origins) @@ -269,7 +270,7 @@ def pre_handshake(self, origins=None): if status is None: return response_headers - yield from self.write_http_response(status, headers) + yield from self.write_http_response(status, response_headers) self.opening_handshake.set_result(False) yield from self.close_connection(force=True) From 26506a9f4db8202d0d95c75e2f1f5886bef6f250 Mon Sep 17 00:00:00 2001 From: Chris Jerdonek Date: Fri, 14 Jul 2017 14:58:36 -0700 Subject: [PATCH 3/8] Add a test that get_response_status() precedes handshake(). --- websockets/server.py | 2 +- websockets/test_client_server.py | 28 ++++++++++++++++++++++++++++ 2 files changed, 29 insertions(+), 1 deletion(-) diff --git a/websockets/server.py b/websockets/server.py index 59dd0b1ab..545390e15 100644 --- a/websockets/server.py +++ b/websockets/server.py @@ -247,7 +247,7 @@ def get_response_status(self, set_header): @asyncio.coroutine def pre_handshake(self, origins=None): """ - Handle the response, if possible, before the opening handshake. + Do pre-handshake response handling. If provided, ``origins`` is a list of acceptable HTTP Origin values. Include ``''`` if the lack of an origin is acceptable. diff --git a/websockets/test_client_server.py b/websockets/test_client_server.py index 0edc1408c..996d4d079 100644 --- a/websockets/test_client_server.py +++ b/websockets/test_client_server.py @@ -301,6 +301,34 @@ def get_response_status(self, set_header): self.assertIsInstance(request_headers, http.client.HTTPMessage) self.assertEqual(request_headers.get('origin'), 'http://otherhost') + def test_get_response_status_precedes_handshake(self): + class State: + handshake_called = False + + class HandshakeStoringProtocol(WebSocketServerProtocol): + @asyncio.coroutine + def get_response_status(self, set_header): + if self.path != '/valid': + return http.HTTPStatus.NOT_FOUND + + status = yield from super().get_response_status(set_header) + return status + + @asyncio.coroutine + def handshake(self, *args, **kwargs): + State.handshake_called = True + result = yield from super().handshake(*args, **kwargs) + return result + + with self.temp_server(create_protocol=HandshakeStoringProtocol): + with self.assertRaises(InvalidHandshake) as cm: + self.start_client(path='invalid') + self.assertEqual(str(cm.exception), 'Bad status code: 404') + self.assertFalse(State.handshake_called) + # Check that our overridden handshake() is working correctly. + self.start_client(path='valid') + self.assertTrue(State.handshake_called) + def assert_client_raises_code(self, code): with self.assertRaises(InvalidStatus) as raised: self.start_client() From 3fb330be1fbbc1b4a8354c20b3a1c4cf813e78b1 Mon Sep 17 00:00:00 2001 From: Chris Jerdonek Date: Fri, 14 Jul 2017 22:01:55 -0700 Subject: [PATCH 4/8] Combine return and yield from lines. --- websockets/test_client_server.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/websockets/test_client_server.py b/websockets/test_client_server.py index 996d4d079..0ad2d3910 100644 --- a/websockets/test_client_server.py +++ b/websockets/test_client_server.py @@ -310,15 +310,12 @@ class HandshakeStoringProtocol(WebSocketServerProtocol): def get_response_status(self, set_header): if self.path != '/valid': return http.HTTPStatus.NOT_FOUND - - status = yield from super().get_response_status(set_header) - return status + return (yield from super().get_response_status(set_header)) @asyncio.coroutine def handshake(self, *args, **kwargs): State.handshake_called = True - result = yield from super().handshake(*args, **kwargs) - return result + return (yield from super().handshake(*args, **kwargs)) with self.temp_server(create_protocol=HandshakeStoringProtocol): with self.assertRaises(InvalidHandshake) as cm: From 13b4fc4e1e8a1d24db2c6a26e61cbb418cd10257 Mon Sep 17 00:00:00 2001 From: Chris Jerdonek Date: Fri, 14 Jul 2017 22:10:13 -0700 Subject: [PATCH 5/8] Handle http.HTTPStatus.NOT_FOUND for Python < 3.5. --- websockets/test_client_server.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/websockets/test_client_server.py b/websockets/test_client_server.py index 0ad2d3910..20a1d16af 100644 --- a/websockets/test_client_server.py +++ b/websockets/test_client_server.py @@ -43,6 +43,7 @@ def handler(ws, path): # Order by status code. UNAUTHORIZED = http.HTTPStatus.UNAUTHORIZED FORBIDDEN = http.HTTPStatus.FORBIDDEN + NOT_FOUND = http.HTTPStatus.NOT_FOUND except AttributeError: # pragma: no cover class UNAUTHORIZED: value = 401 @@ -52,6 +53,10 @@ class FORBIDDEN: value = 403 phrase = 'Forbidden' + class NOT_FOUND: + value = 404 + phrase = 'Not Found' + @contextmanager def temp_test_server(test, **kwds): @@ -309,7 +314,7 @@ class HandshakeStoringProtocol(WebSocketServerProtocol): @asyncio.coroutine def get_response_status(self, set_header): if self.path != '/valid': - return http.HTTPStatus.NOT_FOUND + return NOT_FOUND return (yield from super().get_response_status(set_header)) @asyncio.coroutine From e604fee9dd806a8d947eb91092096c2f41f2dfd9 Mon Sep 17 00:00:00 2001 From: Chris Jerdonek Date: Tue, 18 Jul 2017 17:54:17 -0700 Subject: [PATCH 6/8] Fix flake8. --- websockets/server.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/websockets/server.py b/websockets/server.py index 545390e15..1a47cdfbb 100644 --- a/websockets/server.py +++ b/websockets/server.py @@ -66,7 +66,8 @@ def handler(self): # it attemps to log relevant ones and close the connection properly. try: try: - response_headers = yield from self.pre_handshake(origins=self.origins) + response_headers = ( + yield from self.pre_handshake(origins=self.origins)) if response_headers is None: # Then the response was already written. return @@ -275,7 +276,8 @@ def pre_handshake(self, origins=None): yield from self.close_connection(force=True) @asyncio.coroutine - def handshake(self, response_headers, subprotocols=None, extra_headers=None): + def handshake(self, response_headers, subprotocols=None, + extra_headers=None): """ Perform the server side of the opening handshake. @@ -304,14 +306,16 @@ def handshake(self, response_headers, subprotocols=None, extra_headers=None): set_header('Sec-WebSocket-Protocol', self.subprotocol) if extra_headers is not None: if callable(extra_headers): - extra_headers = extra_headers(self.path, self.raw_request_headers) + extra_headers = extra_headers(self.path, + self.raw_request_headers) if isinstance(extra_headers, collections.abc.Mapping): extra_headers = extra_headers.items() for name, value in extra_headers: set_header(name, value) build_response(set_header, key) - yield from self.write_http_response(SWITCHING_PROTOCOLS, response_headers) + yield from self.write_http_response(SWITCHING_PROTOCOLS, + response_headers) assert self.state == CONNECTING self.state = OPEN From 0d440a53ea764d411093795675999a3cb571a87a Mon Sep 17 00:00:00 2001 From: Chris Jerdonek Date: Fri, 21 Jul 2017 19:40:06 -0700 Subject: [PATCH 7/8] Fix ConnectionError handling. --- websockets/server.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/websockets/server.py b/websockets/server.py index 1a47cdfbb..659cd3566 100644 --- a/websockets/server.py +++ b/websockets/server.py @@ -72,14 +72,13 @@ def handler(self): # Then the response was already written. return - try: - yield from self.handshake( - response_headers, subprotocols=self.subprotocols, - extra_headers=self.extra_headers) - except ConnectionError as exc: - logger.debug( - "Connection error in opening handshake", exc_info=True) - raise + yield from self.handshake( + response_headers, subprotocols=self.subprotocols, + extra_headers=self.extra_headers) + except ConnectionError as exc: + logger.debug( + "Connection error in opening handshake", exc_info=True) + raise except Exception as exc: if self._is_server_shutting_down(exc): response = ('HTTP/1.1 503 Service Unavailable\r\n\r\n' From 02b5a296db77e8fc096d4e97b6dc4401f76dbb20 Mon Sep 17 00:00:00 2001 From: Chris Jerdonek Date: Fri, 21 Jul 2017 19:48:28 -0700 Subject: [PATCH 8/8] Use the new InvalidStatus.code instead of checking str(exception). --- websockets/test_client_server.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/websockets/test_client_server.py b/websockets/test_client_server.py index 20a1d16af..a527c24cc 100644 --- a/websockets/test_client_server.py +++ b/websockets/test_client_server.py @@ -323,9 +323,9 @@ def handshake(self, *args, **kwargs): return (yield from super().handshake(*args, **kwargs)) with self.temp_server(create_protocol=HandshakeStoringProtocol): - with self.assertRaises(InvalidHandshake) as cm: + with self.assertRaises(InvalidStatus) as cm: self.start_client(path='invalid') - self.assertEqual(str(cm.exception), 'Bad status code: 404') + self.assertEqual(cm.exception.code, 404) self.assertFalse(State.handshake_called) # Check that our overridden handshake() is working correctly. self.start_client(path='valid')