diff --git a/src/in_n_out/_store.py b/src/in_n_out/_store.py index eafcad8..fb98f17 100644 --- a/src/in_n_out/_store.py +++ b/src/in_n_out/_store.py @@ -753,28 +753,32 @@ def _exec(*args: P.args, **kwargs: P.kwargs) -> R: _sig = cast("Signature", sig) # mypy thinks sig is still optional - # first, get and call the provider functions for each parameter type: - _kwargs: dict[str, Any] = {} - for param in _sig.parameters.values(): - provided = self.provide(param.annotation) - if provided is not None: - _kwargs[param.name] = provided - # use bind_partial to allow the caller to still provide their own args # if desired. (i.e. the injected deps are only used if not provided) bound = _sig.bind_partial(*args, **kwargs) bound.apply_defaults() + # first, get and call the provider functions for each parameter type: + _injected_names: set[str] = set() + for param in _sig.parameters.values(): + if param.name not in bound.arguments: + provided = self.provide(param.annotation) + if provided is not None: + _injected_names.add(param.name) + bound.arguments[param.name] = provided + # call the function with injected values try: - result = func(**{**_kwargs, **bound.arguments}) + result = func(**bound.arguments) except TypeError as e: if "missing" not in e.args[0]: raise # pragma: no cover # likely a required argument is still missing. # show what was injected and raise _argnames = ( - f"arguments: {set(_kwargs)!r}" if _kwargs else "NO arguments" + f"arguments: {_injected_names!r}" + if _injected_names + else "NO arguments" ) raise TypeError( f"After injecting dependencies for {_argnames}, {e}" diff --git a/tests/test_processors.py b/tests/test_processors.py index 652e599..be11a39 100644 --- a/tests/test_processors.py +++ b/tests/test_processors.py @@ -134,3 +134,30 @@ def f(x: int): ino.register_processor(f) ino.process(1) mock.assert_called_once_with(1) + + +def test_processor_provider_recursion() -> None: + """Make sure to avoid infinte recursion when a provider uses processors.""" + + class Thing: + count = 0 + + # this is both a processor and a provider + @ino.register_provider + @ino.inject_processors + def thing_provider() -> Thing: + return Thing() + + @ino.inject + def add_item(thing: Thing) -> None: + thing.count += 1 + + N = 3 + for _ in range(N): + ino.register_processor(add_item) + + @ino.inject + def func(thing: Thing) -> int: + return thing.count + + assert func() == N