diff --git a/tests/protocols/test_websocket.py b/tests/protocols/test_websocket.py index 931cc43c8..311a6875e 100644 --- a/tests/protocols/test_websocket.py +++ b/tests/protocols/test_websocket.py @@ -1171,15 +1171,7 @@ async def websocket_session(url): with pytest.raises(websockets.exceptions.InvalidStatusCode) as exc_info: async with websockets.client.connect(url): pass # pragma: no cover - if ws_protocol_cls == WSProtocol: - # ws protocol has started to send the response when it - # fails with the subsequent invalid message so it cannot - # undo that, we will get the initial 404 response - assert exc_info.value.status_code == 404 - else: - # websockets protocol sends its response in one chunk - # and can override the already started response with a 500 - assert exc_info.value.status_code == 500 + assert exc_info.value.status_code == 500 config = Config( app=app, @@ -1218,10 +1210,7 @@ async def websocket_session(url): with pytest.raises(websockets.exceptions.InvalidStatusCode) as exc_info: async with websockets.client.connect(url): pass # pragma: no cover - if ws_protocol_cls == WSProtocol: - assert exc_info.value.status_code == 404 - else: - assert exc_info.value.status_code == 500 + assert exc_info.value.status_code == 500 config = Config( app=app, diff --git a/uvicorn/protocols/websockets/wsproto_impl.py b/uvicorn/protocols/websockets/wsproto_impl.py index 0e33178a3..f2985599b 100644 --- a/uvicorn/protocols/websockets/wsproto_impl.py +++ b/uvicorn/protocols/websockets/wsproto_impl.py @@ -70,7 +70,10 @@ def __init__( self.queue: asyncio.Queue["WebSocketEvent"] = asyncio.Queue() self.handshake_complete = False self.close_sent = False - self.response_started = False + + # Rejection state + self.reject_event: typing.Optional[typing.Any] = None + self.response_started: bool = False # we have sent response start self.conn = wsproto.WSConnection(connection_type=ConnectionType.SERVER) @@ -224,6 +227,8 @@ def handle_ping(self, event: events.Ping) -> None: self.transport.write(self.conn.send(event.response())) def send_500_response(self) -> None: + if self.response_started or self.handshake_complete: + return # we cannot send responses anymore headers = [ (b"content-type", b"text/plain; charset=utf-8"), (b"connection", b"close"), @@ -243,15 +248,13 @@ async def run_asgi(self) -> None: result = await self.app(self.scope, self.receive, self.send) except BaseException: self.logger.exception("Exception in ASGI application\n") - if not self.response_started: - self.send_500_response() + self.send_500_response() self.transport.close() else: if not self.handshake_complete: msg = "ASGI callable returned without completing handshake." self.logger.error(msg) - if not self.response_started: - self.send_500_response() + self.send_500_response() self.transport.close() elif result is not None: msg = "ASGI callable should return None, but returned '%s'." @@ -264,7 +267,8 @@ async def send(self, message: "ASGISendEvent") -> None: message_type = message["type"] if not self.handshake_complete: - if not self.response_started: + if not (self.response_started or self.reject_event): + # a rejection event has not been sent yet if message_type == "websocket.accept": message = typing.cast("WebSocketAcceptEvent", message) self.logger.info( @@ -319,9 +323,10 @@ async def send(self, message: "ASGISendEvent") -> None: headers=list(message["headers"]), has_body=True, ) - output = self.conn.send(event) - self.transport.write(output) - self.response_started = True + # Create the event here but do not send it, the ASGI spec + # suggest that we wait for the body event before sending. + # https://asgi.readthedocs.io/en/latest/specs/www.html#response-start-send-event + self.reject_event = event else: msg = ( @@ -331,14 +336,23 @@ async def send(self, message: "ASGISendEvent") -> None: ) raise RuntimeError(msg % message_type) else: + # we have started a rejection process with http.response.start if message_type == "websocket.http.response.body": message = typing.cast("WebSocketResponseBodyEvent", message) body_finished = not message.get("more_body", False) reject_data = events.RejectData( data=message["body"], body_finished=body_finished ) + if self.reject_event is not None: + # Prepend with the reject event now that we have a body event. + output = self.conn.send(self.reject_event) + self.transport.write(output) + self.reject_event = None + self.response_started = True + output = self.conn.send(reject_data) self.transport.write(output) + if body_finished: self.queue.put_nowait( {"type": "websocket.disconnect", "code": 1006}