-
-
Notifications
You must be signed in to change notification settings - Fork 754
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Check if handshake is completed before sending frame on wsproto shutdown #1737
Changes from 5 commits
0cc0e01
608e38d
c4bfc3d
e4aae0b
61f4483
6f3d8dc
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -528,7 +528,6 @@ async def app(scope, receive, send): | |
while True: | ||
message = await receive() | ||
if message["type"] == "websocket.connect": | ||
print("accepted") | ||
await send({"type": "websocket.accept"}) | ||
elif message["type"] == "websocket.disconnect": | ||
break | ||
|
@@ -551,6 +550,66 @@ async def app(scope, receive, send): | |
assert got_disconnect_event_before_shutdown is True | ||
|
||
|
||
@pytest.mark.anyio | ||
@pytest.mark.parametrize("ws_protocol_cls", WS_PROTOCOLS) | ||
@pytest.mark.parametrize("http_protocol_cls", HTTP_PROTOCOLS) | ||
async def test_not_accept_on_connection_lost(ws_protocol_cls, http_protocol_cls): | ||
send_accept_task = asyncio.Event() | ||
|
||
async def app(scope, receive, send): | ||
while True: | ||
message = await receive() | ||
if message["type"] == "websocket.connect": | ||
await send_accept_task.wait() | ||
await send({"type": "websocket.accept"}) | ||
elif message["type"] == "websocket.disconnect": | ||
break | ||
|
||
async def websocket_session(uri): | ||
async with websockets.client.connect(uri): | ||
while True: | ||
await asyncio.sleep(0.1) | ||
|
||
config = Config(app=app, ws=ws_protocol_cls, http=http_protocol_cls, lifespan="off") | ||
async with run_server(config): | ||
task = asyncio.create_task(websocket_session("ws://127.0.0.1:8000")) | ||
await asyncio.sleep(0.1) | ||
task.cancel() | ||
send_accept_task.set() | ||
Comment on lines
+553
to
+578
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This test actually fails on the This fails before and after my changes, so it is actually not this PR who broke it. There's a single small change on the |
||
|
||
|
||
@pytest.mark.anyio | ||
@pytest.mark.parametrize("ws_protocol_cls", WS_PROTOCOLS) | ||
@pytest.mark.parametrize("http_protocol_cls", HTTP_PROTOCOLS) | ||
async def test_send_close_on_server_shutdown(ws_protocol_cls, http_protocol_cls): | ||
disconnect_message = {} | ||
|
||
async def app(scope, receive, send): | ||
nonlocal disconnect_message | ||
while True: | ||
message = await receive() | ||
if message["type"] == "websocket.connect": | ||
await send({"type": "websocket.accept"}) | ||
elif message["type"] == "websocket.disconnect": | ||
disconnect_message = message | ||
break | ||
|
||
async def websocket_session(uri): | ||
async with websockets.client.connect(uri): | ||
while True: | ||
await asyncio.sleep(0.1) | ||
|
||
config = Config(app=app, ws=ws_protocol_cls, http=http_protocol_cls, lifespan="off") | ||
async with run_server(config): | ||
task = asyncio.create_task(websocket_session("ws://127.0.0.1:8000")) | ||
await asyncio.sleep(0.1) | ||
disconnect_message_before_shutdown = disconnect_message | ||
|
||
assert disconnect_message_before_shutdown == {} | ||
assert disconnect_message == {"type": "websocket.disconnect", "code": 1012} | ||
task.cancel() | ||
|
||
|
||
@pytest.mark.anyio | ||
@pytest.mark.parametrize("ws_protocol_cls", WS_PROTOCOLS) | ||
@pytest.mark.parametrize("http_protocol_cls", HTTP_PROTOCOLS) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -345,6 +345,8 @@ async def asgi_receive( | |
data = await self.recv() | ||
except ConnectionClosed as exc: | ||
self.closed_event.set() | ||
if self.ws_server.closing: | ||
return {"type": "websocket.disconnect", "code": 1012} | ||
Comment on lines
+348
to
+349
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is to match the behavior on both implementations. If we close the connection on |
||
return {"type": "websocket.disconnect", "code": exc.code} | ||
|
||
msg: WebSocketReceiveEvent = { # type: ignore[typeddict-item] | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -125,9 +125,12 @@ def resume_writing(self): | |
self.writable.set() | ||
|
||
def shutdown(self): | ||
self.queue.put_nowait({"type": "websocket.disconnect", "code": 1012}) | ||
output = self.conn.send(wsproto.events.CloseConnection(code=1012)) | ||
self.transport.write(output) | ||
if self.handshake_complete: | ||
self.queue.put_nowait({"type": "websocket.disconnect", "code": 1012}) | ||
output = self.conn.send(wsproto.events.CloseConnection(code=1012)) | ||
self.transport.write(output) | ||
Comment on lines
+128
to
+131
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is the core here. We want to check if the handshake was completed, because This is the error these lines solve: wsproto.utilities.LocalProtocolError: Event CloseConnection(code=1012, reason=None) cannot be sent during the handshake More about it on #596. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Makes sense to me. 👍 |
||
else: | ||
self.send_500_response() | ||
Comment on lines
+132
to
+133
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The thing is... If the above is Without these lines, the client will have a: curl: (52) Empty reply from server With these lines:
The curl command used was: curl --include \
--no-buffer \
--header "Connection: Upgrade" \
--header "Upgrade: websocket" \
--header "Host: example.com:80" \
--header "Origin: http://example.com:80" \
--header "Sec-WebSocket-Key: SGVsbG8sIHdvcmxkIQ==" \
--header "Sec-WebSocket-Version: 13" \
http://localhost:8000/ There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Much neater. 😌 I did spend some time trying to figure out if we should we also include a textual description here, but I think it's probably okay as it currently stands. |
||
self.transport.close() | ||
|
||
def on_task_complete(self, task): | ||
|
@@ -222,9 +225,8 @@ def send_500_response(self): | |
async def run_asgi(self): | ||
try: | ||
result = await self.app(self.scope, self.receive, self.send) | ||
except BaseException as exc: | ||
msg = "Exception in ASGI application\n" | ||
self.logger.error(msg, exc_info=exc) | ||
except BaseException: | ||
self.logger.exception("Exception in ASGI application\n") | ||
Comment on lines
+228
to
+229
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These are analogous. Just making it more readable. |
||
if not self.handshake_complete: | ||
self.send_500_response() | ||
self.transport.close() | ||
|
@@ -257,14 +259,15 @@ async def send(self, message): | |
extensions = [] | ||
if self.config.ws_per_message_deflate: | ||
extensions.append(PerMessageDeflate()) | ||
output = self.conn.send( | ||
wsproto.events.AcceptConnection( | ||
subprotocol=subprotocol, | ||
extensions=extensions, | ||
extra_headers=extra_headers, | ||
if not self.transport.is_closing(): | ||
output = self.conn.send( | ||
wsproto.events.AcceptConnection( | ||
subprotocol=subprotocol, | ||
extensions=extensions, | ||
extra_headers=extra_headers, | ||
) | ||
) | ||
) | ||
self.transport.write(output) | ||
self.transport.write(output) | ||
Comment on lines
-260
to
+270
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is necessary on this scenario, because the transport was closed on the |
||
|
||
elif message_type == "websocket.close": | ||
self.queue.put_nowait({"type": "websocket.disconnect", "code": None}) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yey! 😎 👍