From f6574c39cb1c10d51ad757726b8e6a9d0eecf01b Mon Sep 17 00:00:00 2001 From: Thomas Grainger Date: Tue, 31 Dec 2024 08:22:39 +0000 Subject: [PATCH 01/10] make start_soon always start soon - support eager tasks --- src/anyio/_backends/_asyncio.py | 348 +++++++++++++++++++++++++++----- tests/conftest.py | 26 +++ 2 files changed, 328 insertions(+), 46 deletions(-) diff --git a/src/anyio/_backends/_asyncio.py b/src/anyio/_backends/_asyncio.py index 5a0aa936..7f6315b4 100644 --- a/src/anyio/_backends/_asyncio.py +++ b/src/anyio/_backends/_asyncio.py @@ -28,8 +28,6 @@ Collection, Coroutine, Iterable, - Iterator, - MutableMapping, Sequence, ) from concurrent.futures import Future @@ -54,6 +52,7 @@ IO, TYPE_CHECKING, Any, + Literal, Optional, TypeVar, cast, @@ -357,6 +356,9 @@ def _task_started(task: asyncio.Task) -> bool: """Return ``True`` if the task has been started and has not finished.""" # The task coro should never be None here, as we never add finished tasks to the # task list + if task.done(): + return False + coro = task.get_coro() assert coro is not None try: @@ -677,31 +679,53 @@ def __init__(self, parent_id: int | None, cancel_scope: CancelScope | None): self.cancel_scope = cancel_scope -class TaskStateStore( - MutableMapping["Coroutine[Any, Any, Any] | asyncio.Task", TaskState] -): +class TaskStateStore: def __init__(self) -> None: self._task_states = WeakKeyDictionary[asyncio.Task, TaskState]() self._preliminary_task_states: dict[Coroutine[Any, Any, Any], TaskState] = {} - def __getitem__(self, key: Coroutine[Any, Any, Any] | asyncio.Task, /) -> TaskState: - task = cast(asyncio.Task, key) + def __getitem__(self, key: asyncio.Task, /) -> TaskState: try: - return self._task_states[task] + return self._task_states[key] except KeyError: - if coro := task.get_coro(): - if state := self._preliminary_task_states.get(coro): - return state + pass + + coro = key.get_coro() + if coro is None: + raise KeyError(key) + + try: + state = self._preliminary_task_states.pop(coro) + except KeyError: + pass + else: + self._task_states[key] = state + return state raise KeyError(key) + def get(self, key: asyncio.Task, /) -> TaskState | None: + try: + return self[key] + except KeyError: + return None + def __setitem__( self, key: asyncio.Task | Coroutine[Any, Any, Any], value: TaskState, / ) -> None: if isinstance(key, Coroutine): self._preliminary_task_states[key] = value - else: - self._task_states[key] = value + return + + self._task_states[key] = value + coro = key.get_coro() + if coro is None: + return + + try: + del self._preliminary_task_states[coro] + except KeyError: + pass def __delitem__(self, key: asyncio.Task | Coroutine[Any, Any, Any], /) -> None: if isinstance(key, Coroutine): @@ -709,13 +733,6 @@ def __delitem__(self, key: asyncio.Task | Coroutine[Any, Any, Any], /) -> None: else: del self._task_states[key] - def __len__(self) -> int: - return len(self._task_states) + len(self._preliminary_task_states) - - def __iter__(self) -> Iterator[Coroutine[Any, Any, Any] | asyncio.Task]: - yield from self._task_states - yield from self._preliminary_task_states - _task_states = TaskStateStore() @@ -745,6 +762,10 @@ def started(self, value: T_contra | None = None) -> None: async def _wait(tasks: Iterable[asyncio.Task[object]]) -> None: tasks = set(tasks) + if not tasks: + await sleep(0) + return + waiter = get_running_loop().create_future() def on_completion(task: asyncio.Task[object]) -> None: @@ -763,12 +784,199 @@ def on_completion(task: asyncio.Task[object]) -> None: tasks.pop().remove_done_callback(on_completion) +_TF_IS_EAGER_MAP: WeakKeyDictionary[Callable, bool] = WeakKeyDictionary() + + +class _ShimEventLoop(asyncio.AbstractEventLoop): + def get_debug(self) -> bool: + return False + + def is_running(self) -> bool: + return True + + def call_exception_handler(self, *args: object, **kwargs: object) -> Any: + pass + + def add_reader(self, *args: object, **kwargs: object) -> Any: + pass + + def add_writer(self, *args: object, **kwargs: object) -> Any: + pass + + def add_signal_handler(self, *args: object, **kwargs: object) -> Any: + pass + + def call_at(self, *args: object, **kwargs: object) -> Any: + pass + + def call_later(self, *args: object, **kwargs: object) -> Any: + pass + + def call_soon(self, *args: object, **kwargs: object) -> Any: + pass + + def call_soon_threadsafe(self, *args: object, **kwargs: object) -> Any: + pass + + def close(self, *args: object, **kwargs: object) -> Any: + pass + + def connect_read_pipe(self, *args: object, **kwargs: object) -> Any: + pass + + def connect_write_pipe(self, *args: object, **kwargs: object) -> Any: + pass + + def create_connection(self, *args: object, **kwargs: object) -> Any: + pass + + def create_datagram_endpoint(self, *args: object, **kwargs: object) -> Any: + pass + + def create_future(self, *args: object, **kwargs: object) -> Any: + pass + + def create_task(self, *args: object, **kwargs: object) -> Any: + pass + + def create_server(self, *args: object, **kwargs: object) -> Any: + pass + + def default_exception_handler(self, *args: object, **kwargs: object) -> Any: + pass + + def get_exception_handler(self, *args: object, **kwargs: object) -> Any: + pass + + def get_task_factory(self, *args: object, **kwargs: object) -> Any: + pass + + def getaddrinfo(self, *args: object, **kwargs: object) -> Any: + pass + + def getnameinfo(self, *args: object, **kwargs: object) -> Any: + pass + + def is_closed(self, *args: object, **kwargs: object) -> Any: + pass + + def remove_reader(self, *args: object, **kwargs: object) -> Any: + pass + + def remove_writer(self, *args: object, **kwargs: object) -> Any: + pass + + def remove_signal_handler(self, *args: object, **kwargs: object) -> Any: + pass + + def run_forever(self, *args: object, **kwargs: object) -> Any: + pass + + def run_in_executor(self, *args: object, **kwargs: object) -> Any: + pass + + def sendfile(self, *args: object, **kwargs: object) -> Any: + pass + + def set_debug(self, *args: object, **kwargs: object) -> Any: + pass + + def run_until_complete(self, *args: object, **kwargs: object) -> Any: + pass + + def set_default_executor(self, *args: object, **kwargs: object) -> Any: + pass + + def set_exception_handler(self, *args: object, **kwargs: object) -> Any: + pass + + def set_task_factory(self, *args: object, **kwargs: object) -> Any: + pass + + def shutdown_asyncgens(self, *args: object, **kwargs: object) -> Any: + pass + + def shutdown_default_executor(self, *args: object, **kwargs: object) -> Any: + pass + + def sock_accept(self, *args: object, **kwargs: object) -> Any: + pass + + def sock_connect(self, *args: object, **kwargs: object) -> Any: + pass + + def sock_recv(self, *args: object, **kwargs: object) -> Any: + pass + + def sock_recv_into(self, *args: object, **kwargs: object) -> Any: + pass + + def sock_recvfrom(self, *args: object, **kwargs: object) -> Any: + pass + + def sock_recvfrom_into(self, *args: object, **kwargs: object) -> Any: + pass + + def sock_sendall(self, *args: object, **kwargs: object) -> Any: + pass + + def sock_sendfile(self, *args: object, **kwargs: object) -> Any: + pass + + def sock_sendto(self, *args: object, **kwargs: object) -> Any: + pass + + def start_tls(self, *args: object, **kwargs: object) -> Any: + pass + + def stop(self, *args: object, **kwargs: object) -> Any: + pass + + def subprocess_exec(self, *args: object, **kwargs: object) -> Any: + pass + + def subprocess_shell(self, *args: object, **kwargs: object) -> Any: + pass + + def time(self, *args: object, **kwargs: object) -> Any: + pass + + +if sys.version_info >= (3, 12): + + def is_eager(tf: Callable[..., Any] | None) -> bool: + if tf is None: + return False + try: + return _TF_IS_EAGER_MAP[tf] + except KeyError: + pass + + ran = False + + async def corofn() -> None: + nonlocal ran + ran = True + + tf(_ShimEventLoop(), corofn()).cancel() + _TF_IS_EAGER_MAP[tf] = ran + return ran +else: + + def is_eager(tf: object) -> Literal[False]: + return False + + +_pending_eager_tasks: RunVar[int] = RunVar("_pending_eager_tasks", 0) + + class TaskGroup(abc.TaskGroup): def __init__(self) -> None: self.cancel_scope: CancelScope = CancelScope() self._active = False self._exceptions: list[BaseException] = [] self._tasks: set[asyncio.Task] = set() + self._pending_eager_tasks = 0 async def __aenter__(self) -> TaskGroup: self.cancel_scope.__enter__() @@ -788,9 +996,9 @@ async def __aexit__( self._exceptions.append(exc_val) try: - if self._tasks: + if self._pending_eager_tasks or self._tasks: with CancelScope() as wait_scope: - while self._tasks: + while self._pending_eager_tasks or self._tasks: try: await _wait(self._tasks) except CancelledError as exc: @@ -835,14 +1043,16 @@ def _spawn( args: tuple[Unpack[PosArgsT]], name: object, task_status_future: asyncio.Future | None = None, - ) -> asyncio.Task: + ) -> Callable[[], Coroutine[Any, Any, Any]]: def task_done(_task: asyncio.Task) -> None: # task_state = _task_states[_task] assert task_state.cancel_scope is not None - assert _task in task_state.cancel_scope._tasks - task_state.cancel_scope._tasks.remove(_task) - self._tasks.remove(task) - del _task_states[_task] + task_state.cancel_scope._tasks.discard(_task) + self._tasks.discard(_task) + try: + del _task_states[_task] + except KeyError: + pass try: exc = _task.exception() @@ -886,7 +1096,7 @@ def task_done(_task: asyncio.Task) -> None: else: parent_id = id(self.cancel_scope._host_task) - coro = func(*args, **kwargs) + coro: Coroutine[Any, Any, Any] = func(*args, **kwargs) # type: ignore[assignment] if not iscoroutine(coro): prefix = f"{func.__module__}." if hasattr(func, "__module__") else "" raise TypeError( @@ -894,27 +1104,71 @@ def task_done(_task: asyncio.Task) -> None: f"the return value ({coro!r}) is not a coroutine object" ) - # Make the spawned task inherit the task group's cancel scope - _task_states[coro] = task_state = TaskState( - parent_id=parent_id, cancel_scope=self.cancel_scope - ) name = get_callable_name(func) if name is None else str(name) - try: - task = create_task(coro, name=name) - finally: - del _task_states[coro] + loop = asyncio.get_running_loop() + + task_state = TaskState(parent_id=parent_id, cancel_scope=self.cancel_scope) - _task_states[task] = task_state - self.cancel_scope._tasks.add(task) - self._tasks.add(task) + task: asyncio.Task | None = None - if task.done(): - # This can happen with eager task factories - task_done(task) + def create() -> None: + assert task_factory is not None + nonlocal task + self._pending_eager_tasks -= 1 + _pending_eager_tasks.set(_pending_eager_tasks.get() - 1) + _task_states[coro] = task_state + try: + task = task_factory(loop, coro, name=name) # type: ignore[assignment, call-arg] + except BaseException: + del _task_states[coro] + raise + + assert task is not None + if task.done(): + try: + del _task_states[coro] + except KeyError: + pass + task_done(task) + else: + # the task state could have changed if the task entered a new CS + new_task_state = _task_states[task] + assert new_task_state.cancel_scope is not None + # Make the spawned task inherit the cancel scope it's in + new_task_state.cancel_scope._tasks.add(task) + self._tasks.add(task) + task.add_done_callback(task_done) + + async def await_task_cancel_and_wait() -> None: + try: + await sleep(0) + finally: + assert task is not None + task.cancel() + await task + + async def cancel_and_wait() -> None: + assert task is not None + task.cancel() + await task + + task_factory = loop.get_task_factory() + if is_eager(task_factory): + self._pending_eager_tasks += 1 + _pending_eager_tasks.set(_pending_eager_tasks.get() + 1) + loop.call_soon(create) + return await_task_cancel_and_wait else: + task = loop.create_task(coro, name=name) + _task_states[task] = task_state + + # Make the spawned task inherit the task group's cancel scope + self.cancel_scope._tasks.add(task) + self._tasks.add(task) + task.add_done_callback(task_done) - return task + return cancel_and_wait def start_soon( self, @@ -928,7 +1182,7 @@ async def start( self, func: Callable[..., Awaitable[Any]], *args: object, name: object = None ) -> Any: future: asyncio.Future = asyncio.Future() - task = self._spawn(func, args, name, future) + cancel_and_wait = self._spawn(func, args, name, future) # If the task raises an exception after sending a start value without a switch # point between, the task group is cancelled and this method never proceeds to @@ -938,9 +1192,8 @@ async def start( return await future except CancelledError: # Cancel the task and wait for it to exit before returning - task.cancel() with CancelScope(shield=True), suppress(CancelledError): - await task + await cancel_and_wait() raise @@ -2834,6 +3087,9 @@ async def wait_all_tasks_blocked(cls) -> None: await cls.checkpoint() this_task = current_task() while True: + if _pending_eager_tasks.get(): + await cls.checkpoint() + for task in all_tasks(): if task is this_task: continue diff --git a/tests/conftest.py b/tests/conftest.py index 9d5acbfa..3b85a3fc 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2,6 +2,7 @@ import asyncio import ssl +import sys from collections.abc import Generator from ssl import SSLContext from typing import Any @@ -28,6 +29,26 @@ pytest_plugins = ["pytester", "pytest_mock"] +if sys.version_info < (3, 13): + if sys.platform == "win32": + EventLoop = asyncio.ProactorEventLoop + else: + EventLoop = asyncio.SelectorEventLoop +else: + EventLoop = asyncio.EventLoop + +if sys.version_info >= (3, 12): + + def eager_task_loop_factory() -> EventLoop: + loop = EventLoop() + loop.set_task_factory(asyncio.eager_task_factory) + return loop + + eager_marks: list[pytest.MarkDecorator] = [] +else: + eager_task_loop_factory = EventLoop + eager_marks = [pytest.mark.skip(reason="eager tasks not supported yet")] + @pytest.fixture( params=[ @@ -40,6 +61,11 @@ marks=uvloop_marks, id="asyncio+uvloop", ), + pytest.param( + ("asyncio", {"debug": True, "loop_factory": eager_task_loop_factory}), + marks=eager_marks, + id="asyncio+eager", + ), pytest.param("trio"), ] ) From b65dec837cf94f453994248bb70ba9968b7fff4d Mon Sep 17 00:00:00 2001 From: Thomas Grainger Date: Wed, 1 Jan 2025 08:36:59 +0000 Subject: [PATCH 02/10] add version history --- docs/versionhistory.rst | 3 +++ 1 file changed, 3 insertions(+) diff --git a/docs/versionhistory.rst b/docs/versionhistory.rst index 4901da4c..01df1f61 100644 --- a/docs/versionhistory.rst +++ b/docs/versionhistory.rst @@ -5,6 +5,9 @@ This library adheres to `Semantic Versioning 2.0 `_. **UNRELEASED** +- Finish support for eager tasks, start_soon behaves the same on trio or asyncio or asyncio + with eager tasks enabled. + (`#851 `_; PR by @graingert) - Configure ``SO_RCVBUF``, ``SO_SNDBUF`` and ``TCP_NODELAY`` on the selector thread waker socket pair. This should improve the performance of ``wait_readable()`` and ``wait_writable()`` when using the ``ProactorEventLoop`` From 6198e2eb37e0fd5a0ed6968c5e1205fca41b6091 Mon Sep 17 00:00:00 2001 From: Thomas Grainger Date: Wed, 1 Jan 2025 11:45:29 +0000 Subject: [PATCH 03/10] test asyncio specific tests with uvloop and eager tasks --- tests/conftest.py | 39 +++++++++++++++++------------------ tests/streams/test_memory.py | 4 +++- tests/test_debugging.py | 4 +++- tests/test_from_thread.py | 4 +++- tests/test_sockets.py | 4 +++- tests/test_synchronization.py | 12 ++++++----- tests/test_taskgroups.py | 38 ++++++++++++++-------------------- tests/test_to_thread.py | 4 +++- 8 files changed, 56 insertions(+), 53 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 3b85a3fc..c0117924 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -49,26 +49,25 @@ def eager_task_loop_factory() -> EventLoop: eager_task_loop_factory = EventLoop eager_marks = [pytest.mark.skip(reason="eager tasks not supported yet")] - -@pytest.fixture( - params=[ - pytest.param( - ("asyncio", {"debug": True, "loop_factory": None}), - id="asyncio", - ), - pytest.param( - ("asyncio", {"debug": True, "loop_factory": uvloop.new_event_loop}), - marks=uvloop_marks, - id="asyncio+uvloop", - ), - pytest.param( - ("asyncio", {"debug": True, "loop_factory": eager_task_loop_factory}), - marks=eager_marks, - id="asyncio+eager", - ), - pytest.param("trio"), - ] -) +asyncio_params = [ + pytest.param( + ("asyncio", {"debug": True, "loop_factory": None}), + id="asyncio", + ), + pytest.param( + ("asyncio", {"debug": True, "loop_factory": uvloop.new_event_loop}), + marks=uvloop_marks, + id="asyncio+uvloop", + ), + pytest.param( + ("asyncio", {"debug": True, "loop_factory": eager_task_loop_factory}), + marks=eager_marks, + id="asyncio+eager", + ), +] + + +@pytest.fixture(params=[*asyncio_params, pytest.param("trio")]) def anyio_backend(request: SubRequest) -> tuple[str, dict[str, Any]]: return request.param diff --git a/tests/streams/test_memory.py b/tests/streams/test_memory.py index 0e6d022a..4a4adbdd 100644 --- a/tests/streams/test_memory.py +++ b/tests/streams/test_memory.py @@ -24,6 +24,8 @@ MemoryObjectSendStream, ) +from ..conftest import asyncio_params + if sys.version_info < (3, 11): from exceptiongroup import ExceptionGroup @@ -486,7 +488,7 @@ async def test_not_closed_warning() -> None: gc.collect() -@pytest.mark.parametrize("anyio_backend", ["asyncio"], indirect=True) +@pytest.mark.parametrize("anyio_backend", asyncio_params) async def test_send_to_natively_cancelled_receiver() -> None: """ Test that if a task waiting on receive.receive() is cancelled and then another diff --git a/tests/test_debugging.py b/tests/test_debugging.py index 72843988..7813eaac 100644 --- a/tests/test_debugging.py +++ b/tests/test_debugging.py @@ -19,6 +19,8 @@ ) from anyio.abc import TaskStatus +from .conftest import asyncio_params + pytestmark = pytest.mark.anyio @@ -127,7 +129,7 @@ def generator_part() -> Generator[object, BaseException, None]: asyncio_event_loop.run_until_complete(native_coro_part()) -@pytest.mark.parametrize("anyio_backend", ["asyncio"]) +@pytest.mark.parametrize("anyio_backend", asyncio_params) async def test_wait_all_tasks_blocked_asend(anyio_backend: str) -> None: """Test that wait_all_tasks_blocked() does not crash on an `asend()` object.""" diff --git a/tests/test_from_thread.py b/tests/test_from_thread.py index e4c29ce0..009edd0f 100644 --- a/tests/test_from_thread.py +++ b/tests/test_from_thread.py @@ -33,6 +33,8 @@ from anyio.from_thread import BlockingPortal, start_blocking_portal from anyio.lowlevel import checkpoint +from .conftest import asyncio_params + if sys.version_info < (3, 11): from exceptiongroup import ExceptionGroup @@ -595,7 +597,7 @@ async def get_var() -> int: assert propagated_value == 6 - @pytest.mark.parametrize("anyio_backend", ["asyncio"]) + @pytest.mark.parametrize("anyio_backend", asyncio_params) async def test_asyncio_run_sync_called(self, caplog: LogCaptureFixture) -> None: """Regression test for #357.""" diff --git a/tests/test_sockets.py b/tests/test_sockets.py index b5143df0..a165c6c2 100644 --- a/tests/test_sockets.py +++ b/tests/test_sockets.py @@ -61,6 +61,8 @@ from anyio.lowlevel import checkpoint from anyio.streams.stapled import MultiListener +from .conftest import asyncio_params + if sys.version_info < (3, 11): from exceptiongroup import ExceptionGroup @@ -488,7 +490,7 @@ def serve() -> None: thread.join() assert thread_exception is None - @pytest.mark.parametrize("anyio_backend", ["asyncio"]) + @pytest.mark.parametrize("anyio_backend", asyncio_params) async def test_unretrieved_future_exception_server_crash( self, family: AnyIPAddressFamily, caplog: LogCaptureFixture ) -> None: diff --git a/tests/test_synchronization.py b/tests/test_synchronization.py index 83758c62..92a7a5a2 100644 --- a/tests/test_synchronization.py +++ b/tests/test_synchronization.py @@ -20,6 +20,8 @@ ) from anyio.abc import CapacityLimiter, TaskStatus +from .conftest import asyncio_params + pytestmark = pytest.mark.anyio @@ -162,7 +164,7 @@ async def waiter() -> None: assert not lock.statistics().locked assert lock.statistics().tasks_waiting == 0 - @pytest.mark.parametrize("anyio_backend", ["asyncio"]) + @pytest.mark.parametrize("anyio_backend", asyncio_params) async def test_asyncio_deadlock(self) -> None: """Regression test for #398.""" lock = Lock() @@ -178,7 +180,7 @@ async def acquire() -> None: task1.cancel() await asyncio.wait_for(task2, 1) - @pytest.mark.parametrize("anyio_backend", ["asyncio"]) + @pytest.mark.parametrize("anyio_backend", asyncio_params) async def test_cancel_after_release(self) -> None: """ Test that a native asyncio cancellation will not cause a lock ownership @@ -565,7 +567,7 @@ async def test_acquire_race(self) -> None: semaphore.release() pytest.raises(WouldBlock, semaphore.acquire_nowait) - @pytest.mark.parametrize("anyio_backend", ["asyncio"]) + @pytest.mark.parametrize("anyio_backend", asyncio_params) async def test_asyncio_deadlock(self) -> None: """Regression test for #398.""" semaphore = Semaphore(1) @@ -581,7 +583,7 @@ async def acquire() -> None: task1.cancel() await asyncio.wait_for(task2, 1) - @pytest.mark.parametrize("anyio_backend", ["asyncio"]) + @pytest.mark.parametrize("anyio_backend", asyncio_params) async def test_cancel_after_release(self) -> None: """ Test that a native asyncio cancellation will not cause a semaphore ownership @@ -731,7 +733,7 @@ async def waiter() -> None: assert limiter.statistics().tasks_waiting == 0 assert limiter.statistics().borrowed_tokens == 0 - @pytest.mark.parametrize("anyio_backend", ["asyncio"]) + @pytest.mark.parametrize("anyio_backend", asyncio_params) async def test_asyncio_deadlock(self) -> None: """Regression test for #398.""" limiter = CapacityLimiter(1) diff --git a/tests/test_taskgroups.py b/tests/test_taskgroups.py index d20565c1..d6680b57 100644 --- a/tests/test_taskgroups.py +++ b/tests/test_taskgroups.py @@ -32,6 +32,8 @@ from anyio.abc import TaskGroup, TaskStatus from anyio.lowlevel import checkpoint +from .conftest import asyncio_params + if sys.version_info < (3, 11): from exceptiongroup import BaseExceptionGroup, ExceptionGroup @@ -200,7 +202,7 @@ async def taskfunc(*, task_status: TaskStatus) -> None: assert not finished -@pytest.mark.parametrize("anyio_backend", ["asyncio"]) +@pytest.mark.parametrize("anyio_backend", asyncio_params) async def test_start_native_host_cancelled() -> None: started = finished = False @@ -224,7 +226,7 @@ async def start_another() -> None: assert not finished -@pytest.mark.parametrize("anyio_backend", ["asyncio"]) +@pytest.mark.parametrize("anyio_backend", asyncio_params) async def test_start_native_child_cancelled() -> None: task = None finished = False @@ -248,7 +250,7 @@ async def start_another() -> None: assert not finished -@pytest.mark.parametrize("anyio_backend", ["asyncio"]) +@pytest.mark.parametrize("anyio_backend", asyncio_params) async def test_propagate_native_cancellation_from_taskgroup() -> None: async def taskfunc() -> None: async with create_task_group() as tg: @@ -261,7 +263,7 @@ async def taskfunc() -> None: await task -@pytest.mark.parametrize("anyio_backend", ["asyncio"]) +@pytest.mark.parametrize("anyio_backend", asyncio_params) async def test_cancel_with_nested_task_groups(mocker: MockerFixture) -> None: """Regression test for #695.""" @@ -691,7 +693,7 @@ async def test_shielded_cleanup_after_cancel() -> None: assert get_current_task().has_pending_cancellation() -@pytest.mark.parametrize("anyio_backend", ["asyncio"]) +@pytest.mark.parametrize("anyio_backend", asyncio_params) async def test_cleanup_after_native_cancel() -> None: """Regression test for #832.""" # See also https://github.com/python/cpython/pull/102815. @@ -791,7 +793,7 @@ async def outer_task() -> None: assert outer_task_ran -@pytest.mark.parametrize("anyio_backend", ["asyncio"]) +@pytest.mark.parametrize("anyio_backend", asyncio_params) async def test_cancel_host_asyncgen() -> None: done = False @@ -1146,7 +1148,7 @@ def generator_part() -> Generator[object, BaseException, None]: @pytest.mark.filterwarnings( 'ignore:"@coroutine" decorator is deprecated:DeprecationWarning' ) -@pytest.mark.parametrize("anyio_backend", ["asyncio"]) +@pytest.mark.parametrize("anyio_backend", asyncio_params) async def test_schedule_old_style_coroutine_func() -> None: """ Test that we give a sensible error when a user tries to spawn a task from a @@ -1169,7 +1171,7 @@ def corofunc() -> Generator[Any, Any, None]: tg.start_soon(corofunc) -@pytest.mark.parametrize("anyio_backend", ["asyncio"]) +@pytest.mark.parametrize("anyio_backend", asyncio_params) async def test_cancel_native_future_tasks() -> None: async def wait_native_future() -> None: loop = asyncio.get_running_loop() @@ -1180,7 +1182,7 @@ async def wait_native_future() -> None: tg.cancel_scope.cancel() -@pytest.mark.parametrize("anyio_backend", ["asyncio"]) +@pytest.mark.parametrize("anyio_backend", asyncio_params) async def test_cancel_native_future_tasks_cancel_scope() -> None: async def wait_native_future() -> None: with anyio.CancelScope(): @@ -1192,7 +1194,7 @@ async def wait_native_future() -> None: tg.cancel_scope.cancel() -@pytest.mark.parametrize("anyio_backend", ["asyncio"]) +@pytest.mark.parametrize("anyio_backend", asyncio_params) async def test_cancel_completed_task() -> None: loop = asyncio.get_running_loop() old_exception_handler = loop.get_exception_handler() @@ -1288,7 +1290,7 @@ async def test_cancelscope_exit_before_enter() -> None: @pytest.mark.parametrize( - "anyio_backend", ["asyncio"] + "anyio_backend", asyncio_params ) # trio does not check for this yet async def test_cancelscope_exit_in_wrong_task() -> None: async def enter_scope(scope: CancelScope) -> None: @@ -1403,7 +1405,7 @@ async def starter_task() -> None: sys.version_info < (3, 11), reason="Task uncancelling is only supported on Python 3.11", ) -@pytest.mark.parametrize("anyio_backend", ["asyncio"]) +@pytest.mark.parametrize("anyio_backend", asyncio_params) class TestUncancel: async def test_uncancel_after_native_cancel(self) -> None: task = cast(asyncio.Task, asyncio.current_task()) @@ -1759,28 +1761,18 @@ async def typetest_optional_status( task_status.started(1) -@pytest.mark.skipif( - sys.version_info < (3, 12), - reason="Eager task factories require Python 3.12", -) -@pytest.mark.parametrize("anyio_backend", ["asyncio"]) async def test_eager_task_factory(request: FixtureRequest) -> None: async def sync_coro() -> None: # This should trigger fetching the task state with CancelScope(): # noqa: ASYNC100 pass - loop = asyncio.get_running_loop() - old_task_factory = loop.get_task_factory() - loop.set_task_factory(asyncio.eager_task_factory) - request.addfinalizer(lambda: loop.set_task_factory(old_task_factory)) - async with create_task_group() as tg: tg.start_soon(sync_coro) tg.cancel_scope.cancel() -@pytest.mark.parametrize("anyio_backend", ["asyncio"]) +@pytest.mark.parametrize("anyio_backend", asyncio_params) async def test_patched_asyncio_task(monkeypatch: MonkeyPatch) -> None: monkeypatch.setattr( asyncio, diff --git a/tests/test_to_thread.py b/tests/test_to_thread.py index 9b80de2d..caffa275 100644 --- a/tests/test_to_thread.py +++ b/tests/test_to_thread.py @@ -23,6 +23,8 @@ ) from anyio.from_thread import BlockingPortalProvider +from .conftest import asyncio_params + pytestmark = pytest.mark.anyio @@ -159,7 +161,7 @@ async def test_asynclib_detection() -> None: await to_thread.run_sync(sniffio.current_async_library) -@pytest.mark.parametrize("anyio_backend", ["asyncio"]) +@pytest.mark.parametrize("anyio_backend", asyncio_params) async def test_asyncio_cancel_native_task() -> None: task: asyncio.Task[None] | None = None From b75efe70302e85d6d754e13ebfc769d886555827 Mon Sep 17 00:00:00 2001 From: Thomas Grainger Date: Wed, 1 Jan 2025 12:05:54 +0000 Subject: [PATCH 04/10] fast track asyncio.eager_task_factory detection --- src/anyio/_backends/_asyncio.py | 11 +++-- tests/test_taskgroups.py | 85 ++++++++++++++++++++++++++++++++- 2 files changed, 91 insertions(+), 5 deletions(-) diff --git a/src/anyio/_backends/_asyncio.py b/src/anyio/_backends/_asyncio.py index 7f6315b4..ce73d160 100644 --- a/src/anyio/_backends/_asyncio.py +++ b/src/anyio/_backends/_asyncio.py @@ -784,9 +784,6 @@ def on_completion(task: asyncio.Task[object]) -> None: tasks.pop().remove_done_callback(on_completion) -_TF_IS_EAGER_MAP: WeakKeyDictionary[Callable, bool] = WeakKeyDictionary() - - class _ShimEventLoop(asyncio.AbstractEventLoop): def get_debug(self) -> bool: return False @@ -943,6 +940,10 @@ def time(self, *args: object, **kwargs: object) -> Any: if sys.version_info >= (3, 12): + _TF_IS_EAGER_MAP: WeakKeyDictionary[Callable, bool] = WeakKeyDictionary() + _TF_IS_EAGER_MAP[asyncio.eager_task_factory] = True + + _TF_CODE = asyncio.eager_task_factory.__code__ def is_eager(tf: Callable[..., Any] | None) -> bool: if tf is None: @@ -952,6 +953,10 @@ def is_eager(tf: Callable[..., Any] | None) -> bool: except KeyError: pass + if getattr(tf, "__code__", object()) is _TF_CODE: + _TF_IS_EAGER_MAP[tf] = True + return True + ran = False async def corofn() -> None: diff --git a/tests/test_taskgroups.py b/tests/test_taskgroups.py index d6680b57..1d8904d5 100644 --- a/tests/test_taskgroups.py +++ b/tests/test_taskgroups.py @@ -1,12 +1,13 @@ from __future__ import annotations import asyncio +import contextvars import gc import math import sys import time from asyncio import CancelledError -from collections.abc import AsyncGenerator, Coroutine, Generator +from collections.abc import AsyncGenerator, Callable, Coroutine, Generator from typing import Any, NoReturn, cast import pytest @@ -32,7 +33,7 @@ from anyio.abc import TaskGroup, TaskStatus from anyio.lowlevel import checkpoint -from .conftest import asyncio_params +from .conftest import EventLoop, asyncio_params if sys.version_info < (3, 11): from exceptiongroup import BaseExceptionGroup, ExceptionGroup @@ -1762,13 +1763,93 @@ async def typetest_optional_status( async def test_eager_task_factory(request: FixtureRequest) -> None: + ran = False + + async def sync_coro() -> None: + nonlocal ran + ran = True + # This should trigger fetching the task state + with CancelScope(): # noqa: ASYNC100 + pass + + async with create_task_group() as tg: + tg.start_soon(sync_coro) + assert not ran + tg.cancel_scope.cancel() + + +if sys.version_info >= (3, 12): + + def task_factory_loop_factory_factory( + task_factory: Callable[..., Any], + ) -> Callable[[], EventLoop]: + def factory() -> EventLoop: + loop = EventLoop() + loop.set_task_factory(task_factory) + return loop + + return factory + + def create_eager_task_factory( + custom_task_constructor: Callable[..., Any], + ) -> Callable[..., Any]: + def factory( + loop: Any, + coro: Any, + *, + name: str | None = None, + context: contextvars.Context | None = None, + ) -> Any: + return custom_task_constructor( + coro, loop=loop, name=name, context=context, eager_start=True + ) + + return factory + + custom_params = [ + pytest.param( + ( + "asyncio", + { + "debug": True, + "loop_factory": task_factory_loop_factory_factory( + asyncio.create_eager_task_factory(asyncio.Task) + ), + }, + ), + id="asyncio+stdlib-custom-eager", + ), + pytest.param( + ( + "asyncio", + { + "debug": True, + "loop_factory": task_factory_loop_factory_factory( + create_eager_task_factory(asyncio.Task) + ), + }, + ), + id="asyncio+my-custom-eager", + ), + ] +else: + custom_params = [] + + +@pytest.mark.parametrize("anyio_backend", asyncio_params + custom_params) +async def test_various_custom_task_factories(request: FixtureRequest) -> None: + ran = False + async def sync_coro() -> None: + nonlocal ran + ran = True # This should trigger fetching the task state with CancelScope(): # noqa: ASYNC100 pass async with create_task_group() as tg: tg.start_soon(sync_coro) + assert not ran tg.cancel_scope.cancel() From f701e3b6a2ccf700af6edd1ae1ab085cc5930d39 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alex=20Gr=C3=B6nholm?= Date: Wed, 1 Jan 2025 14:51:37 +0200 Subject: [PATCH 05/10] Simplified the backend setup in conftest --- tests/conftest.py | 43 +++++++++++++--------------------------- tests/test_taskgroups.py | 8 ++++---- 2 files changed, 18 insertions(+), 33 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index c0117924..43004595 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -29,42 +29,27 @@ pytest_plugins = ["pytester", "pytest_mock"] -if sys.version_info < (3, 13): - if sys.platform == "win32": - EventLoop = asyncio.ProactorEventLoop - else: - EventLoop = asyncio.SelectorEventLoop -else: - EventLoop = asyncio.EventLoop - -if sys.version_info >= (3, 12): - - def eager_task_loop_factory() -> EventLoop: - loop = EventLoop() - loop.set_task_factory(asyncio.eager_task_factory) - return loop - - eager_marks: list[pytest.MarkDecorator] = [] -else: - eager_task_loop_factory = EventLoop - eager_marks = [pytest.mark.skip(reason="eager tasks not supported yet")] - asyncio_params = [ - pytest.param( - ("asyncio", {"debug": True, "loop_factory": None}), - id="asyncio", - ), + pytest.param(("asyncio", {"debug": True}), id="asyncio"), pytest.param( ("asyncio", {"debug": True, "loop_factory": uvloop.new_event_loop}), marks=uvloop_marks, id="asyncio+uvloop", ), - pytest.param( - ("asyncio", {"debug": True, "loop_factory": eager_task_loop_factory}), - marks=eager_marks, - id="asyncio+eager", - ), ] +if sys.version_info >= (3, 12): + + def eager_task_loop_factory() -> asyncio.AbstractEventLoop: + loop = asyncio.new_event_loop() + loop.set_task_factory(asyncio.eager_task_factory) + return loop + + asyncio_params.append( + pytest.param( + ("asyncio", {"debug": True, "loop_factory": eager_task_loop_factory}), + id="asyncio+eager", + ), + ) @pytest.fixture(params=[*asyncio_params, pytest.param("trio")]) diff --git a/tests/test_taskgroups.py b/tests/test_taskgroups.py index 1d8904d5..d77340f3 100644 --- a/tests/test_taskgroups.py +++ b/tests/test_taskgroups.py @@ -33,7 +33,7 @@ from anyio.abc import TaskGroup, TaskStatus from anyio.lowlevel import checkpoint -from .conftest import EventLoop, asyncio_params +from .conftest import asyncio_params if sys.version_info < (3, 11): from exceptiongroup import BaseExceptionGroup, ExceptionGroup @@ -1782,9 +1782,9 @@ async def sync_coro() -> None: def task_factory_loop_factory_factory( task_factory: Callable[..., Any], - ) -> Callable[[], EventLoop]: - def factory() -> EventLoop: - loop = EventLoop() + ) -> Callable[[], asyncio.AbstractEventLoop]: + def factory() -> asyncio.AbstractEventLoop: + loop = asyncio.new_event_loop() loop.set_task_factory(task_factory) return loop From 53bcb8978ed0cd4a8cfcb165637048cda42b8965 Mon Sep 17 00:00:00 2001 From: Thomas Grainger Date: Wed, 1 Jan 2025 18:03:23 +0000 Subject: [PATCH 06/10] remove now redundant test --- tests/test_taskgroups.py | 16 ---------------- 1 file changed, 16 deletions(-) diff --git a/tests/test_taskgroups.py b/tests/test_taskgroups.py index d77340f3..ae154803 100644 --- a/tests/test_taskgroups.py +++ b/tests/test_taskgroups.py @@ -1762,22 +1762,6 @@ async def typetest_optional_status( task_status.started(1) -async def test_eager_task_factory(request: FixtureRequest) -> None: - ran = False - - async def sync_coro() -> None: - nonlocal ran - ran = True - # This should trigger fetching the task state - with CancelScope(): # noqa: ASYNC100 - pass - - async with create_task_group() as tg: - tg.start_soon(sync_coro) - assert not ran - tg.cancel_scope.cancel() - - if sys.version_info >= (3, 12): def task_factory_loop_factory_factory( From c1c4fe8a28d3c1b5a134a298ecf82a8ae099d5b8 Mon Sep 17 00:00:00 2001 From: Thomas Grainger Date: Thu, 2 Jan 2025 07:48:54 +0000 Subject: [PATCH 07/10] get rid of _pending_eager_tasks --- src/anyio/_backends/_asyncio.py | 99 ++++++++++++++++++--------------- 1 file changed, 53 insertions(+), 46 deletions(-) diff --git a/src/anyio/_backends/_asyncio.py b/src/anyio/_backends/_asyncio.py index ce73d160..4e34c938 100644 --- a/src/anyio/_backends/_asyncio.py +++ b/src/anyio/_backends/_asyncio.py @@ -972,16 +972,12 @@ def is_eager(tf: object) -> Literal[False]: return False -_pending_eager_tasks: RunVar[int] = RunVar("_pending_eager_tasks", 0) - - class TaskGroup(abc.TaskGroup): def __init__(self) -> None: self.cancel_scope: CancelScope = CancelScope() self._active = False self._exceptions: list[BaseException] = [] self._tasks: set[asyncio.Task] = set() - self._pending_eager_tasks = 0 async def __aenter__(self) -> TaskGroup: self.cancel_scope.__enter__() @@ -1001,9 +997,9 @@ async def __aexit__( self._exceptions.append(exc_val) try: - if self._pending_eager_tasks or self._tasks: + if self._tasks: with CancelScope() as wait_scope: - while self._pending_eager_tasks or self._tasks: + while self._tasks: try: await _wait(self._tasks) except CancelledError as exc: @@ -1049,39 +1045,53 @@ def _spawn( name: object, task_status_future: asyncio.Future | None = None, ) -> Callable[[], Coroutine[Any, Any, Any]]: + def get_exception(_task: asyncio.Task) -> BaseException: + try: + return _task.exception() + except CancelledError as e: + while isinstance(e.__context__, CancelledError): + e = e.__context__ + + return e + + def process_exception(exc: BaseException) -> None: + # The future can only be in the cancelled state if the host task was + # cancelled, so return immediately instead of adding one more + # CancelledError to the exceptions list + if task_status_future is not None and task_status_future.cancelled(): + return + + if task_status_future is None or task_status_future.done(): + if not isinstance(exc, CancelledError): + self._exceptions.append(exc) + + if not self.cancel_scope._effectively_cancelled: + self.cancel_scope.cancel() + else: + task_status_future.set_exception(exc) + + + def simple_task_done(_task: asyncio.Task) -> None: + self._tasks.discard(_task) + exc = get_exception(_task) + + if exc is not None: + process_exception(exc) + def task_done(_task: asyncio.Task) -> None: - # task_state = _task_states[_task] + self._tasks.discard(_task) assert task_state.cancel_scope is not None task_state.cancel_scope._tasks.discard(_task) - self._tasks.discard(_task) + try: del _task_states[_task] except KeyError: pass - try: - exc = _task.exception() - except CancelledError as e: - while isinstance(e.__context__, CancelledError): - e = e.__context__ - - exc = e + exc = get_exception(_task) if exc is not None: - # The future can only be in the cancelled state if the host task was - # cancelled, so return immediately instead of adding one more - # CancelledError to the exceptions list - if task_status_future is not None and task_status_future.cancelled(): - return - - if task_status_future is None or task_status_future.done(): - if not isinstance(exc, CancelledError): - self._exceptions.append(exc) - - if not self.cancel_scope._effectively_cancelled: - self.cancel_scope.cancel() - else: - task_status_future.set_exception(exc) + process_exception(exc) elif task_status_future is not None and not task_status_future.done(): task_status_future.set_exception( RuntimeError("Child exited without calling task_status.started()") @@ -1116,11 +1126,10 @@ def task_done(_task: asyncio.Task) -> None: task: asyncio.Task | None = None - def create() -> None: + async def create() -> None: + await sleep(0) assert task_factory is not None nonlocal task - self._pending_eager_tasks -= 1 - _pending_eager_tasks.set(_pending_eager_tasks.get() - 1) _task_states[coro] = task_state try: task = task_factory(loop, coro, name=name) # type: ignore[assignment, call-arg] @@ -1135,14 +1144,15 @@ def create() -> None: except KeyError: pass task_done(task) - else: - # the task state could have changed if the task entered a new CS - new_task_state = _task_states[task] - assert new_task_state.cancel_scope is not None - # Make the spawned task inherit the cancel scope it's in - new_task_state.cancel_scope._tasks.add(task) - self._tasks.add(task) - task.add_done_callback(task_done) + return + + # the task state could have changed if the task entered a new CS + new_task_state = _task_states[task] + assert new_task_state.cancel_scope is not None + # Make the spawned task inherit the cancel scope it's in + new_task_state.cancel_scope._tasks.add(task) + self._tasks.add(task) + task.add_done_callback(task_done) async def await_task_cancel_and_wait() -> None: try: @@ -1159,9 +1169,9 @@ async def cancel_and_wait() -> None: task_factory = loop.get_task_factory() if is_eager(task_factory): - self._pending_eager_tasks += 1 - _pending_eager_tasks.set(_pending_eager_tasks.get() + 1) - loop.call_soon(create) + spawn_task: asyncio.Task = task_factory(loop, create(), name=f"spawn {name}") # type: ignore[assignment, misc, call-arg] + self._tasks.add(spawn_task) + spawn_task.add_done_callback(simple_task_done) return await_task_cancel_and_wait else: task = loop.create_task(coro, name=name) @@ -3092,9 +3102,6 @@ async def wait_all_tasks_blocked(cls) -> None: await cls.checkpoint() this_task = current_task() while True: - if _pending_eager_tasks.get(): - await cls.checkpoint() - for task in all_tasks(): if task is this_task: continue From 1454455beaf5de2d229a9096c682f7fe52795d72 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 2 Jan 2025 07:49:12 +0000 Subject: [PATCH 08/10] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/anyio/_backends/_asyncio.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/anyio/_backends/_asyncio.py b/src/anyio/_backends/_asyncio.py index 4e34c938..a6663a2f 100644 --- a/src/anyio/_backends/_asyncio.py +++ b/src/anyio/_backends/_asyncio.py @@ -1070,7 +1070,6 @@ def process_exception(exc: BaseException) -> None: else: task_status_future.set_exception(exc) - def simple_task_done(_task: asyncio.Task) -> None: self._tasks.discard(_task) exc = get_exception(_task) @@ -1169,7 +1168,9 @@ async def cancel_and_wait() -> None: task_factory = loop.get_task_factory() if is_eager(task_factory): - spawn_task: asyncio.Task = task_factory(loop, create(), name=f"spawn {name}") # type: ignore[assignment, misc, call-arg] + spawn_task: asyncio.Task = task_factory( + loop, create(), name=f"spawn {name}" + ) # type: ignore[assignment, misc, call-arg] self._tasks.add(spawn_task) spawn_task.add_done_callback(simple_task_done) return await_task_cancel_and_wait From c13ec699a6ec8bffa499d587e96362a5730459fe Mon Sep 17 00:00:00 2001 From: Thomas Grainger Date: Thu, 2 Jan 2025 07:50:28 +0000 Subject: [PATCH 09/10] fix type --- src/anyio/_backends/_asyncio.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/anyio/_backends/_asyncio.py b/src/anyio/_backends/_asyncio.py index a6663a2f..142f72d5 100644 --- a/src/anyio/_backends/_asyncio.py +++ b/src/anyio/_backends/_asyncio.py @@ -1045,7 +1045,7 @@ def _spawn( name: object, task_status_future: asyncio.Future | None = None, ) -> Callable[[], Coroutine[Any, Any, Any]]: - def get_exception(_task: asyncio.Task) -> BaseException: + def get_exception(_task: asyncio.Task) -> BaseException | None: try: return _task.exception() except CancelledError as e: From b7b6e5ba41f015fc4aed82a6e5776bcb90194a7d Mon Sep 17 00:00:00 2001 From: Thomas Grainger Date: Thu, 2 Jan 2025 07:52:32 +0000 Subject: [PATCH 10/10] Update src/anyio/_backends/_asyncio.py --- src/anyio/_backends/_asyncio.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/anyio/_backends/_asyncio.py b/src/anyio/_backends/_asyncio.py index 142f72d5..dbb4d3aa 100644 --- a/src/anyio/_backends/_asyncio.py +++ b/src/anyio/_backends/_asyncio.py @@ -762,10 +762,6 @@ def started(self, value: T_contra | None = None) -> None: async def _wait(tasks: Iterable[asyncio.Task[object]]) -> None: tasks = set(tasks) - if not tasks: - await sleep(0) - return - waiter = get_running_loop().create_future() def on_completion(task: asyncio.Task[object]) -> None: