Skip to content

Commit

Permalink
Use correct WebSocket error codes (#1753)
Browse files Browse the repository at this point in the history
  • Loading branch information
Kludex authored Nov 20, 2022
1 parent 53d7d1e commit 41156aa
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 17 deletions.
30 changes: 19 additions & 11 deletions tests/protocols/test_websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -553,22 +553,22 @@ async def app(scope, receive, send):
@pytest.mark.anyio
@pytest.mark.parametrize("ws_protocol_cls", WS_PROTOCOLS)
@pytest.mark.parametrize("http_protocol_cls", HTTP_PROTOCOLS)
async def test_not_accept_on_connection_lost(ws_protocol_cls, http_protocol_cls):
async def test_connection_lost_before_handshake_complete(
ws_protocol_cls, http_protocol_cls
):
send_accept_task = asyncio.Event()
disconnect_message = {}

async def app(scope, receive, send):
while True:
message = await receive()
if message["type"] == "websocket.connect":
await send_accept_task.wait()
await send({"type": "websocket.accept"})
elif message["type"] == "websocket.disconnect":
break
nonlocal disconnect_message
message = await receive()
if message["type"] == "websocket.connect":
await send_accept_task.wait()
await send({"type": "websocket.accept"})
disconnect_message = await receive()

async def websocket_session(uri):
async with websockets.client.connect(uri):
while True:
await asyncio.sleep(0.1)
await websockets.client.connect(uri)

config = Config(app=app, ws=ws_protocol_cls, http=http_protocol_cls, lifespan="off")
async with run_server(config):
Expand All @@ -577,6 +577,8 @@ async def websocket_session(uri):
task.cancel()
send_accept_task.set()

assert disconnect_message == {"type": "websocket.disconnect", "code": 1006}


@pytest.mark.anyio
@pytest.mark.parametrize("ws_protocol_cls", WS_PROTOCOLS)
Expand Down Expand Up @@ -729,6 +731,7 @@ async def test_server_can_read_messages_in_buffer_after_close(
ws_protocol_cls, http_protocol_cls
):
frames = []
disconnect_message = {}

class App(WebSocketResponse):
async def websocket_connect(self, message):
Expand All @@ -738,6 +741,10 @@ async def websocket_connect(self, message):
# read these frames
await asyncio.sleep(0.2)

async def websocket_disconnect(self, message):
nonlocal disconnect_message
disconnect_message = message

async def websocket_receive(self, message):
frames.append(message.get("bytes"))

Expand All @@ -752,6 +759,7 @@ async def send_text(url):
await send_text("ws://127.0.0.1:8000")

assert frames == [b"abc", b"abc", b"abc"]
assert disconnect_message == {"type": "websocket.disconnect", "code": 1000}


@pytest.mark.anyio
Expand Down
12 changes: 9 additions & 3 deletions uvicorn/protocols/websockets/websockets_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ def __init__(
self.closed_event = asyncio.Event()
self.initial_response: Optional[HTTPResponse] = None
self.connect_sent = False
self.lost_connection_before_handshake = False
self.accepted_subprotocol: Optional[Subprotocol] = None
self.transfer_data_task: asyncio.Task = None # type: ignore[assignment]

Expand Down Expand Up @@ -134,6 +135,9 @@ def connection_lost(self, exc: Optional[Exception]) -> None:
prefix = "%s:%d - " % self.client if self.client else ""
self.logger.log(TRACE_LOG_LEVEL, "%sWebSocket connection lost", prefix)

self.lost_connection_before_handshake = (
not self.handshake_completed_event.is_set()
)
self.handshake_completed_event.set()
super().connection_lost(exc)
if exc is None:
Expand Down Expand Up @@ -335,11 +339,13 @@ async def asgi_receive(

await self.handshake_completed_event.wait()

if self.closed_event.is_set():
# If client disconnected, use WebSocketServerProtocol.close_code property.
if self.lost_connection_before_handshake:
# If the handshake failed or the app closed before handshake completion,
# use 1006 Abnormal Closure.
return {"type": "websocket.disconnect", "code": self.close_code or 1006}
return {"type": "websocket.disconnect", "code": 1006}

if self.closed_event.is_set():
return {"type": "websocket.disconnect", "code": 1005}

try:
data = await self.recv()
Expand Down
8 changes: 5 additions & 3 deletions uvicorn/protocols/websockets/wsproto_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,13 +70,15 @@ def connection_made(self, transport):
self.logger.log(TRACE_LOG_LEVEL, "%sWebSocket connection made", prefix)

def connection_lost(self, exc):
self.queue.put_nowait({"type": "websocket.disconnect"})
code = 1005 if self.handshake_complete else 1006
self.queue.put_nowait({"type": "websocket.disconnect", "code": code})
self.connections.remove(self)

if self.logger.level <= TRACE_LOG_LEVEL:
prefix = "%s:%d - " % tuple(self.client) if self.client else ""
self.logger.log(TRACE_LOG_LEVEL, "%sWebSocket connection lost", prefix)

self.handshake_complete = True
if exc is None:
self.transport.close()

Expand Down Expand Up @@ -232,13 +234,13 @@ async def send(self, message):
self.scope["client"],
get_path_with_query_string(self.scope),
)
self.handshake_complete = True
subprotocol = message.get("subprotocol")
extra_headers = self.default_headers + list(message.get("headers", []))
extensions = []
if self.config.ws_per_message_deflate:
extensions.append(PerMessageDeflate())
if not self.transport.is_closing():
self.handshake_complete = True
output = self.conn.send(
wsproto.events.AcceptConnection(
subprotocol=subprotocol,
Expand All @@ -249,7 +251,7 @@ async def send(self, message):
self.transport.write(output)

elif message_type == "websocket.close":
self.queue.put_nowait({"type": "websocket.disconnect", "code": None})
self.queue.put_nowait({"type": "websocket.disconnect", "code": 1006})
self.logger.info(
'%s - "WebSocket %s" 403',
self.scope["client"],
Expand Down

0 comments on commit 41156aa

Please sign in to comment.