From 38867260bac565262a4766aab5fe57be550691b3 Mon Sep 17 00:00:00 2001 From: Simon Willison Date: Fri, 22 Apr 2022 11:44:01 -0700 Subject: [PATCH] Support resolving unregistered functions, closes #13 --- README.md | 23 +++++++++++++++++++++++ asyncinject/__init__.py | 28 +++++++++++++++++++++++----- tests/test_asyncinject.py | 27 +++++++++++++++++++++++++++ 3 files changed, 73 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index 26342ad..4c8f80f 100644 --- a/README.md +++ b/README.md @@ -50,6 +50,29 @@ The HTTP requests to `www.example.com` and `simonwillison.net` will be performed The library notices that `both()` takes two arguments which are the names of other registered `async def` functions, and will construct an execution plan that executes those two functions in parallel, then passes their results to the `both()` method. +### Resolving an unregistered function + +You don't need to register the final function that you pass to `.resolve()` - if you pass an unregistered function, the library will introspect the function's parameters and resolve them directly. + +This works with both regular and async functions: + +```python +async def one(): + return 1 + +async def two(): + return 2 + +registry = Registry(one, two) + +# async def works here too: +def three(one, two): + return one + two + +print(await registry.resolve(three)) +# Prints 3 +``` + ### Parameters are passed through Your dependent functions can require keyword arguments which have been passed to the `.resolve()` call: diff --git a/asyncinject/__init__.py b/asyncinject/__init__.py index effc41b..de69bd5 100644 --- a/asyncinject/__init__.py +++ b/asyncinject/__init__.py @@ -12,6 +12,7 @@ class Registry: def __init__(self, *fns, parallel=True, timer=None): self._registry = {} self._graph = None + self._reversed = None self.parallel = parallel self.timer = timer for fn in fns: @@ -19,8 +20,9 @@ def __init__(self, *fns, parallel=True, timer=None): def register(self, fn): self._registry[fn.__name__] = fn - # Clear _graph cache: + # Clear caches: self._graph = None + self._reversed = None def _make_time_logger(self, awaitable): async def inner(): @@ -41,12 +43,28 @@ def graph(self): } return self._graph + @property + def reversed(self): + if self._reversed is None: + self._reversed = dict(reversed(pair) for pair in self._registry.items()) + return self._reversed + async def resolve(self, fn, **kwargs): - try: - name = fn.__name__ - except AttributeError: + if not isinstance(fn, str): + # It's a fn - is it a registered one? + name = self.reversed.get(fn) + if name is None: + # Special case - since it is not registered we need to + # introspect its parameters here and use resolve_multi + params = inspect.signature(fn).parameters.keys() + to_resolve = {p for p in params if p not in kwargs} + resolved = await self.resolve_multi(to_resolve, results=kwargs) + result = fn(**{param: resolved[param] for param in params}) + if asyncio.iscoroutine(result): + result = await result + return result + else: name = fn - results = await self.resolve_multi([name], results=kwargs) return results[name] diff --git a/tests/test_asyncinject.py b/tests/test_asyncinject.py index d6929e8..ad62722 100644 --- a/tests/test_asyncinject.py +++ b/tests/test_asyncinject.py @@ -154,3 +154,30 @@ async def d(b, c): end = time.perf_counter() # Should have taken ~0.2s assert 0.18 < (end - start) < 0.22 + + +@pytest.mark.asyncio +@pytest.mark.parametrize("use_async", (True, False)) +async def test_resolve_unregistered_function(use_async): + # https://github.com/simonw/asyncinject/issues/13 + async def one(): + return 1 + + async def two(): + return 2 + + registry = Registry(one, two) + + async def three_async(one, two): + return one + two + + def three_not_async(one, two): + return one + two + + fn = three_async if use_async else three_not_async + result = await registry.resolve(fn) + assert result == 3 + + # Test that passing parameters works too + result2 = await registry.resolve(fn, one=2) + assert result2 == 4