Skip to content

Commit

Permalink
Fixed AssertionError when using nest_asyncio (#841)
Browse files Browse the repository at this point in the history
This stems from the incorrect placement of `nest_asyncio.apply()`, as it should be called before `asyncio.run()`.

Fixes #840.
  • Loading branch information
agronholm authored Dec 17, 2024
1 parent d14f005 commit c518300
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 16 deletions.
2 changes: 2 additions & 0 deletions docs/versionhistory.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ This library adheres to `Semantic Versioning 2.0 <http://semver.org/>`_.
thread waker socket pair. This should improve the performance of ``wait_readable()``
and ``wait_writable()`` when using the ``ProactorEventLoop``
(`#836 <https://github.com/agronholm/anyio/pull/836>`_; PR by @graingert)
- Fixed ``AssertionError`` when using ``nest-asyncio``
(`#840 <https://github.com/agronholm/anyio/issues/840>`_)

**4.7.0**

Expand Down
32 changes: 17 additions & 15 deletions src/anyio/_backends/_asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -677,40 +677,42 @@ def __init__(self, parent_id: int | None, cancel_scope: CancelScope | None):
self.cancel_scope = cancel_scope


class TaskStateStore(MutableMapping["Awaitable[Any] | asyncio.Task", TaskState]):
class TaskStateStore(
MutableMapping["Coroutine[Any, Any, Any] | asyncio.Task", TaskState]
):
def __init__(self) -> None:
self._task_states = WeakKeyDictionary[asyncio.Task, TaskState]()
self._preliminary_task_states: dict[Awaitable[Any], TaskState] = {}
self._preliminary_task_states: dict[Coroutine[Any, Any, Any], TaskState] = {}

def __getitem__(self, key: Awaitable[Any] | asyncio.Task, /) -> TaskState:
assert isinstance(key, asyncio.Task)
def __getitem__(self, key: Coroutine[Any, Any, Any] | asyncio.Task, /) -> TaskState:
task = cast(asyncio.Task, key)
try:
return self._task_states[key]
return self._task_states[task]
except KeyError:
if coro := key.get_coro():
if coro := task.get_coro():
if state := self._preliminary_task_states.get(coro):
return state

raise KeyError(key)

def __setitem__(
self, key: asyncio.Task | Awaitable[Any], value: TaskState, /
self, key: asyncio.Task | Coroutine[Any, Any, Any], value: TaskState, /
) -> None:
if isinstance(key, asyncio.Task):
self._task_states[key] = value
else:
if isinstance(key, Coroutine):
self._preliminary_task_states[key] = value

def __delitem__(self, key: asyncio.Task | Awaitable[Any], /) -> None:
if isinstance(key, asyncio.Task):
del self._task_states[key]
else:
self._task_states[key] = value

def __delitem__(self, key: asyncio.Task | Coroutine[Any, Any, Any], /) -> None:
if isinstance(key, Coroutine):
del self._preliminary_task_states[key]
else:
del self._task_states[key]

def __len__(self) -> int:
return len(self._task_states) + len(self._preliminary_task_states)

def __iter__(self) -> Iterator[Awaitable[Any] | asyncio.Task]:
def __iter__(self) -> Iterator[Coroutine[Any, Any, Any] | asyncio.Task]:
yield from self._task_states
yield from self._preliminary_task_states

Expand Down
13 changes: 12 additions & 1 deletion tests/test_taskgroups.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

import pytest
from exceptiongroup import catch
from pytest import FixtureRequest
from pytest import FixtureRequest, MonkeyPatch
from pytest_mock import MockerFixture

import anyio
Expand Down Expand Up @@ -1778,3 +1778,14 @@ async def sync_coro() -> None:
async with create_task_group() as tg:
tg.start_soon(sync_coro)
tg.cancel_scope.cancel()


@pytest.mark.parametrize("anyio_backend", ["asyncio"])
async def test_patched_asyncio_task(monkeypatch: MonkeyPatch) -> None:
monkeypatch.setattr(
asyncio,
"Task",
asyncio.tasks._PyTask, # type: ignore[attr-defined]
)
async with create_task_group() as tg:
tg.start_soon(sleep, 0)

0 comments on commit c518300

Please sign in to comment.