Skip to content

Commit

Permalink
Refactored types
Browse files Browse the repository at this point in the history
  • Loading branch information
agronholm committed Dec 15, 2024
1 parent e8a917f commit ebed14d
Showing 1 changed file with 12 additions and 11 deletions.
23 changes: 12 additions & 11 deletions src/anyio/_backends/_asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -677,32 +677,33 @@ def __init__(self, parent_id: int | None, cancel_scope: CancelScope | None):
self.cancel_scope = cancel_scope


class TaskStateStore(MutableMapping["Awaitable[Any] | asyncio.Task", TaskState]):
class TaskStateStore(
MutableMapping["Coroutine[Any, Any, Any] | asyncio.Task", TaskState]
):
def __init__(self) -> None:
self._task_states = WeakKeyDictionary[
"asyncio.Task | Awaitable[Any]", TaskState
]()
self._preliminary_task_states: dict[Awaitable[Any], TaskState] = {}
self._task_states = WeakKeyDictionary[asyncio.Task, TaskState]()
self._preliminary_task_states: dict[Coroutine[Any, Any, Any], TaskState] = {}

def __getitem__(self, key: Awaitable[Any] | asyncio.Task, /) -> TaskState:
def __getitem__(self, key: Coroutine[Any, Any, Any] | asyncio.Task, /) -> TaskState:
task = cast(asyncio.Task, key)
try:
return self._task_states[key]
return self._task_states[task]
except KeyError:
if coro := cast(asyncio.Task, key).get_coro():
if coro := task.get_coro():
if state := self._preliminary_task_states.get(coro):
return state

raise KeyError(key)

def __setitem__(
self, key: asyncio.Task | Awaitable[Any], value: TaskState, /
self, key: asyncio.Task | Coroutine[Any, Any, Any], value: TaskState, /
) -> None:
if isinstance(key, Coroutine):
self._preliminary_task_states[key] = value
else:
self._task_states[key] = value

def __delitem__(self, key: asyncio.Task | Awaitable[Any], /) -> None:
def __delitem__(self, key: asyncio.Task | Coroutine[Any, Any, Any], /) -> None:
if isinstance(key, Coroutine):
del self._preliminary_task_states[key]
else:
Expand All @@ -711,7 +712,7 @@ def __delitem__(self, key: asyncio.Task | Awaitable[Any], /) -> None:
def __len__(self) -> int:
return len(self._task_states) + len(self._preliminary_task_states)

def __iter__(self) -> Iterator[Awaitable[Any] | asyncio.Task]:
def __iter__(self) -> Iterator[Coroutine[Any, Any, Any] | asyncio.Task]:
yield from self._task_states
yield from self._preliminary_task_states

Expand Down

0 comments on commit ebed14d

Please sign in to comment.