diff --git a/src/anyio/_backends/_asyncio.py b/src/anyio/_backends/_asyncio.py index 1157abd2..5a0aa936 100644 --- a/src/anyio/_backends/_asyncio.py +++ b/src/anyio/_backends/_asyncio.py @@ -677,32 +677,33 @@ def __init__(self, parent_id: int | None, cancel_scope: CancelScope | None): self.cancel_scope = cancel_scope -class TaskStateStore(MutableMapping["Awaitable[Any] | asyncio.Task", TaskState]): +class TaskStateStore( + MutableMapping["Coroutine[Any, Any, Any] | asyncio.Task", TaskState] +): def __init__(self) -> None: - self._task_states = WeakKeyDictionary[ - "asyncio.Task | Awaitable[Any]", TaskState - ]() - self._preliminary_task_states: dict[Awaitable[Any], TaskState] = {} + self._task_states = WeakKeyDictionary[asyncio.Task, TaskState]() + self._preliminary_task_states: dict[Coroutine[Any, Any, Any], TaskState] = {} - def __getitem__(self, key: Awaitable[Any] | asyncio.Task, /) -> TaskState: + def __getitem__(self, key: Coroutine[Any, Any, Any] | asyncio.Task, /) -> TaskState: + task = cast(asyncio.Task, key) try: - return self._task_states[key] + return self._task_states[task] except KeyError: - if coro := cast(asyncio.Task, key).get_coro(): + if coro := task.get_coro(): if state := self._preliminary_task_states.get(coro): return state raise KeyError(key) def __setitem__( - self, key: asyncio.Task | Awaitable[Any], value: TaskState, / + self, key: asyncio.Task | Coroutine[Any, Any, Any], value: TaskState, / ) -> None: if isinstance(key, Coroutine): self._preliminary_task_states[key] = value else: self._task_states[key] = value - def __delitem__(self, key: asyncio.Task | Awaitable[Any], /) -> None: + def __delitem__(self, key: asyncio.Task | Coroutine[Any, Any, Any], /) -> None: if isinstance(key, Coroutine): del self._preliminary_task_states[key] else: @@ -711,7 +712,7 @@ def __delitem__(self, key: asyncio.Task | Awaitable[Any], /) -> None: def __len__(self) -> int: return len(self._task_states) + len(self._preliminary_task_states) - def __iter__(self) -> Iterator[Awaitable[Any] | asyncio.Task]: + def __iter__(self) -> Iterator[Coroutine[Any, Any, Any] | asyncio.Task]: yield from self._task_states yield from self._preliminary_task_states