Skip to content

Commit

Permalink
allow overriding in injections (#80)
Browse files Browse the repository at this point in the history
* allow overriding in injections
  • Loading branch information
lesnik512 authored Sep 11, 2024
1 parent c863270 commit 2635c20
Show file tree
Hide file tree
Showing 3 changed files with 93 additions and 68 deletions.
17 changes: 11 additions & 6 deletions tests/test_injection.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,20 @@ async def test_injection(
assert fixture_one == 1


async def test_wrong_injection() -> None:
async def test_injection_with_overriding() -> None:
@inject
async def inner(
_: container.SimpleFactory = Provide[container.DIContainer.simple_factory],
arg1: bool,
arg2: container.SimpleFactory = Provide[container.DIContainer.simple_factory],
) -> None:
"""Do nothing."""

with pytest.raises(RuntimeError, match="Injected arguments must not be redefined"):
await inner(_=container.SimpleFactory(dep1="1", dep2=2))
_ = arg1
original_obj = await container.DIContainer.simple_factory()
assert arg2.dep1 != original_obj.dep1
assert arg2.dep2 != original_obj.dep2

await inner(arg1=True, arg2=container.SimpleFactory(dep1="1", dep2=2))
await inner(True, container.SimpleFactory(dep1="1", dep2=2))
await inner(True, arg2=container.SimpleFactory(dep1="1", dep2=2))


async def test_empty_injection() -> None:
Expand Down
8 changes: 5 additions & 3 deletions that_depends/injection.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,15 @@ def _inject_to_async(
@functools.wraps(func)
async def inner(*args: P.args, **kwargs: P.kwargs) -> T:
injected = False
for field_name, field_value in signature.parameters.items():
for i, (field_name, field_value) in enumerate(signature.parameters.items()):
if i < len(args):
continue

if not isinstance(field_value.default, AbstractProvider):
continue

if field_name in kwargs:
msg = f"Injected arguments must not be redefined, {field_name=}"
raise RuntimeError(msg)
continue

kwargs[field_name] = await field_value.default()
injected = True
Expand Down
Loading

0 comments on commit 2635c20

Please sign in to comment.