diff --git a/pyproject.toml b/pyproject.toml index 379263f..fb42c05 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -106,6 +106,22 @@ ignore = "D100,D213,D401,D413,D107" minversion = "6.0" testpaths = ["tests"] filterwarnings = ["error"] +addopts = ["--cov"] + +[tool.coverage.run] +source = ['src/in_n_out'] +command_line = "-m pytest" + +# https://coverage.readthedocs.io/en/6.4/config.html +[tool.coverage.report] +exclude_lines = [ + "pragma: no cover", + "if TYPE_CHECKING:", + "@overload", + "except ImportError", +] +show_missing = true +skip_covered = true # https://mypy.readthedocs.io/en/stable/config_file.html [tool.mypy] @@ -120,16 +136,6 @@ pretty = true modules = ['tests.*'] disallow_untyped_defs = false - -# https://coverage.readthedocs.io/en/6.4/config.html -[tool.coverage.report] -exclude_lines = [ - "pragma: no cover", - "if TYPE_CHECKING:", - "@overload", - "except ImportError", -] - # https://github.com/cruft/cruft [tool.cruft] skip = ["tests"] diff --git a/src/in_n_out/_inject.py b/src/in_n_out/_inject.py index 18c2394..3d6fb45 100644 --- a/src/in_n_out/_inject.py +++ b/src/in_n_out/_inject.py @@ -3,10 +3,11 @@ import warnings from functools import wraps from inspect import isgeneratorfunction -from typing import TYPE_CHECKING, cast +from typing import TYPE_CHECKING, Union, cast, overload from ._processors import get_processor from ._providers import get_provider +from ._store import Store from ._type_resolution import type_resolved_signature if TYPE_CHECKING: @@ -20,13 +21,38 @@ RaiseWarnReturnIgnore = Literal["raise", "warn", "return", "ignore"] +@overload def inject_dependencies( func: Callable[P, R], *, localns: Optional[dict] = None, + store: Union[str, Store, None] = None, on_unresolved_required_args: RaiseWarnReturnIgnore = "raise", on_unannotated_required_args: RaiseWarnReturnIgnore = "warn", ) -> Callable[P, R]: + ... + + +@overload +def inject_dependencies( + func: Literal[None] = None, + *, + localns: Optional[dict] = None, + store: Union[str, Store, None] = None, + on_unresolved_required_args: RaiseWarnReturnIgnore = "raise", + on_unannotated_required_args: RaiseWarnReturnIgnore = "warn", +) -> Callable[[Callable[P, R]], Callable[P, R]]: + ... + + +def inject_dependencies( + func: Callable[P, R] = None, + *, + localns: Optional[dict] = None, + store: Union[str, Store, None] = None, + on_unresolved_required_args: RaiseWarnReturnIgnore = "raise", + on_unannotated_required_args: RaiseWarnReturnIgnore = "warn", +) -> Union[Callable[P, R], Callable[[Callable[P, R]], Callable[P, R]]]: """Decorator returns func that can access/process objects based on type hints. This is form of dependency injection, and result processing. It does 2 things: @@ -46,6 +72,9 @@ def inject_dependencies( a function with type hints localns : Optional[dict] Optional local namespace for name resolution, by default None + store : Union[str, Store, None] + Optional store to use when retrieving providers and processors, + by default the global store will be used. on_unresolved_required_args : RaiseWarnReturnIgnore What to do when a required parameter (one without a default) is encountered with an unresolvable type annotation. @@ -73,85 +102,93 @@ def inject_dependencies( Callable A function with dependencies injected """ - # if the function takes no arguments and has no return annotation - # there's nothing to be done - if not func.__code__.co_argcount and "return" not in getattr( - func, "__annotations__", {} - ): - return func - - # get a signature object with all type annotations resolved - # this may result in a NameError if a required argument is unresolveable. - # There may also be unannotated required arguments, which will likely fail - # when the function is called later. We break this out into a seperate - # function to handle notifying the user on these cases. - sig = _resolve_sig_or_inform( - func, - localns, - on_unresolved_required_args, - on_unannotated_required_args, - ) - if sig is None: # something went wrong, and the user was notified. - return func - process_return = sig.return_annotation is not sig.empty - - # get provider functions for each required parameter - @wraps(func) - def _exec(*args: P.args, **kwargs: P.kwargs) -> R: - # sourcery skip: use-named-expression - # we're actually calling the "injected function" now - - _sig = cast("Signature", sig) - # first, get and call the provider functions for each parameter type: - _kwargs = {} - for param in _sig.parameters.values(): - provider: Optional[Callable] = get_provider(param.annotation) - if provider: - _kwargs[param.name] = provider() - - # use bind_partial to allow the caller to still provide their own arguments - # if desired. (i.e. the injected deps are only used if not provided) - bound = _sig.bind_partial(*args, **kwargs) - bound.apply_defaults() - _kwargs.update(**bound.arguments) - - try: # call the function with injected values - result = func(**_kwargs) # type: ignore [arg-type] - except TypeError as e: - # likely a required argument is still missing. - raise TypeError( - f"After injecting dependencies for arguments {set(_kwargs)}, {e}" - ) from e - - if process_return: - processor = get_processor(_sig.return_annotation) - if processor: - processor(result) - - return result - - out = _exec - - # if it came in as a generatorfunction, it needs to go out as one. - if isgeneratorfunction(func): + _store = store if isinstance(store, Store) else Store.get_store(store) + + # inner decorator, allows for optional decorator arguments + def _inner(func: Callable[P, R]) -> Callable[P, R]: + # if the function takes no arguments and has no return annotation + # there's nothing to be done + if not func.__code__.co_argcount and "return" not in getattr( + func, "__annotations__", {} + ): + return func + + # get a signature object with all type annotations resolved + # this may result in a NameError if a required argument is unresolveable. + # There may also be unannotated required arguments, which will likely fail + # when the function is called later. We break this out into a seperate + # function to handle notifying the user on these cases. + sig = _resolve_sig_or_inform( + func, + localns={**_store.namespace, **(localns or {})}, + on_unresolved_required_args=on_unresolved_required_args, + on_unannotated_required_args=on_unannotated_required_args, + ) + if sig is None: # something went wrong, and the user was notified. + return func + process_return = sig.return_annotation is not sig.empty + # get provider functions for each required parameter @wraps(func) - def _gexec(*args: P.args, **kwargs: P.kwargs) -> R: # type: ignore [misc] - yield from _exec(*args, **kwargs) # type: ignore [misc] - - out = _gexec - - # update some metadata on the decorated function. - out.__signature__ = sig # type: ignore [attr-defined] - out.__annotations__ = { - **{p.name: p.annotation for p in sig.parameters.values()}, - "return": sig.return_annotation, - } - out.__doc__ = ( - out.__doc__ or "" - ) + "\n\n*This function will inject dependencies when called.*" - out._dependencies_injected = True # type: ignore [attr-defined] - return out + def _exec(*args: P.args, **kwargs: P.kwargs) -> R: + # sourcery skip: use-named-expression + # we're actually calling the "injected function" now + + _sig = cast("Signature", sig) + # first, get and call the provider functions for each parameter type: + _kwargs = {} + for param in _sig.parameters.values(): + provider: Optional[Callable] = get_provider( + param.annotation, store=store + ) + if provider: + _kwargs[param.name] = provider() + + # use bind_partial to allow the caller to still provide their own arguments + # if desired. (i.e. the injected deps are only used if not provided) + bound = _sig.bind_partial(*args, **kwargs) + bound.apply_defaults() + _kwargs.update(**bound.arguments) + + try: # call the function with injected values + result = func(**_kwargs) # type: ignore [arg-type] + except TypeError as e: + # likely a required argument is still missing. + raise TypeError( + f"After injecting dependencies for arguments {set(_kwargs)}, {e}" + ) from e + + if process_return: + processor = get_processor(_sig.return_annotation, store=store) + if processor: + processor(result) + + return result + + out = _exec + + # if it came in as a generatorfunction, it needs to go out as one. + if isgeneratorfunction(func): + + @wraps(func) + def _gexec(*args: P.args, **kwargs: P.kwargs) -> R: # type: ignore [misc] + yield from _exec(*args, **kwargs) # type: ignore [misc] + + out = _gexec + + # update some metadata on the decorated function. + out.__signature__ = sig # type: ignore [attr-defined] + out.__annotations__ = { + **{p.name: p.annotation for p in sig.parameters.values()}, + "return": sig.return_annotation, + } + out.__doc__ = ( + out.__doc__ or "" + ) + "\n\n*This function will inject dependencies when called.*" + out._dependencies_injected = True # type: ignore [attr-defined] + return out + + return _inner(func) if func is not None else _inner def _resolve_sig_or_inform( diff --git a/src/in_n_out/_processors.py b/src/in_n_out/_processors.py index 6ede877..f6fa899 100644 --- a/src/in_n_out/_processors.py +++ b/src/in_n_out/_processors.py @@ -93,12 +93,20 @@ def get_processor( @overload -def clear_processor(type_: Type[T]) -> Union[Callable[[], T], None]: +def clear_processor( + type_: Type[T], + warn_missing: bool = False, + store: Union[str, Store, None] = None, +) -> Union[Callable[[], T], None]: ... @overload -def clear_processor(type_: object) -> Union[Callable[[], Optional[T]], None]: +def clear_processor( + type_: object, + warn_missing: bool = False, + store: Union[str, Store, None] = None, +) -> Union[Callable[[], Optional[T]], None]: ... diff --git a/src/in_n_out/_providers.py b/src/in_n_out/_providers.py index 9f37607..9bdb038 100644 --- a/src/in_n_out/_providers.py +++ b/src/in_n_out/_providers.py @@ -63,12 +63,16 @@ def __exit__(self, *_: Any) -> None: @overload -def get_provider(type_: Type[T]) -> Union[Callable[[], T], None]: +def get_provider( + type_: Type[T], store: Union[str, Store, None] = None +) -> Union[Callable[[], T], None]: ... @overload -def get_provider(type_: object) -> Union[Callable[[], Optional[T]], None]: +def get_provider( + type_: object, store: Union[str, Store, None] = None +) -> Union[Callable[[], Optional[T]], None]: # `object` captures passing get_provider(Optional[type]) ... @@ -110,12 +114,20 @@ def _get_provider( @overload -def clear_provider(type_: Type[T]) -> Union[Callable[[], T], None]: +def clear_provider( + type_: Type[T], + warn_missing: bool = False, + store: Union[str, Store, None] = None, +) -> Union[Callable[[], T], None]: ... @overload -def clear_provider(type_: object) -> Union[Callable[[], Optional[T]], None]: +def clear_provider( + type_: object, + warn_missing: bool = False, + store: Union[str, Store, None] = None, +) -> Union[Callable[[], Optional[T]], None]: ... diff --git a/src/in_n_out/_store.py b/src/in_n_out/_store.py index f5c8354..f8c735d 100644 --- a/src/in_n_out/_store.py +++ b/src/in_n_out/_store.py @@ -19,6 +19,8 @@ Processor = TypeVar("Processor", bound=Callable[[Any], Any]) _GLOBAL = "global" +Namespace = Mapping[str, object] + class Store: """A Store is a collection of providers and processors.""" @@ -94,6 +96,7 @@ def __init__(self, name: str) -> None: self.providers: Dict[Type, Callable[[], Any]] = {} self.opt_providers: Dict[Type, Callable[[], Optional[Any]]] = {} self.processors: Dict[Any, Callable[[Any], Any]] = {} + self._namespace: Union[Namespace, Callable[[], Namespace], None] = None @property def name(self) -> str: @@ -106,6 +109,22 @@ def clear(self) -> None: self.opt_providers.clear() self.processors.clear() + @property + def namespace(self) -> Dict[str, object]: + """Return namespace for type resolution, if this store has one. + + If no namespace is set, this will return an empty `dict`. + """ + if self._namespace is None: + return {} + if callable(self._namespace): + return dict(self._namespace()) + return dict(self._namespace) + + @namespace.setter + def namespace(self, ns: Union[Namespace, Callable[[], Namespace]]): + self._namespace = ns + def _get( self, type_: Union[object, Type[T]], provider: bool, pop: bool ) -> Optional[Callable]: diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..1c69d00 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,12 @@ +import pytest + + +@pytest.fixture +def test_store(): + from in_n_out._store import Store + + store = Store.create("test") + try: + yield store + finally: + Store.destroy("test") diff --git a/tests/test_store.py b/tests/test_store.py index 9497708..ef334e2 100644 --- a/tests/test_store.py +++ b/tests/test_store.py @@ -2,7 +2,7 @@ import pytest -from in_n_out import Store, set_processors, set_providers +from in_n_out import Store, inject_dependencies, provider, set_processors, set_providers from in_n_out._store import _GLOBAL @@ -37,9 +37,8 @@ def test_create_get_destroy(): assert len(Store._instances) == 1 -def test_store_clear(): +def test_store_clear(test_store: Store): - test_store = Store.create("test") assert not test_store.providers assert not test_store.opt_providers assert not test_store.processors @@ -55,3 +54,30 @@ def test_store_clear(): assert not test_store.providers assert not test_store.opt_providers assert not test_store.processors + + +def test_store_namespace(test_store: Store): + class T: + ... + + @provider(store=test_store) + def provide_t() -> T: + return T() + + # namespace can be a static dict + test_store.namespace = {"Hint": T} + + @inject_dependencies(store=test_store) + def use_t(t: "Hint") -> None: # type: ignore # noqa: F821 + return t + + assert isinstance(use_t(), T) + + # namespace can also be a callable + test_store.namespace = lambda: {"Hint2": T} + + @inject_dependencies(store="test") + def use_t2(t: "Hint2") -> None: # type: ignore # noqa: F821 + return t + + assert isinstance(use_t2(), T)