Skip to content

Commit

Permalink
Added unit test and support for calling getfixturevalue on a synchron…
Browse files Browse the repository at this point in the history
…ous fixture definition
  • Loading branch information
jhominal committed Sep 29, 2023
1 parent c9e331b commit d4977b0
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 2 deletions.
24 changes: 22 additions & 2 deletions src/anyio/pytest_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from collections.abc import Iterator
from contextlib import contextmanager
from contextvars import Context, copy_context
from contextvars import Context, ContextVar, copy_context
from inspect import isasyncgenfunction, iscoroutinefunction
from typing import Any, Dict, Tuple, cast

Expand All @@ -15,17 +15,35 @@
from .abc import TestRunner

_current_runner: TestRunner | None = None
_current_reentrancy_token = ContextVar[object]("anyio.pytest_plugin.reentrancy_token")
contextvars_context_key: StashKey[Context] = StashKey()
_test_context_like_key: StashKey[ContextLike] = StashKey()


class _TestContext(ContextLike):
"""This class manages transmission of sniffio.current_async_library_cvar"""
"""Manages reentrancy and transmission of sniffio.current_async_library_cvar"""

def __init__(self, context: Context):
self._context = context
self._reentrancy_token = object()

def _is_already_in_context(self) -> bool:
# if context var is not set to the token, we are in another context
if _current_reentrancy_token.get(None) is not self._reentrancy_token:
return False

# Token value is the same, but we may be in a copy of self._context
test_value = object()
reset_reentrancy = _current_reentrancy_token.set(test_value)
try:
return self._context[_current_reentrancy_token] is test_value
finally:
_current_reentrancy_token.reset(reset_reentrancy)

def run(self, func: Any, /, *args: Any, **kwargs: Any) -> Any:
if self._is_already_in_context():
return func(*args, **kwargs)

return self._context.run(
self._set_context_and_run,
sniffio.current_async_library_cvar.get(None),
Expand All @@ -37,6 +55,7 @@ def run(self, func: Any, /, *args: Any, **kwargs: Any) -> Any:
def _set_context_and_run(
self, current_async_library: str | None, func: Any, /, *args: Any, **kwargs: Any
) -> Any:
reset_reentrancy = _current_reentrancy_token.set(self._reentrancy_token)
reset_sniffio = None
if current_async_library is not None:
reset_sniffio = sniffio.current_async_library_cvar.set(
Expand All @@ -46,6 +65,7 @@ def _set_context_and_run(
try:
return func(*args, **kwargs)
finally:
_current_reentrancy_token.reset(reset_reentrancy)
if reset_sniffio is not None:
sniffio.current_async_library_cvar.reset(reset_sniffio)

Expand Down
58 changes: 58 additions & 0 deletions tests/test_pytest_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,3 +439,61 @@ def test_sync_func_sync_then_async_fixture(

result = testdir.runpytest(*pytest_args)
result.assert_outcomes(passed=4 * len(get_all_backends()))


def test_sync_getfixturevalue(testdir: Pytester) -> None:
testdir.makepyfile(
"""
from __future__ import annotations
from contextvars import ContextVar
import pytest
var = ContextVar("var")
@pytest.fixture
def function_fixture():
return "function"
@pytest.fixture
def generator_fixture():
yield "generator"
@pytest.fixture
def set_var():
value = object()
reset = var.set(value)
yield value
var.reset(reset)
@pytest.mark.parametrize("prefix", ["function", "generator"])
def test_getfixturevalue_from_sync(request, prefix):
assert request.getfixturevalue(f"{prefix}_fixture") == prefix
@pytest.mark.anyio
@pytest.mark.parametrize("prefix", ["function", "generator"])
async def test_getfixturevalue_from_async(request, prefix):
assert request.getfixturevalue(f"{prefix}_fixture") == prefix
def test_getfixturevalue_with_context_from_sync(request):
value = request.getfixturevalue("set_var")
assert var.get(None) is value
@pytest.mark.anyio
async def test_getfixturevalue_with_context_from_async(request):
value = request.getfixturevalue("set_var")
assert var.get(None) is value
"""
)

result = testdir.runpytest(*pytest_args)
result.assert_outcomes(passed=3 * len(get_all_backends()) + 3)

0 comments on commit d4977b0

Please sign in to comment.