Skip to content

Commit

Permalink
Refactor SmartQueue to work correctly with async task iterators
Browse files Browse the repository at this point in the history
  • Loading branch information
azawlocki committed May 25, 2021
1 parent 197c569 commit 5aa3388
Show file tree
Hide file tree
Showing 3 changed files with 107 additions and 78 deletions.
73 changes: 67 additions & 6 deletions tests/executor/test_smartq.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,17 +61,28 @@ async def test_smart_queue_empty():

@pytest.mark.asyncio
async def test_unassigned_items():
"""Test the `SmartQueue.has_unassigned_items()` method."""
q = SmartQueue(async_iter([1, 2, 3]))

with q.new_consumer() as c:
await asyncio.sleep(0.1)
assert q.has_unassigned_items()

async for handle in c:
assert await q.has_new_items() == await q.has_unassigned_items()
if not await q.has_unassigned_items():
assert not q.finished()
await asyncio.sleep(0.1)
if not q.has_unassigned_items():
assert handle.data == 3
break
assert not await q.has_unassigned_items()
# All items are in progress, `has_unassigned_items()` should return `False`
assert not q.has_unassigned_items()
await q.reschedule_all(c)
assert await q.has_unassigned_items()
assert not await q.has_new_items()
# Now the items are unassigned again
assert q.has_unassigned_items()
# Queue is still not finished
assert not q.finished()

await q.close()


@pytest.mark.asyncio
Expand Down Expand Up @@ -109,11 +120,12 @@ async def invalid_worker(q):
print("w end")

assert outputs == {1, 2, 3}
await q.close()


@pytest.mark.asyncio
async def test_has_unassigned_items_doesnt_block():
"""Test if `has_unassigned_items` does not block if there are rescheduled items."""
"""Check that the queue does not block waiting for new items if there are rescheduled items."""

loop = asyncio.get_event_loop()

Expand All @@ -139,3 +151,52 @@ async def worker():
worker_1 = loop.create_task(worker())
worker_2 = loop.create_task(worker())
await asyncio.gather(worker_1, worker_2)
await q.close()


@pytest.mark.parametrize(
"task_iterator_interval, worker_interval, executor_interval",
[
(0.3, 0.0, 0.1),
(0.1, 0.0, 0.3),
(0.2, 0.3, 0.1),
(0.1, 0.3, 0.2),
],
)
@pytest.mark.asyncio
async def test_async_task_iterator(task_iterator_interval, worker_interval, executor_interval):
"""Check that the queue waits until new items appear."""

inputs = list(range(5))
yielded = []

async def tasks():
for n in inputs:
print(f"Yielding {n}")
yield n
yielded.append(n)
await asyncio.sleep(task_iterator_interval)

q = SmartQueue(tasks())
loop = asyncio.get_event_loop()

# A sample worker that accepts a task quickly and then exits
async def worker():
print("Started new worker")
with q.new_consumer() as consumer:
async for handle in consumer:
await asyncio.sleep(worker_interval)
await q.mark_done(handle)
print(f"Exiting after task {handle.data}")
return

# Simulate how Executor works: spawn new workers until all items are handled
worker_task = None
done_task = loop.create_task(q.wait_until_done())
while not done_task.done():
if (worker_task is None or worker_task.done()) and q.has_unassigned_items():
worker_task = loop.create_task(worker())
await asyncio.sleep(executor_interval)

assert yielded == inputs
await q.close()
2 changes: 1 addition & 1 deletion yapapi/executor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1030,7 +1030,7 @@ async def worker_starter() -> None:
while True:
await asyncio.sleep(2)
await job.agreements_pool.cycle()
if len(workers) < self._max_workers and await work_queue.has_unassigned_items():
if len(workers) < self._max_workers and work_queue.has_unassigned_items():
new_task = None
try:
new_task = await job.agreements_pool.use_agreement(
Expand Down
110 changes: 39 additions & 71 deletions yapapi/executor/_smartq.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,52 +55,15 @@ def data(self) -> Item:
return self._data


class PeekableAsyncIterator(AsyncIterator[Item]):
"""An AsyncIterator with an additional `has_next()` method."""

def __init__(self, base: AsyncIterator[Item]):
self._base: AsyncIterator[Item] = base
self._first_item: Optional[Item] = None
self._first_set: bool = False
self._lock = asyncio.Lock()

async def _get_first(self) -> None:

# Both __anext__() and has_next() may call this method, and both
# may be called concurrently. We need to ensure that _base.__anext__()
# is not called concurrently or else it may raise a RuntimeError.
async with self._lock:
if self._first_set:
return
self._first_item = await self._base.__anext__()
self._first_set = True

async def __anext__(self) -> Item:
await self._get_first()
item = self._first_item
self._first_item = None
self._first_set = False
return item # type: ignore

async def has_next(self) -> bool:
"""Return `True` if and only this iterator has more elements.
In other words, `has_next()` returns `True` iff the next call to `__anext__()`
will return normally, without raising `StopAsyncIteration`.
"""
try:
await self._get_first()
except StopAsyncIteration:
return False
return True


class SmartQueue(Generic[Item]):
def __init__(self, items: AsyncIterator[Item]):
"""
:param items: the items to be iterated over
"""
self._items: PeekableAsyncIterator[Item] = PeekableAsyncIterator(items)

self._buffer: "asyncio.Queue[Item]" = asyncio.Queue(maxsize=1)
self._incoming_finished = False
self._buffer_task = asyncio.get_event_loop().create_task(self._fill_buffer(items))

"""The items scheduled for reassignment to another consumer"""
self._rescheduled_items: Set[Handle[Item]] = set()
Expand All @@ -113,33 +76,36 @@ def __init__(self, items: AsyncIterator[Item]):
self._new_items = Condition(lock=self._lock)
self._eof = Condition(lock=self._lock)

async def has_new_items(self) -> bool:
"""Check whether this queue has any items that were not retrieved by any consumer yet."""
return await self._items.has_next()

async def has_unassigned_items(self) -> bool:
"""Check whether this queue has any unassigned items.
An item is _unassigned_ if it's new (hasn't been retrieved yet by any consumer)
or it has been rescheduled and is not in progress.
async def _fill_buffer(self, incoming: AsyncIterator[Item]):
try:
async for item in incoming:
await self._buffer.put(item)
async with self._lock:
self._new_items.notify_all()
self._incoming_finished = True
async with self._lock:
self._eof.notify_all()
self._new_items.notify_all()
except asyncio.CancelledError:
pass

async def close(self):
if self._buffer_task:
self._buffer_task.cancel()
await self._buffer_task

def finished(self):
return (
not self.has_unassigned_items() and not (self._in_progress) and self._incoming_finished
)

A queue has unassigned items iff the next call to `get()` will immediately return
some item, without waiting for an item that is currently "in progress" to be rescheduled.
"""
while True:
if self._rescheduled_items:
return True
try:
return await asyncio.wait_for(self.has_new_items(), 1.0)
except asyncio.TimeoutError:
pass
def has_unassigned_items(self) -> bool:
"""Check if this queue has a new or rescheduled item immediately available."""
return bool(self._rescheduled_items) or bool(self._buffer.qsize())

def new_consumer(self) -> "Consumer[Item]":
return Consumer(self)

async def __has_data(self):
return await self.has_unassigned_items() or bool(self._in_progress)

def __find_rescheduled_item(self, consumer: "Consumer[Item]") -> Optional[Handle[Item]]:
return next(
(
Expand All @@ -153,7 +119,7 @@ def __find_rescheduled_item(self, consumer: "Consumer[Item]") -> Optional[Handle
async def get(self, consumer: "Consumer[Item]") -> Handle[Item]:
"""Get a handle to the next item to be processed (either a new one or rescheduled)."""
async with self._lock:
while await self.__has_data():
while not self.finished():

handle = self.__find_rescheduled_item(consumer)
if handle:
Expand All @@ -162,8 +128,8 @@ async def get(self, consumer: "Consumer[Item]") -> Handle[Item]:
handle.assign_consumer(consumer)
return handle

if await self.has_new_items():
next_elem = await self._items.__anext__()
if self._buffer.qsize():
next_elem = await self._buffer.get()
handle = Handle(next_elem, consumer=consumer)
self._in_progress.add(handle)
return handle
Expand All @@ -180,8 +146,9 @@ async def mark_done(self, handle: Handle[Item]) -> None:
self._eof.notify_all()
self._new_items.notify_all()
if _logger.isEnabledFor(logging.DEBUG):
stats = self.stats()
_logger.debug(
f"status in-progress={len(self._in_progress)}, have_item={bool(self._items)}"
"status: " + ", ".join(f"{key}: {val}" for key, val in self.stats().items())
)

async def reschedule(self, handle: Handle[Item]) -> None:
Expand All @@ -204,15 +171,16 @@ async def reschedule_all(self, consumer: "Consumer[Item]"):
def stats(self) -> Dict:
return {
"locked": self._lock.locked(),
"items": bool(self._items),
"in-progress": len(self._in_progress),
"rescheduled-items": len(self._rescheduled_items),
"in progress": len(self._in_progress),
"rescheduled": len(self._rescheduled_items),
"in buffer": self._buffer.qsize(),
"incoming finished": self._incoming_finished,
}

async def wait_until_done(self) -> None:
"""Wait until all items in the queue are processed."""
async with self._lock:
while await self.__has_data():
while not self.finished():
await self._eof.wait()


Expand Down

0 comments on commit 5aa3388

Please sign in to comment.