From ed7d669eadf044d9cd04a0974c497ca8fe688317 Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Mon, 26 Dec 2022 21:31:26 +0100 Subject: [PATCH] Improve tests --- setup.cfg | 2 +- tests/protocols/test_websocket.py | 36 +++++++++++++++---- .../protocols/websockets/websockets_impl.py | 5 ++- uvicorn/protocols/websockets/wsproto_impl.py | 7 +++- 4 files changed, 41 insertions(+), 9 deletions(-) diff --git a/setup.cfg b/setup.cfg index 71d63a395..ff6cb1264 100644 --- a/setup.cfg +++ b/setup.cfg @@ -44,7 +44,7 @@ plugins = [coverage:report] precision = 2 -fail_under = 98.50 +fail_under = 98.70 show_missing = true skip_covered = true exclude_lines = diff --git a/tests/protocols/test_websocket.py b/tests/protocols/test_websocket.py index c92de84b8..84bb2c82e 100644 --- a/tests/protocols/test_websocket.py +++ b/tests/protocols/test_websocket.py @@ -1,4 +1,5 @@ import asyncio +import typing import httpx import pytest @@ -713,11 +714,23 @@ async def app(scope, receive, send): message = await receive() if message["type"] == "websocket.connect": await send_accept_task.wait() - await send({"type": "websocket.accept"}) disconnect_message = await receive() + response: typing.Optional[httpx.Response] = None + async def websocket_session(uri): - await websockets.client.connect(uri) + # await websockets.client.connect(uri) + nonlocal response + async with httpx.AsyncClient() as client: + response = await client.get( + f"http://127.0.0.1:{unused_tcp_port}", + headers={ + "upgrade": "websocket", + "connection": "upgrade", + "sec-websocket-version": "13", + "sec-websocket-key": "dGhlIHNhbXBsZSBub25jZQ==", + }, + ) config = Config( app=app, @@ -731,9 +744,13 @@ async def websocket_session(uri): websocket_session(f"ws://127.0.0.1:{unused_tcp_port}") ) await asyncio.sleep(0.1) - task.cancel() send_accept_task.set() + task.cancel() + + assert response is not None + assert response.status_code == 500, response.text + assert response.text == "Internal Server Error" assert disconnect_message == {"type": "websocket.disconnect", "code": 1006} @@ -744,6 +761,7 @@ async def test_send_close_on_server_shutdown( ws_protocol_cls, http_protocol_cls, unused_tcp_port: int ): disconnect_message = {} + server_shutdown_event = asyncio.Event() async def app(scope, receive, send): nonlocal disconnect_message @@ -755,10 +773,13 @@ async def app(scope, receive, send): disconnect_message = message break + websocket: typing.Optional[websockets.client.WebSocketClientProtocol] = None + async def websocket_session(uri): - async with websockets.client.connect(uri): - while True: - await asyncio.sleep(0.1) + nonlocal websocket + async with websockets.client.connect(uri) as ws_connection: + websocket = ws_connection + await server_shutdown_event.wait() config = Config( app=app, @@ -773,7 +794,10 @@ async def websocket_session(uri): ) await asyncio.sleep(0.1) disconnect_message_before_shutdown = disconnect_message + server_shutdown_event.set() + assert websocket is not None + assert websocket.close_code == 1012 assert disconnect_message_before_shutdown == {} assert disconnect_message == {"type": "websocket.disconnect", "code": 1012} task.cancel() diff --git a/uvicorn/protocols/websockets/websockets_impl.py b/uvicorn/protocols/websockets/websockets_impl.py index ed9cf1372..297203ec6 100644 --- a/uvicorn/protocols/websockets/websockets_impl.py +++ b/uvicorn/protocols/websockets/websockets_impl.py @@ -144,8 +144,11 @@ def connection_lost(self, exc: Optional[Exception]) -> None: def shutdown(self) -> None: self.ws_server.closing = True - if not self.transport.is_closing(): + if self.handshake_completed_event.is_set(): self.fail_connection(1012) + else: + self.send_500_response() + self.transport.close() def on_task_complete(self, task: asyncio.Task) -> None: self.tasks.discard(task) diff --git a/uvicorn/protocols/websockets/wsproto_impl.py b/uvicorn/protocols/websockets/wsproto_impl.py index f2677e004..9723145ab 100644 --- a/uvicorn/protocols/websockets/wsproto_impl.py +++ b/uvicorn/protocols/websockets/wsproto_impl.py @@ -233,7 +233,12 @@ def send_500_response(self) -> None: (b"connection", b"close"), ] if self.conn.connection is None: - output = self.conn.send(wsproto.events.RejectConnection(status_code=500)) + output = self.conn.send( + wsproto.events.RejectConnection(status_code=500, has_body=True) + ) + output += self.conn.send( + wsproto.events.RejectData(data=b"Internal Server Error") + ) else: msg = h11.Response( status_code=500, headers=headers, reason="Internal Server Error"