From 36da9c6171b9666f78569f51b5be77777edc4e7b Mon Sep 17 00:00:00 2001 From: Aviram Hassan Date: Fri, 30 Jul 2021 12:27:23 +0300 Subject: [PATCH 1/3] Don't abruptly shutdown active websockets on graceful shutdown. --- uvicorn/protocols/websockets/websockets_impl.py | 1 - uvicorn/protocols/websockets/wsproto_impl.py | 7 +++---- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/uvicorn/protocols/websockets/websockets_impl.py b/uvicorn/protocols/websockets/websockets_impl.py index ddc4aa536..b6db8d2ea 100644 --- a/uvicorn/protocols/websockets/websockets_impl.py +++ b/uvicorn/protocols/websockets/websockets_impl.py @@ -98,7 +98,6 @@ def connection_lost(self, exc): def shutdown(self): self.ws_server.closing = True - self.transport.close() def on_task_complete(self, task): self.tasks.discard(task) diff --git a/uvicorn/protocols/websockets/wsproto_impl.py b/uvicorn/protocols/websockets/wsproto_impl.py index a870f1710..9f516234c 100644 --- a/uvicorn/protocols/websockets/wsproto_impl.py +++ b/uvicorn/protocols/websockets/wsproto_impl.py @@ -129,10 +129,9 @@ def resume_writing(self): self.writable.set() def shutdown(self): - self.queue.put_nowait({"type": "websocket.disconnect", "code": 1012}) - output = self.conn.send(wsproto.events.CloseConnection(code=1012)) - self.transport.write(output) - self.transport.close() + """ + Don't do anything - `run_asgi ` closes all resources after answering a request + """ def on_task_complete(self, task): self.tasks.discard(task) From 443414d8a71c399ef6153f4cfd1fdca0ccf07109 Mon Sep 17 00:00:00 2001 From: Aviram Hassan Date: Tue, 31 Aug 2021 14:56:43 +0300 Subject: [PATCH 2/3] fix force shutdown issues with ws --- uvicorn/_handlers/http.py | 8 ++++++-- uvicorn/protocols/websockets/websockets_impl.py | 6 ++++++ uvicorn/protocols/websockets/wsproto_impl.py | 6 ++++++ 3 files changed, 18 insertions(+), 2 deletions(-) diff --git a/uvicorn/_handlers/http.py b/uvicorn/_handlers/http.py index e810f6626..0805d9a35 100644 --- a/uvicorn/_handlers/http.py +++ b/uvicorn/_handlers/http.py @@ -39,7 +39,11 @@ async def handle_http( protocol = config.http_protocol_class( # type: ignore[call-arg, operator] config=config, server_state=server_state, - on_connection_lost=lambda: connection_lost.set_result(True), + on_connection_lost=lambda: ( + connection_lost.set_result(True) + if not connection_lost.cancelled() + else None + ), ) transport = writer.transport transport.set_protocol(protocol) @@ -55,7 +59,7 @@ async def handle_http( @task.add_done_callback def retrieve_exception(task: asyncio.Task) -> None: - exc = task.exception() + exc = task.exception() if not task.cancelled() else None if exc is None: return diff --git a/uvicorn/protocols/websockets/websockets_impl.py b/uvicorn/protocols/websockets/websockets_impl.py index 702c16aec..1024fe603 100644 --- a/uvicorn/protocols/websockets/websockets_impl.py +++ b/uvicorn/protocols/websockets/websockets_impl.py @@ -1,4 +1,5 @@ import asyncio +import concurrent import http import inspect import logging @@ -200,6 +201,11 @@ async def run_asgi(self): """ try: result = await self.app(self.scope, self.asgi_receive, self.asgi_send) + except (concurrent.futures.CancelledError, asyncio.CancelledError): + self.logger.error("ASGI callable was cancelled while running") + if not self.handshake_started_event.is_set(): + self.send_500_response() + self.transport.close() except BaseException as exc: self.closed_event.set() msg = "Exception in ASGI application\n" diff --git a/uvicorn/protocols/websockets/wsproto_impl.py b/uvicorn/protocols/websockets/wsproto_impl.py index 9f516234c..735ee6879 100644 --- a/uvicorn/protocols/websockets/wsproto_impl.py +++ b/uvicorn/protocols/websockets/wsproto_impl.py @@ -1,4 +1,5 @@ import asyncio +import concurrent import logging from typing import Callable from urllib.parse import unquote @@ -225,6 +226,11 @@ def send_500_response(self): async def run_asgi(self): try: result = await self.app(self.scope, self.receive, self.send) + except (concurrent.futures.CancelledError, asyncio.CancelledError): + self.logger.error("ASGI callable was cancelled while running") + if not self.handshake_complete: + self.send_500_response() + self.transport.close() except BaseException as exc: msg = "Exception in ASGI application\n" self.logger.error(msg, exc_info=exc) From dfcc9446ad98477f3e2fae25536c496541380a31 Mon Sep 17 00:00:00 2001 From: Aviram Hassan Date: Tue, 19 Oct 2021 09:23:07 +0300 Subject: [PATCH 3/3] add test for graceful shutdown --- tests/protocols/test_websocket.py | 31 +++++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/tests/protocols/test_websocket.py b/tests/protocols/test_websocket.py index d21be54de..76d067db5 100644 --- a/tests/protocols/test_websocket.py +++ b/tests/protocols/test_websocket.py @@ -1,3 +1,5 @@ +import asyncio + import httpx import pytest @@ -5,6 +7,7 @@ from tests.utils import run_server from uvicorn.config import Config from uvicorn.protocols.websockets.wsproto_impl import WSProtocol +from uvicorn.server import Server try: import websockets @@ -539,3 +542,31 @@ async def send_text(url): with pytest.raises(websockets.ConnectionClosedError) as e: data = await send_text("ws://127.0.0.1:8000") assert e.value.code == expected_result + + +@pytest.mark.asyncio +@pytest.mark.parametrize("ws_protocol_cls", WS_PROTOCOLS) +@pytest.mark.parametrize("http_protocol_cls", HTTP_PROTOCOLS) +async def test_graceful_shutdown(ws_protocol_cls, http_protocol_cls): + class App(WebSocketResponse): + async def websocket_connect(self, message): + await self.send({"type": "websocket.accept"}) + + async def websocket_receive(self, message): + _bytes = message.get("bytes") + await self.send({"type": "websocket.send", "bytes": _bytes}) + + config = Config(app=App, ws=ws_protocol_cls, http=http_protocol_cls, lifespan="off") + server = Server(config=config) + cancel_handle = asyncio.ensure_future(server.serve()) + await asyncio.sleep(0.1) + + async with websockets.connect("ws://127.0.0.1:8000") as websocket: + await websocket.ping() + shutdown = asyncio.ensure_future(server.shutdown()) + data = b"abc" + await websocket.send(data) + resp = await websocket.recv() + assert data == resp + await shutdown + cancel_handle.cancel()