Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Address issue #116: make existing hook take place before the handshake #202

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 50 additions & 35 deletions websockets/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,10 +65,15 @@ def handler(self):
# Since this method doesn't have a caller able to handle exceptions,
# it attemps to log relevant ones and close the connection properly.
try:

try:
path = yield from self.handshake(
origins=self.origins, subprotocols=self.subprotocols,
response_headers = (
yield from self.pre_handshake(origins=self.origins))
if response_headers is None:
# Then the response was already written.
return

yield from self.handshake(
response_headers, subprotocols=self.subprotocols,
extra_headers=self.extra_headers)
except ConnectionError as exc:
logger.debug(
Expand All @@ -89,13 +94,8 @@ def handler(self):
self.writer.write(response.encode())
raise

# Subclasses can customize get_response_status() or handshake() to
# reject the handshake, typically after checking authentication.
if path is None:
return

try:
yield from self.ws_handler(self, path)
yield from self.ws_handler(self, self.path)
except Exception as exc:
if self._is_server_shutting_down(exc):
yield from self.fail_connection(1001)
Expand Down Expand Up @@ -234,22 +234,52 @@ def get_response_status(self, set_header):
It is declared as a coroutine because such authentication checks are
likely to require network requests.

The connection is closed immediately after sending the response when
the status code is not ``HTTPStatus.SWITCHING_PROTOCOLS``.
A return value of ``None`` means to continue to the opening
handshake, which would result in ``HTTPStatus.SWITCHING_PROTOCOLS``
on success. If the return value is not ``None``, the connection is
closed immediately after sending the response.

Call ``set_header(key, value)`` to set additional response headers.

"""
return SWITCHING_PROTOCOLS
return None

@asyncio.coroutine
def handshake(self, origins=None, subprotocols=None, extra_headers=None):
def pre_handshake(self, origins=None):
"""
Perform the server side of the opening handshake.
Do pre-handshake response handling.

If provided, ``origins`` is a list of acceptable HTTP Origin values.
Include ``''`` if the lack of an origin is acceptable.

Return the response headers so far, or None if the response has
already been handled.

"""
yield from self.read_http_request()

request_headers = self.request_headers
get_header = lambda k: request_headers.get(k, '')
self.origin = self.process_origin(get_header, origins)

response_headers = []
set_header = lambda k, v: response_headers.append((k, v))
set_header('Server', USER_AGENT)

status = yield from self.get_response_status(set_header)
if status is None:
return response_headers

yield from self.write_http_response(status, response_headers)
self.opening_handshake.set_result(False)
yield from self.close_connection(force=True)

@asyncio.coroutine
def handshake(self, response_headers, subprotocols=None,
extra_headers=None):
"""
Perform the server side of the opening handshake.

If provided, ``subprotocols`` is a list of supported subprotocols in
order of decreasing preference.

Expand All @@ -263,48 +293,33 @@ def handshake(self, origins=None, subprotocols=None, extra_headers=None):
Return the URI of the request.

"""
path, headers = yield from self.read_http_request()
headers = self.request_headers
get_header = lambda k: headers.get(k, '')

key = check_request(get_header)

self.origin = self.process_origin(get_header, origins)
self.subprotocol = self.process_subprotocol(get_header, subprotocols)

headers = []
set_header = lambda k, v: headers.append((k, v))

set_header('Server', USER_AGENT)

status = yield from self.get_response_status(set_header)

# Abort the connection if the status code isn't 101.
if status.value != SWITCHING_PROTOCOLS.value:
yield from self.write_http_response(status, headers)
self.opening_handshake.set_result(False)
yield from self.close_connection(force=True)
return
set_header = lambda k, v: response_headers.append((k, v))

# Status code is 101, establish the connection.
if self.subprotocol:
set_header('Sec-WebSocket-Protocol', self.subprotocol)
if extra_headers is not None:
if callable(extra_headers):
extra_headers = extra_headers(path, self.raw_request_headers)
extra_headers = extra_headers(self.path,
self.raw_request_headers)
if isinstance(extra_headers, collections.abc.Mapping):
extra_headers = extra_headers.items()
for name, value in extra_headers:
set_header(name, value)
build_response(set_header, key)

yield from self.write_http_response(status, headers)
yield from self.write_http_response(SWITCHING_PROTOCOLS,
response_headers)

assert self.state == CONNECTING
self.state = OPEN
self.opening_handshake.set_result(True)

return path


class WebSocketServer:
"""
Expand Down
30 changes: 30 additions & 0 deletions websockets/test_client_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def handler(ws, path):
# Order by status code.
UNAUTHORIZED = http.HTTPStatus.UNAUTHORIZED
FORBIDDEN = http.HTTPStatus.FORBIDDEN
NOT_FOUND = http.HTTPStatus.NOT_FOUND
except AttributeError: # pragma: no cover
class UNAUTHORIZED:
value = 401
Expand All @@ -52,6 +53,10 @@ class FORBIDDEN:
value = 403
phrase = 'Forbidden'

class NOT_FOUND:
value = 404
phrase = 'Not Found'


@contextmanager
def temp_test_server(test, **kwds):
Expand Down Expand Up @@ -301,6 +306,31 @@ def get_response_status(self, set_header):
self.assertIsInstance(request_headers, http.client.HTTPMessage)
self.assertEqual(request_headers.get('origin'), 'http://otherhost')

def test_get_response_status_precedes_handshake(self):
class State:
handshake_called = False

class HandshakeStoringProtocol(WebSocketServerProtocol):
@asyncio.coroutine
def get_response_status(self, set_header):
if self.path != '/valid':
return NOT_FOUND
return (yield from super().get_response_status(set_header))

@asyncio.coroutine
def handshake(self, *args, **kwargs):
State.handshake_called = True
return (yield from super().handshake(*args, **kwargs))

with self.temp_server(create_protocol=HandshakeStoringProtocol):
with self.assertRaises(InvalidStatus) as cm:
self.start_client(path='invalid')
self.assertEqual(cm.exception.code, 404)
self.assertFalse(State.handshake_called)
# Check that our overridden handshake() is working correctly.
self.start_client(path='valid')
self.assertTrue(State.handshake_called)

def assert_client_raises_code(self, code):
with self.assertRaises(InvalidStatus) as raised:
self.start_client()
Expand Down