From 1973569a0ded18b5e6720c664dfdfc0303e720d8 Mon Sep 17 00:00:00 2001 From: azawlocki Date: Fri, 21 May 2021 11:21:52 +0200 Subject: [PATCH] Make `AsyncWrapper` an async context manager --- tests/executor/test_async_wrapper.py | 129 +++++++++++++---------- tests/executor/test_payment_platforms.py | 4 + yapapi/executor/__init__.py | 93 ++++++++-------- yapapi/executor/utils.py | 39 ++++--- 4 files changed, 148 insertions(+), 117 deletions(-) diff --git a/tests/executor/test_async_wrapper.py b/tests/executor/test_async_wrapper.py index 0d7332b4b..9f2a42823 100644 --- a/tests/executor/test_async_wrapper.py +++ b/tests/executor/test_async_wrapper.py @@ -5,35 +5,54 @@ from yapapi.executor.utils import AsyncWrapper -def test_keyboard_interrupt(event_loop): +def test_async_wrapper_ordering(): + """Test if AsyncWrapper preserves order of calls.""" + + input_ = list(range(10)) + output = [] + + def func(n): + output.append(n) + + async def main(): + async with AsyncWrapper(func) as wrapper: + for n in input_: + wrapper.async_call(n) + + asyncio.get_event_loop().run_until_complete(main()) + assert output == input_ + + +def test_keyboard_interrupt(): """Test if AsyncWrapper handles KeyboardInterrupt by passing it to the event loop.""" def func(interrupt): if interrupt: raise KeyboardInterrupt - wrapper = AsyncWrapper(func, event_loop) - async def main(): - for _ in range(100): - wrapper.async_call(False) - # This will raise KeyboardInterrupt in the wrapper's worker task - wrapper.async_call(True) - await asyncio.sleep(0.01) + async with AsyncWrapper(func) as wrapper: + for _ in range(100): + wrapper.async_call(False) + # This will raise KeyboardInterrupt in the wrapper's worker task + wrapper.async_call(True) + await asyncio.sleep(0.01) - task = event_loop.create_task(main()) + loop = asyncio.get_event_loop() + task = loop.create_task(main()) with pytest.raises(KeyboardInterrupt): - event_loop.run_until_complete(task) + loop.run_until_complete(task) # Make sure the main task did not get KeyboardInterrupt assert not task.done() - # Make sure the wrapper can still make calls, it's worker task shouldn't exit - wrapper.async_call(False) + with pytest.raises(asyncio.CancelledError): + task.cancel() + loop.run_until_complete(task) -def test_stop_doesnt_deadlock(event_loop): - """Test if the AsyncWrapper.stop() coroutine completes after an AsyncWrapper is interrupted. +def test_aexit_doesnt_deadlock(): + """Test if the AsyncWrapper.__aexit__() completes after an AsyncWrapper is interrupted. See https://github.com/golemfactory/yapapi/issues/238. """ @@ -46,55 +65,57 @@ def func(interrupt): async def main(): """"This coroutine mimics how an AsyncWrapper is used in an Executor.""" - wrapper = AsyncWrapper(func, event_loop) - try: - # Queue some calls - for _ in range(10): - wrapper.async_call(False) - wrapper.async_call(True) - for _ in range(10): - wrapper.async_call(False) - # Sleep until cancelled - await asyncio.sleep(30) - assert False, "Sleep should be cancelled" - except asyncio.CancelledError: - # This call should exit without timeout - await asyncio.wait_for(wrapper.stop(), timeout=30.0) - - task = event_loop.create_task(main()) + async with AsyncWrapper(func) as wrapper: + try: + # Queue some calls + for _ in range(10): + wrapper.async_call(False) + wrapper.async_call(True) + for _ in range(10): + wrapper.async_call(False) + # Sleep until cancelled + await asyncio.sleep(30) + assert False, "Sleep should be cancelled" + except asyncio.CancelledError: + pass + + loop = asyncio.get_event_loop() + task = loop.create_task(main()) try: - event_loop.run_until_complete(task) + loop.run_until_complete(task) assert False, "Expected KeyboardInterrupt" except KeyboardInterrupt: task.cancel() - event_loop.run_until_complete(task) + loop.run_until_complete(task) -def test_stop_doesnt_wait(event_loop): - """Test if the AsyncWrapper.stop() coroutine prevents new calls from be queued.""" +def test_cancel_doesnt_wait(): + """Test if the AsyncWrapper stops processing calls when it's cancelled.""" - def func(): - time.sleep(0.1) - pass + num_calls = 0 - wrapper = AsyncWrapper(func, event_loop) + def func(d): + print("Calling func()") + nonlocal num_calls + num_calls += 1 + time.sleep(d) async def main(): - with pytest.raises(RuntimeError): - for n in range(100): - wrapper.async_call() - await asyncio.sleep(0.01) - # wrapper should be stopped before all calls are made - assert False, "Should raise RuntimeError" - - async def stop(): - await asyncio.sleep(0.1) - await wrapper.stop() + try: + async with AsyncWrapper(func) as wrapper: + for _ in range(10): + wrapper.async_call(0.1) + except asyncio.CancelledError: + pass - task = event_loop.create_task(main()) - event_loop.create_task(stop()) - try: - event_loop.run_until_complete(task) - except KeyboardInterrupt: + async def cancel(): + await asyncio.sleep(0.05) + print("Cancelling!") task.cancel() - event_loop.run_until_complete(task) + + loop = asyncio.get_event_loop() + task = loop.create_task(main()) + loop.create_task(cancel()) + loop.run_until_complete(task) + + assert num_calls < 10 diff --git a/tests/executor/test_payment_platforms.py b/tests/executor/test_payment_platforms.py index 5c2c83efd..ac3f8b57b 100644 --- a/tests/executor/test_payment_platforms.py +++ b/tests/executor/test_payment_platforms.py @@ -123,7 +123,11 @@ async def mock_create_allocation(_self, model): create_allocation_args.append(model) return mock.Mock() + async def mock_release_allocation(*args, **kwargs): + pass + monkeypatch.setattr(RequestorApi, "create_allocation", mock_create_allocation) + monkeypatch.setattr(RequestorApi, "release_allocation", mock_release_allocation) with pytest.raises(_StopExecutor): async with Executor( diff --git a/yapapi/executor/__init__.py b/yapapi/executor/__init__.py index 803bec07b..ab5b56f4d 100644 --- a/yapapi/executor/__init__.py +++ b/yapapi/executor/__init__.py @@ -176,11 +176,7 @@ def __init__( # Add buffering to the provided event emitter to make sure # that emitting events will not block - # TODO: make AsyncWrapper an AsyncContextManager and start it in - # in __aenter__(); if it's started here then there's no guarantee that - # it will be cancelled properly - self._wrapped_consumer: Optional[AsyncWrapper] = None - self._event_consumer = event_consumer + self._wrapped_consumer = AsyncWrapper(event_consumer) self._stream_output = stream_output @@ -241,7 +237,15 @@ async def __aenter__(self) -> "Golem": try: stack = self._stack - self._wrapped_consumer = AsyncWrapper(self._event_consumer) + await stack.enter_async_context(self._wrapped_consumer) + + def report_shutdown(*exc_info): + if any(item for item in exc_info): + self.emit(events.ShutdownFinished(exc_info=exc_info)) # noqa + else: + self.emit(events.ShutdownFinished()) + + stack.push(report_shutdown) market_client = await stack.enter_async_context(self._api_config.market()) self._market_api = rest.Market(market_client) @@ -272,53 +276,48 @@ async def __aexit__(self, exc_type, exc_val, exc_tb): # Importing this at the beginning would cause circular dependencies from ..log import pluralize - logger.debug("Golem is shutting down...") - # Wait until all computations are finished - await asyncio.gather(*[job.finished.wait() for job in self._jobs]) + try: + logger.debug("Golem is shutting down...") + # Wait until all computations are finished + await asyncio.gather(*[job.finished.wait() for job in self._jobs]) - self._payment_closing = True + self._payment_closing = True - for task in self._services: - if task is not self._process_invoices_job: - task.cancel() + for task in self._services: + if task is not self._process_invoices_job: + task.cancel() - if self._process_invoices_job and not any( - True for job in self._jobs if job.agreements_pool.confirmed > 0 - ): - logger.debug("No need to wait for invoices.") - self._process_invoices_job.cancel() + if self._process_invoices_job and not any( + True for job in self._jobs if job.agreements_pool.confirmed > 0 + ): + logger.debug("No need to wait for invoices.") + self._process_invoices_job.cancel() - try: - logger.info("Waiting for Golem services to finish...") - _, pending = await asyncio.wait( - self._services, timeout=10, return_when=asyncio.ALL_COMPLETED - ) - if pending: - logger.debug("%s still running: %s", pluralize(len(pending), "service"), pending) - except Exception: - # TODO: add message - logger.debug("TODO", exc_info=True) - - if self._agreements_to_pay and self._process_invoices_job: - logger.info( - "%s still unpaid, waiting for invoices...", - pluralize(len(self._agreements_to_pay), "agreement"), - ) - await asyncio.wait( - {self._process_invoices_job}, timeout=30, return_when=asyncio.ALL_COMPLETED - ) - if self._agreements_to_pay: - logger.warning("Unpaid agreements: %s", self._agreements_to_pay) + try: + logger.info("Waiting for Golem services to finish...") + _, pending = await asyncio.wait( + self._services, timeout=10, return_when=asyncio.ALL_COMPLETED + ) + if pending: + logger.debug( + "%s still running: %s", pluralize(len(pending), "service"), pending + ) + except Exception: + logger.debug("Got error when waiting for services to finish", exc_info=True) + + if self._agreements_to_pay and self._process_invoices_job: + logger.info( + "%s still unpaid, waiting for invoices...", + pluralize(len(self._agreements_to_pay), "agreement"), + ) + await asyncio.wait( + {self._process_invoices_job}, timeout=30, return_when=asyncio.ALL_COMPLETED + ) + if self._agreements_to_pay: + logger.warning("Unpaid agreements: %s", self._agreements_to_pay) - # TODO: prevent new computations at this point (if it's even possible to start one) - try: - await self._stack.aclose() - self.emit(events.ShutdownFinished()) - except Exception: - self.emit(events.ShutdownFinished(exc_info=sys.exc_info())) finally: - if self._wrapped_consumer: - await self._wrapped_consumer.stop() + await self._stack.aclose() async def _create_allocations(self) -> rest.payment.MarketDecoration: diff --git a/yapapi/executor/utils.py b/yapapi/executor/utils.py index 04d70b2bd..13622705b 100644 --- a/yapapi/executor/utils.py +++ b/yapapi/executor/utils.py @@ -1,13 +1,13 @@ """Utility functions and classes used within the `yapapi.executor` package.""" import asyncio import logging -from typing import Callable, Optional +from typing import AsyncContextManager, Callable, Optional logger = logging.getLogger(__name__) -class AsyncWrapper: +class AsyncWrapper(AsyncContextManager): """Wraps a given callable to provide asynchronous calls. Example usage: @@ -25,11 +25,27 @@ class AsyncWrapper: _args_buffer: asyncio.Queue _task: Optional[asyncio.Task] - def __init__(self, wrapped: Callable, event_loop: Optional[asyncio.AbstractEventLoop] = None): + def __init__(self, wrapped: Callable): self._wrapped = wrapped # type: ignore # suppress mypy issue #708 self._args_buffer = asyncio.Queue() - loop = event_loop or asyncio.get_event_loop() - self._task = loop.create_task(self._worker()) + self._loop = asyncio.get_event_loop() + self._task = None + + async def __aenter__(self) -> "AsyncWrapper": + self._task = self._loop.create_task(self._worker()) + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb) -> Optional[bool]: + """Stop the wrapper, process queued calls but do not accept any new ones.""" + if self._task: + # Set self._task to None so we don't accept any more calls in `async_call()` + worker_task = self._task + self._task = None + await self._args_buffer.join() + worker_task.cancel() + await asyncio.gather(worker_task, return_exceptions=True) + # Don't suppress the exception (if any), so return a non-True value + return None async def _worker(self) -> None: while True: @@ -39,6 +55,7 @@ async def _worker(self) -> None: self._wrapped(*args, **kwargs) finally: self._args_buffer.task_done() + await asyncio.sleep(0) except KeyboardInterrupt as ke: # Don't stop on KeyboardInterrupt, but pass it to the event loop logger.debug("Caught KeybordInterrupt in AsyncWrapper's worker task") @@ -46,23 +63,13 @@ async def _worker(self) -> None: def raise_interrupt(ke_): raise ke_ - asyncio.get_event_loop().call_soon(raise_interrupt, ke) + self._loop.call_soon(raise_interrupt, ke) except asyncio.CancelledError: logger.debug("AsyncWrapper's worker task cancelled") break except Exception: logger.exception("Unhandled exception in wrapped callable") - async def stop(self) -> None: - """Stop the wrapper, process queued calls but do not accept any new ones.""" - if self._task: - # Set self._task to None so we don't accept any more calls in `async_call()` - worker_task = self._task - self._task = None - await self._args_buffer.join() - worker_task.cancel() - await asyncio.gather(worker_task, return_exceptions=True) - def async_call(self, *args, **kwargs) -> None: """Schedule an asynchronous call to the wrapped callable.""" if not self._task or self._task.done():