Skip to content

Commit

Permalink
Support resolving unregistered functions, closes #13
Browse files Browse the repository at this point in the history
  • Loading branch information
simonw committed Apr 22, 2022
1 parent 3be0be2 commit 3886726
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 5 deletions.
23 changes: 23 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,29 @@ The HTTP requests to `www.example.com` and `simonwillison.net` will be performed

The library notices that `both()` takes two arguments which are the names of other registered `async def` functions, and will construct an execution plan that executes those two functions in parallel, then passes their results to the `both()` method.

### Resolving an unregistered function

You don't need to register the final function that you pass to `.resolve()` - if you pass an unregistered function, the library will introspect the function's parameters and resolve them directly.

This works with both regular and async functions:

```python
async def one():
return 1

async def two():
return 2

registry = Registry(one, two)

# async def works here too:
def three(one, two):
return one + two

print(await registry.resolve(three))
# Prints 3
```

### Parameters are passed through

Your dependent functions can require keyword arguments which have been passed to the `.resolve()` call:
Expand Down
28 changes: 23 additions & 5 deletions asyncinject/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,17 @@ class Registry:
def __init__(self, *fns, parallel=True, timer=None):
self._registry = {}
self._graph = None
self._reversed = None
self.parallel = parallel
self.timer = timer
for fn in fns:
self.register(fn)

def register(self, fn):
self._registry[fn.__name__] = fn
# Clear _graph cache:
# Clear caches:
self._graph = None
self._reversed = None

def _make_time_logger(self, awaitable):
async def inner():
Expand All @@ -41,12 +43,28 @@ def graph(self):
}
return self._graph

@property
def reversed(self):
if self._reversed is None:
self._reversed = dict(reversed(pair) for pair in self._registry.items())
return self._reversed

async def resolve(self, fn, **kwargs):
try:
name = fn.__name__
except AttributeError:
if not isinstance(fn, str):
# It's a fn - is it a registered one?
name = self.reversed.get(fn)
if name is None:
# Special case - since it is not registered we need to
# introspect its parameters here and use resolve_multi
params = inspect.signature(fn).parameters.keys()
to_resolve = {p for p in params if p not in kwargs}
resolved = await self.resolve_multi(to_resolve, results=kwargs)
result = fn(**{param: resolved[param] for param in params})
if asyncio.iscoroutine(result):
result = await result
return result
else:
name = fn

results = await self.resolve_multi([name], results=kwargs)
return results[name]

Expand Down
27 changes: 27 additions & 0 deletions tests/test_asyncinject.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,3 +154,30 @@ async def d(b, c):
end = time.perf_counter()
# Should have taken ~0.2s
assert 0.18 < (end - start) < 0.22


@pytest.mark.asyncio
@pytest.mark.parametrize("use_async", (True, False))
async def test_resolve_unregistered_function(use_async):
# https://github.com/simonw/asyncinject/issues/13
async def one():
return 1

async def two():
return 2

registry = Registry(one, two)

async def three_async(one, two):
return one + two

def three_not_async(one, two):
return one + two

fn = three_async if use_async else three_not_async
result = await registry.resolve(fn)
assert result == 3

# Test that passing parameters works too
result2 = await registry.resolve(fn, one=2)
assert result2 == 4

0 comments on commit 3886726

Please sign in to comment.