Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implemented contextvars isolation in test runner using distinct tasks for each scope #616

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 81 additions & 35 deletions src/anyio/_backends/_asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -1808,8 +1808,49 @@ def _create_task_info(task: asyncio.Task) -> TaskInfo:
return TaskInfo(id(task), parent_id, task.get_name(), task.get_coro())


class TestRunner(abc.TestRunner):
class _TaskManager:
_send_stream: MemoryObjectSendStream[tuple[Awaitable[Any], asyncio.Future[Any]]]
_task: asyncio.Task
_loop: asyncio.AbstractEventLoop

@staticmethod
async def _run_coroutines(
receive_stream: MemoryObjectReceiveStream[
tuple[Awaitable[Any], asyncio.Future[Any]]
],
) -> None:
with receive_stream:
async for coro, future in receive_stream:
try:
retval = await coro
except BaseException as exc:
if not future.cancelled():
future.set_exception(exc)
else:
if not future.cancelled():
future.set_result(retval)

def __init__(self, loop: asyncio.AbstractEventLoop):
self._loop = loop
self._send_stream, receive_stream = create_memory_object_stream[
Tuple[Awaitable[Any], asyncio.Future]
](1)
self._task = loop.create_task(self._run_coroutines(receive_stream))

async def call_in_task(
self, func: Callable[..., Awaitable[T_Retval]], *args: object, **kwargs: object
) -> T_Retval:
coro = func(*args, **kwargs)
future: asyncio.Future[T_Retval] = self._loop.create_future()
self._send_stream.send_nowait((coro, future))
return await future

def shutdown(self) -> None:
self._send_stream.close()


class TestRunner(abc.TestRunner):
scopes = ("function", "class", "module", "package", "session")

def __init__(
self,
Expand All @@ -1825,7 +1866,7 @@ def __init__(

self._runner = Runner(debug=debug, loop_factory=loop_factory)
self._exceptions: list[BaseException] = []
self._runner_task: asyncio.Task | None = None
self._scope_task_managers: dict[str, _TaskManager] = {}

def __enter__(self) -> TestRunner:
self._runner.__enter__()
Expand Down Expand Up @@ -1862,55 +1903,55 @@ def _raise_async_exceptions(self) -> None:
"Multiple exceptions occurred in asynchronous callbacks", exceptions
)

@staticmethod
async def _run_tests_and_fixtures(
receive_stream: MemoryObjectReceiveStream[
tuple[Awaitable[T_Retval], asyncio.Future[T_Retval]]
],
) -> None:
with receive_stream:
async for coro, future in receive_stream:
try:
retval = await coro
except BaseException as exc:
if not future.cancelled():
future.set_exception(exc)
else:
if not future.cancelled():
future.set_result(retval)
async def _create_task(self) -> _TaskManager:
return _TaskManager(self.get_loop())

async def _get_task_manager(self, scope: str) -> _TaskManager:
if scope not in self.scopes:
raise ValueError(f"Unknown scope '{scope}'")
if scope in self._scope_task_managers:
return self._scope_task_managers[scope]

parent_task: _TaskManager | None = None
for parent_scope in self.scopes[self.scopes.index(scope) + 1 :]:
parent_task = self._scope_task_managers.get(parent_scope)
if parent_task is not None:
break

if parent_task is None:
result = await self._create_task()
else:
result = await parent_task.call_in_task(self._create_task)
self._scope_task_managers[scope] = result
return result

async def _call_in_runner_task(
self, func: Callable[..., Awaitable[T_Retval]], *args: object, **kwargs: object
self,
scope: str,
func: Callable[..., Awaitable[T_Retval]],
*args: object,
**kwargs: object,
) -> T_Retval:
if not self._runner_task:
self._send_stream, receive_stream = create_memory_object_stream[
Tuple[Awaitable[Any], asyncio.Future]
](1)
self._runner_task = self.get_loop().create_task(
self._run_tests_and_fixtures(receive_stream)
)

coro = func(*args, **kwargs)
future: asyncio.Future[T_Retval] = self.get_loop().create_future()
self._send_stream.send_nowait((coro, future))
return await future
task_manager = await self._get_task_manager(scope)
return await task_manager.call_in_task(func, *args, **kwargs)

def run_asyncgen_fixture(
self,
fixture_func: Callable[..., AsyncGenerator[T_Retval, Any]],
kwargs: dict[str, Any],
scope: str = "function",
) -> Iterable[T_Retval]:
asyncgen = fixture_func(**kwargs)
fixturevalue: T_Retval = self.get_loop().run_until_complete(
self._call_in_runner_task(asyncgen.asend, None)
self._call_in_runner_task(scope, asyncgen.asend, None)
)
self._raise_async_exceptions()

yield fixturevalue

try:
self.get_loop().run_until_complete(
self._call_in_runner_task(asyncgen.asend, None)
self._call_in_runner_task(scope, asyncgen.asend, None)
)
except StopAsyncIteration:
self._raise_async_exceptions()
Expand All @@ -1922,9 +1963,10 @@ def run_fixture(
self,
fixture_func: Callable[..., Coroutine[Any, Any, T_Retval]],
kwargs: dict[str, Any],
scope: str = "function",
) -> T_Retval:
retval = self.get_loop().run_until_complete(
self._call_in_runner_task(fixture_func, **kwargs)
self._call_in_runner_task(scope, fixture_func, **kwargs)
)
self._raise_async_exceptions()
return retval
Expand All @@ -1934,13 +1976,17 @@ def run_test(
) -> None:
try:
self.get_loop().run_until_complete(
self._call_in_runner_task(test_func, **kwargs)
self._call_in_runner_task("function", test_func, **kwargs)
)
except Exception as exc:
self._exceptions.append(exc)

self._raise_async_exceptions()

def close_scope(self, scope: str) -> None:
if scope in self._scope_task_managers:
self._scope_task_managers.pop(scope).shutdown()


class AsyncIOBackend(AsyncBackend):
@classmethod
Expand Down
125 changes: 98 additions & 27 deletions src/anyio/_backends/_trio.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import array
import contextvars
import math
import socket
import sys
Expand All @@ -25,8 +26,10 @@
Coroutine,
Generic,
Mapping,
MutableSequence,
NoReturn,
Sequence,
Tuple,
TypeVar,
cast,
overload,
Expand Down Expand Up @@ -61,7 +64,7 @@
from .._core._tasks import CancelScope as BaseCancelScope
from ..abc import IPSockAddrType, UDPPacketType, UNIXDatagramPacketType
from ..abc._eventloop import AsyncBackend
from ..streams.memory import MemoryObjectSendStream
from ..streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream

if sys.version_info < (3, 11):
from exceptiongroup import BaseExceptionGroup
Expand Down Expand Up @@ -719,54 +722,116 @@ async def __anext__(self) -> Signals:
#


class _TaskManager:
_send_stream: MemoryObjectSendStream[
tuple[Awaitable[Any], MutableSequence[Outcome]]
]
task: trio.lowlevel.Task

@staticmethod
async def _run_coroutines(
receive_stream: MemoryObjectReceiveStream[
tuple[Awaitable[Any], MutableSequence[Outcome]]
],
) -> None:
with receive_stream:
async for coro, outcome_holder in receive_stream:
try:
retval = await coro
except BaseException as exc:
outcome_holder.append(Error(exc))
else:
outcome_holder.append(Value(retval))

def __init__(self, context: contextvars.Context) -> None:
self._send_stream, receive_stream = create_memory_object_stream[
Tuple[Awaitable[Any], MutableSequence[Outcome]]
](1)
self.task = trio.lowlevel.spawn_system_task(
self._run_coroutines,
receive_stream,
context=context, # type: ignore[call-arg] # missing from trio-typing
)

def call_in_task(
self, func: Callable[..., Awaitable[Any]], *args: object, **kwargs: object
) -> list[Outcome]:
outcome_holder: list[Outcome] = []
self._send_stream.send_nowait((func(*args, **kwargs), outcome_holder))
return outcome_holder

def shutdown(self) -> None:
self._send_stream.close()


class TestRunner(abc.TestRunner):
scopes = ("function", "class", "module", "package", "session")

def __init__(self, **options: Any) -> None:
from queue import Queue

self._call_queue: Queue[Callable[..., object]] = Queue()
self._send_stream: MemoryObjectSendStream | None = None
self._end_event: trio.Event | None = None
self._options = options
self._scope_task_managers: dict[str, _TaskManager] = {}

def __exit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: types.TracebackType | None,
) -> None:
if self._send_stream:
self._send_stream.close()
while self._send_stream is not None:
if self._end_event:
self._end_event.set()
while self._end_event is not None:
self._call_queue.get()()

async def _run_tests_and_fixtures(self) -> None:
self._send_stream, receive_stream = create_memory_object_stream(1)
with receive_stream:
async for coro, outcome_holder in receive_stream:
try:
retval = await coro
except BaseException as exc:
outcome_holder.append(Error(exc))
else:
outcome_holder.append(Value(retval))
async def _wait_for_end(self) -> None:
self._end_event = trio.Event()
await self._end_event.wait()

def _main_task_finished(self, outcome: object) -> None:
self._send_stream = None
self._end_event = None

def _get_task_manager(self, scope: str) -> _TaskManager:
if scope not in self.scopes:
raise ValueError(f"Unknown scope '{scope}'")
if scope in self._scope_task_managers:
return self._scope_task_managers[scope]

context: contextvars.Context
for parent_scope in self.scopes[self.scopes.index(scope) + 1 :]:
parent = self._scope_task_managers.get(parent_scope)
if parent is not None:
context = parent.task.context.copy()
break
else:
context = contextvars.copy_context()

result = _TaskManager(context=context)
self._scope_task_managers[scope] = result
return result

def _call_in_runner_task(
self, func: Callable[..., Awaitable[T_Retval]], *args: object, **kwargs: object
self,
scope: str,
func: Callable[..., Awaitable[T_Retval]],
*args: object,
**kwargs: object,
) -> T_Retval:
if self._send_stream is None:
if self._end_event is None:
trio.lowlevel.start_guest_run(
self._run_tests_and_fixtures,
self._wait_for_end,
run_sync_soon_threadsafe=self._call_queue.put,
done_callback=self._main_task_finished,
**self._options,
)
while self._send_stream is None:
while self._end_event is None:
self._call_queue.get()()

outcome_holder: list[Outcome] = []
self._send_stream.send_nowait((func(*args, **kwargs), outcome_holder))
task_manager = self._get_task_manager(scope)

outcome_holder = task_manager.call_in_task(func, *args, **kwargs)
while not outcome_holder:
self._call_queue.get()()

Expand All @@ -776,31 +841,37 @@ def run_asyncgen_fixture(
self,
fixture_func: Callable[..., AsyncGenerator[T_Retval, Any]],
kwargs: dict[str, Any],
scope: str = "function",
) -> Iterable[T_Retval]:
asyncgen = fixture_func(**kwargs)
fixturevalue: T_Retval = self._call_in_runner_task(asyncgen.asend, None)
fixturevalue: T_Retval = self._call_in_runner_task(scope, asyncgen.asend, None)

yield fixturevalue

try:
self._call_in_runner_task(asyncgen.asend, None)
self._call_in_runner_task(scope, asyncgen.asend, None)
except StopAsyncIteration:
pass
else:
self._call_in_runner_task(asyncgen.aclose)
self._call_in_runner_task(scope, asyncgen.aclose)
raise RuntimeError("Async generator fixture did not stop")

def run_fixture(
self,
fixture_func: Callable[..., Coroutine[Any, Any, T_Retval]],
kwargs: dict[str, Any],
scope: str = "function",
) -> T_Retval:
return self._call_in_runner_task(fixture_func, **kwargs)
return self._call_in_runner_task(scope, fixture_func, **kwargs)

def run_test(
self, test_func: Callable[..., Coroutine[Any, Any, Any]], kwargs: dict[str, Any]
) -> None:
self._call_in_runner_task(test_func, **kwargs)
self._call_in_runner_task("function", test_func, **kwargs)

def close_scope(self, scope: str) -> None:
if scope in self._scope_task_managers:
self._scope_task_managers.pop(scope).shutdown()


class TrioBackend(AsyncBackend):
Expand Down
Loading