From 3004bb76050593463819bb809249af76542e54ba Mon Sep 17 00:00:00 2001 From: "Bryan C. Mills" Date: Fri, 6 Dec 2024 23:26:38 -0500 Subject: [PATCH] Simplify contextvars support Instead of storing a Context in a context variable, just copy out the changes from the setup task's context into the ambient context, and reset the changes after running the finalizer task. --- pytest_asyncio/plugin.py | 115 ++++++++---------- .../test_async_fixtures_contextvars.py | 47 +++++-- 2 files changed, 88 insertions(+), 74 deletions(-) diff --git a/pytest_asyncio/plugin.py b/pytest_asyncio/plugin.py index 762c58ae..ac756a2a 100644 --- a/pytest_asyncio/plugin.py +++ b/pytest_asyncio/plugin.py @@ -319,12 +319,27 @@ def _asyncgen_fixture_wrapper(request: FixtureRequest, **kwargs: Any): kwargs.pop(event_loop_fixture_id, None) gen_obj = func(**_add_kwargs(func, kwargs, event_loop, request)) - context = _event_loop_context.get(None) - async def setup(): res = await gen_obj.__anext__() # type: ignore[union-attr] return res + context = contextvars.copy_context() + setup_task = _create_task_in_context(event_loop, setup(), context) + result = event_loop.run_until_complete(setup_task) + + # Copy the context vars set by the setup task back into the ambient + # context for the test. + context_tokens = [] + for var in context: + try: + if var.get() is context.get(var): + # Not modified by the fixture, so leave it as-is. + continue + except LookupError: + pass + token = var.set(context.get(var)) + context_tokens.append((var, token)) + def finalizer() -> None: """Yield again, to finalize.""" @@ -341,14 +356,39 @@ async def async_finalizer() -> None: task = _create_task_in_context(event_loop, async_finalizer(), context) event_loop.run_until_complete(task) - setup_task = _create_task_in_context(event_loop, setup(), context) - result = event_loop.run_until_complete(setup_task) + # Since the fixture is now complete, restore any context variables + # it had set back to their original values. + while context_tokens: + (var, token) = context_tokens.pop() + var.reset(token) + request.addfinalizer(finalizer) return result fixturedef.func = _asyncgen_fixture_wrapper # type: ignore[misc] +def _create_task_in_context(loop, coro, context): + """ + Return an asyncio task that runs the coro in the specified context, + if possible. + + This allows fixture setup and teardown to be run as separate asyncio tasks, + while still being able to use context-manager idioms to maintain context + variables and make those variables visible to test functions. + + This is only fully supported on Python 3.11 and newer, as it requires + the API added for https://github.com/python/cpython/issues/91150. + On earlier versions, the returned task will use the default context instead. + """ + if context is not None: + try: + return loop.create_task(coro, context=context) + except TypeError: + pass + return loop.create_task(coro) + + def _wrap_async_fixture(fixturedef: FixtureDef) -> None: fixture = fixturedef.func @@ -365,10 +405,11 @@ async def setup(): res = await func(**_add_kwargs(func, kwargs, event_loop, request)) return res - task = _create_task_in_context( - event_loop, setup(), _event_loop_context.get(None) - ) - return event_loop.run_until_complete(task) + # Since the fixture doesn't have a cleanup phase, if it set any context + # variables we don't have a good way to clear them again. + # Instead, treat this fixture like an asyncio.Task, which has its own + # independent Context that doesn't affect the caller. + return event_loop.run_until_complete(setup()) fixturedef.func = _async_fixture_wrapper # type: ignore[misc] @@ -592,46 +633,6 @@ def pytest_pycollect_makeitem_convert_async_functions_to_subclass( Session: "session", } -# _event_loop_context stores the Context in which asyncio tasks on the fixture -# event loop should be run. After fixture setup, individual async test functions -# are run on copies of this context. -_event_loop_context: contextvars.ContextVar[contextvars.Context] = ( - contextvars.ContextVar("pytest_asyncio_event_loop_context") -) - - -@contextlib.contextmanager -def _set_event_loop_context(): - """Set event_loop_context to a copy of the calling thread's current context.""" - context = contextvars.copy_context() - token = _event_loop_context.set(context) - try: - yield - finally: - _event_loop_context.reset(token) - - -def _create_task_in_context(loop, coro, context): - """ - Return an asyncio task that runs the coro in the specified context, - if possible. - - This allows fixture setup and teardown to be run as separate asyncio tasks, - while still being able to use context-manager idioms to maintain context - variables and make those variables visible to test functions. - - This is only fully supported on Python 3.11 and newer, as it requires - the API added for https://github.com/python/cpython/issues/91150. - On earlier versions, the returned task will use the default context instead. - """ - if context is not None: - try: - return loop.create_task(coro, context=context) - except TypeError: - pass - return loop.create_task(coro) - - # A stack used to push package-scoped loops during collection of a package # and pop those loops during collection of a Module __package_loop_stack: list[FixtureFunctionMarker | FixtureFunction] = [] @@ -679,8 +680,7 @@ def scoped_event_loop( loop = asyncio.new_event_loop() loop.__pytest_asyncio = True # type: ignore[attr-defined] asyncio.set_event_loop(loop) - with _set_event_loop_context(): - yield loop + yield loop loop.close() # @pytest.fixture does not register the fixture anywhere, so pytest doesn't @@ -987,16 +987,9 @@ def wrap_in_sync( @functools.wraps(func) def inner(*args, **kwargs): - # Give each test its own context based on the loop's main context. - context = _event_loop_context.get(None) - if context is not None: - # We are using our own event loop fixture, so make a new copy of the - # fixture context so that the test won't pollute it. - context = context.copy() - coro = func(*args, **kwargs) _loop = _get_event_loop_no_warn() - task = _create_task_in_context(_loop, coro, context) + task = asyncio.ensure_future(coro, loop=_loop) try: _loop.run_until_complete(task) except BaseException: @@ -1105,8 +1098,7 @@ def event_loop(request: FixtureRequest) -> Iterator[asyncio.AbstractEventLoop]: # The magic value must be set as part of the function definition, because pytest # seems to have multiple instances of the same FixtureDef or fixture function loop.__original_fixture_loop = True # type: ignore[attr-defined] - with _set_event_loop_context(): - yield loop + yield loop loop.close() @@ -1119,8 +1111,7 @@ def _session_event_loop( loop = asyncio.new_event_loop() loop.__pytest_asyncio = True # type: ignore[attr-defined] asyncio.set_event_loop(loop) - with _set_event_loop_context(): - yield loop + yield loop loop.close() diff --git a/tests/async_fixtures/test_async_fixtures_contextvars.py b/tests/async_fixtures/test_async_fixtures_contextvars.py index 25bb8106..3f58be54 100644 --- a/tests/async_fixtures/test_async_fixtures_contextvars.py +++ b/tests/async_fixtures/test_async_fixtures_contextvars.py @@ -6,31 +6,54 @@ from __future__ import annotations import sys -from contextlib import asynccontextmanager +from contextlib import contextmanager from contextvars import ContextVar import pytest +_context_var = ContextVar("context_var") -@asynccontextmanager -async def context_var_manager(): - context_var = ContextVar("context_var") - token = context_var.set("value") + +@contextmanager +def context_var_manager(value): + token = _context_var.set(value) try: - yield context_var + yield finally: - context_var.reset(token) + _context_var.reset(token) + + +@pytest.fixture(scope="function") +async def no_var_fixture(): + with pytest.raises(LookupError): + _context_var.get() + yield + with pytest.raises(LookupError): + _context_var.get() + + +@pytest.fixture(scope="function") +async def var_fixture(no_var_fixture): + with context_var_manager("value"): + yield + + +@pytest.fixture(scope="function") +async def var_nop_fixture(var_fixture): + with context_var_manager(_context_var.get()): + yield @pytest.fixture(scope="function") -async def context_var(): - async with context_var_manager() as v: - yield v +def inner_var_fixture(var_nop_fixture): + assert _context_var.get() == "value" + with context_var_manager("value2"): + yield @pytest.mark.asyncio @pytest.mark.xfail( sys.version_info < (3, 11), reason="requires asyncio Task context support" ) -async def test(context_var): - assert context_var.get() == "value" +async def test(inner_var_fixture): + assert _context_var.get() == "value2"