diff --git a/pint/registry_helpers.py b/pint/registry_helpers.py index 6b2f0e0b6..a31836ea6 100644 --- a/pint/registry_helpers.py +++ b/pint/registry_helpers.py @@ -168,14 +168,13 @@ def _converter(ureg, values, strict): return _converter -def _apply_defaults(func, args, kwargs): +def _apply_defaults(sig, args, kwargs): """Apply default keyword arguments. Named keywords may have been left blank. This function applies the default values so that every argument is defined. """ - sig = signature(func) bound_arguments = sig.bind(*args, **kwargs) for param in sig.parameters.values(): if param.name not in bound_arguments.arguments: @@ -254,7 +253,8 @@ def wraps( ret = _to_units_container(ret, ureg) def decorator(func: Callable[..., Any]) -> Callable[..., Quantity]: - count_params = len(signature(func).parameters) + sig = signature(func) + count_params = len(sig.parameters) if len(args) != count_params: raise TypeError( "%s takes %i parameters, but %i units were passed" @@ -270,7 +270,7 @@ def decorator(func: Callable[..., Any]) -> Callable[..., Quantity]: @functools.wraps(func, assigned=assigned, updated=updated) def wrapper(*values, **kw) -> Quantity: - values, kw = _apply_defaults(func, values, kw) + values, kw = _apply_defaults(sig, values, kw) # In principle, the values are used as is # When then extract the magnitudes when needed. @@ -335,7 +335,8 @@ def check( ] def decorator(func): - count_params = len(signature(func).parameters) + sig = signature(func) + count_params = len(sig.parameters) if len(dimensions) != count_params: raise TypeError( "%s takes %i parameters, but %i dimensions were passed" @@ -351,7 +352,7 @@ def decorator(func): @functools.wraps(func, assigned=assigned, updated=updated) def wrapper(*args, **kwargs): - list_args, empty = _apply_defaults(func, args, kwargs) + list_args, empty = _apply_defaults(sig, args, kwargs) for dim, value in zip(dimensions, list_args): if dim is None: diff --git a/pint/testsuite/benchmarks/test_20_quantity.py b/pint/testsuite/benchmarks/test_20_quantity.py index 36c0f92ba..1ec7cbb60 100644 --- a/pint/testsuite/benchmarks/test_20_quantity.py +++ b/pint/testsuite/benchmarks/test_20_quantity.py @@ -53,3 +53,39 @@ def test_op2(benchmark, setup, keys, op): _, data = setup key1, key2 = keys benchmark(op, data[key1], data[key2]) + + +@pytest.mark.parametrize("key", ALL_VALUES_Q) +def test_wrapper(benchmark, setup, key): + ureg, data = setup + value, unit = key.split("_") + + @ureg.wraps(None, (unit,)) + def f(a): + pass + + benchmark(f, data[key]) + + +@pytest.mark.parametrize("key", ALL_VALUES_Q) +def test_wrapper_nonstrict(benchmark, setup, key): + ureg, data = setup + value, unit = key.split("_") + + @ureg.wraps(None, (unit,), strict=False) + def f(a): + pass + + benchmark(f, data[value]) + + +@pytest.mark.parametrize("key", ALL_VALUES_Q) +def test_wrapper_ret(benchmark, setup, key): + ureg, data = setup + value, unit = key.split("_") + + @ureg.wraps(unit, (unit,)) + def f(a): + return a + + benchmark(f, data[key])