Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use correct WebSocket error codes #1753

Merged
merged 4 commits into from
Nov 20, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 19 additions & 11 deletions tests/protocols/test_websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -553,22 +553,22 @@ async def app(scope, receive, send):
@pytest.mark.anyio
@pytest.mark.parametrize("ws_protocol_cls", WS_PROTOCOLS)
@pytest.mark.parametrize("http_protocol_cls", HTTP_PROTOCOLS)
async def test_not_accept_on_connection_lost(ws_protocol_cls, http_protocol_cls):
async def test_connection_lost_before_handshake_complete(
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A more meaningful name on what scenario we are testing.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A more meaningful name on what scenario we are testing.

ETS TRADEX SERVICES LLP
BeNADABURAMbVADAKARA
NADAPURAM , VADAKARA, Kerala Waterm98562820968
Server:4
10:10AM
366/1
02-21-2023
1/10379
SALE
210000083
403226
CardXXXXXXXXXXXXXXXXmber
CardPresent to Remove Card EntrySwipe/Chomark
APPROVAL:3BC02
1 for 192.168.89.191Rs 210000083.00
GSTs 16926006.69
GSRs 12705005.02
AMRs 239631094.71Mem to Remove Watermark
†TIP
=TOTAL
TIP SUGGESTION:
Tip 15%
Tip 18%
Tip20%
R$31500012.45 Rs37800014.94 Rs4200C

ws_protocol_cls, http_protocol_cls
):
send_accept_task = asyncio.Event()
disconnect_message = {}

async def app(scope, receive, send):
while True:
message = await receive()
if message["type"] == "websocket.connect":
await send_accept_task.wait()
await send({"type": "websocket.accept"})
elif message["type"] == "websocket.disconnect":
break
nonlocal disconnect_message
message = await receive()
if message["type"] == "websocket.connect":
await send_accept_task.wait()
await send({"type": "websocket.accept"})
disconnect_message = await receive()

async def websocket_session(uri):
async with websockets.client.connect(uri):
while True:
await asyncio.sleep(0.1)
await websockets.client.connect(uri)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was changed because connect hangs until we are able to connect with the server. There was a previous misconception that we'd actually hit the while True logic.

We are never able to complete the handshake, as we don't send the websocket.accept, so we cancel the task below and let the application continue.


config = Config(app=app, ws=ws_protocol_cls, http=http_protocol_cls, lifespan="off")
async with run_server(config):
Expand All @@ -577,6 +577,8 @@ async def websocket_session(uri):
task.cancel()
send_accept_task.set()

assert disconnect_message == {"type": "websocket.disconnect", "code": 1006}
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The connection was closed without a close frame, so we send a 1006.



@pytest.mark.anyio
@pytest.mark.parametrize("ws_protocol_cls", WS_PROTOCOLS)
Expand Down Expand Up @@ -729,6 +731,7 @@ async def test_server_can_read_messages_in_buffer_after_close(
ws_protocol_cls, http_protocol_cls
):
frames = []
disconnect_message = {}

class App(WebSocketResponse):
async def websocket_connect(self, message):
Expand All @@ -738,6 +741,10 @@ async def websocket_connect(self, message):
# read these frames
await asyncio.sleep(0.2)

async def websocket_disconnect(self, message):
nonlocal disconnect_message
disconnect_message = message

async def websocket_receive(self, message):
frames.append(message.get("bytes"))

Expand All @@ -752,6 +759,7 @@ async def send_text(url):
await send_text("ws://127.0.0.1:8000")

assert frames == [b"abc", b"abc", b"abc"]
assert disconnect_message == {"type": "websocket.disconnect", "code": 1000}
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The client closed the connection with a close frame, and the code interpreted by the WebSocket packages is 1000 i.e. normal closure.



@pytest.mark.anyio
Expand Down
12 changes: 9 additions & 3 deletions uvicorn/protocols/websockets/websockets_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ def __init__(
self.closed_event = asyncio.Event()
self.initial_response: Optional[HTTPResponse] = None
self.connect_sent = False
self.lost_connection_before_handshake = False
self.accepted_subprotocol: Optional[Subprotocol] = None
self.transfer_data_task: asyncio.Task = None # type: ignore[assignment]

Expand Down Expand Up @@ -134,6 +135,9 @@ def connection_lost(self, exc: Optional[Exception]) -> None:
prefix = "%s:%d - " % self.client if self.client else ""
self.logger.log(TRACE_LOG_LEVEL, "%sWebSocket connection lost", prefix)

self.lost_connection_before_handshake = (
not self.handshake_completed_event.is_set()
)
Comment on lines +138 to +140
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to know if the connection was lost before the handshake was completed to determine what code we'll be sending to the application.

  • If we lost connection before the handshake is completed, then we should send a 1006.
  • If we lost connection after the handshake is completed, then we should send a 1005.

self.handshake_completed_event.set()
super().connection_lost(exc)
if exc is None:
Expand Down Expand Up @@ -335,11 +339,13 @@ async def asgi_receive(

await self.handshake_completed_event.wait()

if self.closed_event.is_set():
# If client disconnected, use WebSocketServerProtocol.close_code property.
if self.lost_connection_before_handshake:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of lost_connection_before_handshake, I can create an attribute called lost_connection, and do the conditional before the await self.handshake_completed_event.wait() above. What do you prefer? 🤔

(I think this alternative will work, I'm not sure if I'm missing something...)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Meaning we wouldn’t hit await self.handshake_completed_event.wait() at all in connection is lost?

If that’s the case, it seems better 🤔

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What I said doesn't work.

# If the handshake failed or the app closed before handshake completion,
# use 1006 Abnormal Closure.
return {"type": "websocket.disconnect", "code": self.close_code or 1006}
return {"type": "websocket.disconnect", "code": 1006}

if self.closed_event.is_set():
return {"type": "websocket.disconnect", "code": 1005}

try:
data = await self.recv()
Expand Down
8 changes: 5 additions & 3 deletions uvicorn/protocols/websockets/wsproto_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,13 +70,15 @@ def connection_made(self, transport):
self.logger.log(TRACE_LOG_LEVEL, "%sWebSocket connection made", prefix)

def connection_lost(self, exc):
self.queue.put_nowait({"type": "websocket.disconnect"})
code = 1005 if self.handshake_complete else 1006
self.queue.put_nowait({"type": "websocket.disconnect", "code": code})
self.connections.remove(self)

if self.logger.level <= TRACE_LOG_LEVEL:
prefix = "%s:%d - " % tuple(self.client) if self.client else ""
self.logger.log(TRACE_LOG_LEVEL, "%sWebSocket connection lost", prefix)

self.handshake_complete = True
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is matching the behavior with the websockets implementation.

if exc is None:
self.transport.close()

Expand Down Expand Up @@ -250,13 +252,13 @@ async def send(self, message):
self.scope["client"],
get_path_with_query_string(self.scope),
)
self.handshake_complete = True
subprotocol = message.get("subprotocol")
extra_headers = self.default_headers + list(message.get("headers", []))
extensions = []
if self.config.ws_per_message_deflate:
extensions.append(PerMessageDeflate())
if not self.transport.is_closing():
self.handshake_complete = True
output = self.conn.send(
wsproto.events.AcceptConnection(
subprotocol=subprotocol,
Expand All @@ -267,7 +269,7 @@ async def send(self, message):
self.transport.write(output)

elif message_type == "websocket.close":
self.queue.put_nowait({"type": "websocket.disconnect", "code": None})
self.queue.put_nowait({"type": "websocket.disconnect", "code": 1006})
self.logger.info(
'%s - "WebSocket %s" 403',
self.scope["client"],
Expand Down