Skip to content
This repository has been archived by the owner on Apr 24, 2023. It is now read-only.

Commit

Permalink
fix: Support passing name and context kwarg to create_task() (#10)
Browse files Browse the repository at this point in the history
  • Loading branch information
achimnol authored Apr 16, 2023
1 parent 75eb96a commit 88bc66f
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 0 deletions.
6 changes: 6 additions & 0 deletions aiomonitor/monitor.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import asyncio
import contextvars
import functools
import logging
import os
Expand Down Expand Up @@ -318,6 +319,9 @@ def _create_task(
self,
loop: asyncio.AbstractEventLoop,
coro: Coroutine[Any, Any, T_co] | Generator[Any, None, T_co],
*,
name: str | None = None,
context: contextvars.Context | None = None,
) -> asyncio.Future[T_co]:
assert loop is self._monitored_loop
try:
Expand All @@ -331,6 +335,8 @@ def _create_task(
cancellation_chain_queue=self._cancellation_chain_queue.sync_q,
persistent=persistent,
loop=self._monitored_loop,
name=name, # since Python 3.8
context=context, # since Python 3.11
)
task._orig_coro = cast(Coroutine[Any, Any, T_co], coro)
self._created_tracebacks[task] = _extract_stack_from_frame(sys._getframe())[
Expand Down
27 changes: 27 additions & 0 deletions tests/test_monitor.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import asyncio
import contextvars
import telnetlib
import threading
import time
Expand Down Expand Up @@ -142,6 +143,32 @@ def test_basic_monitor(monitor, tn_client, loop):
assert "No task 123" in resp


myvar = contextvars.ContextVar("myvar", default=42)


def test_monitor_task_factory():
ctx = contextvars.Context()
# This context is bound at the outermost scope,
# and inside it the initial value of myvar is kept intact.

async def do():
await asyncio.sleep(0)
assert myvar.get() == 42 # we are referring the outer context
myself = asyncio.current_task()
assert myself is not None
assert myself.get_name() == "mytask"

async def main():
myvar.set(99) # override in the current task's context
loop = asyncio.get_running_loop()
with Monitor(loop, console_enabled=False, hook_task_factory=True):
t = asyncio.create_task(do(), name="mytask", context=ctx)
await t
assert myvar.get() == 99

asyncio.run(main())


def test_cancel_where_tasks(monitor, tn_client, loop):
tn = tn_client

Expand Down

0 comments on commit 88bc66f

Please sign in to comment.