diff --git a/Lib/test/test_contextlib.py b/Lib/test/test_contextlib.py index 093f2593a4a464..0a79b3617defe8 100644 --- a/Lib/test/test_contextlib.py +++ b/Lib/test/test_contextlib.py @@ -162,6 +162,15 @@ def whoo(): # The "gen" attribute is an implementation detail. self.assertFalse(ctx.gen.gi_suspended) + def test_contextmanager_trap_no_yield(self): + @contextmanager + def whoo(): + if False: + yield + ctx = whoo() + with self.assertRaises(RuntimeError): + ctx.__enter__() + def test_contextmanager_trap_second_yield(self): @contextmanager def whoo(): @@ -175,6 +184,19 @@ def whoo(): # The "gen" attribute is an implementation detail. self.assertFalse(ctx.gen.gi_suspended) + def test_contextmanager_non_normalised(self): + @contextmanager + def whoo(): + try: + yield + except RuntimeError: + raise SyntaxError + + ctx = whoo() + ctx.__enter__() + with self.assertRaises(SyntaxError): + ctx.__exit__(RuntimeError, None, None) + def test_contextmanager_except(self): state = [] @contextmanager @@ -254,6 +276,25 @@ def test_issue29692(): self.assertEqual(ex.args[0], 'issue29692:Unchained') self.assertIsNone(ex.__cause__) + def test_contextmanager_wrap_runtimeerror(self): + @contextmanager + def woohoo(): + try: + yield + except Exception as exc: + raise RuntimeError(f'caught {exc}') from exc + + with self.assertRaises(RuntimeError): + with woohoo(): + 1 / 0 + + # If the context manager wrapped StopIteration in a RuntimeError, + # we also unwrap it, because we can't tell whether the wrapping was + # done by the generator machinery or by the generator itself. + with self.assertRaises(StopIteration): + with woohoo(): + raise StopIteration + def _create_contextmanager_attribs(self): def attribs(**kw): def decorate(func): @@ -265,6 +306,7 @@ def decorate(func): @attribs(foo='bar') def baz(spam): """Whee!""" + yield return baz def test_contextmanager_attribs(self): @@ -321,8 +363,11 @@ def woohoo(a, *, b): def test_recursive(self): depth = 0 + ncols = 0 @contextmanager def woohoo(): + nonlocal ncols + ncols += 1 nonlocal depth before = depth depth += 1 @@ -336,6 +381,7 @@ def recursive(): recursive() recursive() + self.assertEqual(ncols, 10) self.assertEqual(depth, 0)