Skip to content

Commit

Permalink
Switch from class-based to function-based (#9)
Browse files Browse the repository at this point in the history
Major rewrite of functionality. This will break anything that uses this library!

* Use functions instead of class methods, closes #8
* Also includes parallel= mechanism from #6
  • Loading branch information
simonw authored Apr 15, 2022
1 parent 6b36a71 commit 246a0fc
Show file tree
Hide file tree
Showing 3 changed files with 173 additions and 282 deletions.
117 changes: 39 additions & 78 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,57 +16,57 @@ Install this library using `pip`:

This library is inspired by [pytest fixtures](https://docs.pytest.org/en/6.2.x/fixture.html).

The idea is to simplify executing parallel `asyncio` operations by allowing them to be collected in a class, with the names of parameters to the class methods specifying which other methods should be executed first.
The idea is to simplify executing parallel `asyncio` operations by allowing them to be defined using a collection of functions, where the function arguments represent dependent functions that need to be executed first.

This then allows the library to create and execute a plan for executing various dependent methods in parallel.
The library can then create and execute a plan for executing the required functions in parallel in the most efficient sequence possible.

Here's an example, using the [httpx](https://www.python-httpx.org/) HTTP library.

```python
from asyncinject import AsyncInjectAll
from asyncinject import AsyncRegistry
import httpx


async def get(url):
async with httpx.AsyncClient() as client:
return (await client.get(url)).text

class FetchThings(AsyncInjectAll):
async def example(self):
return await get("http://www.example.com/")

async def simonwillison(self):
return await get("https://simonwillison.net/search/?tag=empty")
async def example():
return await get("http://www.example.com/")

async def both(self, example, simonwillison):
return example + "\n\n" + simonwillison
async def simonwillison():
return await get("https://simonwillison.net/search/?tag=empty")

async def both(example, simonwillison):
return example + "\n\n" + simonwillison

combined = await FetchThings().both()
registry = AsyncRegistry(example, simonwillison, both)
combined = await registry.resolve(both)
print(combined)
```
If you run this in `ipython` (which supports top-level await) you will see output that combines HTML from both of those pages.

The HTTP requests to `www.example.com` and `simonwillison.net` will be performed in parallel.

The library will notice that `both()` takes two arguments which are the names of other `async def` methods on that class, and will construct an execution plan that executes those two methods in parallel, then passes their results to the `both()` method.
The library notices that `both()` takes two arguments which are the names of other registered `async def` functions, and will construct an execution plan that executes those two functions in parallel, then passes their results to the `both()` method.

### Parameters are passed through

Your dependent methods can require keyword arguments which are passed to the original method.
Your dependent functions can require keyword arguments which have been passed to the `.resolve()` call:

```python
class FetchWithParams(AsyncInjectAll):
async def get_param_1(self, param1):
return await get(param1)
async def get_param_1(param1):
return await get(param1)

async def get_param_2(self, param2):
return await get(param2)
async def get_param_2(param2):
return await get(param2)

async def both(self, get_param_1, get_param_2):
return get_param_1 + "\n\n" + get_param_2
async def both(get_param_1, get_param_2):
return get_param_1 + "\n\n" + get_param_2


combined = await FetchWithParams().both(
combined = await AsyncRegistry(get_param_1, get_param_2, both).resolve(
both,
param1 = "http://www.example.com/",
param2 = "https://simonwillison.net/search/?tag=empty"
)
Expand All @@ -77,70 +77,35 @@ print(combined)
You can opt a parameter out of the dependency injection mechanism by assigning it a default value:

```python
class IgnoreDefaultParameters(AsyncInjectAll):
async def go(self, calc1, x=5):
return calc1 + x
async def go(calc1, x=5):
return calc1 + x

async def calc1(self):
return 5
async def calc1():
return 5

print(await IgnoreDefaultParameters().go())
print(await AsyncRegistry(calc1, go).resolve(go))
# Prints 10
```

### AsyncInject and @inject

The above example illustrates the `AsyncInjectAll` class, which assumes that every `async def` method on the class should be treated as a dependency injection method.

You can also specify individual methods using the `AsyncInject` base class an the `@inject` decorator:

```python
from asyncinject import AsyncInject, inject

class FetchThings(AsyncInject):
@inject
async def example(self):
return await get("http://www.example.com/")

@inject
async def simonwillison(self):
return await get("https://simonwillison.net/search/?tag=empty")

@inject
async def both(self, example, simonwillison):
return example + "\n\n" + simonwillison
```
### The resolve() function

If you want to execute a set of methods in parallel without defining a third method that lists them as parameters, you can do so using the `resolve()` function. This will execute the specified methods (in parallel, where possible) and return a dictionary of the results.

```python
from asyncinject import resolve

fetcher = FetchThings()
results = await resolve(fetcher, ["example", "simonwillison"])
```
`results` will now be:
```json
{
"example": "contents of http://www.example.com/",
"simonwillison": "contents of https://simonwillison.net/search/?tag=empty"
}
```
### Debug logging

You can assign a `_log` method to your class or instance to see the execution plan when it runs. Your `_log` method should take a single `message` argument - the easiest way to do this is to use `print`:
You can pass a `log=` callable to the `AsyncRegistry` constructor. Your function should take a single `message` argument - the easiest way to do this is to use `print`:
```python
fetcher = FetchThings()
fetcher._log = print
combined = await fetcher.both()
combined = await AsyncRegistry(
get_param_1, get_param_2, both, log=print
).resolve(
both,
param1 = "http://www.example.com/",
param2 = "https://simonwillison.net/search/?tag=empty"
)
```
This will output:
```
Resolving ['example', 'simonwillison'] in <__main__.FetchThings>
Run ['example', 'simonwillison']
Resolving ['both']
Run []
Run ['get_param_2', 'get_param_1']
Run ['both']
```

## Development

To contribute to this library, first checkout the code. Then create a new virtual environment:
Expand All @@ -149,10 +114,6 @@ To contribute to this library, first checkout the code. Then create a new virtua
python -m venv venv
source venv/bin/activate

Or if you are using `pipenv`:

pipenv shell

Now install the dependencies and test dependencies:

pip install -e '.[test]'
Expand Down
192 changes: 77 additions & 115 deletions asyncinject/__init__.py
Original file line number Diff line number Diff line change
@@ -1,125 +1,87 @@
import asyncio
from functools import wraps
import inspect

try:
import graphlib
except ImportError:
from . import vendored_graphlib as graphlib
import asyncio


def inject(fn):
"Mark method as having dependency-injected parameters"
fn._inject = True
return fn


def _make_method(method):
parameters = inspect.signature(method).parameters

@wraps(method)
async def inner(self, **kwargs):
# Any parameters not provided by kwargs are resolved from registry
to_resolve = [
p
for p in parameters
# Not already provided
if p not in kwargs
# Not self
and p != "self"
# Doesn't have a default value
and parameters[p].default is inspect._empty
]
missing = [p for p in to_resolve if p not in self._registry]
assert (
not missing
), "The following DI parameters could not be found in the registry: {}".format(
missing
)

results = {}
results.update(kwargs)
if to_resolve:
resolved_parameters = await resolve(self, to_resolve, results)
results.update(resolved_parameters)
return await method(
self, **{k: v for k, v in results.items() if k in parameters}
)

return inner


class AsyncInject:
def _log(self, message):
pass

def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
# Decorate any items that are 'async def' methods
cls._registry = {}
inject_all = getattr(cls, "_inject_all", False)
for name in dir(cls):
value = getattr(cls, name)
if inspect.iscoroutinefunction(value) and (
inject_all or getattr(value, "_inject", None)
):
setattr(cls, name, _make_method(value))
cls._registry[name] = getattr(cls, name)
# Gather graph for later dependency resolution
graph = {
key: {
p
for p in inspect.signature(method).parameters.keys()
if p != "self" and not p.startswith("_")
class AsyncRegistry:
def __init__(self, *fns, parallel=True, log=None):
self._registry = {}
self._graph = None
self.parallel = parallel
self.log = log or (lambda *args: None)
for fn in fns:
self.register(fn)

def register(self, fn):
self._registry[fn.__name__] = fn
# Clear _graph cache:
self._graph = None

@property
def graph(self):
if self._graph is None:
self._graph = {
key: {
p
for p in inspect.signature(fn).parameters.keys()
if not p.startswith("_")
}
for key, fn in self._registry.items()
}
for key, method in cls._registry.items()
}
cls._graph = graph


class AsyncInjectAll(AsyncInject):
_inject_all = True


async def resolve(instance, names, results=None):
if results is None:
results = {}

# Come up with an execution plan, just for these nodes
ts = graphlib.TopologicalSorter()
to_do = set(names)
done = set()
while to_do:
item = to_do.pop()
dependencies = instance._graph.get(item) or set()
ts.add(item, *dependencies)
done.add(item)
# Add any not-done dependencies to the queue
to_do.update({k for k in dependencies if k not in done})

ts.prepare()
plan = []
while ts.is_active():
node_group = ts.get_ready()
plan.append(node_group)
ts.done(*node_group)

instance._log(
"Resolving {} in {}>".format(names, repr(instance).split(" object at ")[0])
)

for node_group in plan:
awaitable_names = [name for name in node_group if name in instance._registry]
instance._log(" Run {}".format(awaitable_names))
awaitables = [
instance._registry[name](
instance,
_results=results,
**{k: v for k, v in results.items() if k in instance._graph[name]},
)
for name in awaitable_names
]
awaitable_results = await asyncio.gather(*awaitables)
results.update(dict(zip(awaitable_names, awaitable_results)))

return results
return self._graph

async def resolve(self, fn, **kwargs):
try:
name = fn.__name__
except AttributeError:
name = fn

results = await self.resolve_multi([name], results=kwargs)
return results[name]

async def resolve_multi(self, names, results=None):
if results is None:
results = {}

# Come up with an execution plan, just for these nodes
ts = graphlib.TopologicalSorter()
to_do = set(names)
done = set(results.keys())
while to_do:
item = to_do.pop()
dependencies = self.graph.get(item) or set()
ts.add(item, *dependencies)
done.add(item)
# Add any not-done dependencies to the queue
to_do.update({k for k in dependencies if k not in done})

ts.prepare()
plan = []
while ts.is_active():
node_group = ts.get_ready()
plan.append(node_group)
ts.done(*node_group)

self.log("Resolving {}".format(names))

for node_group in plan:
awaitable_names = [name for name in node_group if name in self._registry]
self.log(" Run {}".format(sorted(awaitable_names)))
awaitables = [
self._registry[name](
**{k: v for k, v in results.items() if k in self.graph[name]},
)
for name in awaitable_names
]
if self.parallel:
awaitable_results = await asyncio.gather(*awaitables)
else:
awaitable_results = (await fn() for fn in awaitables)
results.update(dict(zip(awaitable_names, awaitable_results)))

return results
Loading

0 comments on commit 246a0fc

Please sign in to comment.