Skip to content

Commit

Permalink
Make AsyncWrapper an async context manager
Browse files Browse the repository at this point in the history
  • Loading branch information
azawlocki committed May 21, 2021
1 parent 37e4c01 commit 1973569
Show file tree
Hide file tree
Showing 4 changed files with 148 additions and 117 deletions.
129 changes: 75 additions & 54 deletions tests/executor/test_async_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,35 +5,54 @@
from yapapi.executor.utils import AsyncWrapper


def test_keyboard_interrupt(event_loop):
def test_async_wrapper_ordering():
"""Test if AsyncWrapper preserves order of calls."""

input_ = list(range(10))
output = []

def func(n):
output.append(n)

async def main():
async with AsyncWrapper(func) as wrapper:
for n in input_:
wrapper.async_call(n)

asyncio.get_event_loop().run_until_complete(main())
assert output == input_


def test_keyboard_interrupt():
"""Test if AsyncWrapper handles KeyboardInterrupt by passing it to the event loop."""

def func(interrupt):
if interrupt:
raise KeyboardInterrupt

wrapper = AsyncWrapper(func, event_loop)

async def main():
for _ in range(100):
wrapper.async_call(False)
# This will raise KeyboardInterrupt in the wrapper's worker task
wrapper.async_call(True)
await asyncio.sleep(0.01)
async with AsyncWrapper(func) as wrapper:
for _ in range(100):
wrapper.async_call(False)
# This will raise KeyboardInterrupt in the wrapper's worker task
wrapper.async_call(True)
await asyncio.sleep(0.01)

task = event_loop.create_task(main())
loop = asyncio.get_event_loop()
task = loop.create_task(main())
with pytest.raises(KeyboardInterrupt):
event_loop.run_until_complete(task)
loop.run_until_complete(task)

# Make sure the main task did not get KeyboardInterrupt
assert not task.done()

# Make sure the wrapper can still make calls, it's worker task shouldn't exit
wrapper.async_call(False)
with pytest.raises(asyncio.CancelledError):
task.cancel()
loop.run_until_complete(task)


def test_stop_doesnt_deadlock(event_loop):
"""Test if the AsyncWrapper.stop() coroutine completes after an AsyncWrapper is interrupted.
def test_aexit_doesnt_deadlock():
"""Test if the AsyncWrapper.__aexit__() completes after an AsyncWrapper is interrupted.
See https://github.com/golemfactory/yapapi/issues/238.
"""
Expand All @@ -46,55 +65,57 @@ def func(interrupt):
async def main():
""""This coroutine mimics how an AsyncWrapper is used in an Executor."""

wrapper = AsyncWrapper(func, event_loop)
try:
# Queue some calls
for _ in range(10):
wrapper.async_call(False)
wrapper.async_call(True)
for _ in range(10):
wrapper.async_call(False)
# Sleep until cancelled
await asyncio.sleep(30)
assert False, "Sleep should be cancelled"
except asyncio.CancelledError:
# This call should exit without timeout
await asyncio.wait_for(wrapper.stop(), timeout=30.0)

task = event_loop.create_task(main())
async with AsyncWrapper(func) as wrapper:
try:
# Queue some calls
for _ in range(10):
wrapper.async_call(False)
wrapper.async_call(True)
for _ in range(10):
wrapper.async_call(False)
# Sleep until cancelled
await asyncio.sleep(30)
assert False, "Sleep should be cancelled"
except asyncio.CancelledError:
pass

loop = asyncio.get_event_loop()
task = loop.create_task(main())
try:
event_loop.run_until_complete(task)
loop.run_until_complete(task)
assert False, "Expected KeyboardInterrupt"
except KeyboardInterrupt:
task.cancel()
event_loop.run_until_complete(task)
loop.run_until_complete(task)


def test_stop_doesnt_wait(event_loop):
"""Test if the AsyncWrapper.stop() coroutine prevents new calls from be queued."""
def test_cancel_doesnt_wait():
"""Test if the AsyncWrapper stops processing calls when it's cancelled."""

def func():
time.sleep(0.1)
pass
num_calls = 0

wrapper = AsyncWrapper(func, event_loop)
def func(d):
print("Calling func()")
nonlocal num_calls
num_calls += 1
time.sleep(d)

async def main():
with pytest.raises(RuntimeError):
for n in range(100):
wrapper.async_call()
await asyncio.sleep(0.01)
# wrapper should be stopped before all calls are made
assert False, "Should raise RuntimeError"

async def stop():
await asyncio.sleep(0.1)
await wrapper.stop()
try:
async with AsyncWrapper(func) as wrapper:
for _ in range(10):
wrapper.async_call(0.1)
except asyncio.CancelledError:
pass

task = event_loop.create_task(main())
event_loop.create_task(stop())
try:
event_loop.run_until_complete(task)
except KeyboardInterrupt:
async def cancel():
await asyncio.sleep(0.05)
print("Cancelling!")
task.cancel()
event_loop.run_until_complete(task)

loop = asyncio.get_event_loop()
task = loop.create_task(main())
loop.create_task(cancel())
loop.run_until_complete(task)

assert num_calls < 10
4 changes: 4 additions & 0 deletions tests/executor/test_payment_platforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,11 @@ async def mock_create_allocation(_self, model):
create_allocation_args.append(model)
return mock.Mock()

async def mock_release_allocation(*args, **kwargs):
pass

monkeypatch.setattr(RequestorApi, "create_allocation", mock_create_allocation)
monkeypatch.setattr(RequestorApi, "release_allocation", mock_release_allocation)

with pytest.raises(_StopExecutor):
async with Executor(
Expand Down
93 changes: 46 additions & 47 deletions yapapi/executor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,11 +176,7 @@ def __init__(

# Add buffering to the provided event emitter to make sure
# that emitting events will not block
# TODO: make AsyncWrapper an AsyncContextManager and start it in
# in __aenter__(); if it's started here then there's no guarantee that
# it will be cancelled properly
self._wrapped_consumer: Optional[AsyncWrapper] = None
self._event_consumer = event_consumer
self._wrapped_consumer = AsyncWrapper(event_consumer)

self._stream_output = stream_output

Expand Down Expand Up @@ -241,7 +237,15 @@ async def __aenter__(self) -> "Golem":
try:
stack = self._stack

self._wrapped_consumer = AsyncWrapper(self._event_consumer)
await stack.enter_async_context(self._wrapped_consumer)

def report_shutdown(*exc_info):
if any(item for item in exc_info):
self.emit(events.ShutdownFinished(exc_info=exc_info)) # noqa
else:
self.emit(events.ShutdownFinished())

stack.push(report_shutdown)

market_client = await stack.enter_async_context(self._api_config.market())
self._market_api = rest.Market(market_client)
Expand Down Expand Up @@ -272,53 +276,48 @@ async def __aexit__(self, exc_type, exc_val, exc_tb):
# Importing this at the beginning would cause circular dependencies
from ..log import pluralize

logger.debug("Golem is shutting down...")
# Wait until all computations are finished
await asyncio.gather(*[job.finished.wait() for job in self._jobs])
try:
logger.debug("Golem is shutting down...")
# Wait until all computations are finished
await asyncio.gather(*[job.finished.wait() for job in self._jobs])

self._payment_closing = True
self._payment_closing = True

for task in self._services:
if task is not self._process_invoices_job:
task.cancel()
for task in self._services:
if task is not self._process_invoices_job:
task.cancel()

if self._process_invoices_job and not any(
True for job in self._jobs if job.agreements_pool.confirmed > 0
):
logger.debug("No need to wait for invoices.")
self._process_invoices_job.cancel()
if self._process_invoices_job and not any(
True for job in self._jobs if job.agreements_pool.confirmed > 0
):
logger.debug("No need to wait for invoices.")
self._process_invoices_job.cancel()

try:
logger.info("Waiting for Golem services to finish...")
_, pending = await asyncio.wait(
self._services, timeout=10, return_when=asyncio.ALL_COMPLETED
)
if pending:
logger.debug("%s still running: %s", pluralize(len(pending), "service"), pending)
except Exception:
# TODO: add message
logger.debug("TODO", exc_info=True)

if self._agreements_to_pay and self._process_invoices_job:
logger.info(
"%s still unpaid, waiting for invoices...",
pluralize(len(self._agreements_to_pay), "agreement"),
)
await asyncio.wait(
{self._process_invoices_job}, timeout=30, return_when=asyncio.ALL_COMPLETED
)
if self._agreements_to_pay:
logger.warning("Unpaid agreements: %s", self._agreements_to_pay)
try:
logger.info("Waiting for Golem services to finish...")
_, pending = await asyncio.wait(
self._services, timeout=10, return_when=asyncio.ALL_COMPLETED
)
if pending:
logger.debug(
"%s still running: %s", pluralize(len(pending), "service"), pending
)
except Exception:
logger.debug("Got error when waiting for services to finish", exc_info=True)

if self._agreements_to_pay and self._process_invoices_job:
logger.info(
"%s still unpaid, waiting for invoices...",
pluralize(len(self._agreements_to_pay), "agreement"),
)
await asyncio.wait(
{self._process_invoices_job}, timeout=30, return_when=asyncio.ALL_COMPLETED
)
if self._agreements_to_pay:
logger.warning("Unpaid agreements: %s", self._agreements_to_pay)

# TODO: prevent new computations at this point (if it's even possible to start one)
try:
await self._stack.aclose()
self.emit(events.ShutdownFinished())
except Exception:
self.emit(events.ShutdownFinished(exc_info=sys.exc_info()))
finally:
if self._wrapped_consumer:
await self._wrapped_consumer.stop()
await self._stack.aclose()

async def _create_allocations(self) -> rest.payment.MarketDecoration:

Expand Down
39 changes: 23 additions & 16 deletions yapapi/executor/utils.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
"""Utility functions and classes used within the `yapapi.executor` package."""
import asyncio
import logging
from typing import Callable, Optional
from typing import AsyncContextManager, Callable, Optional


logger = logging.getLogger(__name__)


class AsyncWrapper:
class AsyncWrapper(AsyncContextManager):
"""Wraps a given callable to provide asynchronous calls.
Example usage:
Expand All @@ -25,11 +25,27 @@ class AsyncWrapper:
_args_buffer: asyncio.Queue
_task: Optional[asyncio.Task]

def __init__(self, wrapped: Callable, event_loop: Optional[asyncio.AbstractEventLoop] = None):
def __init__(self, wrapped: Callable):
self._wrapped = wrapped # type: ignore # suppress mypy issue #708
self._args_buffer = asyncio.Queue()
loop = event_loop or asyncio.get_event_loop()
self._task = loop.create_task(self._worker())
self._loop = asyncio.get_event_loop()
self._task = None

async def __aenter__(self) -> "AsyncWrapper":
self._task = self._loop.create_task(self._worker())
return self

async def __aexit__(self, exc_type, exc_val, exc_tb) -> Optional[bool]:
"""Stop the wrapper, process queued calls but do not accept any new ones."""
if self._task:
# Set self._task to None so we don't accept any more calls in `async_call()`
worker_task = self._task
self._task = None
await self._args_buffer.join()
worker_task.cancel()
await asyncio.gather(worker_task, return_exceptions=True)
# Don't suppress the exception (if any), so return a non-True value
return None

async def _worker(self) -> None:
while True:
Expand All @@ -39,30 +55,21 @@ async def _worker(self) -> None:
self._wrapped(*args, **kwargs)
finally:
self._args_buffer.task_done()
await asyncio.sleep(0)
except KeyboardInterrupt as ke:
# Don't stop on KeyboardInterrupt, but pass it to the event loop
logger.debug("Caught KeybordInterrupt in AsyncWrapper's worker task")

def raise_interrupt(ke_):
raise ke_

asyncio.get_event_loop().call_soon(raise_interrupt, ke)
self._loop.call_soon(raise_interrupt, ke)
except asyncio.CancelledError:
logger.debug("AsyncWrapper's worker task cancelled")
break
except Exception:
logger.exception("Unhandled exception in wrapped callable")

async def stop(self) -> None:
"""Stop the wrapper, process queued calls but do not accept any new ones."""
if self._task:
# Set self._task to None so we don't accept any more calls in `async_call()`
worker_task = self._task
self._task = None
await self._args_buffer.join()
worker_task.cancel()
await asyncio.gather(worker_task, return_exceptions=True)

def async_call(self, *args, **kwargs) -> None:
"""Schedule an asynchronous call to the wrapped callable."""
if not self._task or self._task.done():
Expand Down

0 comments on commit 1973569

Please sign in to comment.