Skip to content

Commit

Permalink
Wraps benchmark (hgrecco#1862)
Browse files Browse the repository at this point in the history
- Add wrapper benchmark
  • Loading branch information
Saelyos authored Oct 24, 2023
1 parent a01a0bf commit f9e139e
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 6 deletions.
13 changes: 7 additions & 6 deletions pint/registry_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,14 +168,13 @@ def _converter(ureg, values, strict):
return _converter


def _apply_defaults(func, args, kwargs):
def _apply_defaults(sig, args, kwargs):
"""Apply default keyword arguments.
Named keywords may have been left blank. This function applies the default
values so that every argument is defined.
"""

sig = signature(func)
bound_arguments = sig.bind(*args, **kwargs)
for param in sig.parameters.values():
if param.name not in bound_arguments.arguments:
Expand Down Expand Up @@ -254,7 +253,8 @@ def wraps(
ret = _to_units_container(ret, ureg)

def decorator(func: Callable[..., Any]) -> Callable[..., Quantity]:
count_params = len(signature(func).parameters)
sig = signature(func)
count_params = len(sig.parameters)
if len(args) != count_params:
raise TypeError(
"%s takes %i parameters, but %i units were passed"
Expand All @@ -270,7 +270,7 @@ def decorator(func: Callable[..., Any]) -> Callable[..., Quantity]:

@functools.wraps(func, assigned=assigned, updated=updated)
def wrapper(*values, **kw) -> Quantity:
values, kw = _apply_defaults(func, values, kw)
values, kw = _apply_defaults(sig, values, kw)

# In principle, the values are used as is
# When then extract the magnitudes when needed.
Expand Down Expand Up @@ -335,7 +335,8 @@ def check(
]

def decorator(func):
count_params = len(signature(func).parameters)
sig = signature(func)
count_params = len(sig.parameters)
if len(dimensions) != count_params:
raise TypeError(
"%s takes %i parameters, but %i dimensions were passed"
Expand All @@ -351,7 +352,7 @@ def decorator(func):

@functools.wraps(func, assigned=assigned, updated=updated)
def wrapper(*args, **kwargs):
list_args, empty = _apply_defaults(func, args, kwargs)
list_args, empty = _apply_defaults(sig, args, kwargs)

for dim, value in zip(dimensions, list_args):
if dim is None:
Expand Down
36 changes: 36 additions & 0 deletions pint/testsuite/benchmarks/test_20_quantity.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,3 +53,39 @@ def test_op2(benchmark, setup, keys, op):
_, data = setup
key1, key2 = keys
benchmark(op, data[key1], data[key2])


@pytest.mark.parametrize("key", ALL_VALUES_Q)
def test_wrapper(benchmark, setup, key):
ureg, data = setup
value, unit = key.split("_")

@ureg.wraps(None, (unit,))
def f(a):
pass

benchmark(f, data[key])


@pytest.mark.parametrize("key", ALL_VALUES_Q)
def test_wrapper_nonstrict(benchmark, setup, key):
ureg, data = setup
value, unit = key.split("_")

@ureg.wraps(None, (unit,), strict=False)
def f(a):
pass

benchmark(f, data[value])


@pytest.mark.parametrize("key", ALL_VALUES_Q)
def test_wrapper_ret(benchmark, setup, key):
ureg, data = setup
value, unit = key.split("_")

@ureg.wraps(unit, (unit,))
def f(a):
return a

benchmark(f, data[key])

0 comments on commit f9e139e

Please sign in to comment.