From c518300593f2411196fd492d917269dcb6f7682b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alex=20Gr=C3=B6nholm?= Date: Wed, 18 Dec 2024 00:07:30 +0200 Subject: [PATCH] Fixed AssertionError when using nest_asyncio (#841) This stems from the incorrect placement of `nest_asyncio.apply()`, as it should be called before `asyncio.run()`. Fixes #840. --- docs/versionhistory.rst | 2 ++ src/anyio/_backends/_asyncio.py | 32 +++++++++++++++++--------------- tests/test_taskgroups.py | 13 ++++++++++++- 3 files changed, 31 insertions(+), 16 deletions(-) diff --git a/docs/versionhistory.rst b/docs/versionhistory.rst index 518aef88..4901da4c 100644 --- a/docs/versionhistory.rst +++ b/docs/versionhistory.rst @@ -9,6 +9,8 @@ This library adheres to `Semantic Versioning 2.0 `_. thread waker socket pair. This should improve the performance of ``wait_readable()`` and ``wait_writable()`` when using the ``ProactorEventLoop`` (`#836 `_; PR by @graingert) +- Fixed ``AssertionError`` when using ``nest-asyncio`` + (`#840 `_) **4.7.0** diff --git a/src/anyio/_backends/_asyncio.py b/src/anyio/_backends/_asyncio.py index 0b7479d2..5a0aa936 100644 --- a/src/anyio/_backends/_asyncio.py +++ b/src/anyio/_backends/_asyncio.py @@ -677,40 +677,42 @@ def __init__(self, parent_id: int | None, cancel_scope: CancelScope | None): self.cancel_scope = cancel_scope -class TaskStateStore(MutableMapping["Awaitable[Any] | asyncio.Task", TaskState]): +class TaskStateStore( + MutableMapping["Coroutine[Any, Any, Any] | asyncio.Task", TaskState] +): def __init__(self) -> None: self._task_states = WeakKeyDictionary[asyncio.Task, TaskState]() - self._preliminary_task_states: dict[Awaitable[Any], TaskState] = {} + self._preliminary_task_states: dict[Coroutine[Any, Any, Any], TaskState] = {} - def __getitem__(self, key: Awaitable[Any] | asyncio.Task, /) -> TaskState: - assert isinstance(key, asyncio.Task) + def __getitem__(self, key: Coroutine[Any, Any, Any] | asyncio.Task, /) -> TaskState: + task = cast(asyncio.Task, key) try: - return self._task_states[key] + return self._task_states[task] except KeyError: - if coro := key.get_coro(): + if coro := task.get_coro(): if state := self._preliminary_task_states.get(coro): return state raise KeyError(key) def __setitem__( - self, key: asyncio.Task | Awaitable[Any], value: TaskState, / + self, key: asyncio.Task | Coroutine[Any, Any, Any], value: TaskState, / ) -> None: - if isinstance(key, asyncio.Task): - self._task_states[key] = value - else: + if isinstance(key, Coroutine): self._preliminary_task_states[key] = value - - def __delitem__(self, key: asyncio.Task | Awaitable[Any], /) -> None: - if isinstance(key, asyncio.Task): - del self._task_states[key] else: + self._task_states[key] = value + + def __delitem__(self, key: asyncio.Task | Coroutine[Any, Any, Any], /) -> None: + if isinstance(key, Coroutine): del self._preliminary_task_states[key] + else: + del self._task_states[key] def __len__(self) -> int: return len(self._task_states) + len(self._preliminary_task_states) - def __iter__(self) -> Iterator[Awaitable[Any] | asyncio.Task]: + def __iter__(self) -> Iterator[Coroutine[Any, Any, Any] | asyncio.Task]: yield from self._task_states yield from self._preliminary_task_states diff --git a/tests/test_taskgroups.py b/tests/test_taskgroups.py index 1f536940..d20565c1 100644 --- a/tests/test_taskgroups.py +++ b/tests/test_taskgroups.py @@ -11,7 +11,7 @@ import pytest from exceptiongroup import catch -from pytest import FixtureRequest +from pytest import FixtureRequest, MonkeyPatch from pytest_mock import MockerFixture import anyio @@ -1778,3 +1778,14 @@ async def sync_coro() -> None: async with create_task_group() as tg: tg.start_soon(sync_coro) tg.cancel_scope.cancel() + + +@pytest.mark.parametrize("anyio_backend", ["asyncio"]) +async def test_patched_asyncio_task(monkeypatch: MonkeyPatch) -> None: + monkeypatch.setattr( + asyncio, + "Task", + asyncio.tasks._PyTask, # type: ignore[attr-defined] + ) + async with create_task_group() as tg: + tg.start_soon(sleep, 0)