Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Send code 1012 on shutdown for websockets #1816

Merged
merged 5 commits into from
Jan 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ plugins =

[coverage:report]
precision = 2
fail_under = 98.50
fail_under = 98.80
show_missing = true
skip_covered = true
exclude_lines =
Expand Down
34 changes: 28 additions & 6 deletions tests/protocols/test_websocket.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import asyncio
import typing

import httpx
import pytest
Expand Down Expand Up @@ -713,11 +714,22 @@ async def app(scope, receive, send):
message = await receive()
if message["type"] == "websocket.connect":
await send_accept_task.wait()
await send({"type": "websocket.accept"})
disconnect_message = await receive()

response: typing.Optional[httpx.Response] = None

async def websocket_session(uri):
await websockets.client.connect(uri)
nonlocal response
async with httpx.AsyncClient() as client:
response = await client.get(
f"http://127.0.0.1:{unused_tcp_port}",
headers={
"upgrade": "websocket",
"connection": "upgrade",
"sec-websocket-version": "13",
"sec-websocket-key": "dGhlIHNhbXBsZSBub25jZQ==",
},
)

config = Config(
app=app,
Expand All @@ -731,9 +743,12 @@ async def websocket_session(uri):
websocket_session(f"ws://127.0.0.1:{unused_tcp_port}")
)
await asyncio.sleep(0.1)
task.cancel()
send_accept_task.set()

task.cancel()
assert response is not None
assert response.status_code == 500, response.text
assert response.text == "Internal Server Error"
Comment on lines +749 to +751
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We are now testing what the client receives, instead of only the ASGI application.

assert disconnect_message == {"type": "websocket.disconnect", "code": 1006}
Copy link
Member Author

@Kludex Kludex Dec 26, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm... On a second thought... This doesn't make sense, does it? Why are we even sending a websocket.disconnect when the handshake was not even completed? 🤔

I think this was on purpose because then the application could receive a websocket event, but thinking about it again, does it make sense?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I need to check the issues/PRs about this decision. It shouldn't be a blocker for this PR tho.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

According to django/asgiref#364 it's fine.



Expand All @@ -744,6 +759,7 @@ async def test_send_close_on_server_shutdown(
ws_protocol_cls, http_protocol_cls, unused_tcp_port: int
):
disconnect_message = {}
server_shutdown_event = asyncio.Event()

async def app(scope, receive, send):
nonlocal disconnect_message
Expand All @@ -755,10 +771,13 @@ async def app(scope, receive, send):
disconnect_message = message
break

websocket: typing.Optional[websockets.client.WebSocketClientProtocol] = None

async def websocket_session(uri):
async with websockets.client.connect(uri):
while True:
await asyncio.sleep(0.1)
nonlocal websocket
async with websockets.client.connect(uri) as ws_connection:
websocket = ws_connection
await server_shutdown_event.wait()

config = Config(
app=app,
Expand All @@ -773,7 +792,10 @@ async def websocket_session(uri):
)
await asyncio.sleep(0.1)
disconnect_message_before_shutdown = disconnect_message
server_shutdown_event.set()

assert websocket is not None
assert websocket.close_code == 1012
assert disconnect_message_before_shutdown == {}
assert disconnect_message == {"type": "websocket.disconnect", "code": 1012}
task.cancel()
Expand Down
5 changes: 4 additions & 1 deletion uvicorn/protocols/websockets/websockets_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,6 @@ def __init__(
self.connect_sent = False
self.lost_connection_before_handshake = False
self.accepted_subprotocol: Optional[Subprotocol] = None
self.transfer_data_task: asyncio.Task = None # type: ignore[assignment]
Copy link
Member Author

@Kludex Kludex Dec 26, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is necessary to call fail_connection(), because internally there's a check like if hasattr(self, "transfer_data_task"), which should succeed.


self.ws_server: Server = Server() # type: ignore[assignment]

Expand Down Expand Up @@ -145,6 +144,10 @@ def connection_lost(self, exc: Optional[Exception]) -> None:

def shutdown(self) -> None:
self.ws_server.closing = True
if self.handshake_completed_event.is_set():
self.fail_connection(1012)
else:
self.send_500_response()
Comment on lines +147 to +150
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the change that motivated this PR.

We were only sending 1006 to the client, even when the handshake was not completed.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why self.fail_connection(1012) instead of just calling self.close(1012)?

I've been look at websockets server.close() and procotol.close() to see why uvicorn isn't calling `close(), but haven't gotten terribly far yet.

Copy link
Member Author

@Kludex Kludex Dec 26, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because close() is async.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ahh, missed the lack of async here.

Looking deeper at the difference, it seems like the primary difference is fail_connection will proactively cancel the data transfer task but close will simply send the close frame and then wait.

Is there any consequence to that distinction here? My very, very basic testing seems to suggest with uvicorn tasks are given a chance to shutdown cleanly, but I'm still testing.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, this code still waits for it to finish. 👍

self.transport.close()

def on_task_complete(self, task: asyncio.Task) -> None:
Expand Down
18 changes: 7 additions & 11 deletions uvicorn/protocols/websockets/wsproto_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import typing
from urllib.parse import unquote

import h11
import wsproto
from wsproto import ConnectionType, events
from wsproto.connection import ConnectionState
Expand Down Expand Up @@ -232,17 +231,14 @@ def send_500_response(self) -> None:
(b"content-type", b"text/plain; charset=utf-8"),
(b"connection", b"close"),
]
if self.conn.connection is None:
output = self.conn.send(wsproto.events.RejectConnection(status_code=500))
else:
msg = h11.Response(
status_code=500, headers=headers, reason="Internal Server Error"
output = self.conn.send(
wsproto.events.RejectConnection(
status_code=500, headers=headers, has_body=True
)
output = self.conn.send(msg)
msg = h11.Data(data=b"Internal Server Error")
output += self.conn.send(msg)
msg = h11.EndOfMessage()
output += self.conn.send(msg)
)
output += self.conn.send(
wsproto.events.RejectData(data=b"Internal Server Error")
)
Comment on lines -235 to +241
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't be afraid here. I'll explain what's happening.

  1. We were sending RejectConnection(status_code=500) without a body, but on websockets implementation we were sending the "Internal Server Error" body on the analogous behavior. The RejectData matches the behavior.
  2. We are removing the conditional because it's never reached, and the reason for it is that we only call send_500_response is we didn't complete the handshake - which makes a lot of sense to remove it.

Hope it's clear.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Am I missing EndOfMessage? I need to check this.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, I'm not missing it. wsproto internally adds the h11.EndOfMessage. 🙏

self.transport.write(output)

async def run_asgi(self) -> None:
Expand Down