Skip to content

Commit

Permalink
unittest: make obj work more like Function/Class
Browse files Browse the repository at this point in the history
Previously, the `obj` of a `TestCaseFunction` (the unittest plugin item
type) was the unbound method. This is unlike regular `Class` where the
`obj` is a bound method to a fresh instance.

This difference necessitated several special cases in in places outside
of the unittest plugin, such as `FixtureDef` and `FixtureRequest`, and
made things a bit harder to understand.

Instead, match how the python plugin does it, including collecting
fixtures from a fresh instance.

The downside is that now this instance for fixture-collection is kept
around in memory, but it's the same as `Class` so nothing new. Users
should only initialize stuff in `setUp`/`setUpClass` and similar
methods, and not in `__init__` which is generally off-limits in
`TestCase` subclasses.

I am not sure why there was a difference in the first place, though I
will say the previous unittest approach is probably the preferable one,
but first let's get consistency.
  • Loading branch information
bluetech committed Mar 8, 2024
1 parent 03e5471 commit 1a5e0eb
Show file tree
Hide file tree
Showing 5 changed files with 52 additions and 70 deletions.
8 changes: 3 additions & 5 deletions src/_pytest/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,6 @@ def getfuncargnames(
function: Callable[..., object],
*,
name: str = "",
is_method: bool = False,
cls: type | None = None,
) -> tuple[str, ...]:
"""Return the names of a function's mandatory arguments.
Expand All @@ -97,9 +96,8 @@ def getfuncargnames(
* Aren't bound with functools.partial.
* Aren't replaced with mocks.
The is_method and cls arguments indicate that the function should
be treated as a bound method even though it's not unless, only in
the case of cls, the function is a static method.
The cls arguments indicate that the function should be treated as a bound
method even though it's not unless the function is a static method.
The name parameter should be the original name in which the function was collected.
"""
Expand Down Expand Up @@ -137,7 +135,7 @@ def getfuncargnames(
# If this function should be treated as a bound method even though
# it's passed as an unbound method or function, remove the first
# parameter name.
if is_method or (
if (
# Not using `getattr` because we don't want to resolve the staticmethod.
# Not using `cls.__dict__` because we want to check the entire MRO.
cls
Expand Down
57 changes: 17 additions & 40 deletions src/_pytest/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,12 +462,8 @@ def cls(self):
@property
def instance(self):
"""Instance (can be None) on which test function was collected."""
# unittest support hack, see _pytest.unittest.TestCaseFunction.
try:
return self._pyfuncitem._testcase # type: ignore[attr-defined]
except AttributeError:
function = getattr(self, "function", None)
return getattr(function, "__self__", None)
function = getattr(self, "function", None)
return getattr(function, "__self__", None)

@property
def module(self):
Expand Down Expand Up @@ -965,7 +961,6 @@ def __init__(
func: "_FixtureFunc[FixtureValue]",
scope: Union[Scope, _ScopeName, Callable[[str, Config], _ScopeName], None],
params: Optional[Sequence[object]],
unittest: bool = False,
ids: Optional[
Union[Tuple[Optional[object], ...], Callable[[Any], Optional[object]]]
] = None,
Expand Down Expand Up @@ -1011,9 +1006,7 @@ def __init__(
# a parameter value.
self.ids: Final = ids
# The names requested by the fixtures.
self.argnames: Final = getfuncargnames(func, name=argname, is_method=unittest)
# Whether the fixture was collected from a unittest TestCase class.
self.unittest: Final = unittest
self.argnames: Final = getfuncargnames(func, name=argname)
# If the fixture was executed, the current value of the fixture.
# Can change if the fixture is executed with different parameters.
self.cached_result: Optional[_FixtureCachedResult[FixtureValue]] = None
Expand Down Expand Up @@ -1092,25 +1085,20 @@ def resolve_fixture_function(
"""Get the actual callable that can be called to obtain the fixture
value, dealing with unittest-specific instances and bound methods."""
fixturefunc = fixturedef.func
if fixturedef.unittest:
if request.instance is not None:
# Bind the unbound method to the TestCase instance.
fixturefunc = fixturedef.func.__get__(request.instance) # type: ignore[union-attr]
else:
# The fixture function needs to be bound to the actual
# request.instance so that code working with "fixturedef" behaves
# as expected.
if request.instance is not None:
# Handle the case where fixture is defined not in a test class, but some other class
# (for example a plugin class with a fixture), see #2270.
if hasattr(fixturefunc, "__self__") and not isinstance(
request.instance,
fixturefunc.__self__.__class__, # type: ignore[union-attr]
):
return fixturefunc
fixturefunc = getimfunc(fixturedef.func)
if fixturefunc != fixturedef.func:
fixturefunc = fixturefunc.__get__(request.instance) # type: ignore[union-attr]
# The fixture function needs to be bound to the actual
# request.instance so that code working with "fixturedef" behaves
# as expected.
if request.instance is not None:
# Handle the case where fixture is defined not in a test class, but some other class
# (for example a plugin class with a fixture), see #2270.
if hasattr(fixturefunc, "__self__") and not isinstance(
request.instance,
fixturefunc.__self__.__class__, # type: ignore[union-attr]
):
return fixturefunc
fixturefunc = getimfunc(fixturedef.func)
if fixturefunc != fixturedef.func:
fixturefunc = fixturefunc.__get__(request.instance) # type: ignore[union-attr]
return fixturefunc


Expand Down Expand Up @@ -1614,7 +1602,6 @@ def _register_fixture(
Union[Tuple[Optional[object], ...], Callable[[Any], Optional[object]]]
] = None,
autouse: bool = False,
unittest: bool = False,
) -> None:
"""Register a fixture
Expand All @@ -1635,8 +1622,6 @@ def _register_fixture(
The fixture's IDs.
:param autouse:
Whether this is an autouse fixture.
:param unittest:
Set this if this is a unittest fixture.
"""
fixture_def = FixtureDef(
config=self.config,
Expand All @@ -1645,7 +1630,6 @@ def _register_fixture(
func=func,
scope=scope,
params=params,
unittest=unittest,
ids=ids,
_ispytest=True,
)
Expand All @@ -1667,8 +1651,6 @@ def _register_fixture(
def parsefactories(
self,
node_or_obj: nodes.Node,
*,
unittest: bool = ...,
) -> None:
raise NotImplementedError()

Expand All @@ -1677,17 +1659,13 @@ def parsefactories(
self,
node_or_obj: object,
nodeid: Optional[str],
*,
unittest: bool = ...,
) -> None:
raise NotImplementedError()

def parsefactories(
self,
node_or_obj: Union[nodes.Node, object],
nodeid: Union[str, NotSetType, None] = NOTSET,
*,
unittest: bool = False,
) -> None:
"""Collect fixtures from a collection node or object.
Expand Down Expand Up @@ -1739,7 +1717,6 @@ def parsefactories(
func=func,
scope=marker.scope,
params=marker.params,
unittest=unittest,
ids=marker.ids,
autouse=marker.autouse,
)
Expand Down
1 change: 0 additions & 1 deletion src/_pytest/python.py
Original file line number Diff line number Diff line change
Expand Up @@ -1314,7 +1314,6 @@ def parametrize(
func=get_direct_param_fixture_func,
scope=scope_,
params=None,
unittest=False,
ids=None,
_ispytest=True,
)
Expand Down
50 changes: 27 additions & 23 deletions src/_pytest/unittest.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from typing import Union

import _pytest._code
from _pytest.compat import getimfunc
from _pytest.compat import is_async_function
from _pytest.config import hookimpl
from _pytest.fixtures import FixtureRequest
Expand Down Expand Up @@ -63,6 +62,14 @@ class UnitTestCase(Class):
# to declare that our children do not support funcargs.
nofuncargs = True

def newinstance(self):
# TestCase __init__ takes the method (test) name. The TestCase
# constructor treats the name "runTest" as a special no-op, so it can be
# used when a dummy instance is needed. While unittest.TestCase has a
# default, some subclasses omit the default (#9610), so always supply
# it.
return self.obj("runTest")

def collect(self) -> Iterable[Union[Item, Collector]]:
from unittest import TestLoader

Expand All @@ -76,15 +83,15 @@ def collect(self) -> Iterable[Union[Item, Collector]]:
self._register_unittest_setup_class_fixture(cls)
self._register_setup_class_fixture()

self.session._fixturemanager.parsefactories(self, unittest=True)
self.session._fixturemanager.parsefactories(self.newinstance(), self.nodeid)

loader = TestLoader()
foundsomething = False
for name in loader.getTestCaseNames(self.obj):
x = getattr(self.obj, name)
if not getattr(x, "__test__", True):
continue
funcobj = getimfunc(x)
yield TestCaseFunction.from_parent(self, name=name, callobj=funcobj)
yield TestCaseFunction.from_parent(self, name=name)
foundsomething = True

if not foundsomething:
Expand Down Expand Up @@ -169,31 +176,28 @@ def unittest_setup_method_fixture(
class TestCaseFunction(Function):
nofuncargs = True
_excinfo: Optional[List[_pytest._code.ExceptionInfo[BaseException]]] = None
_testcase: Optional["unittest.TestCase"] = None

def _getobj(self):
assert self.parent is not None
# Unlike a regular Function in a Class, where `item.obj` returns
# a *bound* method (attached to an instance), TestCaseFunction's
# `obj` returns an *unbound* method (not attached to an instance).
# This inconsistency is probably not desirable, but needs some
# consideration before changing.
return getattr(self.parent.obj, self.originalname) # type: ignore[attr-defined]
assert isinstance(self.parent, UnitTestCase)
testcase = self.parent.obj(self.name)
return getattr(testcase, self.name)

# Backward compat for pytest-django; can be removed after pytest-django
# updates + some slack.
@property
def _testcase(self):
return self._obj.__self__

Check warning on line 189 in src/_pytest/unittest.py

View check run for this annotation

Codecov / codecov/patch

src/_pytest/unittest.py#L189

Added line #L189 was not covered by tests

def setup(self) -> None:
# A bound method to be called during teardown() if set (see 'runtest()').
self._explicit_tearDown: Optional[Callable[[], None]] = None
assert self.parent is not None
self._testcase = self.parent.obj(self.name) # type: ignore[attr-defined]
self._obj = getattr(self._testcase, self.name)
super().setup()

def teardown(self) -> None:
super().teardown()
if self._explicit_tearDown is not None:
self._explicit_tearDown()
self._explicit_tearDown = None
self._testcase = None
self._obj = None

def startTest(self, testcase: "unittest.TestCase") -> None:
Expand Down Expand Up @@ -292,14 +296,14 @@ def addDuration(self, testcase: "unittest.TestCase", elapsed: float) -> None:
def runtest(self) -> None:
from _pytest.debugging import maybe_wrap_pytest_function_for_tracing

assert self._testcase is not None
testcase = self.obj.__self__

maybe_wrap_pytest_function_for_tracing(self)

# Let the unittest framework handle async functions.
if is_async_function(self.obj):
# Type ignored because self acts as the TestResult, but is not actually one.
self._testcase(result=self) # type: ignore[arg-type]
testcase(result=self) # type: ignore[arg-type]
else:
# When --pdb is given, we want to postpone calling tearDown() otherwise
# when entering the pdb prompt, tearDown() would have probably cleaned up
Expand All @@ -311,16 +315,16 @@ def runtest(self) -> None:
assert isinstance(self.parent, UnitTestCase)
skipped = _is_skipped(self.obj) or _is_skipped(self.parent.obj)
if self.config.getoption("usepdb") and not skipped:
self._explicit_tearDown = self._testcase.tearDown
setattr(self._testcase, "tearDown", lambda *args: None)
self._explicit_tearDown = testcase.tearDown
setattr(testcase, "tearDown", lambda *args: None)

# We need to update the actual bound method with self.obj, because
# wrap_pytest_function_for_tracing replaces self.obj by a wrapper.
setattr(self._testcase, self.name, self.obj)
setattr(testcase, self.name, self.obj)
try:
self._testcase(result=self) # type: ignore[arg-type]
testcase(result=self) # type: ignore[arg-type]
finally:
delattr(self._testcase, self.name)
delattr(testcase, self.name)

def _traceback_filter(
self, excinfo: _pytest._code.ExceptionInfo[BaseException]
Expand Down
6 changes: 5 additions & 1 deletion testing/test_unittest.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,10 +208,14 @@ def test_demo(self):
"""
)

pytester.inline_run("-s", testpath)
gc.collect()

# Either already destroyed, or didn't run setUp.
for obj in gc.get_objects():
assert type(obj).__name__ != "TestCaseObjectsShouldBeCleanedUp"
if type(obj).__name__ == "TestCaseObjectsShouldBeCleanedUp":
assert not hasattr(obj, "an_expensive_obj")


def test_unittest_skip_issue148(pytester: Pytester) -> None:
Expand Down

0 comments on commit 1a5e0eb

Please sign in to comment.