From 656818496e136139c1daeb9ab86c8253b7bbad2d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kristj=C3=A1n=20Valur=20J=C3=B3nsson?= Date: Sun, 17 Dec 2023 20:28:09 +0000 Subject: [PATCH] Support the WebSocket Denial Response ASGI extension (#1916) Co-authored-by: Marcelo Trylesinski --- tests/protocols/test_websocket.py | 349 +++++++++++++++++- tests/response.py | 5 +- uvicorn/_types.py | 2 +- .../protocols/websockets/websockets_impl.py | 45 ++- uvicorn/protocols/websockets/wsproto_impl.py | 61 ++- 5 files changed, 447 insertions(+), 15 deletions(-) diff --git a/tests/protocols/test_websocket.py b/tests/protocols/test_websocket.py index 1ba20fc7c..df2415f2e 100644 --- a/tests/protocols/test_websocket.py +++ b/tests/protocols/test_websocket.py @@ -10,13 +10,16 @@ from websockets.extensions.permessage_deflate import ClientPerMessageDeflateFactory from websockets.typing import Subprotocol +from tests.response import Response from tests.utils import run_server from uvicorn._types import ( ASGIReceiveCallable, + ASGIReceiveEvent, ASGISendCallable, Scope, WebSocketCloseEvent, WebSocketDisconnectEvent, + WebSocketResponseStartEvent, ) from uvicorn.config import Config from uvicorn.protocols.websockets.websockets_impl import WebSocketProtocol @@ -55,6 +58,21 @@ async def asgi(self): break +async def wsresponse(url): + """ + A simple websocket connection request and response helper + """ + url = url.replace("ws:", "http:") + headers = { + "connection": "upgrade", + "upgrade": "websocket", + "Sec-WebSocket-Key": "x3JJHMbDL1EzLkh9GBhXDw==", + "Sec-WebSocket-Version": "13", + } + async with httpx.AsyncClient() as client: + return await client.get(url, headers=headers) + + @pytest.mark.anyio async def test_invalid_upgrade( ws_protocol_cls: "typing.Type[WSProtocol | WebSocketProtocol]", @@ -942,7 +960,10 @@ async def test_server_reject_connection( http_protocol_cls: "typing.Type[H11Protocol | HttpToolsProtocol]", unused_tcp_port: int, ): + disconnected_message: ASGIReceiveEvent = {} # type: ignore + async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable): + nonlocal disconnected_message assert scope["type"] == "websocket" # Pull up first recv message. @@ -955,15 +976,241 @@ async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable # This doesn't raise `TypeError`: # See https://github.com/encode/uvicorn/issues/244 + disconnected_message = await receive() + + async def websocket_session(url): + with pytest.raises(websockets.exceptions.InvalidStatusCode) as exc_info: + async with websockets.client.connect(url): + pass # pragma: no cover + assert exc_info.value.status_code == 403 + + config = Config( + app=app, + ws=ws_protocol_cls, + http=http_protocol_cls, + lifespan="off", + port=unused_tcp_port, + ) + async with run_server(config): + await websocket_session(f"ws://127.0.0.1:{unused_tcp_port}") + + assert disconnected_message == {"type": "websocket.disconnect", "code": 1006} + + +@pytest.mark.anyio +async def test_server_reject_connection_with_response( + ws_protocol_cls: "typing.Type[WSProtocol | WebSocketProtocol]", + http_protocol_cls: "typing.Type[H11Protocol | HttpToolsProtocol]", + unused_tcp_port: int, +): + disconnected_message = {} + + async def app(scope, receive, send): + nonlocal disconnected_message + assert scope["type"] == "websocket" + assert "websocket.http.response" in scope["extensions"] + + # Pull up first recv message. message = await receive() - assert message["type"] == "websocket.disconnect" + assert message["type"] == "websocket.connect" + + # Reject the connection with a response + response = Response(b"goodbye", status_code=400) + await response(scope, receive, send) + disconnected_message = await receive() + + async def websocket_session(url): + response = await wsresponse(url) + assert response.status_code == 400 + assert response.content == b"goodbye" + + config = Config( + app=app, + ws=ws_protocol_cls, + http=http_protocol_cls, + lifespan="off", + port=unused_tcp_port, + ) + async with run_server(config): + await websocket_session(f"ws://127.0.0.1:{unused_tcp_port}") + + assert disconnected_message == {"type": "websocket.disconnect", "code": 1006} + + +@pytest.mark.anyio +async def test_server_reject_connection_with_multibody_response( + ws_protocol_cls: "typing.Type[WSProtocol | WebSocketProtocol]", + http_protocol_cls: "typing.Type[H11Protocol | HttpToolsProtocol]", + unused_tcp_port: int, +): + disconnected_message: ASGIReceiveEvent = {} # type: ignore + + async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable): + nonlocal disconnected_message + assert scope["type"] == "websocket" + assert "extensions" in scope + assert "websocket.http.response" in scope["extensions"] + + # Pull up first recv message. + message = await receive() + assert message["type"] == "websocket.connect" + await send( + { + "type": "websocket.http.response.start", + "status": 400, + "headers": [ + (b"Content-Length", b"20"), + (b"Content-Type", b"text/plain"), + ], + } + ) + await send( + { + "type": "websocket.http.response.body", + "body": b"x" * 10, + "more_body": True, + } + ) + await send({"type": "websocket.http.response.body", "body": b"y" * 10}) + disconnected_message = await receive() async def websocket_session(url: str): - try: + response = await wsresponse(url) + assert response.status_code == 400 + assert response.content == (b"x" * 10) + (b"y" * 10) + + config = Config( + app=app, + ws=ws_protocol_cls, + http=http_protocol_cls, + lifespan="off", + port=unused_tcp_port, + ) + async with run_server(config): + await websocket_session(f"ws://127.0.0.1:{unused_tcp_port}") + + assert disconnected_message == {"type": "websocket.disconnect", "code": 1006} + + +@pytest.mark.anyio +async def test_server_reject_connection_with_invalid_status( + ws_protocol_cls: "typing.Type[WSProtocol | WebSocketProtocol]", + http_protocol_cls: "typing.Type[H11Protocol | HttpToolsProtocol]", + unused_tcp_port: int, +): + # this test checks that even if there is an error in the response, the server + # can successfully send a 500 error back to the client + async def app(scope, receive, send): + assert scope["type"] == "websocket" + assert "websocket.http.response" in scope["extensions"] + + # Pull up first recv message. + message = await receive() + assert message["type"] == "websocket.connect" + + message = { + "type": "websocket.http.response.start", + "status": 700, # invalid status code + "headers": [(b"Content-Length", b"0"), (b"Content-Type", b"text/plain")], + } + await send(message) + message = { + "type": "websocket.http.response.body", + "body": b"", + } + await send(message) + + async def websocket_session(url): + response = await wsresponse(url) + assert response.status_code == 500 + assert response.content == b"Internal Server Error" + + config = Config( + app=app, + ws=ws_protocol_cls, + http=http_protocol_cls, + lifespan="off", + port=unused_tcp_port, + ) + async with run_server(config): + await websocket_session(f"ws://127.0.0.1:{unused_tcp_port}") + + +@pytest.mark.anyio +async def test_server_reject_connection_with_body_nolength( + ws_protocol_cls: "typing.Type[WSProtocol | WebSocketProtocol]", + http_protocol_cls: "typing.Type[H11Protocol | HttpToolsProtocol]", + unused_tcp_port: int, +): + # test that the server can send a response with a body but no content-length + async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable): + assert scope["type"] == "websocket" + assert "extensions" in scope + assert "websocket.http.response" in scope["extensions"] + + # Pull up first recv message. + message = await receive() + assert message["type"] == "websocket.connect" + + await send( + { + "type": "websocket.http.response.start", + "status": 403, + "headers": [], + } + ) + await send({"type": "websocket.http.response.body", "body": b"hardbody"}) + + async def websocket_session(url): + response = await wsresponse(url) + assert response.status_code == 403 + assert response.content == b"hardbody" + if ws_protocol_cls == WSProtocol: # pragma: no cover + # wsproto automatically makes the message chunked + assert response.headers["transfer-encoding"] == "chunked" + else: # pragma: no cover + # websockets automatically adds a content-length + assert response.headers["content-length"] == "8" + + config = Config( + app=app, + ws=ws_protocol_cls, + http=http_protocol_cls, + lifespan="off", + port=unused_tcp_port, + ) + async with run_server(config): + await websocket_session(f"ws://127.0.0.1:{unused_tcp_port}") + + +@pytest.mark.anyio +async def test_server_reject_connection_with_invalid_msg( + ws_protocol_cls: "typing.Type[WSProtocol | WebSocketProtocol]", + http_protocol_cls: "typing.Type[H11Protocol | HttpToolsProtocol]", + unused_tcp_port: int, +): + async def app(scope, receive, send): + assert scope["type"] == "websocket" + assert "websocket.http.response" in scope["extensions"] + + # Pull up first recv message. + message = await receive() + assert message["type"] == "websocket.connect" + + message = { + "type": "websocket.http.response.start", + "status": 404, + "headers": [(b"Content-Length", b"0"), (b"Content-Type", b"text/plain")], + } + await send(message) + # send invalid message. This will raise an exception here + await send(message) + + async def websocket_session(url): + with pytest.raises(websockets.exceptions.InvalidStatusCode) as exc_info: async with websockets.client.connect(url): pass # pragma: no cover - except Exception: - pass + assert exc_info.value.status_code == 404 config = Config( app=app, @@ -976,6 +1223,100 @@ async def websocket_session(url: str): await websocket_session(f"ws://127.0.0.1:{unused_tcp_port}") +@pytest.mark.anyio +async def test_server_reject_connection_with_missing_body( + ws_protocol_cls: "typing.Type[WSProtocol | WebSocketProtocol]", + http_protocol_cls: "typing.Type[H11Protocol | HttpToolsProtocol]", + unused_tcp_port: int, +): + async def app(scope, receive, send): + assert scope["type"] == "websocket" + assert "websocket.http.response" in scope["extensions"] + + # Pull up first recv message. + message = await receive() + assert message["type"] == "websocket.connect" + + message = { + "type": "websocket.http.response.start", + "status": 404, + "headers": [(b"Content-Length", b"0"), (b"Content-Type", b"text/plain")], + } + await send(message) + # no further message + + async def websocket_session(url): + with pytest.raises(websockets.exceptions.InvalidStatusCode) as exc_info: + async with websockets.client.connect(url): + pass # pragma: no cover + assert exc_info.value.status_code == 404 + + config = Config( + app=app, + ws=ws_protocol_cls, + http=http_protocol_cls, + lifespan="off", + port=unused_tcp_port, + ) + async with run_server(config): + await websocket_session(f"ws://127.0.0.1:{unused_tcp_port}") + + +@pytest.mark.anyio +async def test_server_multiple_websocket_http_response_start_events( + ws_protocol_cls: "typing.Type[WSProtocol | WebSocketProtocol]", + http_protocol_cls: "typing.Type[H11Protocol | HttpToolsProtocol]", + unused_tcp_port: int, +): + """ + The server should raise an exception if it sends multiple + websocket.http.response.start events. + """ + exception_message: typing.Optional[str] = None + + async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable): + nonlocal exception_message + assert scope["type"] == "websocket" + assert "extensions" in scope + assert "websocket.http.response" in scope["extensions"] + + # Pull up first recv message. + message = await receive() + assert message["type"] == "websocket.connect" + + start_event: WebSocketResponseStartEvent = { + "type": "websocket.http.response.start", + "status": 404, + "headers": [(b"Content-Length", b"0"), (b"Content-Type", b"text/plain")], + } + await send(start_event) + try: + await send(start_event) + except Exception as exc: + exception_message = str(exc) + + async def websocket_session(url: str): + with pytest.raises(websockets.exceptions.InvalidStatusCode) as exc_info: + async with websockets.client.connect(url): + pass + assert exc_info.value.status_code == 404 + + config = Config( + app=app, + ws=ws_protocol_cls, + http=http_protocol_cls, + lifespan="off", + port=unused_tcp_port, + ) + async with run_server(config): + await websocket_session(f"ws://127.0.0.1:{unused_tcp_port}") + + assert exception_message == ( + "Expected ASGI message 'websocket.http.response.body' but got " + "'websocket.http.response.start'." + ) + + @pytest.mark.anyio async def test_server_can_read_messages_in_buffer_after_close( ws_protocol_cls: "typing.Type[WSProtocol | WebSocketProtocol]", diff --git a/tests/response.py b/tests/response.py index 774dee6fa..55766d3f1 100644 --- a/tests/response.py +++ b/tests/response.py @@ -10,9 +10,10 @@ def __init__(self, content, status_code=200, headers=None, media_type=None): self.set_content_length() async def __call__(self, scope, receive, send) -> None: + prefix = "websocket." if scope["type"] == "websocket" else "" await send( { - "type": "http.response.start", + "type": prefix + "http.response.start", "status": self.status_code, "headers": [ [key.encode(), value.encode()] @@ -20,7 +21,7 @@ async def __call__(self, scope, receive, send) -> None: ], } ) - await send({"type": "http.response.body", "body": self.body}) + await send({"type": prefix + "http.response.body", "body": self.body}) def render(self, content) -> bytes: if isinstance(content, bytes): diff --git a/uvicorn/_types.py b/uvicorn/_types.py index be96d940b..2f689d960 100644 --- a/uvicorn/_types.py +++ b/uvicorn/_types.py @@ -199,7 +199,7 @@ class WebSocketResponseStartEvent(TypedDict): class WebSocketResponseBodyEvent(TypedDict): type: Literal["websocket.http.response.body"] body: bytes - more_body: bool + more_body: NotRequired[bool] class WebSocketDisconnectEvent(TypedDict): diff --git a/uvicorn/protocols/websockets/websockets_impl.py b/uvicorn/protocols/websockets/websockets_impl.py index 94f40f233..d22f3c539 100644 --- a/uvicorn/protocols/websockets/websockets_impl.py +++ b/uvicorn/protocols/websockets/websockets_impl.py @@ -29,6 +29,8 @@ WebSocketConnectEvent, WebSocketDisconnectEvent, WebSocketReceiveEvent, + WebSocketResponseBodyEvent, + WebSocketResponseStartEvent, WebSocketScope, WebSocketSendEvent, ) @@ -196,6 +198,7 @@ async def process_request( "headers": asgi_headers, "subprotocols": subprotocols, "state": self.app_state.copy(), + "extensions": {"websocket.http.response": {}}, } task = self.loop.create_task(self.run_asgi()) task.add_done_callback(self.on_task_complete) @@ -302,14 +305,31 @@ async def asgi_send(self, message: "ASGISendEvent") -> None: self.handshake_started_event.set() self.closed_event.set() + elif message_type == "websocket.http.response.start": + message = cast("WebSocketResponseStartEvent", message) + self.logger.info( + '%s - "WebSocket %s" %d', + self.scope["client"], + get_path_with_query_string(self.scope), + message["status"], + ) + # websockets requires the status to be an enum. look it up. + status = http.HTTPStatus(message["status"]) + headers = [ + (name.decode("latin-1"), value.decode("latin-1")) + for name, value in message.get("headers", []) + ] + self.initial_response = (status, headers, b"") + self.handshake_started_event.set() + else: msg = ( - "Expected ASGI message 'websocket.accept' or 'websocket.close', " - "but got '%s'." + "Expected ASGI message 'websocket.accept', 'websocket.close', " + "or 'websocket.http.response.start' but got '%s'." ) raise RuntimeError(msg % message_type) - elif not self.closed_event.is_set(): + elif not self.closed_event.is_set() and self.initial_response is None: await self.handshake_completed_event.wait() if message_type == "websocket.send": @@ -333,8 +353,25 @@ async def asgi_send(self, message: "ASGISendEvent") -> None: ) raise RuntimeError(msg % message_type) + elif self.initial_response is not None: + if message_type == "websocket.http.response.body": + message = cast("WebSocketResponseBodyEvent", message) + body = self.initial_response[2] + message["body"] + self.initial_response = self.initial_response[:2] + (body,) + if not message.get("more_body", False): + self.closed_event.set() + else: + msg = ( + "Expected ASGI message 'websocket.http.response.body' " + "but got '%s'." + ) + raise RuntimeError(msg % message_type) + else: - msg = "Unexpected ASGI message '%s', after sending 'websocket.close'." + msg = ( + "Unexpected ASGI message '%s', after sending 'websocket.close' " + "or response already completed." + ) raise RuntimeError(msg % message_type) async def asgi_receive( diff --git a/uvicorn/protocols/websockets/wsproto_impl.py b/uvicorn/protocols/websockets/wsproto_impl.py index d682eb9f9..db95f57a3 100644 --- a/uvicorn/protocols/websockets/wsproto_impl.py +++ b/uvicorn/protocols/websockets/wsproto_impl.py @@ -15,6 +15,8 @@ WebSocketAcceptEvent, WebSocketCloseEvent, WebSocketEvent, + WebSocketResponseBodyEvent, + WebSocketResponseStartEvent, WebSocketScope, WebSocketSendEvent, ) @@ -63,6 +65,9 @@ def __init__( self.handshake_complete = False self.close_sent = False + # Rejection state + self.response_started = False + self.conn = wsproto.WSConnection(connection_type=ConnectionType.SERVER) self.read_paused = False @@ -171,6 +176,7 @@ def handle_connect(self, event: events.Request) -> None: "headers": headers, "subprotocols": event.subprotocols, "state": self.app_state.copy(), + "extensions": {"websocket.http.response": {}}, } self.queue.put_nowait({"type": "websocket.connect"}) task = self.loop.create_task(self.run_asgi()) @@ -206,6 +212,8 @@ def handle_ping(self, event: events.Ping) -> None: self.transport.write(self.conn.send(event.response())) def send_500_response(self) -> None: + if self.response_started or self.handshake_complete: + return # we cannot send responses anymore headers = [ (b"content-type", b"text/plain; charset=utf-8"), (b"connection", b"close"), @@ -225,8 +233,7 @@ async def run_asgi(self) -> None: result = await self.app(self.scope, self.receive, self.send) except BaseException: self.logger.exception("Exception in ASGI application\n") - if not self.handshake_complete: - self.send_500_response() + self.send_500_response() self.transport.close() else: if not self.handshake_complete: @@ -282,14 +289,37 @@ async def send(self, message: "ASGISendEvent") -> None: self.transport.write(output) self.transport.close() + elif message_type == "websocket.http.response.start": + message = typing.cast("WebSocketResponseStartEvent", message) + # ensure status code is in the valid range + if not (100 <= message["status"] < 600): + msg = "Invalid HTTP status code '%d' in response." + raise RuntimeError(msg % message["status"]) + self.logger.info( + '%s - "WebSocket %s" %d', + self.scope["client"], + get_path_with_query_string(self.scope), + message["status"], + ) + self.handshake_complete = True + event = events.RejectConnection( + status_code=message["status"], + headers=list(message["headers"]), + has_body=True, + ) + output = self.conn.send(event) + self.transport.write(output) + self.response_started = True + else: msg = ( - "Expected ASGI message 'websocket.accept' or 'websocket.close', " + "Expected ASGI message 'websocket.accept', 'websocket.close' " + "or 'websocket.http.response.start' " "but got '%s'." ) raise RuntimeError(msg % message_type) - elif not self.close_sent: + elif not self.close_sent and not self.response_started: if message_type == "websocket.send": message = typing.cast("WebSocketSendEvent", message) bytes_data = message.get("bytes") @@ -320,6 +350,29 @@ async def send(self, message: "ASGISendEvent") -> None: " but got '%s'." ) raise RuntimeError(msg % message_type) + elif self.response_started: + if message_type == "websocket.http.response.body": + message = typing.cast("WebSocketResponseBodyEvent", message) + body_finished = not message.get("more_body", False) + reject_data = events.RejectData( + data=message["body"], body_finished=body_finished + ) + output = self.conn.send(reject_data) + self.transport.write(output) + + if body_finished: + self.queue.put_nowait( + {"type": "websocket.disconnect", "code": 1006} + ) + self.close_sent = True + self.transport.close() + + else: + msg = ( + "Expected ASGI message 'websocket.http.response.body' " + "but got '%s'." + ) + raise RuntimeError(msg % message_type) else: msg = "Unexpected ASGI message '%s', after sending 'websocket.close'."