From 59780fd38be5437d9996122476db0ac9ddc6084e Mon Sep 17 00:00:00 2001 From: Simon Willison Date: Fri, 3 Dec 2021 10:41:41 -0800 Subject: [PATCH] Use __init_subclass__ instead of metaclass, refs #2 --- asyncinject/__init__.py | 54 ++++++++++++++++++----------------------- 1 file changed, 24 insertions(+), 30 deletions(-) diff --git a/asyncinject/__init__.py b/asyncinject/__init__.py index e8a1f42..8323b67 100644 --- a/asyncinject/__init__.py +++ b/asyncinject/__init__.py @@ -14,33 +14,6 @@ def inject(fn): return fn -class AsyncInjectMeta(type): - def __new__(cls, name, bases, attrs): - # Decorate any items that are 'async def' methods - _registry = {} - new_attrs = {"_registry": _registry} - inject_all = "AsyncInjectAll" in (b.__name__ for b in bases) - for key, value in attrs.items(): - if inspect.iscoroutinefunction(value) and ( - inject_all or getattr(value, "_inject", None) - ): - new_attrs[key] = _make_method(value) - _registry[key] = new_attrs[key] - else: - new_attrs[key] = value - # Gather graph for later dependency resolution - graph = { - key: { - p - for p in inspect.signature(method).parameters.keys() - if p != "self" and not p.startswith("_") - } - for key, method in _registry.items() - } - new_attrs["_graph"] = graph - return super().__new__(cls, name, bases, new_attrs) - - def _make_method(method): parameters = inspect.signature(method).parameters @@ -76,12 +49,33 @@ async def inner(self, **kwargs): return inner -class AsyncInject(metaclass=AsyncInjectMeta): - pass +class AsyncInject: + def __init_subclass__(cls, **kwargs): + super().__init_subclass__(**kwargs) + # Decorate any items that are 'async def' methods + cls._registry = {} + inject_all = getattr(cls, "_inject_all", False) + for name in dir(cls): + value = getattr(cls, name) + if inspect.iscoroutinefunction(value) and ( + inject_all or getattr(value, "_inject", None) + ): + setattr(cls, name, _make_method(value)) + cls._registry[name] = getattr(cls, name) + # Gather graph for later dependency resolution + graph = { + key: { + p + for p in inspect.signature(method).parameters.keys() + if p != "self" and not p.startswith("_") + } + for key, method in cls._registry.items() + } + cls._graph = graph class AsyncInjectAll(AsyncInject): - pass + _inject_all = True async def resolve(instance, names, results=None):