From 718df5ce794c95a143aa76ed43f9837332df9cc8 Mon Sep 17 00:00:00 2001 From: Joshua Oreman Date: Sun, 3 May 2020 07:45:14 +0000 Subject: [PATCH 1/8] Clean up handling of Handles --- trio_asyncio/_async.py | 4 +- trio_asyncio/_base.py | 124 +++++++--------------- trio_asyncio/_handles.py | 222 ++++++++++++++++++--------------------- trio_asyncio/_sync.py | 7 +- 4 files changed, 145 insertions(+), 212 deletions(-) diff --git a/trio_asyncio/_async.py b/trio_asyncio/_async.py index 5a9a8c64..e46d10e9 100644 --- a/trio_asyncio/_async.py +++ b/trio_asyncio/_async.py @@ -1,7 +1,7 @@ import trio +import asyncio from ._base import BaseTrioEventLoop -from ._handles import Handle class TrioEventLoop(BaseTrioEventLoop): @@ -69,7 +69,7 @@ def stop_me(): if self._stopped.is_set(): waiter.set() else: - self._queue_handle(Handle(stop_me, (), self, context=None, is_sync=True)) + self._queue_handle(asyncio.Handle(stop_me, (), self)) return waiter def _close(self): diff --git a/trio_asyncio/_base.py b/trio_asyncio/_base.py index 8453b577..708a9521 100644 --- a/trio_asyncio/_base.py +++ b/trio_asyncio/_base.py @@ -9,7 +9,7 @@ import warnings import concurrent.futures -from ._handles import Handle, TimerHandle +from ._handles import ScopedHandle, AsyncHandle from ._util import run_aio_future, run_aio_generator from ._deprecate import deprecated, deprecated_alias from . import _util @@ -39,28 +39,6 @@ def clear(self): pass -def _h_raise(handle, exc): - """ - Convince a handle to raise an error. - - trio-asyncio enhanced handles have a method to do this - but asyncio's native handles don't. Thus we need to fudge things. - """ - if hasattr(handle, '_raise'): - handle._raise(exc) - return - - def _raise(exc): - raise exc - - cb, handle._callback = handle._callback, _raise - ar, handle._args = handle._args, (exc,) - try: - handle._run() - finally: - handle._callback, handle._args = cb, ar - - class _TrioSelector(_BaseSelectorImpl): """A selector that hooks into a ``TrioEventLoop``. @@ -241,25 +219,6 @@ async def run_aio_coroutine(self, coro): finally: sniffio.current_async_library_cvar.reset(t) - async def __run_trio(self, h): - """Helper for copying the result of a Trio task to an asyncio future""" - f, proc, *args = h._args - if f.cancelled(): # pragma: no cover - return - try: - with trio.CancelScope() as scope: - h._scope = scope - res = await proc(*args) - if scope.cancelled_caught: - f.cancel() - return - except BaseException as exc: - if not f.cancelled(): # pragma: no branch - f.set_exception(exc) - else: - if not f.cancelled(): # pragma: no branch - f.set_result(res) - def trio_as_future(self, proc, *args): """Start a new Trio task to run ``await proc(*args)`` asynchronously. Return an `asyncio.Future` that will resolve to the value or exception @@ -292,14 +251,7 @@ def trio_as_future(self, proc, *args): an `asyncio.Future` which will resolve to the result of the call to *proc* """ f = asyncio.Future(loop=self) - h = Handle( - self.__run_trio, ( - f, - proc, - ) + args, self, context=None, is_sync=None - ) - self._queue_handle(h) - f.add_done_callback(h._cb_future_cancel) + self._queue_handle(AsyncHandle(proc, args, self, result_future=f)) return f def run_trio_task(self, proc, *args): @@ -314,7 +266,7 @@ def run_trio_task(self, proc, *args): Returns: an `asyncio.Handle` which can be used to cancel the background task """ - return self._queue_handle(Handle(proc, args, self, is_sync=False)) + return self._queue_handle(AsyncHandle(proc, args, self)) # Callback handling # @@ -331,7 +283,7 @@ def _queue_handle(self, handle): def _call_soon(self, *arks, **kwargs): raise RuntimeError("_call_soon() should not have been called") - def call_later(self, delay, callback, *args, context=None): + def call_later(self, delay, callback, *args, **context): """asyncio's timer-based delay Note that the callback is a sync function. @@ -342,36 +294,36 @@ def call_later(self, delay, callback, *args, context=None): """ self._check_callback(callback, 'call_later') assert delay >= 0, delay - h = TimerHandle(delay + self.time(), callback, args, self, context=context, is_sync=True) + h = asyncio.TimerHandle(delay + self.time(), callback, args, self, **context) self._queue_handle(h) return h - def call_at(self, when, callback, *args, context=None): + def call_at(self, when, callback, *args, **context): """asyncio's time-based delay Note that the callback is a sync function. """ self._check_callback(callback, 'call_at') return self._queue_handle( - TimerHandle(when, callback, args, self, context=context, is_sync=True) + asyncio.TimerHandle(when, callback, args, self, **context) ) - def call_soon(self, callback, *args, context=None): + def call_soon(self, callback, *args, **context): """asyncio's defer-to-mainloop callback executor. Note that the callback is a sync function. """ self._check_callback(callback, 'call_soon') - return self._queue_handle(Handle(callback, args, self, context=context, is_sync=True)) + return self._queue_handle(asyncio.Handle(callback, args, self, **context)) - def call_soon_threadsafe(self, callback, *args, context=None): + def call_soon_threadsafe(self, callback, *args, **context): """asyncio's thread-safe defer-to-mainloop Note that the callback is a sync function. """ self._check_callback(callback, 'call_soon_threadsafe') self._check_closed() - h = Handle(callback, args, self, context=context, is_sync=True) + h = asyncio.Handle(callback, args, self, **context) self._token.run_sync_soon(self._q_send.send_nowait, h) # drop all timers @@ -471,7 +423,7 @@ async def synchronize(self): """ w = trio.Event() - self._queue_handle(Handle(w.set, (), self, is_sync=True)) + self._queue_handle(asyncio.Handle(w.set, (), self)) await w.wait() # Signal handling # @@ -488,7 +440,7 @@ def add_signal_handler(self, sig, callback, *args): self._check_signal(sig) if sig == signal.SIGKILL: raise RuntimeError("SIGKILL cannot be caught") - h = Handle(callback, args, self, context=None, is_sync=True) + h = asyncio.Handle(callback, args, self) assert sig not in self._signal_handlers, \ "Signal %d is already being caught" % (sig,) self._orig_signals[sig] = signal.signal(sig, self._handle_sig) @@ -528,7 +480,7 @@ def add_reader(self, fd, callback, *args): def _add_reader(self, fd, callback, *args): self._check_closed() - handle = Handle(callback, args, self, context=None, is_sync=True) + handle = ScopedHandle(callback, args, self) reader = self._set_read_handle(fd, handle) if reader is not None: reader.cancel() @@ -547,20 +499,17 @@ def _set_read_handle(self, fd, handle): self._selector.modify(fd, mask | EVENT_READ, (handle, writer)) return reader - async def _reader_loop(self, fd, handle, task_status=trio.TASK_STATUS_IGNORED): - task_status.started() - with trio.CancelScope() as scope: - handle._scope = scope + async def _reader_loop(self, fd, handle): + with handle._scope: try: - while not handle._cancelled: # pragma: no branch + while True: await _wait_readable(fd) - handle._call_sync() + if handle._cancelled: + break + handle._run() await self.synchronize() except Exception as exc: - _h_raise(handle, exc) - return - finally: - handle._scope = None + handle._raise(exc) # writing to a file descriptor @@ -583,7 +532,7 @@ def add_writer(self, fd, callback, *args): def _add_writer(self, fd, callback, *args): self._check_closed() - handle = Handle(callback, args, self, context=None, is_sync=True) + handle = ScopedHandle(callback, args, self) writer = self._set_write_handle(fd, handle) if writer is not None: writer.cancel() @@ -601,20 +550,17 @@ def _set_write_handle(self, fd, handle): self._selector.modify(fd, mask | EVENT_WRITE, (reader, handle)) return writer - async def _writer_loop(self, fd, handle, task_status=trio.TASK_STATUS_IGNORED): - with trio.CancelScope() as scope: - handle._scope = scope - task_status.started() + async def _writer_loop(self, fd, handle): + with handle._scope: try: - while not handle._cancelled: # pragma: no branch + while True: await _wait_writable(fd) - handle._call_sync() + if handle._cancelled: + break + handle._run() await self.synchronize() except Exception as exc: - _h_raise(handle, exc) - return - finally: - handle._scope = None + handle._raise(exc) def autoclose(self, fd): """ @@ -717,7 +663,7 @@ async def _main_loop_one(self, no_wait=False): # so restart from the beginning. return - if isinstance(obj, TimerHandle): + if isinstance(obj, asyncio.TimerHandle): # A TimerHandle is added to the list of timers. heapq.heappush(self._timers, obj) return @@ -732,13 +678,17 @@ async def _main_loop_one(self, no_wait=False): # Don't go through the expensive nursery dance # if this is a sync function. - if getattr(obj, '_is_sync', True): + if isinstance(obj, AsyncHandle): + if hasattr(obj, '_context'): + obj._context.run(self._nursery.start_soon, obj._run, name=obj._callback) + else: + self._nursery.start_soon(obj._run, name=obj._callback) + await obj._started.wait() + else: if hasattr(obj, '_context'): obj._context.run(obj._callback, *obj._args) else: obj._callback(*obj._args) - else: - await self._nursery.start(obj._call_async) async def _main_loop_exit(self): """Finalize the loop. It may not be re-entered.""" diff --git a/trio_asyncio/_handles.py b/trio_asyncio/_handles.py index ddf80782..0c620526 100644 --- a/trio_asyncio/_handles.py +++ b/trio_asyncio/_handles.py @@ -16,59 +16,31 @@ def _format_callback_source(func, args): return func_repr -async def _set_sniff(proc, *args): - sniffio.current_async_library_cvar.set("trio") - return await proc(*args) +class ScopedHandle(asyncio.Handle): + """An asyncio.Handle that cancels a trio.CancelScope when the Handle is cancelled. - -class _TrioHandle: - """ - This extends asyncio.Handle by providing: - * a way to cancel an async callback - * a way to declare the type of the callback function - - ``is_sync`` may be - * True: sync function, use _call_sync() - * False: async function, use _call_async() - * None: also async, but the callback function accepts - the handle as its sole argument - - The caller is responsible for checking whether the handle - has been cancelled before invoking ``call_[a]sync()``. + This is used to manage installed readers and writers, so that the Trio call to + wait_readable() or wait_writable() can be cancelled when the handle is. """ - def _init(self, is_sync): - """Secondary init. - """ - self._is_sync = is_sync - self._scope = None + __slots__ = ("_scope",) + + def __init__(self, *args, **kw): + super().__init__(*args, **kw) + self._scope = trio.CancelScope() def cancel(self): - try: - task = self._callback.__self__ - except AttributeError: - task = None - else: - if not isinstance(task, asyncio.Task): - task = None super().cancel() - if self._scope is not None: - self._scope.cancel() - elif task is not None: - task.cancel() - - def _cb_future_cancel(self, f): - """If a Trio task completes an asyncio Future, - add this callback to the future - and set ``_scope`` to the Trio cancel scope - so that the task is terminated when the future gets canceled. + self._scope.cancel() - """ - if f.cancelled(): - self.cancel() + def _repr_info(self): + return super()._repr_info() + ["scope={!r}".format(self._scope)] def _raise(self, exc): - """This is a copy of the exception handling in asyncio.events.Handle._run() + """This is a copy of the exception handling in asyncio.events.Handle._run(). + It's used to report exceptions that arise when waiting for readability + or writability, and exceptions in async tasks managed by our subclass + AsyncHandle. """ cb = _format_callback_source(self._callback, self._args) msg = 'Exception in callback {}'.format(cb) @@ -81,86 +53,96 @@ def _raise(self, exc): context['source_traceback'] = self._source_traceback self._loop.call_exception_handler(context) - def _repr_info(self): - info = [self.__class__.__name__] - if self._cancelled: - info.append('cancelled') - if self._callback is not None: - info.append(_format_callback_source(self._callback, self._args)) - if self._source_traceback: - frame = self._source_traceback[-1] - info.append('created at %s:%s' % (frame[0], frame[1])) - if self._scope is not None: - info.append('scope=%s' % repr(self._scope)) - return info - - def _call_sync(self): - assert self._is_sync + +class AsyncHandle(ScopedHandle): + """A ScopedHandle associated with the execution of an async function. + If the handle is cancelled, the cancel scope surrounding the async function + will be cancelled too. It is also possible to link a future to the result + of the async function. If you do that, the future will evaluate to the + result of the function, and cancelling the future will cancel the handle too. + """ + + __slots__ = ("_fut", "_started") + + def __init__(self, *args, result_future=None, **kw): + super().__init__(*args, **kw) + self._fut = result_future + self._started = trio.Event() + if self._fut is not None: + @self._fut.add_done_callback + def propagate_cancel(f): + if f.cancelled(): + self.cancel() + + async def _run(self): + sniffio.current_async_library_cvar.set("trio") + self._started.set() if self._cancelled: return - self._run() - if sys.version_info >= (3, 7): - - async def _call_async(self, task_status=trio.TASK_STATUS_IGNORED): - assert not self._is_sync - if self._cancelled: - return - task_status.started() + def report_exception(exc): + if not isinstance(exc, Exception): + # Let BaseExceptions such as Cancelled escape without being noted. + return exc try: - with trio.CancelScope() as scope: - self._scope = scope - if self._is_sync is None: - await self._context.run(_set_sniff, self._callback, self) - else: - await self._context.run(_set_sniff, self._callback, *self._args) - except Exception as exc: + orig_tb = exc.__traceback__ self._raise(exc) - finally: - self._scope = None + # If self._raise() just logged something, suppress the exception. + return None + except BaseException as other_exc: + if other_exc is exc: + # If _raise() reraised its argument, remove the _raise and + # default_exception_handler frames that it added to the traceback. + return exc.with_traceback(orig_tb) + # If _raise() raised a different exception, don't mess with + # the traceback. + return other_exc + + def remove_cancelled(exc): + if isinstance(exc, trio.Cancelled): + return None + return exc + + def only_cancelled(exc): + if isinstance(exc, trio.Cancelled): + return exc + return None - else: # no contextvars - - async def _call_async(self, task_status=trio.TASK_STATUS_IGNORED): - assert not self._is_sync - if self._cancelled: - return - task_status.started() - try: - with trio.CancelScope() as scope: - self._scope = scope - if self._is_sync is None: - await self._callback(self) + try: + # Run the callback + with self._scope: + res = await self._callback(*self._args) + + if self._fut: + # Propagate result or just-this-handle cancellation to the Future + if self._scope.cancelled_caught: + self._fut.cancel() + elif not self._fut.cancelled(): + self._fut.set_result(res) + + except BaseException as exc: + if not self._fut: + # Pass Exceptions through the fallback exception handler since + # they have nowhere better to go. Let BaseExceptions escape so + # that Cancelled and SystemExit work reasonably. + with trio.MultiError.catch(handle_exc): + raise + else: + # The result future gets all the non-Cancelled + # exceptions. Any Cancelled need to keep propagating + # out of this stack frame in order to reach the cancel + # scope for which they're intended. This would be a + # great place for ExceptionGroup.split() if we had it. + cancelled = trio.MultiError.filter(only_cancelled, exc) + rest = trio.MultiError.filter(remove_cancelled, exc) + if not self._fut.cancelled(): + if rest: + self._fut.set_exception(rest) else: - await self._callback(*self._args) - except Exception as exc: - self._raise(exc) - finally: - self._scope = None - - -if sys.version_info >= (3, 7): - - class Handle(_TrioHandle, asyncio.Handle): - def __init__(self, callback, args, loop, context=None, is_sync=True): - assert not isinstance(context, bool) - super().__init__(callback, args, loop, context=context) - self._init(is_sync) - - class TimerHandle(_TrioHandle, asyncio.TimerHandle): - def __init__(self, when, callback, args, loop, context=None, is_sync=True): - assert not isinstance(context, bool) - super().__init__(when, callback, args, loop, context=context) - self._init(is_sync) - -else: - - class Handle(_TrioHandle, asyncio.Handle): - def __init__(self, callback, args, loop, context=None, is_sync=True): - super().__init__(callback, args, loop) - self._init(is_sync) - - class TimerHandle(_TrioHandle, asyncio.TimerHandle): - def __init__(self, when, callback, args, loop, context=None, is_sync=True): - super().__init__(when, callback, args, loop) - self._init(is_sync) + self._fut.cancel() + if cancelled: + raise cancelled + finally: + # asyncio says this is needed to break cycles when an exception occurs. + # I'm not so sure, but it doesn't seem to do any harm. + self = None diff --git a/trio_asyncio/_sync.py b/trio_asyncio/_sync.py index be418527..8650b8f9 100644 --- a/trio_asyncio/_sync.py +++ b/trio_asyncio/_sync.py @@ -1,3 +1,5 @@ +# coding: utf-8 + import trio import queue import asyncio @@ -5,7 +7,6 @@ import outcome from ._base import BaseTrioEventLoop -from ._handles import Handle async def _sync(proc, *args): @@ -54,7 +55,7 @@ def do_stop(): # async def stop_me(): # def kick_(): # raise StopAsyncIteration -# self._queue_handle(Handle(kick_, (), self, context=None, is_sync=True)) +# self._queue_handle(asyncio.Handle(kick_, (), self)) # await self._main_loop() # if threading.current_thread() != self._thread: # self.__run_in_thread(stop_me) @@ -62,7 +63,7 @@ def do_stop(): if self._thread_running and not self._stop_pending: self._stop_pending = True - self._queue_handle(Handle(do_stop, (), self, context=None, is_sync=True)) + self._queue_handle(asyncio.Handle(do_stop, (), self)) def _queue_handle(self, handle): self._check_closed() From 8956ccab075695767e3e08ada582843ead323820 Mon Sep 17 00:00:00 2001 From: Joshua Oreman Date: Sun, 3 May 2020 08:11:00 +0000 Subject: [PATCH 2/8] Work around deprecation warning on pytest 5.4+ --- ci/test-requirements.txt | 2 +- tests/python/conftest.py | 5 ++++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/ci/test-requirements.txt b/ci/test-requirements.txt index cbf2eefb..7779065b 100644 --- a/ci/test-requirements.txt +++ b/ci/test-requirements.txt @@ -1,4 +1,4 @@ -pytest >= 3.1 # fixes a bug in handling async def test_* and turn warnings into errors. +pytest >= 5.4 # for the Node.from_parent() transition pytest-cov pytest-trio outcome diff --git a/tests/python/conftest.py b/tests/python/conftest.py index 956995ed..6fee94a1 100644 --- a/tests/python/conftest.py +++ b/tests/python/conftest.py @@ -69,7 +69,10 @@ def pytest_pycollect_makemodule(path, parent): ) if candidate == expected: fspath = py.path.local(test_asyncio.__file__) - return UnittestOnlyPackage(fspath, parent, nodeid=aio_test_nodeid(fspath)) + node = UnittestOnlyPackage.from_parent(parent, fspath=fspath) + # This keeps all test names from showing as "." + node._nodeid = aio_test_nodeid(fspath) + return node def pytest_collection_modifyitems(items): From 5a1f7d69bacd5e593c15733472b9601dd349dc13 Mon Sep 17 00:00:00 2001 From: Joshua Oreman Date: Tue, 5 May 2020 23:29:58 +0000 Subject: [PATCH 3/8] yapf --- trio_asyncio/_base.py | 4 +--- trio_asyncio/_handles.py | 1 + 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/trio_asyncio/_base.py b/trio_asyncio/_base.py index 708a9521..be88aecf 100644 --- a/trio_asyncio/_base.py +++ b/trio_asyncio/_base.py @@ -304,9 +304,7 @@ def call_at(self, when, callback, *args, **context): Note that the callback is a sync function. """ self._check_callback(callback, 'call_at') - return self._queue_handle( - asyncio.TimerHandle(when, callback, args, self, **context) - ) + return self._queue_handle(asyncio.TimerHandle(when, callback, args, self, **context)) def call_soon(self, callback, *args, **context): """asyncio's defer-to-mainloop callback executor. diff --git a/trio_asyncio/_handles.py b/trio_asyncio/_handles.py index 0c620526..a4ca7014 100644 --- a/trio_asyncio/_handles.py +++ b/trio_asyncio/_handles.py @@ -69,6 +69,7 @@ def __init__(self, *args, result_future=None, **kw): self._fut = result_future self._started = trio.Event() if self._fut is not None: + @self._fut.add_done_callback def propagate_cancel(f): if f.cancelled(): From c394a71470d156c5fba04c26fdabe08db16aef7c Mon Sep 17 00:00:00 2001 From: Joshua Oreman Date: Wed, 6 May 2020 00:15:31 +0000 Subject: [PATCH 4/8] Add newsfragment, document AsyncHandle callback color --- newsfragments/76.bugfix.rst | 6 ++++++ trio_asyncio/_handles.py | 5 ++++- 2 files changed, 10 insertions(+), 1 deletion(-) create mode 100644 newsfragments/76.bugfix.rst diff --git a/newsfragments/76.bugfix.rst b/newsfragments/76.bugfix.rst new file mode 100644 index 00000000..7494fdc4 --- /dev/null +++ b/newsfragments/76.bugfix.rst @@ -0,0 +1,6 @@ +On Python versions with native contextvars support (3.7+), a Trio task +started from asyncio context (using :func:`trio_as_aio`, +:meth:`~BaseTrioEventLoop.trio_as_future`, etc) will now properly +inherit the contextvars of its caller. Also, if the entire +trio-asyncio loop is cancelled, such tasks will no longer let +`trio.Cancelled` exceptions leak into their asyncio caller. diff --git a/trio_asyncio/_handles.py b/trio_asyncio/_handles.py index a4ca7014..67186282 100644 --- a/trio_asyncio/_handles.py +++ b/trio_asyncio/_handles.py @@ -55,11 +55,14 @@ def _raise(self, exc): class AsyncHandle(ScopedHandle): - """A ScopedHandle associated with the execution of an async function. + """A ScopedHandle associated with the execution of a Trio-flavored + async function. + If the handle is cancelled, the cancel scope surrounding the async function will be cancelled too. It is also possible to link a future to the result of the async function. If you do that, the future will evaluate to the result of the function, and cancelling the future will cancel the handle too. + """ __slots__ = ("_fut", "_started") From ac783043dcdc5664662e62332ffa0619fb93c3dd Mon Sep 17 00:00:00 2001 From: Joshua Oreman Date: Wed, 6 May 2020 00:42:35 +0000 Subject: [PATCH 5/8] Add tests to verify expected bugfixes --- tests/test_misc.py | 66 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 66 insertions(+) diff --git a/tests/test_misc.py b/tests/test_misc.py index 3d4fbcdf..40546ac1 100644 --- a/tests/test_misc.py +++ b/tests/test_misc.py @@ -351,3 +351,69 @@ async def run_asyncio_loop(nursery, *, task_status=trio.TASK_STATUS_IGNORED): await nursery.start(run_asyncio_loop, nursery) # Trigger KeyboardInterrupt that should propagate accross the coroutines signal.pthread_kill(threading.get_ident(), signal.SIGINT) + + +@pytest.mark.trio +@pytest.mark.parametrize("throw_another", [False, True]) +async def test_cancel_loop(throw_another): + """Regression test for #76: ensure that cancelling a trio-asyncio loop + does not cause any of the tasks running within it to yield a + result of Cancelled. + """ + async def manage_loop(task_status): + try: + with trio.CancelScope() as scope: + async with trio_asyncio.open_loop() as loop: + task_status.started((loop, scope)) + await trio.sleep_forever() + finally: + assert scope.cancelled_caught + + # Trio-flavored async function. Runs as a trio-aio loop task + # and gets cancelled when the loop does. + async def trio_task(): + async with trio.open_nursery() as nursery: + nursery.start_soon(trio.sleep_forever) + try: + await trio.sleep_forever() + except trio.Cancelled: + if throw_another: + # This will combine with the Cancelled from the + # background sleep_forever task to create a + # MultiError escaping from trio_task + raise ValueError("hi") + + async with trio.open_nursery() as nursery: + loop, scope = await nursery.start(manage_loop) + fut = loop.trio_as_future(trio_task) + await trio.testing.wait_all_tasks_blocked() + scope.cancel() + assert fut.done() + if throw_another: + with pytest.raises(ValueError, match="hi"): + fut.result() + else: + assert fut.cancelled() + + +@pytest.mark.trio +@pytest.mark.skipif(sys.version_info < (3, 7), reason="needs asyncio contextvars") +async def test_contextvars(): + import contextvars + + cvar = contextvars.ContextVar("test_cvar") + cvar.set("outer") + + async def fudge_in_aio(): + assert cvar.get() == "outer" + cvar.set("middle") + await trio_asyncio.trio_as_aio(fudge_in_trio)() + assert cvar.get() == "middle" + + async def fudge_in_trio(): + assert cvar.get() == "middle" + cvar.set("inner") + + async with trio_asyncio.open_loop() as loop: + await trio_asyncio.aio_as_trio(fudge_in_aio)() + assert cvar.get() == "outer" From 111dd16d6397f74f64a3bb78d22a28cf75dd31b8 Mon Sep 17 00:00:00 2001 From: Joshua Oreman Date: Wed, 6 May 2020 09:24:04 +0000 Subject: [PATCH 6/8] Streamline handling of async loop exit --- trio_asyncio/_base.py | 36 ++++++++++++++++++++++++------------ trio_asyncio/_loop.py | 15 ++++----------- 2 files changed, 28 insertions(+), 23 deletions(-) diff --git a/trio_asyncio/_base.py b/trio_asyncio/_base.py index be88aecf..90cbf5d7 100644 --- a/trio_asyncio/_base.py +++ b/trio_asyncio/_base.py @@ -631,8 +631,16 @@ async def _main_loop(self, task_status=trio.TASK_STATUS_IGNORED): sniffio.current_async_library_cvar.set("asyncio") try: - while not self._stopped.is_set(): - await self._main_loop_one() + # The shield here ensures that if the context surrounding + # the loop is cancelled, we keep processing callbacks + # until we reach the callback inserted by stop(). + # There's a call to stop() in the finally block of + # open_loop(), and we're not shielding the body of the + # open_loop() context, so this should be safe against + # deadlocks. + with trio.CancelScope(shield=True): + while not self._stopped.is_set(): + await self._main_loop_one() except StopAsyncIteration: # raised by .stop_me() to interrupt the loop pass @@ -693,16 +701,20 @@ async def _main_loop_exit(self): if self._closed: return - self.stop() - await self.wait_stopped() - - while True: - try: - await self._main_loop_one(no_wait=True) - except trio.WouldBlock: - break - except StopAsyncIteration: - pass + with trio.CancelScope(shield=True): + self.stop() + await self.wait_stopped() + + # Drain all remaining callbacks, even those after an initial + # call to stop(). This avoids a deadlock if stop() was called + # again during unwinding. + while True: + try: + await self._main_loop_one(no_wait=True) + except trio.WouldBlock: + break + except StopAsyncIteration: + pass # Kill off unprocessed work self._cancel_fds() diff --git a/trio_asyncio/_loop.py b/trio_asyncio/_loop.py index abf0f157..a1db1487 100644 --- a/trio_asyncio/_loop.py +++ b/trio_asyncio/_loop.py @@ -390,10 +390,6 @@ async def async_main(*args): # TODO: make sure that there is no asyncio loop already running - def _main_loop_exit(self): - super()._main_loop_exit() - self._thread = None - async with trio.open_nursery() as nursery: loop = TrioEventLoop(queue_len=queue_len) old_loop = current_loop.set(loop) @@ -404,14 +400,11 @@ def _main_loop_exit(self): await yield_(loop) finally: try: - await loop.stop().wait() + await loop._main_loop_exit() finally: - try: - await loop._main_loop_exit() - finally: - loop.close() - nursery.cancel_scope.cancel() - current_loop.reset(old_loop) + loop.close() + nursery.cancel_scope.cancel() + current_loop.reset(old_loop) def run(proc, *args, queue_len=None): From e5a4d6d633f7040771cd6d710bcc840324bb8135 Mon Sep 17 00:00:00 2001 From: Joshua Oreman Date: Wed, 6 May 2020 09:24:45 +0000 Subject: [PATCH 7/8] Add tests for run_trio_task, fix code to satisfy them --- tests/test_misc.py | 69 ++++++++++++++++++++++++++++++++++++++++ trio_asyncio/_handles.py | 19 +++-------- 2 files changed, 74 insertions(+), 14 deletions(-) diff --git a/tests/test_misc.py b/tests/test_misc.py index 40546ac1..31cdc9e4 100644 --- a/tests/test_misc.py +++ b/tests/test_misc.py @@ -396,6 +396,75 @@ async def trio_task(): assert fut.cancelled() +@pytest.mark.trio +async def test_trio_as_fut_throws_after_cancelled(): + """If a trio_as_future() future is cancelled, any exception + thrown by the Trio task as it unwinds is ignored. (This is + somewhat infelicitous, but the asyncio Future API doesn't allow + a future to go from cancelled to some other outcome.) + """ + + async def trio_task(): + try: + await trio.sleep_forever() + finally: + raise ValueError("hi") + + async with trio_asyncio.open_loop() as loop: + fut = loop.trio_as_future(trio_task) + await trio.testing.wait_all_tasks_blocked() + fut.cancel() + with pytest.raises(asyncio.CancelledError): + await fut + + +@pytest.mark.trio +async def test_run_trio_task_errors(monkeypatch): + async with trio_asyncio.open_loop() as loop: + # Test never getting to start the task + handle = loop.run_trio_task(trio.sleep_forever) + handle.cancel() + + # Test cancelling the task + handle = loop.run_trio_task(trio.sleep_forever) + await trio.testing.wait_all_tasks_blocked() + handle.cancel() + + # Helper for the rest of this test, which covers cases where + # the Trio task raises an exception + async def raise_in_aio_loop(exc): + async def raise_it(): + raise exc + + async with trio_asyncio.open_loop() as loop: + loop.run_trio_task(raise_it) + + # We temporarily modify the default exception handler to collect + # the exceptions instead of logging or raising them + + exceptions = [] + + def collect_exceptions(loop, context): + if context.get("exception"): + exceptions.append(context["exception"]) + else: + exceptions.append(RuntimeError(context.get("message") or "unknown")) + + monkeypatch.setattr( + trio_asyncio.TrioEventLoop, "default_exception_handler", collect_exceptions + ) + expected = [ + ValueError("hi"), ValueError("lo"), KeyError(), IndexError() + ] + await raise_in_aio_loop(expected[0]) + with pytest.raises(SystemExit): + await raise_in_aio_loop(SystemExit(0)) + with pytest.raises(SystemExit): + await raise_in_aio_loop(trio.MultiError([expected[1], SystemExit()])) + await raise_in_aio_loop(trio.MultiError(expected[2:])) + assert exceptions == expected + + @pytest.mark.trio @pytest.mark.skipif(sys.version_info < (3, 7), reason="needs asyncio contextvars") async def test_contextvars(): diff --git a/trio_asyncio/_handles.py b/trio_asyncio/_handles.py index 67186282..4c97471b 100644 --- a/trio_asyncio/_handles.py +++ b/trio_asyncio/_handles.py @@ -88,19 +88,10 @@ def report_exception(exc): if not isinstance(exc, Exception): # Let BaseExceptions such as Cancelled escape without being noted. return exc - try: - orig_tb = exc.__traceback__ - self._raise(exc) - # If self._raise() just logged something, suppress the exception. - return None - except BaseException as other_exc: - if other_exc is exc: - # If _raise() reraised its argument, remove the _raise and - # default_exception_handler frames that it added to the traceback. - return exc.with_traceback(orig_tb) - # If _raise() raised a different exception, don't mess with - # the traceback. - return other_exc + # Otherwise defer to the asyncio exception handler. (In an async loop + # this will still raise the exception out of the loop, terminating it.) + self._raise(exc) + return None def remove_cancelled(exc): if isinstance(exc, trio.Cancelled): @@ -129,7 +120,7 @@ def only_cancelled(exc): # Pass Exceptions through the fallback exception handler since # they have nowhere better to go. Let BaseExceptions escape so # that Cancelled and SystemExit work reasonably. - with trio.MultiError.catch(handle_exc): + with trio.MultiError.catch(report_exception): raise else: # The result future gets all the non-Cancelled From f36955c61cede8988b55822b2e65f5b31801d837 Mon Sep 17 00:00:00 2001 From: Joshua Oreman Date: Wed, 6 May 2020 10:02:50 +0000 Subject: [PATCH 8/8] yapf --- trio_asyncio/_loop.py | 1 - 1 file changed, 1 deletion(-) diff --git a/trio_asyncio/_loop.py b/trio_asyncio/_loop.py index a1db1487..1660000e 100644 --- a/trio_asyncio/_loop.py +++ b/trio_asyncio/_loop.py @@ -389,7 +389,6 @@ async def async_main(*args): """ # TODO: make sure that there is no asyncio loop already running - async with trio.open_nursery() as nursery: loop = TrioEventLoop(queue_len=queue_len) old_loop = current_loop.set(loop)