Skip to content

Commit

Permalink
Check if handshake is completed before sending frame on wsproto shutd…
Browse files Browse the repository at this point in the history
…own (#1737)

* Check if handshake is completed before sending frame on wsproto shutdown

* Add test for connection lost before handshake is completed

* Add test for close on shutdown

* Increase fail-under to 97.87

* Increase coverage

* Apply suggestions from code review
  • Loading branch information
Kludex authored Oct 31, 2022
1 parent 697588d commit ec3aac3
Show file tree
Hide file tree
Showing 4 changed files with 79 additions and 15 deletions.
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ plugins =

[coverage:report]
precision = 2
fail_under = 97.82
fail_under = 97.92
show_missing = true
skip_covered = true
exclude_lines =
Expand Down
61 changes: 60 additions & 1 deletion tests/protocols/test_websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -528,7 +528,6 @@ async def app(scope, receive, send):
while True:
message = await receive()
if message["type"] == "websocket.connect":
print("accepted")
await send({"type": "websocket.accept"})
elif message["type"] == "websocket.disconnect":
break
Expand All @@ -551,6 +550,66 @@ async def app(scope, receive, send):
assert got_disconnect_event_before_shutdown is True


@pytest.mark.anyio
@pytest.mark.parametrize("ws_protocol_cls", WS_PROTOCOLS)
@pytest.mark.parametrize("http_protocol_cls", HTTP_PROTOCOLS)
async def test_not_accept_on_connection_lost(ws_protocol_cls, http_protocol_cls):
send_accept_task = asyncio.Event()

async def app(scope, receive, send):
while True:
message = await receive()
if message["type"] == "websocket.connect":
await send_accept_task.wait()
await send({"type": "websocket.accept"})
elif message["type"] == "websocket.disconnect":
break

async def websocket_session(uri):
async with websockets.client.connect(uri):
while True:
await asyncio.sleep(0.1)

config = Config(app=app, ws=ws_protocol_cls, http=http_protocol_cls, lifespan="off")
async with run_server(config):
task = asyncio.create_task(websocket_session("ws://127.0.0.1:8000"))
await asyncio.sleep(0.1)
task.cancel()
send_accept_task.set()


@pytest.mark.anyio
@pytest.mark.parametrize("ws_protocol_cls", WS_PROTOCOLS)
@pytest.mark.parametrize("http_protocol_cls", HTTP_PROTOCOLS)
async def test_send_close_on_server_shutdown(ws_protocol_cls, http_protocol_cls):
disconnect_message = {}

async def app(scope, receive, send):
nonlocal disconnect_message
while True:
message = await receive()
if message["type"] == "websocket.connect":
await send({"type": "websocket.accept"})
elif message["type"] == "websocket.disconnect":
disconnect_message = message
break

async def websocket_session(uri):
async with websockets.client.connect(uri):
while True:
await asyncio.sleep(0.1)

config = Config(app=app, ws=ws_protocol_cls, http=http_protocol_cls, lifespan="off")
async with run_server(config):
task = asyncio.create_task(websocket_session("ws://127.0.0.1:8000"))
await asyncio.sleep(0.1)
disconnect_message_before_shutdown = disconnect_message

assert disconnect_message_before_shutdown == {}
assert disconnect_message == {"type": "websocket.disconnect", "code": 1012}
task.cancel()


@pytest.mark.anyio
@pytest.mark.parametrize("ws_protocol_cls", WS_PROTOCOLS)
@pytest.mark.parametrize("http_protocol_cls", HTTP_PROTOCOLS)
Expand Down
2 changes: 2 additions & 0 deletions uvicorn/protocols/websockets/websockets_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,8 @@ async def asgi_receive(
data = await self.recv()
except ConnectionClosed as exc:
self.closed_event.set()
if self.ws_server.closing:
return {"type": "websocket.disconnect", "code": 1012}
return {"type": "websocket.disconnect", "code": exc.code}

msg: WebSocketReceiveEvent = { # type: ignore[typeddict-item]
Expand Down
29 changes: 16 additions & 13 deletions uvicorn/protocols/websockets/wsproto_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,9 +122,12 @@ def resume_writing(self):
self.writable.set()

def shutdown(self):
self.queue.put_nowait({"type": "websocket.disconnect", "code": 1012})
output = self.conn.send(wsproto.events.CloseConnection(code=1012))
self.transport.write(output)
if self.handshake_complete:
self.queue.put_nowait({"type": "websocket.disconnect", "code": 1012})
output = self.conn.send(wsproto.events.CloseConnection(code=1012))
self.transport.write(output)
else:
self.send_500_response()
self.transport.close()

def on_task_complete(self, task):
Expand Down Expand Up @@ -219,9 +222,8 @@ def send_500_response(self):
async def run_asgi(self):
try:
result = await self.app(self.scope, self.receive, self.send)
except BaseException as exc:
msg = "Exception in ASGI application\n"
self.logger.error(msg, exc_info=exc)
except BaseException:
self.logger.exception("Exception in ASGI application\n")
if not self.handshake_complete:
self.send_500_response()
self.transport.close()
Expand Down Expand Up @@ -254,14 +256,15 @@ async def send(self, message):
extensions = []
if self.config.ws_per_message_deflate:
extensions.append(PerMessageDeflate())
output = self.conn.send(
wsproto.events.AcceptConnection(
subprotocol=subprotocol,
extensions=extensions,
extra_headers=extra_headers,
if not self.transport.is_closing():
output = self.conn.send(
wsproto.events.AcceptConnection(
subprotocol=subprotocol,
extensions=extensions,
extra_headers=extra_headers,
)
)
)
self.transport.write(output)
self.transport.write(output)

elif message_type == "websocket.close":
self.queue.put_nowait({"type": "websocket.disconnect", "code": None})
Expand Down

0 comments on commit ec3aac3

Please sign in to comment.