Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix cyclic garbage that keeps traceback frames alive in taskgroup exceptions #806

Merged
merged 13 commits into from
Oct 13, 2024
3 changes: 3 additions & 0 deletions docs/versionhistory.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@ This library adheres to `Semantic Versioning 2.0 <http://semver.org/>`_.
- Fixed an async fixture's ``self`` being different than the test's ``self`` in
class-based tests (`#633 <https://github.com/agronholm/anyio/issues/633>`_)
(PR by @agronholm and @graingert)
- Fixed TaskGroup and CancelScope producing cyclic references in tracebacks
when raising exceptions (`#806 <https://github.com/agronholm/anyio/pull/806>`_)
(PR by @graingert)

**4.6.0**

Expand Down
175 changes: 102 additions & 73 deletions src/anyio/_backends/_asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,6 +425,8 @@ def __exit__(
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> bool | None:
del exc_tb

if not self._active:
raise RuntimeError("This cancel scope is not active")
if current_task() is not self._host_task:
Expand All @@ -441,42 +443,46 @@ def __exit__(
"current cancel scope"
)

self._active = False
if self._timeout_handle:
self._timeout_handle.cancel()
self._timeout_handle = None

self._tasks.remove(self._host_task)
if self._parent_scope is not None:
self._parent_scope._child_scopes.remove(self)
self._parent_scope._tasks.add(self._host_task)

host_task_state.cancel_scope = self._parent_scope

# Undo all cancellations done by this scope
if self._cancelling is not None:
while self._cancel_calls:
self._cancel_calls -= 1
if self._host_task.uncancel() <= self._cancelling:
break
try:
self._active = False
if self._timeout_handle:
self._timeout_handle.cancel()
self._timeout_handle = None

# We only swallow the exception iff it was an AnyIO CancelledError, either
# directly as exc_val or inside an exception group and there are no cancelled
# parent cancel scopes visible to us here
not_swallowed_exceptions = 0
swallow_exception = False
if exc_val is not None:
for exc in iterate_exceptions(exc_val):
if self._cancel_called and isinstance(exc, CancelledError):
if not (swallow_exception := self._uncancel(exc)):
self._tasks.remove(self._host_task)
if self._parent_scope is not None:
self._parent_scope._child_scopes.remove(self)
self._parent_scope._tasks.add(self._host_task)

host_task_state.cancel_scope = self._parent_scope

# Undo all cancellations done by this scope
if self._cancelling is not None:
while self._cancel_calls:
self._cancel_calls -= 1
if self._host_task.uncancel() <= self._cancelling:
break

# We only swallow the exception iff it was an AnyIO CancelledError, either
# directly as exc_val or inside an exception group and there are no cancelled
# parent cancel scopes visible to us here
not_swallowed_exceptions = 0
swallow_exception = False
if exc_val is not None:
for exc in iterate_exceptions(exc_val):
if self._cancel_called and isinstance(exc, CancelledError):
if not (swallow_exception := self._uncancel(exc)):
not_swallowed_exceptions += 1
else:
not_swallowed_exceptions += 1
else:
not_swallowed_exceptions += 1

# Restart the cancellation effort in the closest visible, cancelled parent
# scope if necessary
self._restart_cancellation_in_parent()
return swallow_exception and not not_swallowed_exceptions
# Restart the cancellation effort in the closest visible, cancelled parent
# scope if necessary
self._restart_cancellation_in_parent()
return swallow_exception and not not_swallowed_exceptions
finally:
self._host_task = None
del exc_val

@property
def _effectively_cancelled(self) -> bool:
Expand Down Expand Up @@ -683,6 +689,26 @@ def started(self, value: T_contra | None = None) -> None:
_task_states[task].parent_id = self._parent_id


async def _wait(tasks: Iterable[asyncio.Task[object]]) -> None:
tasks = set(tasks)
waiter = get_running_loop().create_future()

def on_completion(task: asyncio.Task[object]) -> None:
tasks.discard(task)
if not tasks and not waiter.done():
waiter.set_result(None)

for task in tasks:
task.add_done_callback(on_completion)
del task

try:
await waiter
finally:
while tasks:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tasks.pop().remove_done_callback(on_completion)


class TaskGroup(abc.TaskGroup):
def __init__(self) -> None:
self.cancel_scope: CancelScope = CancelScope()
Expand All @@ -701,50 +727,53 @@ async def __aexit__(
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> bool | None:
if exc_val is not None:
self.cancel_scope.cancel()
if not isinstance(exc_val, CancelledError):
self._exceptions.append(exc_val)

try:
if self._tasks:
with CancelScope() as wait_scope:
while self._tasks:
try:
await asyncio.wait(self._tasks)
except CancelledError as exc:
# Shield the scope against further cancellation attempts,
# as they're not productive (#695)
wait_scope.shield = True
self.cancel_scope.cancel()

# Set exc_val from the cancellation exception if it was
# previously unset. However, we should not replace a native
# cancellation exception with one raise by a cancel scope.
if exc_val is None or (
isinstance(exc_val, CancelledError)
and not is_anyio_cancellation(exc)
):
exc_val = exc
else:
# If there are no child tasks to wait on, run at least one checkpoint
# anyway
await AsyncIOBackend.cancel_shielded_checkpoint()
if exc_val is not None:
self.cancel_scope.cancel()
if not isinstance(exc_val, CancelledError):
self._exceptions.append(exc_val)

self._active = False
if self._exceptions:
raise BaseExceptionGroup(
"unhandled errors in a TaskGroup", self._exceptions
)
elif exc_val:
raise exc_val
except BaseException as exc:
if self.cancel_scope.__exit__(type(exc), exc, exc.__traceback__):
return True
try:
if self._tasks:
with CancelScope() as wait_scope:
while self._tasks:
try:
await _wait(self._tasks)
except CancelledError as exc:
# Shield the scope against further cancellation attempts,
# as they're not productive (#695)
wait_scope.shield = True
self.cancel_scope.cancel()

# Set exc_val from the cancellation exception if it was
# previously unset. However, we should not replace a native
# cancellation exception with one raise by a cancel scope.
if exc_val is None or (
isinstance(exc_val, CancelledError)
and not is_anyio_cancellation(exc)
):
exc_val = exc
else:
# If there are no child tasks to wait on, run at least one checkpoint
# anyway
await AsyncIOBackend.cancel_shielded_checkpoint()

raise
self._active = False
if self._exceptions:
raise BaseExceptionGroup(
"unhandled errors in a TaskGroup", self._exceptions
)
elif exc_val:
raise exc_val
except BaseException as exc:
if self.cancel_scope.__exit__(type(exc), exc, exc.__traceback__):
return True

raise

return self.cancel_scope.__exit__(exc_type, exc_val, exc_tb)
return self.cancel_scope.__exit__(exc_type, exc_val, exc_tb)
finally:
del exc_val, exc_tb, self._exceptions

def _spawn(
self,
Expand Down
7 changes: 3 additions & 4 deletions src/anyio/_backends/_trio.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,13 +186,12 @@ async def __aexit__(
try:
return await self._nursery_manager.__aexit__(exc_type, exc_val, exc_tb)
except BaseExceptionGroup as exc:
_, rest = exc.split(trio.Cancelled)
if not rest:
cancelled_exc = trio.Cancelled._create()
raise cancelled_exc from exc
if not exc.split(trio.Cancelled)[1]:
raise trio.Cancelled._create() from exc

raise
finally:
del exc_val, exc_tb
self._active = False

def start_soon(
Expand Down
122 changes: 121 additions & 1 deletion tests/test_taskgroups.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import asyncio
import gc
import math
import sys
import time
Expand All @@ -9,7 +10,7 @@
from typing import Any, NoReturn, cast

import pytest
from exceptiongroup import catch
from exceptiongroup import ExceptionGroup, catch
from pytest_mock import MockerFixture

import anyio
Expand Down Expand Up @@ -1548,6 +1549,125 @@ async def in_task_group(task_status: TaskStatus[None]) -> None:
assert not tg.cancel_scope.cancel_called


if sys.version_info <= (3, 11):

def no_other_refs() -> list[object]:
return [sys._getframe(1)]
else:

def no_other_refs() -> list[object]:
return []


@pytest.mark.skipif(
sys.implementation.name == "pypy",
reason=(
"gc.get_referrers is broken on PyPy see "
"https://github.com/pypy/pypy/issues/5075"
),
)
class TestRefcycles:
async def test_exception_refcycles_direct(self) -> None:
"""
Test that TaskGroup doesn't keep a reference to the raised ExceptionGroup

Note: This test never failed on anyio, but keeping this test to align
with the tests from cpython.
"""
tg = create_task_group()
exc = None

class _Done(Exception):
pass

try:
async with tg:
raise _Done
except ExceptionGroup as e:
exc = e

assert exc is not None
assert gc.get_referrers(exc) == no_other_refs()

async def test_exception_refcycles_errors(self) -> None:
"""Test that TaskGroup deletes self._exceptions, and __aexit__ args"""
tg = create_task_group()
exc = None

class _Done(Exception):
pass

try:
async with tg:
raise _Done
except ExceptionGroup as excs:
exc = excs.exceptions[0]

assert isinstance(exc, _Done)
assert gc.get_referrers(exc) == no_other_refs()

async def test_exception_refcycles_parent_task(self) -> None:
"""Test that TaskGroup's cancel_scope deletes self._host_task"""
tg = create_task_group()
exc = None

class _Done(Exception):
pass

async def coro_fn() -> None:
async with tg:
raise _Done

try:
async with anyio.create_task_group() as tg2:
tg2.start_soon(coro_fn)
except ExceptionGroup as excs:
exc = excs.exceptions[0].exceptions[0]

assert isinstance(exc, _Done)
assert gc.get_referrers(exc) == no_other_refs()

async def test_exception_refcycles_propagate_cancellation_error(self) -> None:
"""Test that TaskGroup deletes cancelled_exc"""
tg = anyio.create_task_group()
exc = None

with CancelScope() as cs:
cs.cancel()
try:
async with tg:
await checkpoint()
except get_cancelled_exc_class() as e:
exc = e
raise

assert isinstance(exc, get_cancelled_exc_class())
assert gc.get_referrers(exc) == no_other_refs()

async def test_exception_refcycles_base_error(self) -> None:
"""
Test for BaseExceptions.

anyio doesn't treat these differently so this test is redundant
but copied from CPython's asyncio.TaskGroup tests for completion.
"""

class MyKeyboardInterrupt(KeyboardInterrupt):
pass

tg = create_task_group()
exc = None

try:
async with tg:
raise MyKeyboardInterrupt
except BaseExceptionGroup as excs:
exc = excs.exceptions[0]

assert isinstance(exc, MyKeyboardInterrupt)
assert gc.get_referrers(exc) == no_other_refs()


class TestTaskStatusTyping:
"""
These tests do not do anything at run time, but since the test suite is also checked
Expand Down