diff --git a/tests/executor/test_peekable_async_iterator.py b/tests/executor/test_peekable_async_iterator.py deleted file mode 100644 index b2cbed0ba..000000000 --- a/tests/executor/test_peekable_async_iterator.py +++ /dev/null @@ -1,34 +0,0 @@ -from typing import AsyncIterable - -import pytest - -from yapapi.executor._smartq import PeekableAsyncIterator - - -@pytest.mark.asyncio -@pytest.mark.parametrize( - "input", - [ - [], - [1], - [1, 2], - [1, 2, 3], - [1, 2, 3, 4], - ], -) -async def test_iterator(input): - async def iterator(): - for item in input: - yield item - - it = PeekableAsyncIterator(iterator()) - assert (await it.has_next()) == bool(input) - - output = [] - - async for item in it: - output.append(item) - assert (await it.has_next()) == (len(output) < len(input)) - - assert not await it.has_next() - assert input == output diff --git a/tests/executor/test_smartq.py b/tests/executor/test_smartq.py index bad82a4f7..8b0753c8c 100644 --- a/tests/executor/test_smartq.py +++ b/tests/executor/test_smartq.py @@ -61,17 +61,28 @@ async def test_smart_queue_empty(): @pytest.mark.asyncio async def test_unassigned_items(): + """Test the `SmartQueue.has_unassigned_items()` method.""" q = SmartQueue(async_iter([1, 2, 3])) + with q.new_consumer() as c: + await asyncio.sleep(0.1) + assert q.has_unassigned_items() + async for handle in c: - assert await q.has_new_items() == await q.has_unassigned_items() - if not await q.has_unassigned_items(): + assert not q.finished() + await asyncio.sleep(0.1) + if not q.has_unassigned_items(): assert handle.data == 3 break - assert not await q.has_unassigned_items() + # All items are in progress, `has_unassigned_items()` should return `False` + assert not q.has_unassigned_items() await q.reschedule_all(c) - assert await q.has_unassigned_items() - assert not await q.has_new_items() + # Now the items are unassigned again + assert q.has_unassigned_items() + # Queue is still not finished + assert not q.finished() + + await q.close() @pytest.mark.asyncio @@ -109,11 +120,12 @@ async def invalid_worker(q): print("w end") assert outputs == {1, 2, 3} + await q.close() @pytest.mark.asyncio async def test_has_unassigned_items_doesnt_block(): - """Test if `has_unassigned_items` does not block if there are rescheduled items.""" + """Check that the queue does not block waiting for new items if there are rescheduled items.""" loop = asyncio.get_event_loop() @@ -139,3 +151,52 @@ async def worker(): worker_1 = loop.create_task(worker()) worker_2 = loop.create_task(worker()) await asyncio.gather(worker_1, worker_2) + await q.close() + + +@pytest.mark.parametrize( + "task_iterator_interval, worker_interval, executor_interval", + [ + (0.3, 0.0, 0.1), + (0.1, 0.0, 0.3), + (0.2, 0.3, 0.1), + (0.1, 0.3, 0.2), + ], +) +@pytest.mark.asyncio +async def test_async_task_iterator(task_iterator_interval, worker_interval, executor_interval): + """Check that the queue waits until new items appear.""" + + inputs = list(range(5)) + yielded = [] + + async def tasks(): + for n in inputs: + print(f"Yielding {n}") + yield n + yielded.append(n) + await asyncio.sleep(task_iterator_interval) + + q = SmartQueue(tasks()) + loop = asyncio.get_event_loop() + + # A sample worker that accepts a task quickly and then exits + async def worker(): + print("Started new worker") + with q.new_consumer() as consumer: + async for handle in consumer: + await asyncio.sleep(worker_interval) + await q.mark_done(handle) + print(f"Exiting after task {handle.data}") + return + + # Simulate how Executor works: spawn new workers until all items are handled + worker_task = None + done_task = loop.create_task(q.wait_until_done()) + while not done_task.done(): + if (worker_task is None or worker_task.done()) and q.has_unassigned_items(): + worker_task = loop.create_task(worker()) + await asyncio.sleep(executor_interval) + + assert yielded == inputs + await q.close() diff --git a/yapapi/executor/__init__.py b/yapapi/executor/__init__.py index 681a1c3f5..302d4d0c0 100644 --- a/yapapi/executor/__init__.py +++ b/yapapi/executor/__init__.py @@ -1030,7 +1030,7 @@ async def worker_starter() -> None: while True: await asyncio.sleep(2) await job.agreements_pool.cycle() - if len(workers) < self._max_workers and await work_queue.has_unassigned_items(): + if len(workers) < self._max_workers and work_queue.has_unassigned_items(): new_task = None try: new_task = await job.agreements_pool.use_agreement( @@ -1113,6 +1113,9 @@ async def worker_starter() -> None: cancelled = True finally: + + await work_queue.close() + # Importing this at the beginning would cause circular dependencies from ..log import pluralize diff --git a/yapapi/executor/_smartq.py b/yapapi/executor/_smartq.py index bd5e592d5..52e671c07 100644 --- a/yapapi/executor/_smartq.py +++ b/yapapi/executor/_smartq.py @@ -55,52 +55,15 @@ def data(self) -> Item: return self._data -class PeekableAsyncIterator(AsyncIterator[Item]): - """An AsyncIterator with an additional `has_next()` method.""" - - def __init__(self, base: AsyncIterator[Item]): - self._base: AsyncIterator[Item] = base - self._first_item: Optional[Item] = None - self._first_set: bool = False - self._lock = asyncio.Lock() - - async def _get_first(self) -> None: - - # Both __anext__() and has_next() may call this method, and both - # may be called concurrently. We need to ensure that _base.__anext__() - # is not called concurrently or else it may raise a RuntimeError. - async with self._lock: - if self._first_set: - return - self._first_item = await self._base.__anext__() - self._first_set = True - - async def __anext__(self) -> Item: - await self._get_first() - item = self._first_item - self._first_item = None - self._first_set = False - return item # type: ignore - - async def has_next(self) -> bool: - """Return `True` if and only this iterator has more elements. - - In other words, `has_next()` returns `True` iff the next call to `__anext__()` - will return normally, without raising `StopAsyncIteration`. - """ - try: - await self._get_first() - except StopAsyncIteration: - return False - return True - - class SmartQueue(Generic[Item]): def __init__(self, items: AsyncIterator[Item]): """ :param items: the items to be iterated over """ - self._items: PeekableAsyncIterator[Item] = PeekableAsyncIterator(items) + + self._buffer: "asyncio.Queue[Item]" = asyncio.Queue(maxsize=1) + self._incoming_finished = False + self._buffer_task = asyncio.get_event_loop().create_task(self._fill_buffer(items)) """The items scheduled for reassignment to another consumer""" self._rescheduled_items: Set[Handle[Item]] = set() @@ -113,33 +76,37 @@ def __init__(self, items: AsyncIterator[Item]): self._new_items = Condition(lock=self._lock) self._eof = Condition(lock=self._lock) - async def has_new_items(self) -> bool: - """Check whether this queue has any items that were not retrieved by any consumer yet.""" - return await self._items.has_next() - - async def has_unassigned_items(self) -> bool: - """Check whether this queue has any unassigned items. - - An item is _unassigned_ if it's new (hasn't been retrieved yet by any consumer) - or it has been rescheduled and is not in progress. + async def _fill_buffer(self, incoming: AsyncIterator[Item]): + try: + async for item in incoming: + await self._buffer.put(item) + async with self._lock: + self._new_items.notify_all() + self._incoming_finished = True + async with self._lock: + self._eof.notify_all() + self._new_items.notify_all() + except asyncio.CancelledError: + pass + + async def close(self): + if self._buffer_task: + self._buffer_task.cancel() + await self._buffer_task + self._buffer_task = None + + def finished(self): + return ( + not self.has_unassigned_items() and not (self._in_progress) and self._incoming_finished + ) - A queue has unassigned items iff the next call to `get()` will immediately return - some item, without waiting for an item that is currently "in progress" to be rescheduled. - """ - while True: - if self._rescheduled_items: - return True - try: - return await asyncio.wait_for(self.has_new_items(), 1.0) - except asyncio.TimeoutError: - pass + def has_unassigned_items(self) -> bool: + """Check if this queue has a new or rescheduled item immediately available.""" + return bool(self._rescheduled_items) or bool(self._buffer.qsize()) def new_consumer(self) -> "Consumer[Item]": return Consumer(self) - async def __has_data(self): - return await self.has_unassigned_items() or bool(self._in_progress) - def __find_rescheduled_item(self, consumer: "Consumer[Item]") -> Optional[Handle[Item]]: return next( ( @@ -153,7 +120,7 @@ def __find_rescheduled_item(self, consumer: "Consumer[Item]") -> Optional[Handle async def get(self, consumer: "Consumer[Item]") -> Handle[Item]: """Get a handle to the next item to be processed (either a new one or rescheduled).""" async with self._lock: - while await self.__has_data(): + while not self.finished(): handle = self.__find_rescheduled_item(consumer) if handle: @@ -162,8 +129,8 @@ async def get(self, consumer: "Consumer[Item]") -> Handle[Item]: handle.assign_consumer(consumer) return handle - if await self.has_new_items(): - next_elem = await self._items.__anext__() + if self._buffer.qsize(): + next_elem = await self._buffer.get() handle = Handle(next_elem, consumer=consumer) self._in_progress.add(handle) return handle @@ -180,8 +147,9 @@ async def mark_done(self, handle: Handle[Item]) -> None: self._eof.notify_all() self._new_items.notify_all() if _logger.isEnabledFor(logging.DEBUG): + stats = self.stats() _logger.debug( - f"status in-progress={len(self._in_progress)}, have_item={bool(self._items)}" + "status: " + ", ".join(f"{key}: {val}" for key, val in self.stats().items()) ) async def reschedule(self, handle: Handle[Item]) -> None: @@ -204,15 +172,16 @@ async def reschedule_all(self, consumer: "Consumer[Item]"): def stats(self) -> Dict: return { "locked": self._lock.locked(), - "items": bool(self._items), - "in-progress": len(self._in_progress), - "rescheduled-items": len(self._rescheduled_items), + "in progress": len(self._in_progress), + "rescheduled": len(self._rescheduled_items), + "in buffer": self._buffer.qsize(), + "incoming finished": self._incoming_finished, } async def wait_until_done(self) -> None: """Wait until all items in the queue are processed.""" async with self._lock: - while await self.__has_data(): + while not self.finished(): await self._eof.wait()