diff --git a/changelog/5701.bugfix.rst b/changelog/5701.bugfix.rst new file mode 100644 index 00000000000..b654e74479a --- /dev/null +++ b/changelog/5701.bugfix.rst @@ -0,0 +1 @@ +Fix collection of ``staticmethod`` objects defined with ``functools.partial``. diff --git a/src/_pytest/compat.py b/src/_pytest/compat.py index 52ffc36bc98..97c06e3ffc6 100644 --- a/src/_pytest/compat.py +++ b/src/_pytest/compat.py @@ -78,7 +78,7 @@ def num_mock_patch_args(function): ) -def getfuncargnames(function, is_method=False, cls=None): +def getfuncargnames(function, *, name: str = "", is_method=False, cls=None): """Returns the names of a function's mandatory arguments. This should return the names of all function arguments that: @@ -91,11 +91,12 @@ def getfuncargnames(function, is_method=False, cls=None): be treated as a bound method even though it's not unless, only in the case of cls, the function is a static method. + The name parameter should be the original name in which the function was collected. + @RonnyPfannschmidt: This function should be refactored when we revisit fixtures. The fixture mechanism should ask the node for the fixture names, and not try to obtain directly from the function object well after collection has occurred. - """ # The parameters attribute of a Signature object contains an # ordered mapping of parameter names to Parameter instances. This @@ -118,11 +119,14 @@ def getfuncargnames(function, is_method=False, cls=None): ) and p.default is Parameter.empty ) + if not name: + name = function.__name__ + # If this function should be treated as a bound method even though # it's passed as an unbound method or function, remove the first # parameter name. if is_method or ( - cls and not isinstance(cls.__dict__.get(function.__name__, None), staticmethod) + cls and not isinstance(cls.__dict__.get(name, None), staticmethod) ): arg_names = arg_names[1:] # Remove any names that will be replaced with mocks. @@ -245,7 +249,7 @@ def get_real_method(obj, holder): try: is_method = hasattr(obj, "__func__") obj = get_real_func(obj) - except Exception: + except Exception: # pragma: no cover return obj if is_method and hasattr(obj, "__get__") and callable(obj.__get__): obj = obj.__get__(holder) diff --git a/src/_pytest/fixtures.py b/src/_pytest/fixtures.py index 965a2e6e9e2..6a3e82907fc 100644 --- a/src/_pytest/fixtures.py +++ b/src/_pytest/fixtures.py @@ -828,7 +828,7 @@ def __init__( where=baseid, ) self.params = params - self.argnames = getfuncargnames(func, is_method=unittest) + self.argnames = getfuncargnames(func, name=argname, is_method=unittest) self.unittest = unittest self.ids = ids self._finalizers = [] @@ -1143,7 +1143,7 @@ def _get_direct_parametrize_args(self, node): def getfixtureinfo(self, node, func, cls, funcargs=True): if funcargs and not getattr(node, "nofuncargs", False): - argnames = getfuncargnames(func, cls=cls) + argnames = getfuncargnames(func, name=node.name, cls=cls) else: argnames = () diff --git a/testing/python/collect.py b/testing/python/collect.py index 61fc5857998..e6dd3e87088 100644 --- a/testing/python/collect.py +++ b/testing/python/collect.py @@ -1143,52 +1143,6 @@ class Test(object): assert result.ret == ExitCode.NO_TESTS_COLLECTED -def test_collect_functools_partial(testdir): - """ - Test that collection of functools.partial object works, and arguments - to the wrapped functions are dealt with correctly (see #811). - """ - testdir.makepyfile( - """ - import functools - import pytest - - @pytest.fixture - def fix1(): - return 'fix1' - - @pytest.fixture - def fix2(): - return 'fix2' - - def check1(i, fix1): - assert i == 2 - assert fix1 == 'fix1' - - def check2(fix1, i): - assert i == 2 - assert fix1 == 'fix1' - - def check3(fix1, i, fix2): - assert i == 2 - assert fix1 == 'fix1' - assert fix2 == 'fix2' - - test_ok_1 = functools.partial(check1, i=2) - test_ok_2 = functools.partial(check1, i=2, fix1='fix1') - test_ok_3 = functools.partial(check1, 2) - test_ok_4 = functools.partial(check2, i=2) - test_ok_5 = functools.partial(check3, i=2) - test_ok_6 = functools.partial(check3, i=2, fix1='fix1') - - test_fail_1 = functools.partial(check2, 2) - test_fail_2 = functools.partial(check3, 2) - """ - ) - result = testdir.inline_run() - result.assertoutcome(passed=6, failed=2) - - @pytest.mark.filterwarnings("default") def test_dont_collect_non_function_callable(testdir): """Test for issue https://github.com/pytest-dev/pytest/issues/331 diff --git a/testing/python/fixtures.py b/testing/python/fixtures.py index c0c230ccf2b..a85a0d73175 100644 --- a/testing/python/fixtures.py +++ b/testing/python/fixtures.py @@ -10,7 +10,9 @@ from _pytest.warnings import SHOW_PYTEST_WARNINGS_ARG -def test_getfuncargnames(): +def test_getfuncargnames_functions(): + """Test getfuncargnames for normal functions""" + def f(): pass @@ -31,18 +33,56 @@ def j(arg1, arg2, arg3="hello"): assert fixtures.getfuncargnames(j) == ("arg1", "arg2") + +def test_getfuncargnames_methods(): + """Test getfuncargnames for normal methods""" + class A: def f(self, arg1, arg2="hello"): pass + assert fixtures.getfuncargnames(A().f) == ("arg1",) + + +def test_getfuncargnames_staticmethod(): + """Test getfuncargnames for staticmethods""" + + class A: @staticmethod - def static(arg1, arg2): + def static(arg1, arg2, x=1): pass - assert fixtures.getfuncargnames(A().f) == ("arg1",) assert fixtures.getfuncargnames(A.static, cls=A) == ("arg1", "arg2") +def test_getfuncargnames_partial(): + """Check getfuncargnames for methods defined with functools.partial (#5701)""" + import functools + + def check(arg1, arg2, i): + pass + + class T: + test_ok = functools.partial(check, i=2) + + values = fixtures.getfuncargnames(T().test_ok, name="test_ok") + assert values == ("arg1", "arg2") + + +def test_getfuncargnames_staticmethod_partial(): + """Check getfuncargnames for staticmethods defined with functools.partial (#5701)""" + import functools + + def check(arg1, arg2, i): + pass + + class T: + test_ok = staticmethod(functools.partial(check, i=2)) + + values = fixtures.getfuncargnames(T().test_ok, name="test_ok") + assert values == ("arg1", "arg2") + + @pytest.mark.pytester_example_path("fixtures/fill_fixtures") class TestFillFixtures: def test_fillfuncargs_exposed(self): diff --git a/testing/test_compat.py b/testing/test_compat.py index 9e7d05c5df5..fb2470d07c6 100644 --- a/testing/test_compat.py +++ b/testing/test_compat.py @@ -1,4 +1,5 @@ import sys +from functools import partial from functools import wraps import pytest @@ -72,6 +73,16 @@ def func(): assert get_real_func(wrapped_func2) is wrapped_func +def test_get_real_func_partial(): + """Test get_real_func handles partial instances correctly""" + + def foo(x): + return x + + assert get_real_func(foo) is foo + assert get_real_func(partial(foo)) is foo + + def test_is_generator_asyncio(testdir): testdir.makepyfile( """