Skip to content

Commit

Permalink
Simplify contextvars support
Browse files Browse the repository at this point in the history
Instead of storing a Context in a context variable, just copy out the
changes from the setup task's context into the ambient context, and
reset the changes after running the finalizer task.
  • Loading branch information
bcmills authored and seifertm committed Dec 12, 2024
1 parent 746c114 commit 3004bb7
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 74 deletions.
115 changes: 53 additions & 62 deletions pytest_asyncio/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,12 +319,27 @@ def _asyncgen_fixture_wrapper(request: FixtureRequest, **kwargs: Any):
kwargs.pop(event_loop_fixture_id, None)
gen_obj = func(**_add_kwargs(func, kwargs, event_loop, request))

context = _event_loop_context.get(None)

async def setup():
res = await gen_obj.__anext__() # type: ignore[union-attr]
return res

context = contextvars.copy_context()
setup_task = _create_task_in_context(event_loop, setup(), context)
result = event_loop.run_until_complete(setup_task)

# Copy the context vars set by the setup task back into the ambient
# context for the test.
context_tokens = []
for var in context:
try:
if var.get() is context.get(var):
# Not modified by the fixture, so leave it as-is.
continue
except LookupError:
pass
token = var.set(context.get(var))
context_tokens.append((var, token))

def finalizer() -> None:
"""Yield again, to finalize."""

Expand All @@ -341,14 +356,39 @@ async def async_finalizer() -> None:
task = _create_task_in_context(event_loop, async_finalizer(), context)
event_loop.run_until_complete(task)

setup_task = _create_task_in_context(event_loop, setup(), context)
result = event_loop.run_until_complete(setup_task)
# Since the fixture is now complete, restore any context variables
# it had set back to their original values.
while context_tokens:
(var, token) = context_tokens.pop()
var.reset(token)

request.addfinalizer(finalizer)
return result

fixturedef.func = _asyncgen_fixture_wrapper # type: ignore[misc]


def _create_task_in_context(loop, coro, context):
"""
Return an asyncio task that runs the coro in the specified context,
if possible.
This allows fixture setup and teardown to be run as separate asyncio tasks,
while still being able to use context-manager idioms to maintain context
variables and make those variables visible to test functions.
This is only fully supported on Python 3.11 and newer, as it requires
the API added for https://github.com/python/cpython/issues/91150.
On earlier versions, the returned task will use the default context instead.
"""
if context is not None:
try:
return loop.create_task(coro, context=context)
except TypeError:
pass
return loop.create_task(coro)


def _wrap_async_fixture(fixturedef: FixtureDef) -> None:
fixture = fixturedef.func

Expand All @@ -365,10 +405,11 @@ async def setup():
res = await func(**_add_kwargs(func, kwargs, event_loop, request))
return res

task = _create_task_in_context(
event_loop, setup(), _event_loop_context.get(None)
)
return event_loop.run_until_complete(task)
# Since the fixture doesn't have a cleanup phase, if it set any context
# variables we don't have a good way to clear them again.
# Instead, treat this fixture like an asyncio.Task, which has its own
# independent Context that doesn't affect the caller.
return event_loop.run_until_complete(setup())

fixturedef.func = _async_fixture_wrapper # type: ignore[misc]

Expand Down Expand Up @@ -592,46 +633,6 @@ def pytest_pycollect_makeitem_convert_async_functions_to_subclass(
Session: "session",
}

# _event_loop_context stores the Context in which asyncio tasks on the fixture
# event loop should be run. After fixture setup, individual async test functions
# are run on copies of this context.
_event_loop_context: contextvars.ContextVar[contextvars.Context] = (
contextvars.ContextVar("pytest_asyncio_event_loop_context")
)


@contextlib.contextmanager
def _set_event_loop_context():
"""Set event_loop_context to a copy of the calling thread's current context."""
context = contextvars.copy_context()
token = _event_loop_context.set(context)
try:
yield
finally:
_event_loop_context.reset(token)


def _create_task_in_context(loop, coro, context):
"""
Return an asyncio task that runs the coro in the specified context,
if possible.
This allows fixture setup and teardown to be run as separate asyncio tasks,
while still being able to use context-manager idioms to maintain context
variables and make those variables visible to test functions.
This is only fully supported on Python 3.11 and newer, as it requires
the API added for https://github.com/python/cpython/issues/91150.
On earlier versions, the returned task will use the default context instead.
"""
if context is not None:
try:
return loop.create_task(coro, context=context)
except TypeError:
pass
return loop.create_task(coro)


# A stack used to push package-scoped loops during collection of a package
# and pop those loops during collection of a Module
__package_loop_stack: list[FixtureFunctionMarker | FixtureFunction] = []
Expand Down Expand Up @@ -679,8 +680,7 @@ def scoped_event_loop(
loop = asyncio.new_event_loop()
loop.__pytest_asyncio = True # type: ignore[attr-defined]
asyncio.set_event_loop(loop)
with _set_event_loop_context():
yield loop
yield loop
loop.close()

# @pytest.fixture does not register the fixture anywhere, so pytest doesn't
Expand Down Expand Up @@ -987,16 +987,9 @@ def wrap_in_sync(

@functools.wraps(func)
def inner(*args, **kwargs):
# Give each test its own context based on the loop's main context.
context = _event_loop_context.get(None)
if context is not None:
# We are using our own event loop fixture, so make a new copy of the
# fixture context so that the test won't pollute it.
context = context.copy()

coro = func(*args, **kwargs)
_loop = _get_event_loop_no_warn()
task = _create_task_in_context(_loop, coro, context)
task = asyncio.ensure_future(coro, loop=_loop)
try:
_loop.run_until_complete(task)
except BaseException:
Expand Down Expand Up @@ -1105,8 +1098,7 @@ def event_loop(request: FixtureRequest) -> Iterator[asyncio.AbstractEventLoop]:
# The magic value must be set as part of the function definition, because pytest
# seems to have multiple instances of the same FixtureDef or fixture function
loop.__original_fixture_loop = True # type: ignore[attr-defined]
with _set_event_loop_context():
yield loop
yield loop
loop.close()


Expand All @@ -1119,8 +1111,7 @@ def _session_event_loop(
loop = asyncio.new_event_loop()
loop.__pytest_asyncio = True # type: ignore[attr-defined]
asyncio.set_event_loop(loop)
with _set_event_loop_context():
yield loop
yield loop
loop.close()


Expand Down
47 changes: 35 additions & 12 deletions tests/async_fixtures/test_async_fixtures_contextvars.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,31 +6,54 @@
from __future__ import annotations

import sys
from contextlib import asynccontextmanager
from contextlib import contextmanager
from contextvars import ContextVar

import pytest

_context_var = ContextVar("context_var")

@asynccontextmanager
async def context_var_manager():
context_var = ContextVar("context_var")
token = context_var.set("value")

@contextmanager
def context_var_manager(value):
token = _context_var.set(value)
try:
yield context_var
yield
finally:
context_var.reset(token)
_context_var.reset(token)


@pytest.fixture(scope="function")
async def no_var_fixture():
with pytest.raises(LookupError):
_context_var.get()
yield
with pytest.raises(LookupError):
_context_var.get()


@pytest.fixture(scope="function")
async def var_fixture(no_var_fixture):
with context_var_manager("value"):
yield


@pytest.fixture(scope="function")
async def var_nop_fixture(var_fixture):
with context_var_manager(_context_var.get()):
yield


@pytest.fixture(scope="function")
async def context_var():
async with context_var_manager() as v:
yield v
def inner_var_fixture(var_nop_fixture):
assert _context_var.get() == "value"
with context_var_manager("value2"):
yield


@pytest.mark.asyncio
@pytest.mark.xfail(
sys.version_info < (3, 11), reason="requires asyncio Task context support"
)
async def test(context_var):
assert context_var.get() == "value"
async def test(inner_var_fixture):
assert _context_var.get() == "value2"

0 comments on commit 3004bb7

Please sign in to comment.