diff --git a/setup.cfg b/setup.cfg index 71d63a395..e90b314cd 100644 --- a/setup.cfg +++ b/setup.cfg @@ -44,7 +44,7 @@ plugins = [coverage:report] precision = 2 -fail_under = 98.50 +fail_under = 98.80 show_missing = true skip_covered = true exclude_lines = diff --git a/tests/protocols/test_websocket.py b/tests/protocols/test_websocket.py index c92de84b8..bb5c47c83 100644 --- a/tests/protocols/test_websocket.py +++ b/tests/protocols/test_websocket.py @@ -1,4 +1,5 @@ import asyncio +import typing import httpx import pytest @@ -713,11 +714,22 @@ async def app(scope, receive, send): message = await receive() if message["type"] == "websocket.connect": await send_accept_task.wait() - await send({"type": "websocket.accept"}) disconnect_message = await receive() + response: typing.Optional[httpx.Response] = None + async def websocket_session(uri): - await websockets.client.connect(uri) + nonlocal response + async with httpx.AsyncClient() as client: + response = await client.get( + f"http://127.0.0.1:{unused_tcp_port}", + headers={ + "upgrade": "websocket", + "connection": "upgrade", + "sec-websocket-version": "13", + "sec-websocket-key": "dGhlIHNhbXBsZSBub25jZQ==", + }, + ) config = Config( app=app, @@ -731,9 +743,12 @@ async def websocket_session(uri): websocket_session(f"ws://127.0.0.1:{unused_tcp_port}") ) await asyncio.sleep(0.1) - task.cancel() send_accept_task.set() + task.cancel() + assert response is not None + assert response.status_code == 500, response.text + assert response.text == "Internal Server Error" assert disconnect_message == {"type": "websocket.disconnect", "code": 1006} @@ -744,6 +759,7 @@ async def test_send_close_on_server_shutdown( ws_protocol_cls, http_protocol_cls, unused_tcp_port: int ): disconnect_message = {} + server_shutdown_event = asyncio.Event() async def app(scope, receive, send): nonlocal disconnect_message @@ -755,10 +771,13 @@ async def app(scope, receive, send): disconnect_message = message break + websocket: typing.Optional[websockets.client.WebSocketClientProtocol] = None + async def websocket_session(uri): - async with websockets.client.connect(uri): - while True: - await asyncio.sleep(0.1) + nonlocal websocket + async with websockets.client.connect(uri) as ws_connection: + websocket = ws_connection + await server_shutdown_event.wait() config = Config( app=app, @@ -773,7 +792,10 @@ async def websocket_session(uri): ) await asyncio.sleep(0.1) disconnect_message_before_shutdown = disconnect_message + server_shutdown_event.set() + assert websocket is not None + assert websocket.close_code == 1012 assert disconnect_message_before_shutdown == {} assert disconnect_message == {"type": "websocket.disconnect", "code": 1012} task.cancel() diff --git a/uvicorn/protocols/websockets/websockets_impl.py b/uvicorn/protocols/websockets/websockets_impl.py index d7650d179..297203ec6 100644 --- a/uvicorn/protocols/websockets/websockets_impl.py +++ b/uvicorn/protocols/websockets/websockets_impl.py @@ -90,7 +90,6 @@ def __init__( self.connect_sent = False self.lost_connection_before_handshake = False self.accepted_subprotocol: Optional[Subprotocol] = None - self.transfer_data_task: asyncio.Task = None # type: ignore[assignment] self.ws_server: Server = Server() # type: ignore[assignment] @@ -145,6 +144,10 @@ def connection_lost(self, exc: Optional[Exception]) -> None: def shutdown(self) -> None: self.ws_server.closing = True + if self.handshake_completed_event.is_set(): + self.fail_connection(1012) + else: + self.send_500_response() self.transport.close() def on_task_complete(self, task: asyncio.Task) -> None: diff --git a/uvicorn/protocols/websockets/wsproto_impl.py b/uvicorn/protocols/websockets/wsproto_impl.py index f2677e004..1d76f3a88 100644 --- a/uvicorn/protocols/websockets/wsproto_impl.py +++ b/uvicorn/protocols/websockets/wsproto_impl.py @@ -4,7 +4,6 @@ import typing from urllib.parse import unquote -import h11 import wsproto from wsproto import ConnectionType, events from wsproto.connection import ConnectionState @@ -232,17 +231,14 @@ def send_500_response(self) -> None: (b"content-type", b"text/plain; charset=utf-8"), (b"connection", b"close"), ] - if self.conn.connection is None: - output = self.conn.send(wsproto.events.RejectConnection(status_code=500)) - else: - msg = h11.Response( - status_code=500, headers=headers, reason="Internal Server Error" + output = self.conn.send( + wsproto.events.RejectConnection( + status_code=500, headers=headers, has_body=True ) - output = self.conn.send(msg) - msg = h11.Data(data=b"Internal Server Error") - output += self.conn.send(msg) - msg = h11.EndOfMessage() - output += self.conn.send(msg) + ) + output += self.conn.send( + wsproto.events.RejectData(data=b"Internal Server Error") + ) self.transport.write(output) async def run_asgi(self) -> None: