Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

gh-104745: Limit starting a patcher more than once without stopping it #126649

Merged
merged 11 commits into from
Nov 13, 2024
52 changes: 50 additions & 2 deletions Lib/test/test_unittest/testmock/testpatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -745,6 +745,54 @@ def test_stop_idempotent(self):
self.assertIsNone(patcher.stop())


def test_exit_idempotent(self):
patcher = patch(foo_name, 'bar', 3)
with patcher:
patcher.stop()


def test_second_start_failure(self):
patcher = patch(foo_name, 'bar', 3)
patcher.start()
try:
self.assertRaises(RuntimeError, patcher.start)
finally:
patcher.stop()


def test_second_enter_failure(self):
patcher = patch(foo_name, 'bar', 3)
with patcher:
self.assertRaises(RuntimeError, patcher.start)


def test_second_start_after_stop(self):
patcher = patch(foo_name, 'bar', 3)
patcher.start()
patcher.stop()
patcher.start()
patcher.stop()


def test_property_setters(self):
mock_object = Mock()
mock_bar = mock_object.bar
patcher = patch.object(mock_object, 'bar', 'x')
with patcher:
self.assertEqual(patcher.is_local, False)
self.assertIs(patcher.target, mock_object)
self.assertEqual(patcher.temp_original, mock_bar)
patcher.is_local = True
patcher.target = mock_bar
patcher.temp_original = mock_object
self.assertEqual(patcher.is_local, True)
self.assertIs(patcher.target, mock_bar)
self.assertEqual(patcher.temp_original, mock_object)
# if changes are left intact, they may lead to disruption as shown below (it might be what someone needs though)
self.assertEqual(mock_bar.bar, mock_object)
self.assertEqual(mock_object.bar, 'x')


def test_patchobject_start_stop(self):
original = something
patcher = patch.object(PTModule, 'something', 'foo')
Expand Down Expand Up @@ -1098,7 +1146,7 @@ def test_new_callable_patch(self):

self.assertIsNot(m1, m2)
for mock in m1, m2:
self.assertNotCallable(m1)
self.assertNotCallable(mock)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice spot!



def test_new_callable_patch_object(self):
Expand All @@ -1111,7 +1159,7 @@ def test_new_callable_patch_object(self):

self.assertIsNot(m1, m2)
for mock in m1, m2:
self.assertNotCallable(m1)
self.assertNotCallable(mock)


def test_new_callable_keyword_arguments(self):
Expand Down
84 changes: 69 additions & 15 deletions Lib/unittest/mock.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@


import asyncio
from collections import namedtuple
import contextlib
import io
import inspect
Expand Down Expand Up @@ -1320,6 +1321,9 @@ def _check_spec_arg_typos(kwargs_to_check):
)


_PatchContext = namedtuple("_PatchContext", "exit_stack is_local original target")


class _patch(object):

attribute_name = None
Expand Down Expand Up @@ -1360,6 +1364,7 @@ def __init__(
self.autospec = autospec
self.kwargs = kwargs
self.additional_patchers = []
self._context = None


def copy(self):
Expand Down Expand Up @@ -1469,13 +1474,58 @@ def get_original(self):
)
return original, local

@property
def is_started(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Previously, those were writable attributes and now it's no more the case. Could there be some code in the wild assuming so? (for instance pytest which makes quite hacky things, though I don't know if they do hacky things with this specific part of CPython).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess so... I have committed property setters just in case. Will write tests if needed when we figure out what to do with temp_original: do we preserve it and somehow deprecate or anything else

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed them. A bit ugly but anyway these setters exist not for an intended usecase but for backwards compatibility only

return self._context is not None

@property
def is_local(self):
return self._context.is_local

@property
def target(self):
return self._context.target

@property
def temp_original(self):
return self._context.original

@is_local.setter
def is_local(self, value):
self._context = _PatchContext(
exit_stack=self._context.exit_stack,
is_local=value,
original=self._context.original,
target=self._context.target,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Urgh, I forgot that you cannot change the value of namedtuples. Ok, my suggestion using namedtuples was wrong. To reduce memory footprint, we can use __slots__ in a regular class instead like you had before. That way, we save an from collections import namedtuple as well and simplify the property's setter. WDYT? (again sorry for this bad suggestion).


@target.setter
def target(self, value):
self._context = _PatchContext(
exit_stack=self._context.exit_stack,
is_local=self._context.is_local,
original=self._context.original,
target=value,
)

@temp_original.setter
def temp_original(self, value):
self._context = _PatchContext(
exit_stack=self._context.exit_stack,
is_local=self._context.is_local,
original=value,
target=self._context.target,
)

def __enter__(self):
"""Perform the patch."""
if self.is_started:
raise RuntimeError("Patch is already started")

new, spec, spec_set = self.new, self.spec, self.spec_set
autospec, kwargs = self.autospec, self.kwargs
new_callable = self.new_callable
self.target = self.getter()
target = self.getter()

# normalise False to None
if spec is False:
Expand All @@ -1491,7 +1541,7 @@ def __enter__(self):
spec_set not in (True, None)):
raise TypeError("Can't provide explicit spec_set *and* spec or autospec")

original, local = self.get_original()
original, is_local = self.get_original()
cjw296 marked this conversation as resolved.
Show resolved Hide resolved

if new is DEFAULT and autospec is None:
inherit = False
Expand Down Expand Up @@ -1579,17 +1629,17 @@ def __enter__(self):
if autospec is True:
autospec = original

if _is_instance_mock(self.target):
if _is_instance_mock(target):
raise InvalidSpecError(
f'Cannot autospec attr {self.attribute!r} as the patch '
f'target has already been mocked out. '
f'[target={self.target!r}, attr={autospec!r}]')
f'[target={target!r}, attr={autospec!r}]')
if _is_instance_mock(autospec):
target_name = getattr(self.target, '__name__', self.target)
target_name = getattr(target, '__name__', target)
raise InvalidSpecError(
f'Cannot autospec attr {self.attribute!r} from target '
f'{target_name!r} as it has already been mocked out. '
f'[target={self.target!r}, attr={autospec!r}]')
f'[target={target!r}, attr={autospec!r}]')

new = create_autospec(autospec, spec_set=spec_set,
_name=self.attribute, **kwargs)
Expand All @@ -1600,17 +1650,21 @@ def __enter__(self):

new_attr = new

self.temp_original = original
self.is_local = local
self._exit_stack = contextlib.ExitStack()
exit_stack = contextlib.ExitStack()
self._context = _PatchContext(
exit_stack=exit_stack,
is_local=is_local,
original=original,
target=target,
)
try:
setattr(self.target, self.attribute, new_attr)
if self.attribute_name is not None:
extra_args = {}
if self.new is DEFAULT:
extra_args[self.attribute_name] = new
for patching in self.additional_patchers:
arg = self._exit_stack.enter_context(patching)
arg = exit_stack.enter_context(patching)
if patching.new is DEFAULT:
extra_args.update(arg)
return extra_args
Expand All @@ -1622,6 +1676,9 @@ def __enter__(self):

def __exit__(self, *exc_info):
"""Undo the patch."""
if not self.is_started:
return

if self.is_local and self.temp_original is not DEFAULT:
setattr(self.target, self.attribute, self.temp_original)
else:
Expand All @@ -1633,11 +1690,8 @@ def __exit__(self, *exc_info):
# needed for proxy objects like django settings
setattr(self.target, self.attribute, self.temp_original)

del self.temp_original
del self.is_local
del self.target
exit_stack = self._exit_stack
del self._exit_stack
exit_stack = self._context.exit_stack
self._context = None
return exit_stack.__exit__(*exc_info)


Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Limit starting a patcher (from :func:`unittest.mock.patch` or
:func:`unittest.mock.patch.object`) more than
once without stopping it
Loading