diff --git a/Doc/library/asyncio-eventloop.rst b/Doc/library/asyncio-eventloop.rst index 4776853b5a56d8..4f0f8c06fee787 100644 --- a/Doc/library/asyncio-eventloop.rst +++ b/Doc/library/asyncio-eventloop.rst @@ -330,7 +330,7 @@ Creating Futures and Tasks .. versionadded:: 3.5.2 -.. method:: loop.create_task(coro, *, name=None) +.. method:: loop.create_task(coro, *, name=None, context=None) Schedule the execution of a :ref:`coroutine`. Return a :class:`Task` object. @@ -342,9 +342,16 @@ Creating Futures and Tasks If the *name* argument is provided and not ``None``, it is set as the name of the task using :meth:`Task.set_name`. + An optional keyword-only *context* argument allows specifying a + custom :class:`contextvars.Context` for the *coro* to run in. + The current context copy is created when no *context* is provided. + .. versionchanged:: 3.8 Added the *name* parameter. + .. versionchanged:: 3.11 + Added the *context* parameter. + .. method:: loop.set_task_factory(factory) Set a task factory that will be used by @@ -352,7 +359,7 @@ Creating Futures and Tasks If *factory* is ``None`` the default task factory will be set. Otherwise, *factory* must be a *callable* with the signature matching - ``(loop, coro)``, where *loop* is a reference to the active + ``(loop, coro, context=None)``, where *loop* is a reference to the active event loop, and *coro* is a coroutine object. The callable must return a :class:`asyncio.Future`-compatible object. diff --git a/Doc/library/asyncio-task.rst b/Doc/library/asyncio-task.rst index b30b2894277a2a..faf5910124f9b7 100644 --- a/Doc/library/asyncio-task.rst +++ b/Doc/library/asyncio-task.rst @@ -244,7 +244,7 @@ Running an asyncio Program Creating Tasks ============== -.. function:: create_task(coro, *, name=None) +.. function:: create_task(coro, *, name=None, context=None) Wrap the *coro* :ref:`coroutine ` into a :class:`Task` and schedule its execution. Return the Task object. @@ -252,6 +252,10 @@ Creating Tasks If *name* is not ``None``, it is set as the name of the task using :meth:`Task.set_name`. + An optional keyword-only *context* argument allows specifying a + custom :class:`contextvars.Context` for the *coro* to run in. + The current context copy is created when no *context* is provided. + The task is executed in the loop returned by :func:`get_running_loop`, :exc:`RuntimeError` is raised if there is no running loop in current thread. @@ -281,6 +285,9 @@ Creating Tasks .. versionchanged:: 3.8 Added the *name* parameter. + .. versionchanged:: 3.11 + Added the *context* parameter. + Sleeping ======== diff --git a/Lib/asyncio/base_events.py b/Lib/asyncio/base_events.py index 51c4e664d74e9d..5eea1658df8f6f 100644 --- a/Lib/asyncio/base_events.py +++ b/Lib/asyncio/base_events.py @@ -426,18 +426,23 @@ def create_future(self): """Create a Future object attached to the loop.""" return futures.Future(loop=self) - def create_task(self, coro, *, name=None): + def create_task(self, coro, *, name=None, context=None): """Schedule a coroutine object. Return a task object. """ self._check_closed() if self._task_factory is None: - task = tasks.Task(coro, loop=self, name=name) + task = tasks.Task(coro, loop=self, name=name, context=context) if task._source_traceback: del task._source_traceback[-1] else: - task = self._task_factory(self, coro) + if context is None: + # Use legacy API if context is not needed + task = self._task_factory(self, coro) + else: + task = self._task_factory(self, coro, context=context) + tasks._set_task_name(task, name) return task diff --git a/Lib/asyncio/events.py b/Lib/asyncio/events.py index e682a192a887f2..0d26ea545baa5d 100644 --- a/Lib/asyncio/events.py +++ b/Lib/asyncio/events.py @@ -274,7 +274,7 @@ def create_future(self): # Method scheduling a coroutine object: create a task. - def create_task(self, coro, *, name=None): + def create_task(self, coro, *, name=None, context=None): raise NotImplementedError # Methods for interacting with threads. diff --git a/Lib/asyncio/taskgroups.py b/Lib/asyncio/taskgroups.py index c3ce94a4dd0a95..6af21f3a15d93a 100644 --- a/Lib/asyncio/taskgroups.py +++ b/Lib/asyncio/taskgroups.py @@ -138,12 +138,15 @@ async def __aexit__(self, et, exc, tb): me = BaseExceptionGroup('unhandled errors in a TaskGroup', errors) raise me from None - def create_task(self, coro, *, name=None): + def create_task(self, coro, *, name=None, context=None): if not self._entered: raise RuntimeError(f"TaskGroup {self!r} has not been entered") if self._exiting and self._unfinished_tasks == 0: raise RuntimeError(f"TaskGroup {self!r} is finished") - task = self._loop.create_task(coro) + if context is None: + task = self._loop.create_task(coro) + else: + task = self._loop.create_task(coro, context=context) tasks._set_task_name(task, name) task.add_done_callback(self._on_task_done) self._unfinished_tasks += 1 diff --git a/Lib/asyncio/tasks.py b/Lib/asyncio/tasks.py index e604298e5efc01..b4f1eed91a9321 100644 --- a/Lib/asyncio/tasks.py +++ b/Lib/asyncio/tasks.py @@ -93,7 +93,7 @@ class Task(futures._PyFuture): # Inherit Python Task implementation # status is still pending _log_destroy_pending = True - def __init__(self, coro, *, loop=None, name=None): + def __init__(self, coro, *, loop=None, name=None, context=None): super().__init__(loop=loop) if self._source_traceback: del self._source_traceback[-1] @@ -112,7 +112,10 @@ def __init__(self, coro, *, loop=None, name=None): self._must_cancel = False self._fut_waiter = None self._coro = coro - self._context = contextvars.copy_context() + if context is None: + self._context = contextvars.copy_context() + else: + self._context = context self._loop.call_soon(self.__step, context=self._context) _register_task(self) @@ -360,13 +363,18 @@ def __wakeup(self, future): Task = _CTask = _asyncio.Task -def create_task(coro, *, name=None): +def create_task(coro, *, name=None, context=None): """Schedule the execution of a coroutine object in a spawn task. Return a Task object. """ loop = events.get_running_loop() - task = loop.create_task(coro) + if context is None: + # Use legacy API if context is not needed + task = loop.create_task(coro) + else: + task = loop.create_task(coro, context=context) + _set_task_name(task, name) return task diff --git a/Lib/test/test_asyncio/test_taskgroups.py b/Lib/test/test_asyncio/test_taskgroups.py index df51528e107939..dea5d6de524204 100644 --- a/Lib/test/test_asyncio/test_taskgroups.py +++ b/Lib/test/test_asyncio/test_taskgroups.py @@ -2,6 +2,7 @@ import asyncio +import contextvars from asyncio import taskgroups import unittest @@ -708,6 +709,23 @@ async def coro(): t = g.create_task(coro(), name="yolo") self.assertEqual(t.get_name(), "yolo") + async def test_taskgroup_task_context(self): + cvar = contextvars.ContextVar('cvar') + + async def coro(val): + await asyncio.sleep(0) + cvar.set(val) + + async with taskgroups.TaskGroup() as g: + ctx = contextvars.copy_context() + self.assertIsNone(ctx.get(cvar)) + t1 = g.create_task(coro(1), context=ctx) + await t1 + self.assertEqual(1, ctx.get(cvar)) + t2 = g.create_task(coro(2), context=ctx) + await t2 + self.assertEqual(2, ctx.get(cvar)) + if __name__ == "__main__": unittest.main() diff --git a/Lib/test/test_asyncio/test_tasks.py b/Lib/test/test_asyncio/test_tasks.py index 95fabf728818bb..b6ef62725166dc 100644 --- a/Lib/test/test_asyncio/test_tasks.py +++ b/Lib/test/test_asyncio/test_tasks.py @@ -95,8 +95,8 @@ class BaseTaskTests: Task = None Future = None - def new_task(self, loop, coro, name='TestTask'): - return self.__class__.Task(coro, loop=loop, name=name) + def new_task(self, loop, coro, name='TestTask', context=None): + return self.__class__.Task(coro, loop=loop, name=name, context=context) def new_future(self, loop): return self.__class__.Future(loop=loop) @@ -2527,6 +2527,90 @@ async def main(): self.assertEqual(cvar.get(), -1) + def test_context_4(self): + cvar = contextvars.ContextVar('cvar') + + async def coro(val): + await asyncio.sleep(0) + cvar.set(val) + + async def main(): + ret = [] + ctx = contextvars.copy_context() + ret.append(ctx.get(cvar)) + t1 = self.new_task(loop, coro(1), context=ctx) + await t1 + ret.append(ctx.get(cvar)) + t2 = self.new_task(loop, coro(2), context=ctx) + await t2 + ret.append(ctx.get(cvar)) + return ret + + loop = asyncio.new_event_loop() + try: + task = self.new_task(loop, main()) + ret = loop.run_until_complete(task) + finally: + loop.close() + + self.assertEqual([None, 1, 2], ret) + + def test_context_5(self): + cvar = contextvars.ContextVar('cvar') + + async def coro(val): + await asyncio.sleep(0) + cvar.set(val) + + async def main(): + ret = [] + ctx = contextvars.copy_context() + ret.append(ctx.get(cvar)) + t1 = asyncio.create_task(coro(1), context=ctx) + await t1 + ret.append(ctx.get(cvar)) + t2 = asyncio.create_task(coro(2), context=ctx) + await t2 + ret.append(ctx.get(cvar)) + return ret + + loop = asyncio.new_event_loop() + try: + task = self.new_task(loop, main()) + ret = loop.run_until_complete(task) + finally: + loop.close() + + self.assertEqual([None, 1, 2], ret) + + def test_context_6(self): + cvar = contextvars.ContextVar('cvar') + + async def coro(val): + await asyncio.sleep(0) + cvar.set(val) + + async def main(): + ret = [] + ctx = contextvars.copy_context() + ret.append(ctx.get(cvar)) + t1 = loop.create_task(coro(1), context=ctx) + await t1 + ret.append(ctx.get(cvar)) + t2 = loop.create_task(coro(2), context=ctx) + await t2 + ret.append(ctx.get(cvar)) + return ret + + loop = asyncio.new_event_loop() + try: + task = loop.create_task(main()) + ret = loop.run_until_complete(task) + finally: + loop.close() + + self.assertEqual([None, 1, 2], ret) + def test_get_coro(self): loop = asyncio.new_event_loop() coro = coroutine_function() diff --git a/Lib/unittest/async_case.py b/Lib/unittest/async_case.py index 3c57bb5cda2c03..25adc3deff63d1 100644 --- a/Lib/unittest/async_case.py +++ b/Lib/unittest/async_case.py @@ -1,4 +1,5 @@ import asyncio +import contextvars import inspect import warnings @@ -34,7 +35,7 @@ class IsolatedAsyncioTestCase(TestCase): def __init__(self, methodName='runTest'): super().__init__(methodName) self._asyncioTestLoop = None - self._asyncioCallsQueue = None + self._asyncioTestContext = contextvars.copy_context() async def asyncSetUp(self): pass @@ -58,7 +59,7 @@ def addAsyncCleanup(self, func, /, *args, **kwargs): self.addCleanup(*(func, *args), **kwargs) def _callSetUp(self): - self.setUp() + self._asyncioTestContext.run(self.setUp) self._callAsync(self.asyncSetUp) def _callTestMethod(self, method): @@ -68,47 +69,30 @@ def _callTestMethod(self, method): def _callTearDown(self): self._callAsync(self.asyncTearDown) - self.tearDown() + self._asyncioTestContext.run(self.tearDown) def _callCleanup(self, function, *args, **kwargs): self._callMaybeAsync(function, *args, **kwargs) def _callAsync(self, func, /, *args, **kwargs): assert self._asyncioTestLoop is not None, 'asyncio test loop is not initialized' - ret = func(*args, **kwargs) - assert inspect.isawaitable(ret), f'{func!r} returned non-awaitable' - fut = self._asyncioTestLoop.create_future() - self._asyncioCallsQueue.put_nowait((fut, ret)) - return self._asyncioTestLoop.run_until_complete(fut) + assert inspect.iscoroutinefunction(func), f'{func!r} is not an async function' + task = self._asyncioTestLoop.create_task( + func(*args, **kwargs), + context=self._asyncioTestContext, + ) + return self._asyncioTestLoop.run_until_complete(task) def _callMaybeAsync(self, func, /, *args, **kwargs): assert self._asyncioTestLoop is not None, 'asyncio test loop is not initialized' - ret = func(*args, **kwargs) - if inspect.isawaitable(ret): - fut = self._asyncioTestLoop.create_future() - self._asyncioCallsQueue.put_nowait((fut, ret)) - return self._asyncioTestLoop.run_until_complete(fut) + if inspect.iscoroutinefunction(func): + task = self._asyncioTestLoop.create_task( + func(*args, **kwargs), + context=self._asyncioTestContext, + ) + return self._asyncioTestLoop.run_until_complete(task) else: - return ret - - async def _asyncioLoopRunner(self, fut): - self._asyncioCallsQueue = queue = asyncio.Queue() - fut.set_result(None) - while True: - query = await queue.get() - queue.task_done() - if query is None: - return - fut, awaitable = query - try: - ret = await awaitable - if not fut.cancelled(): - fut.set_result(ret) - except (SystemExit, KeyboardInterrupt): - raise - except (BaseException, asyncio.CancelledError) as ex: - if not fut.cancelled(): - fut.set_exception(ex) + return self._asyncioTestContext.run(func, *args, **kwargs) def _setupAsyncioLoop(self): assert self._asyncioTestLoop is None, 'asyncio test loop already initialized' @@ -116,16 +100,11 @@ def _setupAsyncioLoop(self): asyncio.set_event_loop(loop) loop.set_debug(True) self._asyncioTestLoop = loop - fut = loop.create_future() - self._asyncioCallsTask = loop.create_task(self._asyncioLoopRunner(fut)) - loop.run_until_complete(fut) def _tearDownAsyncioLoop(self): assert self._asyncioTestLoop is not None, 'asyncio test loop is not initialized' loop = self._asyncioTestLoop self._asyncioTestLoop = None - self._asyncioCallsQueue.put_nowait(None) - loop.run_until_complete(self._asyncioCallsQueue.join()) try: # cancel all tasks diff --git a/Lib/unittest/test/test_async_case.py b/Lib/unittest/test/test_async_case.py index 3717486b26563e..7dc8a6bffa019e 100644 --- a/Lib/unittest/test/test_async_case.py +++ b/Lib/unittest/test/test_async_case.py @@ -1,4 +1,5 @@ import asyncio +import contextvars import unittest from test import support @@ -11,6 +12,9 @@ def tearDownModule(): asyncio.set_event_loop_policy(None) +VAR = contextvars.ContextVar('VAR', default=()) + + class TestAsyncCase(unittest.TestCase): maxDiff = None @@ -24,22 +28,26 @@ class Test(unittest.IsolatedAsyncioTestCase): def setUp(self): self.assertEqual(events, []) events.append('setUp') + VAR.set(VAR.get() + ('setUp',)) async def asyncSetUp(self): self.assertEqual(events, ['setUp']) events.append('asyncSetUp') + VAR.set(VAR.get() + ('asyncSetUp',)) self.addAsyncCleanup(self.on_cleanup1) async def test_func(self): self.assertEqual(events, ['setUp', 'asyncSetUp']) events.append('test') + VAR.set(VAR.get() + ('test',)) self.addAsyncCleanup(self.on_cleanup2) async def asyncTearDown(self): self.assertEqual(events, ['setUp', 'asyncSetUp', 'test']) + VAR.set(VAR.get() + ('asyncTearDown',)) events.append('asyncTearDown') def tearDown(self): @@ -48,6 +56,7 @@ def tearDown(self): 'test', 'asyncTearDown']) events.append('tearDown') + VAR.set(VAR.get() + ('tearDown',)) async def on_cleanup1(self): self.assertEqual(events, ['setUp', @@ -57,6 +66,9 @@ async def on_cleanup1(self): 'tearDown', 'cleanup2']) events.append('cleanup1') + VAR.set(VAR.get() + ('cleanup1',)) + nonlocal cvar + cvar = VAR.get() async def on_cleanup2(self): self.assertEqual(events, ['setUp', @@ -65,8 +77,10 @@ async def on_cleanup2(self): 'asyncTearDown', 'tearDown']) events.append('cleanup2') + VAR.set(VAR.get() + ('cleanup2',)) events = [] + cvar = () test = Test("test_func") result = test.run() self.assertEqual(result.errors, []) @@ -74,13 +88,17 @@ async def on_cleanup2(self): expected = ['setUp', 'asyncSetUp', 'test', 'asyncTearDown', 'tearDown', 'cleanup2', 'cleanup1'] self.assertEqual(events, expected) + self.assertEqual(cvar, tuple(expected)) events = [] + cvar = () test = Test("test_func") test.debug() self.assertEqual(events, expected) + self.assertEqual(cvar, tuple(expected)) test.doCleanups() self.assertEqual(events, expected) + self.assertEqual(cvar, tuple(expected)) def test_exception_in_setup(self): class Test(unittest.IsolatedAsyncioTestCase): diff --git a/Misc/NEWS.d/next/Library/2022-03-12-12-34-13.bpo-46994.d7hPdz.rst b/Misc/NEWS.d/next/Library/2022-03-12-12-34-13.bpo-46994.d7hPdz.rst new file mode 100644 index 00000000000000..765936f1efb594 --- /dev/null +++ b/Misc/NEWS.d/next/Library/2022-03-12-12-34-13.bpo-46994.d7hPdz.rst @@ -0,0 +1,2 @@ +Accept explicit contextvars.Context in :func:`asyncio.create_task` and +:meth:`asyncio.loop.create_task`. diff --git a/Modules/_asynciomodule.c b/Modules/_asynciomodule.c index 2a6c0b335ccfb0..4b12744e625e19 100644 --- a/Modules/_asynciomodule.c +++ b/Modules/_asynciomodule.c @@ -2003,14 +2003,16 @@ _asyncio.Task.__init__ * loop: object = None name: object = None + context: object = None A coroutine wrapped in a Future. [clinic start generated code]*/ static int _asyncio_Task___init___impl(TaskObj *self, PyObject *coro, PyObject *loop, - PyObject *name) -/*[clinic end generated code: output=88b12b83d570df50 input=352a3137fe60091d]*/ + PyObject *name, PyObject *context) +/*[clinic end generated code: output=49ac96fe33d0e5c7 input=924522490c8ce825]*/ + { if (future_init((FutureObj*)self, loop)) { return -1; @@ -2028,9 +2030,13 @@ _asyncio_Task___init___impl(TaskObj *self, PyObject *coro, PyObject *loop, return -1; } - Py_XSETREF(self->task_context, PyContext_CopyCurrent()); - if (self->task_context == NULL) { - return -1; + if (context == Py_None) { + Py_XSETREF(self->task_context, PyContext_CopyCurrent()); + if (self->task_context == NULL) { + return -1; + } + } else { + self->task_context = Py_NewRef(context); } Py_CLEAR(self->task_fut_waiter); diff --git a/Modules/clinic/_asynciomodule.c.h b/Modules/clinic/_asynciomodule.c.h index 2b84ef0a477c71..4a90dfa67c22b2 100644 --- a/Modules/clinic/_asynciomodule.c.h +++ b/Modules/clinic/_asynciomodule.c.h @@ -310,28 +310,29 @@ _asyncio_Future__repr_info(FutureObj *self, PyObject *Py_UNUSED(ignored)) } PyDoc_STRVAR(_asyncio_Task___init____doc__, -"Task(coro, *, loop=None, name=None)\n" +"Task(coro, *, loop=None, name=None, context=None)\n" "--\n" "\n" "A coroutine wrapped in a Future."); static int _asyncio_Task___init___impl(TaskObj *self, PyObject *coro, PyObject *loop, - PyObject *name); + PyObject *name, PyObject *context); static int _asyncio_Task___init__(PyObject *self, PyObject *args, PyObject *kwargs) { int return_value = -1; - static const char * const _keywords[] = {"coro", "loop", "name", NULL}; + static const char * const _keywords[] = {"coro", "loop", "name", "context", NULL}; static _PyArg_Parser _parser = {NULL, _keywords, "Task", 0}; - PyObject *argsbuf[3]; + PyObject *argsbuf[4]; PyObject * const *fastargs; Py_ssize_t nargs = PyTuple_GET_SIZE(args); Py_ssize_t noptargs = nargs + (kwargs ? PyDict_GET_SIZE(kwargs) : 0) - 1; PyObject *coro; PyObject *loop = Py_None; PyObject *name = Py_None; + PyObject *context = Py_None; fastargs = _PyArg_UnpackKeywords(_PyTuple_CAST(args)->ob_item, nargs, kwargs, NULL, &_parser, 1, 1, 0, argsbuf); if (!fastargs) { @@ -347,9 +348,15 @@ _asyncio_Task___init__(PyObject *self, PyObject *args, PyObject *kwargs) goto skip_optional_kwonly; } } - name = fastargs[2]; + if (fastargs[2]) { + name = fastargs[2]; + if (!--noptargs) { + goto skip_optional_kwonly; + } + } + context = fastargs[3]; skip_optional_kwonly: - return_value = _asyncio_Task___init___impl((TaskObj *)self, coro, loop, name); + return_value = _asyncio_Task___init___impl((TaskObj *)self, coro, loop, name, context); exit: return return_value; @@ -917,4 +924,4 @@ _asyncio__leave_task(PyObject *module, PyObject *const *args, Py_ssize_t nargs, exit: return return_value; } -/*[clinic end generated code: output=344927e9b6016ad7 input=a9049054013a1b77]*/ +/*[clinic end generated code: output=540ed3caf5a4d57d input=a9049054013a1b77]*/