Skip to content

Commit

Permalink
Improve termination of Application. Don't ever suppress CancelledError.
Browse files Browse the repository at this point in the history
This fixes a race condition when the prompt_toolkit Application gets cancelled
while waiting for the background tasks to complete. Catching `CancelledError`
at this point caused any code following the `Application.run_async` to
continue, instead of being cancelled.

In the future, we should probably adapt task groups (from anyio or Python
3.11), but until then, this is sufficient.
  • Loading branch information
jonathanslenders committed Nov 21, 2022
1 parent 0ddf173 commit 427f4bc
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 18 deletions.
69 changes: 56 additions & 13 deletions src/prompt_toolkit/application/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import time
from asyncio import (
AbstractEventLoop,
CancelledError,
Future,
Task,
ensure_future,
Expand All @@ -32,6 +31,7 @@
Iterator,
List,
Optional,
Set,
Tuple,
Type,
TypeVar,
Expand Down Expand Up @@ -433,7 +433,7 @@ def reset(self) -> None:

self.exit_style = ""

self.background_tasks: List[Task[None]] = []
self._background_tasks: Set[Task[None]] = set()

self.renderer.reset()
self.key_processor.reset()
Expand Down Expand Up @@ -1066,32 +1066,75 @@ def create_background_task(
the `Application` terminates, unfinished background tasks will be
cancelled.
If asyncio had nurseries like Trio, we would create a nursery in
`Application.run_async`, and run the given coroutine in that nursery.
Given that we still support Python versions before 3.11, we can't use
task groups (and exception groups), because of that, these background
tasks are not allowed to raise exceptions. If they do, we'll call the
default exception handler from the event loop.
Not threadsafe.
If at some point, we have Python 3.11 as the minimum supported Python
version, then we can use a `TaskGroup` (with the lifetime of
`Application.run_async()`, and run run the background tasks in there.
This is not threadsafe.
"""
task: asyncio.Task[None] = get_event_loop().create_task(coroutine)
self.background_tasks.append(task)
self._background_tasks.add(task)

task.add_done_callback(self._on_background_task_done)
return task

def _on_background_task_done(self, task: "asyncio.Task[None]") -> None:
"""
Called when a background task completes. Remove it from
`_background_tasks`, and handle exceptions if any.
"""
self._background_tasks.discard(task)

if task.cancelled():
return

exc = task.exception()
if exc is not None:
get_event_loop().call_exception_handler(
{
"message": f"prompt_toolkit.Application background task {task!r} "
"raised an unexpected exception.",
"exception": exc,
"task": task,
}
)

async def cancel_and_wait_for_background_tasks(self) -> None:
"""
Cancel all background tasks, and wait for the cancellation to be done.
Cancel all background tasks, and wait for the cancellation to complete.
If any of the background tasks raised an exception, this will also
propagate the exception.
(If we had nurseries like Trio, this would be the `__aexit__` of a
nursery.)
"""
for task in self.background_tasks:
for task in self._background_tasks:
task.cancel()

for task in self.background_tasks:
try:
await task
except CancelledError:
pass
# Wait until the cancellation of the background tasks completes.
# `asyncio.wait()` does not propagate exceptions raised within any of
# these tasks, which is what we want. Otherwise, we can't distinguish
# between a `CancelledError` raised in this task because it got
# cancelled, and a `CancelledError` raised on this `await` checkpoint,
# because *we* got cancelled during the teardown of the application.
# (If we get cancelled here, then it's important to not suppress the
# `CancelledError`, and have it propagate.)
# NOTE: Currently, if we get cancelled at this point then we can't wait
# for the cancellation to complete (in the future, we should be
# using anyio or Python's 3.11 TaskGroup.)
# Also, if we had exception groups, we could propagate an
# `ExceptionGroup` if something went wrong here. Right now, we
# don't propagate exceptions, but have them printed in
# `_on_background_task_done`.
if len(self._background_tasks) > 0:
await asyncio.wait(
self._background_tasks, timeout=None, return_when=asyncio.ALL_COMPLETED
)

async def _poll_output_size(self) -> None:
"""
Expand Down
13 changes: 8 additions & 5 deletions src/prompt_toolkit/contrib/telnet/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,11 +323,14 @@ async def stop(self) -> None:
for t in self._application_tasks:
t.cancel()

for t in self._application_tasks:
try:
await t
except asyncio.CancelledError:
logger.debug("Task %s cancelled", str(t))
# (This is similar to
# `Application.cancel_and_wait_for_background_tasks`. We wait for the
# background tasks to complete, but don't propagate exceptions, because
# we can't use `ExceptionGroup` yet.)
if len(self._application_tasks) > 0:
await asyncio.wait(
self._application_tasks, timeout=None, return_when=asyncio.ALL_COMPLETED
)

def _accept(self) -> None:
"""
Expand Down

0 comments on commit 427f4bc

Please sign in to comment.