Skip to content

Commit

Permalink
Ensure names are cleaned up before reassignment attempt
Browse files Browse the repository at this point in the history
  • Loading branch information
bswck committed Oct 24, 2024
1 parent f835a6c commit eb9ae87
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 8 deletions.
23 changes: 17 additions & 6 deletions injection/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@

from contextlib import suppress
from dataclasses import dataclass
from threading import RLock, get_ident
from typing import TYPE_CHECKING, Any, Generic, Literal, TypeVar, overload
from threading import Lock, RLock, get_ident
from typing import TYPE_CHECKING, Any, ClassVar, Generic, Literal, TypeVar, overload

from injection.compat import get_frame

Expand All @@ -21,6 +21,8 @@
"Injection",
"ObjectState",
"inject",
"lenient_recursion_guard",
"strict_recursion_guard",
)


Expand Down Expand Up @@ -62,7 +64,11 @@ def __hash__(self) -> int:
return self.hash


def default_recursion_guard(early: EarlyObject[object]) -> Never:
def lenient_recursion_guard(early: EarlyObject[object]) -> Never:
pass


def strict_recursion_guard(early: EarlyObject[object]) -> Never:
msg = f"{early} requested itself"
raise RecursionError(msg)

Expand All @@ -73,9 +79,11 @@ class Injection(Generic[Object_co]):
pass_scope: bool = False
cache: bool = False
cache_per_alias: bool = False
recursion_guard: Callable[[EarlyObject[Any]], object] = default_recursion_guard
recursion_guard: Callable[[EarlyObject[Any]], object] = lenient_recursion_guard
debug_info: str | None = None

_reassignment_lock: ClassVar[Lock] = Lock()

def _call_factory(self, scope: Locals) -> Object_co:
if self.pass_scope:
return self.factory(scope)
Expand Down Expand Up @@ -118,7 +126,10 @@ def assign_to(self, *aliases: str, scope: Locals) -> None:
debug_info=debug_info,
)
key = InjectionKey(alias, early)
scope[key] = early

with self._reassignment_lock:
scope.pop(key, None)
scope[key] = early


SENTINEL = object()
Expand Down Expand Up @@ -247,7 +258,7 @@ def inject( # noqa: PLR0913
pass_scope: bool = False,
cache: bool = False,
cache_per_alias: bool = False,
recursion_guard: Callable[[EarlyObject[Any]], object] = default_recursion_guard,
recursion_guard: Callable[[EarlyObject[Any]], object] = strict_recursion_guard,
debug_info: str | None = None,
) -> None:
"""
Expand Down
12 changes: 10 additions & 2 deletions tests/unit_tests/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import pytest

from injection import Injection, inject
from injection import Injection, inject, lenient_recursion_guard


def test_injection_basic() -> None:
Expand Down Expand Up @@ -211,10 +211,18 @@ def test_injection_recursive_guard() -> None:
def factory() -> str:
return scope.get("my_alias", "default_value")

inject("my_alias", into=scope, factory=factory)
inject(
"my_alias", into=scope, factory=factory, recursion_guard=lenient_recursion_guard
)

obj = scope["my_alias"]
assert obj == "default_value"
del scope["my_alias"]

inject("my_alias", into=scope, factory=factory) # strict

with pytest.raises(RecursionError, match="requested itself"):
obj = scope["my_alias"]


def test_injection_with_no_aliases() -> None:
Expand Down

0 comments on commit eb9ae87

Please sign in to comment.