Skip to content

Commit

Permalink
Support recv() after the connection is closed.
Browse files Browse the repository at this point in the history
Fix #1538.
  • Loading branch information
aaugustin committed Nov 11, 2024
1 parent bdfc8cf commit 3034834
Show file tree
Hide file tree
Showing 7 changed files with 142 additions and 43 deletions.
7 changes: 7 additions & 0 deletions docs/project/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,13 @@ notice.

.. _14.0:

Bug fixes
.........

* Once the connection is closed, messages previously received and buffered can
be read in the :mod:`asyncio` and :mod:`threading` implementations, just like
in the legacy implementation.

14.0
----

Expand Down
20 changes: 7 additions & 13 deletions src/websockets/asyncio/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,11 @@ def put(self, item: T) -> None:
if self.get_waiter is not None and not self.get_waiter.done():
self.get_waiter.set_result(None)

async def get(self) -> T:
async def get(self, block: bool = True) -> T:
"""Remove and return an item from the queue, waiting if necessary."""
if not self.queue:
if not block:
raise EOFError("stream of frames ended")
assert self.get_waiter is None, "cannot call get() concurrently"
self.get_waiter = self.loop.create_future()
try:
Expand Down Expand Up @@ -133,20 +135,16 @@ async def get(self, decode: bool | None = None) -> Data:
:meth:`get_iter` concurrently.
"""
if self.closed:
raise EOFError("stream of frames ended")

if self.get_in_progress:
raise ConcurrencyError("get() or get_iter() is already running")

self.get_in_progress = True

# Locking with get_in_progress prevents concurrent execution
# until get() fetches a complete message or is cancelled.

try:
# First frame
frame = await self.frames.get()
frame = await self.frames.get(not self.closed)
self.maybe_resume()
assert frame.opcode is OP_TEXT or frame.opcode is OP_BINARY
if decode is None:
Expand All @@ -156,7 +154,7 @@ async def get(self, decode: bool | None = None) -> Data:
# Following frames, for fragmented messages
while not frame.fin:
try:
frame = await self.frames.get()
frame = await self.frames.get(not self.closed)
except asyncio.CancelledError:
# Put frames already received back into the queue
# so that future calls to get() can return them.
Expand Down Expand Up @@ -200,12 +198,8 @@ async def get_iter(self, decode: bool | None = None) -> AsyncIterator[Data]:
:meth:`get_iter` concurrently.
"""
if self.closed:
raise EOFError("stream of frames ended")

if self.get_in_progress:
raise ConcurrencyError("get() or get_iter() is already running")

self.get_in_progress = True

# Locking with get_in_progress prevents concurrent execution
Expand All @@ -216,7 +210,7 @@ async def get_iter(self, decode: bool | None = None) -> AsyncIterator[Data]:

# First frame
try:
frame = await self.frames.get()
frame = await self.frames.get(not self.closed)
except asyncio.CancelledError:
self.get_in_progress = False
raise
Expand All @@ -236,7 +230,7 @@ async def get_iter(self, decode: bool | None = None) -> AsyncIterator[Data]:
# previous fragments — we're streaming them. Canceling get_iter()
# here will leave the assembler in a stuck state. Future calls to
# get() or get_iter() will raise ConcurrencyError.
frame = await self.frames.get()
frame = await self.frames.get(not self.closed)
self.maybe_resume()
assert frame.opcode is OP_CONT
if decode:
Expand Down
27 changes: 14 additions & 13 deletions src/websockets/sync/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,10 +69,16 @@ def __init__(
def get_next_frame(self, timeout: float | None = None) -> Frame:
# Helper to factor out the logic for getting the next frame from the
# queue, while handling timeouts and reaching the end of the stream.
try:
frame = self.frames.get(timeout=timeout)
except queue.Empty:
raise TimeoutError(f"timed out in {timeout:.1f}s") from None
if self.closed:
try:
frame = self.frames.get(block=False)
except queue.Empty:
raise EOFError("stream of frames ended") from None
else:
try:
frame = self.frames.get(block=True, timeout=timeout)
except queue.Empty:
raise TimeoutError(f"timed out in {timeout:.1f}s") from None
if frame is None:
raise EOFError("stream of frames ended")
return frame
Expand All @@ -87,7 +93,7 @@ def reset_queue(self, frames: Iterable[Frame]) -> None:
queued = []
try:
while True:
queued.append(self.frames.get_nowait())
queued.append(self.frames.get(block=False))
except queue.Empty:
pass
for frame in frames:
Expand Down Expand Up @@ -123,9 +129,6 @@ def get(self, timeout: float | None = None, decode: bool | None = None) -> Data:
"""
with self.mutex:
if self.closed:
raise EOFError("stream of frames ended")

if self.get_in_progress:
raise ConcurrencyError("get() or get_iter() is already running")
self.get_in_progress = True
Expand Down Expand Up @@ -194,9 +197,6 @@ def get_iter(self, decode: bool | None = None) -> Iterator[Data]:
"""
with self.mutex:
if self.closed:
raise EOFError("stream of frames ended")

if self.get_in_progress:
raise ConcurrencyError("get() or get_iter() is already running")
self.get_in_progress = True
Expand Down Expand Up @@ -288,5 +288,6 @@ def close(self) -> None:

self.closed = True

# Unblock get() or get_iter().
self.frames.put(None)
if self.get_in_progress:
# Unblock get() or get_iter().
self.frames.put(None)
8 changes: 3 additions & 5 deletions tests/asyncio/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -793,14 +793,12 @@ async def test_close_timeout_waiting_for_connection_closed(self):
# Remove socket.timeout when dropping Python < 3.10.
self.assertIsInstance(exc.__cause__, (socket.timeout, TimeoutError))

async def test_close_does_not_wait_for_recv(self):
# Closing the connection discards messages buffered in the assembler.
# This is allowed by the RFC:
# > However, there is no guarantee that the endpoint that has already
# > sent a Close frame will continue to process data.
async def test_close_preserves_queued_messages(self):
"""close preserves messages buffered in the assembler."""
await self.remote_connection.send("😀")
await self.connection.close()

self.assertEqual(await self.connection.recv(), "😀")
with self.assertRaises(ConnectionClosedOK) as raised:
await self.connection.recv()

Expand Down
52 changes: 52 additions & 0 deletions tests/asyncio/test_messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,58 @@ async def test_get_iter_fails_after_close(self):
async for _ in self.assembler.get_iter():
self.fail("no fragment expected")

async def test_get_queued_message_after_close(self):
"""get returns a message after close is called."""
self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9"))
self.assembler.close()
message = await self.assembler.get()
self.assertEqual(message, "café")

async def test_get_iter_queued_message_after_close(self):
"""get_iter yields a message after close is called."""
self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9"))
self.assembler.close()
fragments = await alist(self.assembler.get_iter())
self.assertEqual(fragments, ["café"])

async def test_get_queued_fragmented_message_after_close(self):
"""get reassembles a fragmented message after close is called."""
self.assembler.put(Frame(OP_BINARY, b"t", fin=False))
self.assembler.put(Frame(OP_CONT, b"e", fin=False))
self.assembler.put(Frame(OP_CONT, b"a"))
self.assembler.close()
self.assembler.close()
message = await self.assembler.get()
self.assertEqual(message, b"tea")

async def test_get_iter_queued_fragmented_message_after_close(self):
"""get_iter yields a fragmented message after close is called."""
self.assembler.put(Frame(OP_BINARY, b"t", fin=False))
self.assembler.put(Frame(OP_CONT, b"e", fin=False))
self.assembler.put(Frame(OP_CONT, b"a"))
self.assembler.close()
fragments = await alist(self.assembler.get_iter())
self.assertEqual(fragments, [b"t", b"e", b"a"])

async def test_get_partially_queued_fragmented_message_after_close(self):
"""get raises EOF on a partial fragmented message after close is called."""
self.assembler.put(Frame(OP_BINARY, b"t", fin=False))
self.assembler.put(Frame(OP_CONT, b"e", fin=False))
self.assembler.close()
with self.assertRaises(EOFError):
await self.assembler.get()

async def test_get_iter_partially_queued_fragmented_message_after_close(self):
"""get_iter yields a partial fragmented message after close is called."""
self.assembler.put(Frame(OP_BINARY, b"t", fin=False))
self.assembler.put(Frame(OP_CONT, b"e", fin=False))
self.assembler.close()
fragments = []
with self.assertRaises(EOFError):
async for fragment in self.assembler.get_iter():
fragments.append(fragment)
self.assertEqual(fragments, [b"t", b"e"])

async def test_put_fails_after_close(self):
"""put raises EOFError after close is called."""
self.assembler.close()
Expand Down
19 changes: 7 additions & 12 deletions tests/sync/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -543,17 +543,12 @@ def test_close_timeout_waiting_for_connection_closed(self):
# Remove socket.timeout when dropping Python < 3.10.
self.assertIsInstance(exc.__cause__, (socket.timeout, TimeoutError))

def test_close_does_not_wait_for_recv(self):
# Closing the connection discards messages buffered in the assembler.
# This is allowed by the RFC:
# > However, there is no guarantee that the endpoint that has already
# > sent a Close frame will continue to process data.
def test_close_preserves_queued_messages(self):
"""close preserves messages buffered in the assembler."""
self.remote_connection.send("😀")
self.connection.close()

close_thread = threading.Thread(target=self.connection.close)
close_thread.start()

self.assertEqual(self.connection.recv(), "😀")
with self.assertRaises(ConnectionClosedOK) as raised:
self.connection.recv()

Expand All @@ -576,10 +571,10 @@ def test_close_idempotency(self):
def test_close_idempotency_race_condition(self):
"""close waits if the connection is already closing."""

self.connection.close_timeout = 5 * MS
self.connection.close_timeout = 6 * MS

def closer():
with self.delay_frames_rcvd(3 * MS):
with self.delay_frames_rcvd(4 * MS):
self.connection.close()

close_thread = threading.Thread(target=closer)
Expand All @@ -591,14 +586,14 @@ def closer():

# Connection isn't closed yet.
with self.assertRaises(TimeoutError):
self.connection.recv(timeout=0)
self.connection.recv(timeout=MS)

self.connection.close()
self.assertNoFrameSent()

# Connection is closed now.
with self.assertRaises(ConnectionClosedOK):
self.connection.recv(timeout=0)
self.connection.recv(timeout=MS)

close_thread.join()

Expand Down
52 changes: 52 additions & 0 deletions tests/sync/test_messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,58 @@ def test_get_iter_fails_after_close(self):
for _ in self.assembler.get_iter():
self.fail("no fragment expected")

def test_get_queued_message_after_close(self):
"""get returns a message after close is called."""
self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9"))
self.assembler.close()
message = self.assembler.get()
self.assertEqual(message, "café")

def test_get_iter_queued_message_after_close(self):
"""get_iter yields a message after close is called."""
self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9"))
self.assembler.close()
fragments = list(self.assembler.get_iter())
self.assertEqual(fragments, ["café"])

def test_get_queued_fragmented_message_after_close(self):
"""get reassembles a fragmented message after close is called."""
self.assembler.put(Frame(OP_BINARY, b"t", fin=False))
self.assembler.put(Frame(OP_CONT, b"e", fin=False))
self.assembler.put(Frame(OP_CONT, b"a"))
self.assembler.close()
self.assembler.close()
message = self.assembler.get()
self.assertEqual(message, b"tea")

def test_get_iter_queued_fragmented_message_after_close(self):
"""get_iter yields a fragmented message after close is called."""
self.assembler.put(Frame(OP_BINARY, b"t", fin=False))
self.assembler.put(Frame(OP_CONT, b"e", fin=False))
self.assembler.put(Frame(OP_CONT, b"a"))
self.assembler.close()
fragments = list(self.assembler.get_iter())
self.assertEqual(fragments, [b"t", b"e", b"a"])

def test_get_partially_queued_fragmented_message_after_close(self):
"""get raises EOF on a partial fragmented message after close is called."""
self.assembler.put(Frame(OP_BINARY, b"t", fin=False))
self.assembler.put(Frame(OP_CONT, b"e", fin=False))
self.assembler.close()
with self.assertRaises(EOFError):
self.assembler.get()

def test_get_iter_partially_queued_fragmented_message_after_close(self):
"""get_iter yields a partial fragmented message after close is called."""
self.assembler.put(Frame(OP_BINARY, b"t", fin=False))
self.assembler.put(Frame(OP_CONT, b"e", fin=False))
self.assembler.close()
fragments = []
with self.assertRaises(EOFError):
for fragment in self.assembler.get_iter():
fragments.append(fragment)
self.assertEqual(fragments, [b"t", b"e"])

def test_put_fails_after_close(self):
"""put raises EOFError after close is called."""
self.assembler.close()
Expand Down

0 comments on commit 3034834

Please sign in to comment.