diff --git a/Lib/contextlib.py b/Lib/contextlib.py index 8343d7e5196713..1384d8903d17bf 100644 --- a/Lib/contextlib.py +++ b/Lib/contextlib.py @@ -191,6 +191,14 @@ class _AsyncGeneratorContextManager( ): """Helper for @asynccontextmanager decorator.""" + def __call__(self, func): + @wraps(func) + async def inner(*args, **kwds): + async with self.__class__(self.func, self.args, self.kwds): + return await func(*args, **kwds) + + return inner + async def __aenter__(self): # do not keep args and kwds alive unnecessarily # they are only needed for recreation, which is not possible anymore diff --git a/Lib/test/test_contextlib_async.py b/Lib/test/test_contextlib_async.py index 74fddef3f34ec5..c738bf3c0bdfeb 100644 --- a/Lib/test/test_contextlib_async.py +++ b/Lib/test/test_contextlib_async.py @@ -318,6 +318,82 @@ async def recursive(): self.assertEqual(ncols, 10) self.assertEqual(depth, 0) + @_async_test + async def test_decorator(self): + entered = False + + @asynccontextmanager + async def context(): + nonlocal entered + entered = True + yield + entered = False + + @context() + async def test(): + self.assertTrue(entered) + + self.assertFalse(entered) + await test() + self.assertFalse(entered) + + @_async_test + async def test_decorator_with_exception(self): + entered = False + + @asynccontextmanager + async def context(): + nonlocal entered + try: + entered = True + yield + finally: + entered = False + + @context() + async def test(): + self.assertTrue(entered) + raise NameError('foo') + + self.assertFalse(entered) + with self.assertRaisesRegex(NameError, 'foo'): + await test() + self.assertFalse(entered) + + @_async_test + async def test_decorating_method(self): + + @asynccontextmanager + async def context(): + yield + + + class Test(object): + + @context() + async def method(self, a, b, c=None): + self.a = a + self.b = b + self.c = c + + # these tests are for argument passing when used as a decorator + test = Test() + await test.method(1, 2) + self.assertEqual(test.a, 1) + self.assertEqual(test.b, 2) + self.assertEqual(test.c, None) + + test = Test() + await test.method('a', 'b', 'c') + self.assertEqual(test.a, 'a') + self.assertEqual(test.b, 'b') + self.assertEqual(test.c, 'c') + + test = Test() + await test.method(a=1, b=2) + self.assertEqual(test.a, 1) + self.assertEqual(test.b, 2) + class AclosingTestCase(unittest.TestCase): diff --git a/Misc/NEWS.d/next/Library/2019-10-08-14-08-59.bpo-38415.N1bUw6.rst b/Misc/NEWS.d/next/Library/2019-10-08-14-08-59.bpo-38415.N1bUw6.rst new file mode 100644 index 00000000000000..f99bf0d19b1f8e --- /dev/null +++ b/Misc/NEWS.d/next/Library/2019-10-08-14-08-59.bpo-38415.N1bUw6.rst @@ -0,0 +1,3 @@ +Added missing behavior to :func:`contextlib.asynccontextmanager` to match +:func:`contextlib.contextmanager` so decorated functions can themselves be +decorators.