Skip to content

Commit

Permalink
Pass along the received item to the next receiver if the task was can…
Browse files Browse the repository at this point in the history
…celled
  • Loading branch information
agronholm committed Aug 16, 2020
1 parent 1d548ca commit bd9a310
Show file tree
Hide file tree
Showing 2 changed files with 139 additions and 19 deletions.
59 changes: 41 additions & 18 deletions src/anyio/streams/memory.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from collections import deque, OrderedDict
from dataclasses import dataclass, field
from typing import TypeVar, Generic, List, Deque
from typing import TypeVar, Generic, List, Deque, Tuple

import anyio
from .. import get_cancelled_exc_class
from .._core._lowlevel import checkpoint
from .._core._synchronization import create_event
from ..abc.synchronization import Event
from ..abc.streams import ObjectSendStream, ObjectReceiveStream
from ..exceptions import ClosedResourceError, BrokenResourceError, WouldBlock, EndOfStream
Expand All @@ -16,8 +18,7 @@ class MemoryObjectStreamState(Generic[T_Item]):
buffer: Deque[T_Item] = field(init=False, default_factory=deque)
open_send_channels: int = field(init=False, default=0)
open_receive_channels: int = field(init=False, default=0)
waiting_receivers: 'OrderedDict[Event, List[T_Item]]' = field(init=False,
default_factory=OrderedDict)
waiting_receivers: Deque[Tuple[Event, List[T_Item]]] = field(init=False, default_factory=deque)
waiting_senders: 'OrderedDict[Event, T_Item]' = field(init=False, default_factory=OrderedDict)


Expand Down Expand Up @@ -58,20 +59,41 @@ async def receive_nowait(self) -> T_Item:
raise WouldBlock

async def receive(self) -> T_Item:
# anyio.check_cancelled()
await checkpoint()
try:
return await self.receive_nowait()
except WouldBlock:
# Add ourselves in the queue
receive_event = anyio.create_event()
receive_event = create_event()
container: List[T_Item] = []
self._state.waiting_receivers[receive_event] = container
ticket = receive_event, container
self._state.waiting_receivers.append(ticket)

try:
await receive_event.wait()
except BaseException:
self._state.waiting_receivers.pop(receive_event, None)
except get_cancelled_exc_class():
# If we already received an item in the container, pass it to the next receiver in
# line
index = self._state.waiting_receivers.index(ticket) + 1
if container:
item = container[0]
while index < len(self._state.waiting_receivers):
receive_event, container = self._state.waiting_receivers[index]
if container:
item, container[0] = container[0], item
else:
# Found an untriggered receiver
container.append(item)
await receive_event.set()
break
else:
# Could not find an untriggered receiver, so in order to not lose any
# items, put it in the buffer, even if it exceeds the maximum buffer size
self._state.buffer.append(item)

raise
finally:
self._state.waiting_receivers.remove(ticket)

if container:
return container[0]
Expand Down Expand Up @@ -129,22 +151,24 @@ async def send_nowait(self, item: T_Item) -> None:
if not self._state.open_receive_channels:
raise BrokenResourceError

if self._state.waiting_receivers:
receive_event, container = self._state.waiting_receivers.popitem(last=False)
container.append(item)
await receive_event.set()
elif len(self._state.buffer) < self._state.max_buffer_size:
for receive_event, container in self._state.waiting_receivers:
if not container:
container.append(item)
await receive_event.set()
return

if len(self._state.buffer) < self._state.max_buffer_size:
self._state.buffer.append(item)
else:
raise WouldBlock

async def send(self, item: T_Item) -> None:
# await check_cancelled()
await checkpoint()
try:
await self.send_nowait(item)
except WouldBlock:
# Wait until there's someone on the receiving end
send_event = anyio.create_event()
send_event = create_event()
self._state.waiting_senders[send_event] = item
try:
await send_event.wait()
Expand Down Expand Up @@ -175,7 +199,6 @@ async def aclose(self) -> None:
self._closed = True
self._state.open_send_channels -= 1
if self._state.open_send_channels == 0:
receive_events = list(self._state.waiting_receivers.keys())
self._state.waiting_receivers.clear()
receive_events = [event for event, container in self._state.waiting_receivers]
for event in receive_events:
await event.set()
99 changes: 98 additions & 1 deletion tests/streams/test_memory.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import pytest

from anyio import (
create_task_group, wait_all_tasks_blocked, create_memory_object_stream, fail_after)
create_task_group, wait_all_tasks_blocked, create_memory_object_stream, fail_after,
open_cancel_scope)
from anyio.exceptions import EndOfStream, ClosedResourceError, BrokenResourceError, WouldBlock

pytestmark = pytest.mark.anyio
Expand Down Expand Up @@ -177,3 +178,99 @@ async def test_receive_after_send_closed():
await send.send('hello')
await send.aclose()
assert await receive.receive() == 'hello'


async def test_receive_when_cancelled():
"""
Test that calling receive() in a cancelled scope prevents it from going through with the
operation.
"""
send, receive = create_memory_object_stream()
async with create_task_group() as tg:
await tg.spawn(send.send, 'hello')
await wait_all_tasks_blocked()
await tg.spawn(send.send, 'world')
await wait_all_tasks_blocked()

async with open_cancel_scope() as scope:
await scope.cancel()
await receive.receive()

assert await receive.receive() == 'hello'
assert await receive.receive() == 'world'


async def test_send_when_cancelled():
"""
Test that calling send() in a cancelled scope prevents it from going through with the
operation.
"""
async def receiver():
received.append(await receive.receive())

received = []
send, receive = create_memory_object_stream()
async with create_task_group() as tg:
await tg.spawn(receiver)
async with open_cancel_scope() as scope:
await scope.cancel()
await send.send('hello')

await send.send('world')

assert received == ['world']


async def test_cancel_during_receive():
"""
Test that cancelling a pending receive() operation does not cause an item in the stream to be
lost.
"""
async def scoped_receiver():
nonlocal receiver_scope
async with open_cancel_scope() as receiver_scope:
await receive.receive()

async def receiver():
received.append(await receive.receive())

receiver_scope = None
received = []
send, receive = create_memory_object_stream()
async with create_task_group() as tg:
await tg.spawn(scoped_receiver)
await wait_all_tasks_blocked()
await tg.spawn(receiver)
await receiver_scope.cancel()
await send.send('hello')

assert received == ['hello']


async def test_cancel_during_receive_last_receiver():
"""
Test that cancelling a pending receive() operation does not cause an item in the stream to be
lost, even if there are no other receivers waiting.
"""
async def scoped_receiver():
nonlocal receiver_scope
async with open_cancel_scope() as receiver_scope:
await receive.receive()
pytest.fail('This point should never be reached')

receiver_scope = None
send, receive = create_memory_object_stream()
async with create_task_group() as tg:
await tg.spawn(scoped_receiver)
await wait_all_tasks_blocked()
await receiver_scope.cancel()
await send.send_nowait('hello')

with pytest.raises(WouldBlock):
await send.send_nowait('world')

assert await receive.receive_nowait() == 'hello'

0 comments on commit bd9a310

Please sign in to comment.