diff --git a/src/exceptiongroup/_exceptions.py b/src/exceptiongroup/_exceptions.py index 2986fb6..a45d25d 100644 --- a/src/exceptiongroup/_exceptions.py +++ b/src/exceptiongroup/_exceptions.py @@ -42,6 +42,17 @@ def get_condition_filter( raise TypeError("expected a function, exception type or tuple of exception types") +def _derive_and_copy_attributes(self, excs): + eg = self.derive(excs) + eg.__cause__ = self.__cause__ + eg.__context__ = self.__context__ + eg.__traceback__ = self.__traceback__ + if hasattr(self, "__notes__"): + # Create a new list so that add_note() only affects one exceptiongroup + eg.__notes__ = list(self.__notes__) + return eg + + class BaseExceptionGroup(BaseException, Generic[_BaseExceptionT_co]): """A combination of multiple unrelated exceptions.""" @@ -154,10 +165,7 @@ def subgroup( if not modified: return self elif exceptions: - group = self.derive(exceptions) - group.__cause__ = self.__cause__ - group.__context__ = self.__context__ - group.__traceback__ = self.__traceback__ + group = _derive_and_copy_attributes(self, exceptions) return group else: return None @@ -230,17 +238,13 @@ def split( matching_group: _BaseExceptionGroupSelf | None = None if matching_exceptions: - matching_group = self.derive(matching_exceptions) - matching_group.__cause__ = self.__cause__ - matching_group.__context__ = self.__context__ - matching_group.__traceback__ = self.__traceback__ + matching_group = _derive_and_copy_attributes(self, matching_exceptions) nonmatching_group: _BaseExceptionGroupSelf | None = None if nonmatching_exceptions: - nonmatching_group = self.derive(nonmatching_exceptions) - nonmatching_group.__cause__ = self.__cause__ - nonmatching_group.__context__ = self.__context__ - nonmatching_group.__traceback__ = self.__traceback__ + nonmatching_group = _derive_and_copy_attributes( + self, nonmatching_exceptions + ) return matching_group, nonmatching_group @@ -257,12 +261,7 @@ def derive( def derive( self, __excs: Sequence[_BaseExceptionT] ) -> BaseExceptionGroup[_BaseExceptionT]: - eg = BaseExceptionGroup(self.message, __excs) - if hasattr(self, "__notes__"): - # Create a new list so that add_note() only affects one exceptiongroup - eg.__notes__ = list(self.__notes__) - - return eg + return BaseExceptionGroup(self.message, __excs) def __str__(self) -> str: suffix = "" if len(self._exceptions) == 1 else "s" diff --git a/tests/test_exceptions.py b/tests/test_exceptions.py index bd20f15..65c1c67 100644 --- a/tests/test_exceptions.py +++ b/tests/test_exceptions.py @@ -205,6 +205,13 @@ def test_notes_is_list_of_strings_if_it_exists(self): eg.add_note(note) self.assertEqual(eg.__notes__, [note]) + def test_derive_doesn_copy_notes(self): + eg = create_simple_eg() + eg.add_note("hello") + assert eg.__notes__ == ["hello"] + eg2 = eg.derive([ValueError()]) + assert not hasattr(eg2, "__notes__") + class ExceptionGroupTestBase(unittest.TestCase): def assertMatchesTemplate(self, exc, exc_type, template): @@ -786,6 +793,7 @@ def derive(self, excs): except ValueError as ve: raise EG("eg", [ve, nested], 42) except EG as e: + e.add_note("hello") eg = e self.assertMatchesTemplate(eg, EG, [ValueError(1), [TypeError(2)]]) @@ -796,29 +804,35 @@ def derive(self, excs): self.assertMatchesTemplate(rest, EG, [ValueError(1), [TypeError(2)]]) self.assertEqual(rest.code, 42) self.assertEqual(rest.exceptions[1].code, 101) + self.assertEqual(rest.__notes__, ["hello"]) # Match Everything match, rest = self.split_exception_group(eg, (ValueError, TypeError)) self.assertMatchesTemplate(match, EG, [ValueError(1), [TypeError(2)]]) self.assertEqual(match.code, 42) self.assertEqual(match.exceptions[1].code, 101) + self.assertEqual(match.__notes__, ["hello"]) self.assertIsNone(rest) # Match ValueErrors match, rest = self.split_exception_group(eg, ValueError) self.assertMatchesTemplate(match, EG, [ValueError(1)]) self.assertEqual(match.code, 42) + self.assertEqual(match.__notes__, ["hello"]) self.assertMatchesTemplate(rest, EG, [[TypeError(2)]]) self.assertEqual(rest.code, 42) self.assertEqual(rest.exceptions[0].code, 101) + self.assertEqual(rest.__notes__, ["hello"]) # Match TypeErrors match, rest = self.split_exception_group(eg, TypeError) self.assertMatchesTemplate(match, EG, [[TypeError(2)]]) self.assertEqual(match.code, 42) self.assertEqual(match.exceptions[0].code, 101) + self.assertEqual(match.__notes__, ["hello"]) self.assertMatchesTemplate(rest, EG, [ValueError(1)]) self.assertEqual(rest.code, 42) + self.assertEqual(rest.__notes__, ["hello"]) def test_repr():