Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix bug in WebSocketProtocol.asgi_receive #1619

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 50 additions & 1 deletion tests/protocols/test_websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -631,7 +631,7 @@ async def websocket_session(url):
@pytest.mark.anyio
@pytest.mark.parametrize("ws_protocol_cls", WS_PROTOCOLS)
@pytest.mark.parametrize("http_protocol_cls", HTTP_PROTOCOLS)
async def test_server_can_read_messages_in_buffer_after_close(
async def test_server_can_read_messages_in_buffer_after_client_close(
ws_protocol_cls, http_protocol_cls
):
frames = []
Expand All @@ -658,3 +658,52 @@ async def send_text(url):
await send_text("ws://127.0.0.1:8000")

assert frames == [b"abc", b"abc", b"abc"]


@pytest.mark.anyio
@pytest.mark.parametrize("ws_protocol_cls", [WebSocketProtocol])
@pytest.mark.parametrize("http_protocol_cls", HTTP_PROTOCOLS)
async def test_server_can_read_messages_in_buffer_after_server_close(
ws_protocol_cls, http_protocol_cls
):
"""
Note: this doesn't work with WSProtocol.
"""

frames = []

class App(WebSocketResponse):
async def websocket_connect(self, message):
await self.send({"type": "websocket.accept"})
# Ensure server doesn't start reading frames from read buffer until
# after server has sent close frame, but server is still able to
# read these frames
await asyncio.sleep(0.2)
await self.send({"type": "websocket.close"})

async def websocket_receive(self, message):
frames.append(message.get("bytes"))

async def send_text(url):
async with websockets.connect(url) as websocket:
await websocket.send(b"abc")
await websocket.send(b"abc")
await websocket.send(b"abc")

# Wait until after server has sent close frame
await asyncio.sleep(0.3)

try:
# Client will fail to send this frame
await websocket.send(b"abc")
except websockets.exceptions.ConnectionClosed:
pass
else:
raise AssertionError("connection was not closed")

config = Config(app=App, ws=ws_protocol_cls, http=http_protocol_cls, lifespan="off")
async with run_server(config):
await send_text("ws://127.0.0.1:8000")

# Client tried to send 4 frames, could only send 3
assert frames == [b"abc", b"abc", b"abc"]
7 changes: 6 additions & 1 deletion uvicorn/protocols/websockets/websockets_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ def __init__(
self.closed_event = asyncio.Event()
self.initial_response: Optional[HTTPResponse] = None
self.connect_sent = False
self.close_frame_sent = False
self.accepted_subprotocol: Optional[Subprotocol] = None
self.transfer_data_task: asyncio.Task = None # type: ignore[assignment]

Expand Down Expand Up @@ -303,6 +304,7 @@ async def asgi_send(self, message: "ASGISendEvent") -> None:
reason = message.get("reason", "") or ""
await self.close(code, reason)
self.closed_event.set()
self.close_frame_sent = True

else:
msg = (
Expand All @@ -326,10 +328,13 @@ async def asgi_receive(

await self.handshake_completed_event.wait()

if self.closed_event.is_set():
if self.closed_event.is_set() and not self.close_frame_sent:
# If client disconnected, use WebSocketServerProtocol.close_code property.
# If the handshake failed or the app closed before handshake completion,
# use 1006 Abnormal Closure.
#
# If server disconnected, make sure server can still read remaining
# messages from websockets read queue.
return {"type": "websocket.disconnect", "code": self.close_code or 1006}

try:
Expand Down