From 71b0541b48a5c64cccfc6fc332d634ebbd29092f Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Tue, 1 Nov 2022 08:39:37 +0100 Subject: [PATCH] Use correct WebSocket error codes --- tests/protocols/test_websocket.py | 20 ++++++++++++-------- uvicorn/protocols/websockets/wsproto_impl.py | 4 ++-- 2 files changed, 14 insertions(+), 10 deletions(-) diff --git a/tests/protocols/test_websocket.py b/tests/protocols/test_websocket.py index d495f5108..add66aade 100644 --- a/tests/protocols/test_websocket.py +++ b/tests/protocols/test_websocket.py @@ -555,15 +555,15 @@ async def app(scope, receive, send): @pytest.mark.parametrize("http_protocol_cls", HTTP_PROTOCOLS) async def test_not_accept_on_connection_lost(ws_protocol_cls, http_protocol_cls): send_accept_task = asyncio.Event() + disconnect_message = {} async def app(scope, receive, send): - while True: - message = await receive() - if message["type"] == "websocket.connect": - await send_accept_task.wait() - await send({"type": "websocket.accept"}) - elif message["type"] == "websocket.disconnect": - break + nonlocal disconnect_message + message = await receive() + if message["type"] == "websocket.connect": + await send_accept_task.wait() + await send({"type": "websocket.accept"}) + disconnect_message = await receive() async def websocket_session(uri): async with websockets.client.connect(uri): @@ -577,6 +577,8 @@ async def websocket_session(uri): task.cancel() send_accept_task.set() + assert disconnect_message == {"type": "websocket.disconnect", "code": 1006} + @pytest.mark.anyio @pytest.mark.parametrize("ws_protocol_cls", WS_PROTOCOLS) @@ -729,6 +731,7 @@ async def test_server_can_read_messages_in_buffer_after_close( ws_protocol_cls, http_protocol_cls ): frames = [] + client_close_connection = asyncio.Event() class App(WebSocketResponse): async def websocket_connect(self, message): @@ -736,7 +739,7 @@ async def websocket_connect(self, message): # Ensure server doesn't start reading frames from read buffer until # after client has sent close frame, but server is still able to # read these frames - await asyncio.sleep(0.2) + await client_close_connection.wait() async def websocket_receive(self, message): frames.append(message.get("bytes")) @@ -750,6 +753,7 @@ async def send_text(url): config = Config(app=App, ws=ws_protocol_cls, http=http_protocol_cls, lifespan="off") async with run_server(config): await send_text("ws://127.0.0.1:8000") + client_close_connection.set() assert frames == [b"abc", b"abc", b"abc"] diff --git a/uvicorn/protocols/websockets/wsproto_impl.py b/uvicorn/protocols/websockets/wsproto_impl.py index a97766ff5..93d1d0483 100644 --- a/uvicorn/protocols/websockets/wsproto_impl.py +++ b/uvicorn/protocols/websockets/wsproto_impl.py @@ -70,7 +70,7 @@ def connection_made(self, transport): self.logger.log(TRACE_LOG_LEVEL, "%sWebSocket connection made", prefix) def connection_lost(self, exc): - self.queue.put_nowait({"type": "websocket.disconnect"}) + self.queue.put_nowait({"type": "websocket.disconnect", "code": 1005}) self.connections.remove(self) if self.logger.level <= TRACE_LOG_LEVEL: @@ -267,7 +267,7 @@ async def send(self, message): self.transport.write(output) elif message_type == "websocket.close": - self.queue.put_nowait({"type": "websocket.disconnect", "code": None}) + self.queue.put_nowait({"type": "websocket.disconnect", "code": 1006}) self.logger.info( '%s - "WebSocket %s" 403', self.scope["client"],