Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: inject store namespace #7

Merged
merged 4 commits into from
Jun 30, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 16 additions & 10 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,22 @@ ignore = "D100,D213,D401,D413,D107"
minversion = "6.0"
testpaths = ["tests"]
filterwarnings = ["error"]
addopts = ["--cov"]

[tool.coverage.run]
source = ['src/in_n_out']
command_line = "-m pytest"

# https://coverage.readthedocs.io/en/6.4/config.html
[tool.coverage.report]
exclude_lines = [
"pragma: no cover",
"if TYPE_CHECKING:",
"@overload",
"except ImportError",
]
show_missing = true
skip_covered = true

# https://mypy.readthedocs.io/en/stable/config_file.html
[tool.mypy]
Expand All @@ -120,16 +136,6 @@ pretty = true
modules = ['tests.*']
disallow_untyped_defs = false


# https://coverage.readthedocs.io/en/6.4/config.html
[tool.coverage.report]
exclude_lines = [
"pragma: no cover",
"if TYPE_CHECKING:",
"@overload",
"except ImportError",
]

# https://github.com/cruft/cruft
[tool.cruft]
skip = ["tests"]
Expand Down
193 changes: 115 additions & 78 deletions src/in_n_out/_inject.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@
import warnings
from functools import wraps
from inspect import isgeneratorfunction
from typing import TYPE_CHECKING, cast
from typing import TYPE_CHECKING, Union, cast, overload

from ._processors import get_processor
from ._providers import get_provider
from ._store import Store
from ._type_resolution import type_resolved_signature

if TYPE_CHECKING:
Expand All @@ -20,13 +21,38 @@
RaiseWarnReturnIgnore = Literal["raise", "warn", "return", "ignore"]


@overload
def inject_dependencies(
func: Callable[P, R],
*,
localns: Optional[dict] = None,
store: Union[str, Store, None] = None,
on_unresolved_required_args: RaiseWarnReturnIgnore = "raise",
on_unannotated_required_args: RaiseWarnReturnIgnore = "warn",
) -> Callable[P, R]:
...


@overload
def inject_dependencies(
func: Literal[None] = None,
*,
localns: Optional[dict] = None,
store: Union[str, Store, None] = None,
on_unresolved_required_args: RaiseWarnReturnIgnore = "raise",
on_unannotated_required_args: RaiseWarnReturnIgnore = "warn",
) -> Callable[[Callable[P, R]], Callable[P, R]]:
...


def inject_dependencies(
func: Callable[P, R] = None,
*,
localns: Optional[dict] = None,
store: Union[str, Store, None] = None,
on_unresolved_required_args: RaiseWarnReturnIgnore = "raise",
on_unannotated_required_args: RaiseWarnReturnIgnore = "warn",
) -> Union[Callable[P, R], Callable[[Callable[P, R]], Callable[P, R]]]:
"""Decorator returns func that can access/process objects based on type hints.

This is form of dependency injection, and result processing. It does 2 things:
Expand All @@ -46,6 +72,9 @@ def inject_dependencies(
a function with type hints
localns : Optional[dict]
Optional local namespace for name resolution, by default None
store : Union[str, Store, None]
Optional store to use when retrieving providers and processors,
by default the global store will be used.
on_unresolved_required_args : RaiseWarnReturnIgnore
What to do when a required parameter (one without a default) is encountered
with an unresolvable type annotation.
Expand Down Expand Up @@ -73,85 +102,93 @@ def inject_dependencies(
Callable
A function with dependencies injected
"""
# if the function takes no arguments and has no return annotation
# there's nothing to be done
if not func.__code__.co_argcount and "return" not in getattr(
func, "__annotations__", {}
):
return func

# get a signature object with all type annotations resolved
# this may result in a NameError if a required argument is unresolveable.
# There may also be unannotated required arguments, which will likely fail
# when the function is called later. We break this out into a seperate
# function to handle notifying the user on these cases.
sig = _resolve_sig_or_inform(
func,
localns,
on_unresolved_required_args,
on_unannotated_required_args,
)
if sig is None: # something went wrong, and the user was notified.
return func
process_return = sig.return_annotation is not sig.empty

# get provider functions for each required parameter
@wraps(func)
def _exec(*args: P.args, **kwargs: P.kwargs) -> R:
# sourcery skip: use-named-expression
# we're actually calling the "injected function" now

_sig = cast("Signature", sig)
# first, get and call the provider functions for each parameter type:
_kwargs = {}
for param in _sig.parameters.values():
provider: Optional[Callable] = get_provider(param.annotation)
if provider:
_kwargs[param.name] = provider()

# use bind_partial to allow the caller to still provide their own arguments
# if desired. (i.e. the injected deps are only used if not provided)
bound = _sig.bind_partial(*args, **kwargs)
bound.apply_defaults()
_kwargs.update(**bound.arguments)

try: # call the function with injected values
result = func(**_kwargs) # type: ignore [arg-type]
except TypeError as e:
# likely a required argument is still missing.
raise TypeError(
f"After injecting dependencies for arguments {set(_kwargs)}, {e}"
) from e

if process_return:
processor = get_processor(_sig.return_annotation)
if processor:
processor(result)

return result

out = _exec

# if it came in as a generatorfunction, it needs to go out as one.
if isgeneratorfunction(func):
_store = store if isinstance(store, Store) else Store.get_store(store)

# inner decorator, allows for optional decorator arguments
def _inner(func: Callable[P, R]) -> Callable[P, R]:
# if the function takes no arguments and has no return annotation
# there's nothing to be done
if not func.__code__.co_argcount and "return" not in getattr(
func, "__annotations__", {}
):
return func

# get a signature object with all type annotations resolved
# this may result in a NameError if a required argument is unresolveable.
# There may also be unannotated required arguments, which will likely fail
# when the function is called later. We break this out into a seperate
# function to handle notifying the user on these cases.
sig = _resolve_sig_or_inform(
func,
localns={**_store.namespace, **(localns or {})},
on_unresolved_required_args=on_unresolved_required_args,
on_unannotated_required_args=on_unannotated_required_args,
)
if sig is None: # something went wrong, and the user was notified.
return func
process_return = sig.return_annotation is not sig.empty

# get provider functions for each required parameter
@wraps(func)
def _gexec(*args: P.args, **kwargs: P.kwargs) -> R: # type: ignore [misc]
yield from _exec(*args, **kwargs) # type: ignore [misc]

out = _gexec

# update some metadata on the decorated function.
out.__signature__ = sig # type: ignore [attr-defined]
out.__annotations__ = {
**{p.name: p.annotation for p in sig.parameters.values()},
"return": sig.return_annotation,
}
out.__doc__ = (
out.__doc__ or ""
) + "\n\n*This function will inject dependencies when called.*"
out._dependencies_injected = True # type: ignore [attr-defined]
return out
def _exec(*args: P.args, **kwargs: P.kwargs) -> R:
# sourcery skip: use-named-expression
# we're actually calling the "injected function" now

_sig = cast("Signature", sig)
# first, get and call the provider functions for each parameter type:
_kwargs = {}
for param in _sig.parameters.values():
provider: Optional[Callable] = get_provider(
param.annotation, store=store
)
if provider:
_kwargs[param.name] = provider()

# use bind_partial to allow the caller to still provide their own arguments
# if desired. (i.e. the injected deps are only used if not provided)
bound = _sig.bind_partial(*args, **kwargs)
bound.apply_defaults()
_kwargs.update(**bound.arguments)

try: # call the function with injected values
result = func(**_kwargs) # type: ignore [arg-type]
except TypeError as e:
# likely a required argument is still missing.
raise TypeError(
f"After injecting dependencies for arguments {set(_kwargs)}, {e}"
) from e

if process_return:
processor = get_processor(_sig.return_annotation, store=store)
if processor:
processor(result)

return result

out = _exec

# if it came in as a generatorfunction, it needs to go out as one.
if isgeneratorfunction(func):

@wraps(func)
def _gexec(*args: P.args, **kwargs: P.kwargs) -> R: # type: ignore [misc]
yield from _exec(*args, **kwargs) # type: ignore [misc]

out = _gexec

# update some metadata on the decorated function.
out.__signature__ = sig # type: ignore [attr-defined]
out.__annotations__ = {
**{p.name: p.annotation for p in sig.parameters.values()},
"return": sig.return_annotation,
}
out.__doc__ = (
out.__doc__ or ""
) + "\n\n*This function will inject dependencies when called.*"
out._dependencies_injected = True # type: ignore [attr-defined]
return out

return _inner(func) if func is not None else _inner


def _resolve_sig_or_inform(
Expand Down
12 changes: 10 additions & 2 deletions src/in_n_out/_processors.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,12 +93,20 @@ def get_processor(


@overload
def clear_processor(type_: Type[T]) -> Union[Callable[[], T], None]:
def clear_processor(
type_: Type[T],
warn_missing: bool = False,
store: Union[str, Store, None] = None,
) -> Union[Callable[[], T], None]:
...


@overload
def clear_processor(type_: object) -> Union[Callable[[], Optional[T]], None]:
def clear_processor(
type_: object,
warn_missing: bool = False,
store: Union[str, Store, None] = None,
) -> Union[Callable[[], Optional[T]], None]:
...


Expand Down
20 changes: 16 additions & 4 deletions src/in_n_out/_providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,12 +63,16 @@ def __exit__(self, *_: Any) -> None:


@overload
def get_provider(type_: Type[T]) -> Union[Callable[[], T], None]:
def get_provider(
type_: Type[T], store: Union[str, Store, None] = None
) -> Union[Callable[[], T], None]:
...


@overload
def get_provider(type_: object) -> Union[Callable[[], Optional[T]], None]:
def get_provider(
type_: object, store: Union[str, Store, None] = None
) -> Union[Callable[[], Optional[T]], None]:
# `object` captures passing get_provider(Optional[type])
...

Expand Down Expand Up @@ -110,12 +114,20 @@ def _get_provider(


@overload
def clear_provider(type_: Type[T]) -> Union[Callable[[], T], None]:
def clear_provider(
type_: Type[T],
warn_missing: bool = False,
store: Union[str, Store, None] = None,
) -> Union[Callable[[], T], None]:
...


@overload
def clear_provider(type_: object) -> Union[Callable[[], Optional[T]], None]:
def clear_provider(
type_: object,
warn_missing: bool = False,
store: Union[str, Store, None] = None,
) -> Union[Callable[[], Optional[T]], None]:
...


Expand Down
19 changes: 19 additions & 0 deletions src/in_n_out/_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
Processor = TypeVar("Processor", bound=Callable[[Any], Any])
_GLOBAL = "global"

Namespace = Mapping[str, object]


class Store:
"""A Store is a collection of providers and processors."""
Expand Down Expand Up @@ -94,6 +96,7 @@ def __init__(self, name: str) -> None:
self.providers: Dict[Type, Callable[[], Any]] = {}
self.opt_providers: Dict[Type, Callable[[], Optional[Any]]] = {}
self.processors: Dict[Any, Callable[[Any], Any]] = {}
self._namespace: Union[Namespace, Callable[[], Namespace], None] = None

@property
def name(self) -> str:
Expand All @@ -106,6 +109,22 @@ def clear(self) -> None:
self.opt_providers.clear()
self.processors.clear()

@property
def namespace(self) -> Dict[str, object]:
"""Return namespace for type resolution, if this store has one.

If no namespace is set, this will return an empty `dict`.
"""
if self._namespace is None:
return {}
if callable(self._namespace):
return dict(self._namespace())
return dict(self._namespace)

@namespace.setter
def namespace(self, ns: Union[Namespace, Callable[[], Namespace]]):
self._namespace = ns

def _get(
self, type_: Union[object, Type[T]], provider: bool, pop: bool
) -> Optional[Callable]:
Expand Down
12 changes: 12 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
import pytest


@pytest.fixture
def test_store():
from in_n_out._store import Store

store = Store.create("test")
try:
yield store
finally:
Store.destroy("test")
Loading