-
-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Switch from class-based to function-based (#9)
Major rewrite of functionality. This will break anything that uses this library! * Use functions instead of class methods, closes #8 * Also includes parallel= mechanism from #6
- Loading branch information
Showing
3 changed files
with
173 additions
and
282 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,125 +1,87 @@ | ||
import asyncio | ||
from functools import wraps | ||
import inspect | ||
|
||
try: | ||
import graphlib | ||
except ImportError: | ||
from . import vendored_graphlib as graphlib | ||
import asyncio | ||
|
||
|
||
def inject(fn): | ||
"Mark method as having dependency-injected parameters" | ||
fn._inject = True | ||
return fn | ||
|
||
|
||
def _make_method(method): | ||
parameters = inspect.signature(method).parameters | ||
|
||
@wraps(method) | ||
async def inner(self, **kwargs): | ||
# Any parameters not provided by kwargs are resolved from registry | ||
to_resolve = [ | ||
p | ||
for p in parameters | ||
# Not already provided | ||
if p not in kwargs | ||
# Not self | ||
and p != "self" | ||
# Doesn't have a default value | ||
and parameters[p].default is inspect._empty | ||
] | ||
missing = [p for p in to_resolve if p not in self._registry] | ||
assert ( | ||
not missing | ||
), "The following DI parameters could not be found in the registry: {}".format( | ||
missing | ||
) | ||
|
||
results = {} | ||
results.update(kwargs) | ||
if to_resolve: | ||
resolved_parameters = await resolve(self, to_resolve, results) | ||
results.update(resolved_parameters) | ||
return await method( | ||
self, **{k: v for k, v in results.items() if k in parameters} | ||
) | ||
|
||
return inner | ||
|
||
|
||
class AsyncInject: | ||
def _log(self, message): | ||
pass | ||
|
||
def __init_subclass__(cls, **kwargs): | ||
super().__init_subclass__(**kwargs) | ||
# Decorate any items that are 'async def' methods | ||
cls._registry = {} | ||
inject_all = getattr(cls, "_inject_all", False) | ||
for name in dir(cls): | ||
value = getattr(cls, name) | ||
if inspect.iscoroutinefunction(value) and ( | ||
inject_all or getattr(value, "_inject", None) | ||
): | ||
setattr(cls, name, _make_method(value)) | ||
cls._registry[name] = getattr(cls, name) | ||
# Gather graph for later dependency resolution | ||
graph = { | ||
key: { | ||
p | ||
for p in inspect.signature(method).parameters.keys() | ||
if p != "self" and not p.startswith("_") | ||
class AsyncRegistry: | ||
def __init__(self, *fns, parallel=True, log=None): | ||
self._registry = {} | ||
self._graph = None | ||
self.parallel = parallel | ||
self.log = log or (lambda *args: None) | ||
for fn in fns: | ||
self.register(fn) | ||
|
||
def register(self, fn): | ||
self._registry[fn.__name__] = fn | ||
# Clear _graph cache: | ||
self._graph = None | ||
|
||
@property | ||
def graph(self): | ||
if self._graph is None: | ||
self._graph = { | ||
key: { | ||
p | ||
for p in inspect.signature(fn).parameters.keys() | ||
if not p.startswith("_") | ||
} | ||
for key, fn in self._registry.items() | ||
} | ||
for key, method in cls._registry.items() | ||
} | ||
cls._graph = graph | ||
|
||
|
||
class AsyncInjectAll(AsyncInject): | ||
_inject_all = True | ||
|
||
|
||
async def resolve(instance, names, results=None): | ||
if results is None: | ||
results = {} | ||
|
||
# Come up with an execution plan, just for these nodes | ||
ts = graphlib.TopologicalSorter() | ||
to_do = set(names) | ||
done = set() | ||
while to_do: | ||
item = to_do.pop() | ||
dependencies = instance._graph.get(item) or set() | ||
ts.add(item, *dependencies) | ||
done.add(item) | ||
# Add any not-done dependencies to the queue | ||
to_do.update({k for k in dependencies if k not in done}) | ||
|
||
ts.prepare() | ||
plan = [] | ||
while ts.is_active(): | ||
node_group = ts.get_ready() | ||
plan.append(node_group) | ||
ts.done(*node_group) | ||
|
||
instance._log( | ||
"Resolving {} in {}>".format(names, repr(instance).split(" object at ")[0]) | ||
) | ||
|
||
for node_group in plan: | ||
awaitable_names = [name for name in node_group if name in instance._registry] | ||
instance._log(" Run {}".format(awaitable_names)) | ||
awaitables = [ | ||
instance._registry[name]( | ||
instance, | ||
_results=results, | ||
**{k: v for k, v in results.items() if k in instance._graph[name]}, | ||
) | ||
for name in awaitable_names | ||
] | ||
awaitable_results = await asyncio.gather(*awaitables) | ||
results.update(dict(zip(awaitable_names, awaitable_results))) | ||
|
||
return results | ||
return self._graph | ||
|
||
async def resolve(self, fn, **kwargs): | ||
try: | ||
name = fn.__name__ | ||
except AttributeError: | ||
name = fn | ||
|
||
results = await self.resolve_multi([name], results=kwargs) | ||
return results[name] | ||
|
||
async def resolve_multi(self, names, results=None): | ||
if results is None: | ||
results = {} | ||
|
||
# Come up with an execution plan, just for these nodes | ||
ts = graphlib.TopologicalSorter() | ||
to_do = set(names) | ||
done = set(results.keys()) | ||
while to_do: | ||
item = to_do.pop() | ||
dependencies = self.graph.get(item) or set() | ||
ts.add(item, *dependencies) | ||
done.add(item) | ||
# Add any not-done dependencies to the queue | ||
to_do.update({k for k in dependencies if k not in done}) | ||
|
||
ts.prepare() | ||
plan = [] | ||
while ts.is_active(): | ||
node_group = ts.get_ready() | ||
plan.append(node_group) | ||
ts.done(*node_group) | ||
|
||
self.log("Resolving {}".format(names)) | ||
|
||
for node_group in plan: | ||
awaitable_names = [name for name in node_group if name in self._registry] | ||
self.log(" Run {}".format(sorted(awaitable_names))) | ||
awaitables = [ | ||
self._registry[name]( | ||
**{k: v for k, v in results.items() if k in self.graph[name]}, | ||
) | ||
for name in awaitable_names | ||
] | ||
if self.parallel: | ||
awaitable_results = await asyncio.gather(*awaitables) | ||
else: | ||
awaitable_results = (await fn() for fn in awaitables) | ||
results.update(dict(zip(awaitable_names, awaitable_results))) | ||
|
||
return results |
Oops, something went wrong.