diff --git a/docs/versionhistory.rst b/docs/versionhistory.rst
index 518aef88..4901da4c 100644
--- a/docs/versionhistory.rst
+++ b/docs/versionhistory.rst
@@ -9,6 +9,8 @@ This library adheres to `Semantic Versioning 2.0 `_.
thread waker socket pair. This should improve the performance of ``wait_readable()``
and ``wait_writable()`` when using the ``ProactorEventLoop``
(`#836 `_; PR by @graingert)
+- Fixed ``AssertionError`` when using ``nest-asyncio``
+ (`#840 `_)
**4.7.0**
diff --git a/src/anyio/_backends/_asyncio.py b/src/anyio/_backends/_asyncio.py
index 0b7479d2..5a0aa936 100644
--- a/src/anyio/_backends/_asyncio.py
+++ b/src/anyio/_backends/_asyncio.py
@@ -677,40 +677,42 @@ def __init__(self, parent_id: int | None, cancel_scope: CancelScope | None):
self.cancel_scope = cancel_scope
-class TaskStateStore(MutableMapping["Awaitable[Any] | asyncio.Task", TaskState]):
+class TaskStateStore(
+ MutableMapping["Coroutine[Any, Any, Any] | asyncio.Task", TaskState]
+):
def __init__(self) -> None:
self._task_states = WeakKeyDictionary[asyncio.Task, TaskState]()
- self._preliminary_task_states: dict[Awaitable[Any], TaskState] = {}
+ self._preliminary_task_states: dict[Coroutine[Any, Any, Any], TaskState] = {}
- def __getitem__(self, key: Awaitable[Any] | asyncio.Task, /) -> TaskState:
- assert isinstance(key, asyncio.Task)
+ def __getitem__(self, key: Coroutine[Any, Any, Any] | asyncio.Task, /) -> TaskState:
+ task = cast(asyncio.Task, key)
try:
- return self._task_states[key]
+ return self._task_states[task]
except KeyError:
- if coro := key.get_coro():
+ if coro := task.get_coro():
if state := self._preliminary_task_states.get(coro):
return state
raise KeyError(key)
def __setitem__(
- self, key: asyncio.Task | Awaitable[Any], value: TaskState, /
+ self, key: asyncio.Task | Coroutine[Any, Any, Any], value: TaskState, /
) -> None:
- if isinstance(key, asyncio.Task):
- self._task_states[key] = value
- else:
+ if isinstance(key, Coroutine):
self._preliminary_task_states[key] = value
-
- def __delitem__(self, key: asyncio.Task | Awaitable[Any], /) -> None:
- if isinstance(key, asyncio.Task):
- del self._task_states[key]
else:
+ self._task_states[key] = value
+
+ def __delitem__(self, key: asyncio.Task | Coroutine[Any, Any, Any], /) -> None:
+ if isinstance(key, Coroutine):
del self._preliminary_task_states[key]
+ else:
+ del self._task_states[key]
def __len__(self) -> int:
return len(self._task_states) + len(self._preliminary_task_states)
- def __iter__(self) -> Iterator[Awaitable[Any] | asyncio.Task]:
+ def __iter__(self) -> Iterator[Coroutine[Any, Any, Any] | asyncio.Task]:
yield from self._task_states
yield from self._preliminary_task_states
diff --git a/tests/test_taskgroups.py b/tests/test_taskgroups.py
index 1f536940..d20565c1 100644
--- a/tests/test_taskgroups.py
+++ b/tests/test_taskgroups.py
@@ -11,7 +11,7 @@
import pytest
from exceptiongroup import catch
-from pytest import FixtureRequest
+from pytest import FixtureRequest, MonkeyPatch
from pytest_mock import MockerFixture
import anyio
@@ -1778,3 +1778,14 @@ async def sync_coro() -> None:
async with create_task_group() as tg:
tg.start_soon(sync_coro)
tg.cancel_scope.cancel()
+
+
+@pytest.mark.parametrize("anyio_backend", ["asyncio"])
+async def test_patched_asyncio_task(monkeypatch: MonkeyPatch) -> None:
+ monkeypatch.setattr(
+ asyncio,
+ "Task",
+ asyncio.tasks._PyTask, # type: ignore[attr-defined]
+ )
+ async with create_task_group() as tg:
+ tg.start_soon(sleep, 0)