Skip to content

Commit

Permalink
feat: add support for injecting into unbound methods (#25)
Browse files Browse the repository at this point in the history
* feat: add support for injecting into unbound methods

* fix: fix no param
  • Loading branch information
tlambert03 authored Jul 11, 2022
1 parent 1b6de73 commit c55fcc1
Show file tree
Hide file tree
Showing 4 changed files with 74 additions and 5 deletions.
4 changes: 4 additions & 0 deletions src/in_n_out/_global.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,7 @@ def inject(
localns: Optional[dict] = None,
on_unresolved_required_args: Optional[RaiseWarnReturnIgnore] = None,
on_unannotated_required_args: Optional[RaiseWarnReturnIgnore] = None,
guess_self: Optional[bool] = None,
store: Union[str, Store, None] = None,
) -> Callable[P, R]:
...
Expand All @@ -222,6 +223,7 @@ def inject(
localns: Optional[dict] = None,
on_unresolved_required_args: Optional[RaiseWarnReturnIgnore] = None,
on_unannotated_required_args: Optional[RaiseWarnReturnIgnore] = None,
guess_self: Optional[bool] = None,
store: Union[str, Store, None] = None,
) -> Callable[[Callable[P, R]], Callable[P, R]]:
...
Expand All @@ -236,6 +238,7 @@ def inject(
localns: Optional[dict] = None,
on_unresolved_required_args: Optional[RaiseWarnReturnIgnore] = None,
on_unannotated_required_args: Optional[RaiseWarnReturnIgnore] = None,
guess_self: Optional[bool] = None,
store: Union[str, Store, None] = None,
) -> Union[Callable[P, R], Callable[[Callable[P, R]], Callable[P, R]]]:
return _store_or_global(store).inject(
Expand All @@ -245,6 +248,7 @@ def inject(
localns=localns,
on_unresolved_required_args=on_unresolved_required_args,
on_unannotated_required_args=on_unannotated_required_args,
guess_self=guess_self,
)


Expand Down
20 changes: 20 additions & 0 deletions src/in_n_out/_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,7 @@ def __init__(self, name: str) -> None:
self._namespace: Union[Namespace, Callable[[], Namespace], None] = None
self.on_unresolved_required_args: RaiseWarnReturnIgnore = "raise"
self.on_unannotated_required_args: RaiseWarnReturnIgnore = "warn"
self.guess_self: bool = True

@property
def name(self) -> str:
Expand Down Expand Up @@ -585,6 +586,7 @@ def inject(
localns: Optional[dict] = None,
on_unresolved_required_args: Optional[RaiseWarnReturnIgnore] = None,
on_unannotated_required_args: Optional[RaiseWarnReturnIgnore] = None,
guess_self: Optional[bool] = None,
) -> Callable[P, R]:
...

Expand All @@ -598,6 +600,7 @@ def inject(
localns: Optional[dict] = None,
on_unresolved_required_args: Optional[RaiseWarnReturnIgnore] = None,
on_unannotated_required_args: Optional[RaiseWarnReturnIgnore] = None,
guess_self: Optional[bool] = None,
) -> Callable[[Callable[P, R]], Callable[P, R]]:
...

Expand All @@ -610,6 +613,7 @@ def inject(
localns: Optional[dict] = None,
on_unresolved_required_args: Optional[RaiseWarnReturnIgnore] = None,
on_unannotated_required_args: Optional[RaiseWarnReturnIgnore] = None,
guess_self: Optional[bool] = None,
) -> Union[Callable[P, R], Callable[[Callable[P, R]], Callable[P, R]]]:
"""Decorate `func` to inject dependencies at calltime.
Expand Down Expand Up @@ -661,6 +665,20 @@ def inject(
- 'return': immediately return the original function without warning
- 'ignore': continue decorating without warning.
guess_self : bool
Whether to infer the type of the first argument if the function is an
unbound class method (by default, `True`) This is done as follows:
- if '.' (but not '<locals>') is in the function's __qualname__
- and if the first parameter is named 'self' or starts with "_"
- and if the first parameter annotation is `inspect.empty`
- then the name preceding `func.__name__` in the function's __qualname__
(which is usually the class name), is looked up in the function's
`__globals__` namespace. If found, it is used as the first parameter's
type annotation.
This allows class methods to be injected with instances of the class.
Returns
-------
Callable
Expand Down Expand Up @@ -691,6 +709,7 @@ def inject(
"""
on_unres = on_unresolved_required_args or self.on_unresolved_required_args
on_unann = on_unannotated_required_args or self.on_unannotated_required_args
_guess_self = guess_self or self.guess_self

# inner decorator, allows for optional decorator arguments
def _inner(func: Callable[P, R]) -> Callable[P, R]:
Expand All @@ -716,6 +735,7 @@ def _inner(func: Callable[P, R]) -> Callable[P, R]:
localns={**self.namespace, **(localns or {})},
on_unresolved_required_args=on_unres,
on_unannotated_required_args=on_unann,
guess_self=_guess_self,
)
if sig is None: # something went wrong, and the user was notified.
return func
Expand Down
44 changes: 39 additions & 5 deletions src/in_n_out/_type_resolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ def type_resolved_signature(
*,
localns: Optional[dict] = None,
raise_unresolved_optional_args: bool = True,
guess_self: bool = True,
) -> Signature:
"""Return a Signature object for a function with resolved type annotations.
Expand All @@ -120,6 +121,17 @@ def type_resolved_signature(
raise_unresolved_optional_args : bool
Whether to raise an exception when an optional parameter (one with a default
value) has an unresolvable type annotation, by default True
guess_self : bool
Whether to infer the type of the first argument if the function is an unbound
class method. This is done as follows:
- if '.' (but not '<locals>') is in the function's __qualname__
- and if the first parameter is named 'self' or starts with "_"
- and if the first parameter annotation is `inspect.empty`
- then the name preceding `func.__name__` in the function's __qualname__
(which is usually the class name), is looked up in the function's
`__globals__` namespace. If found, it is used as the first parameter's
type annotation.
This allows class methods to be injected with instances of the class.
Returns
-------
Expand All @@ -136,11 +148,29 @@ def type_resolved_signature(
an unresolvable type annotation.
"""
sig = Signature.from_callable(func)
hints = {}
if guess_self and sig.parameters:
p0 = next(iter(sig.parameters.values()))
# The best identifier i can figure for a class method is that:
# 1. its qualname contains a period (e.g. "MyClass.my_method"),
# 2. the first parameter tends to be named "self", or some private variable
# 3. the first parameter tends to be unannotated
qualname = getattr(func, "__qualname__", "")
if (
"." in qualname
and "<locals>" not in qualname # don't support locally defd types
and (p0.name == "self" or p0.name.startswith("_"))
and p0.annotation is p0.empty
):
# look up the class name in the function's globals
cls_name = qualname.replace(func.__name__, "").rstrip(".")
func_globals = getattr(func, "__globals__", {})
if cls_name in func_globals:
# add it to the type hints
hints = {p0.name: func_globals[cls_name]}

try:
hints = resolve_type_hints(
func,
localns=localns,
)
hints.update(resolve_type_hints(func, localns=localns))
except (NameError, TypeError) as err:
if raise_unresolved_optional_args:
raise NameError(
Expand Down Expand Up @@ -211,14 +241,18 @@ def _resolve_sig_or_inform(
localns: Optional[dict],
on_unresolved_required_args: RaiseWarnReturnIgnore,
on_unannotated_required_args: RaiseWarnReturnIgnore,
guess_self: bool = True,
) -> Optional[Signature]:
"""Helper function for user warnings/errors during inject_dependencies.
all parameters are described above in inject_dependencies
"""
try:
sig = type_resolved_signature(
func, localns=localns, raise_unresolved_optional_args=False
func,
localns=localns,
raise_unresolved_optional_args=False,
guess_self=guess_self,
)
except NameError as e:
errmsg = str(e)
Expand Down
11 changes: 11 additions & 0 deletions tests/test_injection.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,3 +205,14 @@ def f(x: int):
with test_store.register(providers={Optional[int]: lambda: 2}):
f()
mock.assert_called_once_with(2)


class Foo:
def method(self):
return self


def test_inject_instance_into_unbound_method():
foo = Foo()
with register(providers={Foo: lambda: foo}):
assert inject(Foo.method)() == foo

0 comments on commit c55fcc1

Please sign in to comment.