diff --git a/src/anyio/streams/memory.py b/src/anyio/streams/memory.py index 9b8e8ee2..916dfa0a 100644 --- a/src/anyio/streams/memory.py +++ b/src/anyio/streams/memory.py @@ -1,8 +1,10 @@ from collections import deque, OrderedDict from dataclasses import dataclass, field -from typing import TypeVar, Generic, List, Deque +from typing import TypeVar, Generic, List, Deque, Tuple -import anyio +from .. import get_cancelled_exc_class +from .._core._lowlevel import checkpoint +from .._core._synchronization import create_event from ..abc.synchronization import Event from ..abc.streams import ObjectSendStream, ObjectReceiveStream from ..exceptions import ClosedResourceError, BrokenResourceError, WouldBlock, EndOfStream @@ -16,8 +18,7 @@ class MemoryObjectStreamState(Generic[T_Item]): buffer: Deque[T_Item] = field(init=False, default_factory=deque) open_send_channels: int = field(init=False, default=0) open_receive_channels: int = field(init=False, default=0) - waiting_receivers: 'OrderedDict[Event, List[T_Item]]' = field(init=False, - default_factory=OrderedDict) + waiting_receivers: Deque[Tuple[Event, List[T_Item]]] = field(init=False, default_factory=deque) waiting_senders: 'OrderedDict[Event, T_Item]' = field(init=False, default_factory=OrderedDict) @@ -58,20 +59,41 @@ async def receive_nowait(self) -> T_Item: raise WouldBlock async def receive(self) -> T_Item: - # anyio.check_cancelled() + await checkpoint() try: return await self.receive_nowait() except WouldBlock: # Add ourselves in the queue - receive_event = anyio.create_event() + receive_event = create_event() container: List[T_Item] = [] - self._state.waiting_receivers[receive_event] = container + ticket = receive_event, container + self._state.waiting_receivers.append(ticket) try: await receive_event.wait() - except BaseException: - self._state.waiting_receivers.pop(receive_event, None) + except get_cancelled_exc_class(): + # If we already received an item in the container, pass it to the next receiver in + # line + index = self._state.waiting_receivers.index(ticket) + 1 + if container: + item = container[0] + while index < len(self._state.waiting_receivers): + receive_event, container = self._state.waiting_receivers[index] + if container: + item, container[0] = container[0], item + else: + # Found an untriggered receiver + container.append(item) + await receive_event.set() + break + else: + # Could not find an untriggered receiver, so in order to not lose any + # items, put it in the buffer, even if it exceeds the maximum buffer size + self._state.buffer.append(item) + raise + finally: + self._state.waiting_receivers.remove(ticket) if container: return container[0] @@ -129,22 +151,24 @@ async def send_nowait(self, item: T_Item) -> None: if not self._state.open_receive_channels: raise BrokenResourceError - if self._state.waiting_receivers: - receive_event, container = self._state.waiting_receivers.popitem(last=False) - container.append(item) - await receive_event.set() - elif len(self._state.buffer) < self._state.max_buffer_size: + for receive_event, container in self._state.waiting_receivers: + if not container: + container.append(item) + await receive_event.set() + return + + if len(self._state.buffer) < self._state.max_buffer_size: self._state.buffer.append(item) else: raise WouldBlock async def send(self, item: T_Item) -> None: - # await check_cancelled() + await checkpoint() try: await self.send_nowait(item) except WouldBlock: # Wait until there's someone on the receiving end - send_event = anyio.create_event() + send_event = create_event() self._state.waiting_senders[send_event] = item try: await send_event.wait() @@ -175,7 +199,6 @@ async def aclose(self) -> None: self._closed = True self._state.open_send_channels -= 1 if self._state.open_send_channels == 0: - receive_events = list(self._state.waiting_receivers.keys()) - self._state.waiting_receivers.clear() + receive_events = [event for event, container in self._state.waiting_receivers] for event in receive_events: await event.set() diff --git a/tests/streams/test_memory.py b/tests/streams/test_memory.py index 61778341..4b95c351 100644 --- a/tests/streams/test_memory.py +++ b/tests/streams/test_memory.py @@ -1,7 +1,8 @@ import pytest from anyio import ( - create_task_group, wait_all_tasks_blocked, create_memory_object_stream, fail_after) + create_task_group, wait_all_tasks_blocked, create_memory_object_stream, fail_after, + open_cancel_scope) from anyio.exceptions import EndOfStream, ClosedResourceError, BrokenResourceError, WouldBlock pytestmark = pytest.mark.anyio @@ -177,3 +178,99 @@ async def test_receive_after_send_closed(): await send.send('hello') await send.aclose() assert await receive.receive() == 'hello' + + +async def test_receive_when_cancelled(): + """ + Test that calling receive() in a cancelled scope prevents it from going through with the + operation. + + """ + send, receive = create_memory_object_stream() + async with create_task_group() as tg: + await tg.spawn(send.send, 'hello') + await wait_all_tasks_blocked() + await tg.spawn(send.send, 'world') + await wait_all_tasks_blocked() + + async with open_cancel_scope() as scope: + await scope.cancel() + await receive.receive() + + assert await receive.receive() == 'hello' + assert await receive.receive() == 'world' + + +async def test_send_when_cancelled(): + """ + Test that calling send() in a cancelled scope prevents it from going through with the + operation. + + """ + async def receiver(): + received.append(await receive.receive()) + + received = [] + send, receive = create_memory_object_stream() + async with create_task_group() as tg: + await tg.spawn(receiver) + async with open_cancel_scope() as scope: + await scope.cancel() + await send.send('hello') + + await send.send('world') + + assert received == ['world'] + + +async def test_cancel_during_receive(): + """ + Test that cancelling a pending receive() operation does not cause an item in the stream to be + lost. + + """ + async def scoped_receiver(): + nonlocal receiver_scope + async with open_cancel_scope() as receiver_scope: + await receive.receive() + + async def receiver(): + received.append(await receive.receive()) + + receiver_scope = None + received = [] + send, receive = create_memory_object_stream() + async with create_task_group() as tg: + await tg.spawn(scoped_receiver) + await wait_all_tasks_blocked() + await tg.spawn(receiver) + await receiver_scope.cancel() + await send.send('hello') + + assert received == ['hello'] + + +async def test_cancel_during_receive_last_receiver(): + """ + Test that cancelling a pending receive() operation does not cause an item in the stream to be + lost, even if there are no other receivers waiting. + + """ + async def scoped_receiver(): + nonlocal receiver_scope + async with open_cancel_scope() as receiver_scope: + await receive.receive() + pytest.fail('This point should never be reached') + + receiver_scope = None + send, receive = create_memory_object_stream() + async with create_task_group() as tg: + await tg.spawn(scoped_receiver) + await wait_all_tasks_blocked() + await receiver_scope.cancel() + await send.send_nowait('hello') + + with pytest.raises(WouldBlock): + await send.send_nowait('world') + + assert await receive.receive_nowait() == 'hello'