Skip to content

Commit

Permalink
@Inject decorator or inject_all = True for AsyncBase, refs #878
Browse files Browse the repository at this point in the history
  • Loading branch information
simonw committed Nov 16, 2021
1 parent 86aaa7c commit 22f41f7
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 2 deletions.
12 changes: 11 additions & 1 deletion datasette/utils/asyncdi.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,23 @@
from . import vendored_graphlib as graphlib


def inject(fn):
fn._inject = True
return fn


class AsyncMeta(type):
def __new__(cls, name, bases, attrs):
# Decorate any items that are 'async def' methods
_registry = {}
new_attrs = {"_registry": _registry}
inject_all = attrs.get("inject_all")
for key, value in attrs.items():
if inspect.iscoroutinefunction(value) and not value.__name__ == "resolve":
if (
inspect.iscoroutinefunction(value)
and not value.__name__ == "resolve"
and (inject_all or getattr(value, "_inject", None))
):
new_attrs[key] = make_method(value)
_registry[key] = new_attrs[key]
else:
Expand Down
12 changes: 11 additions & 1 deletion tests/test_asyncdi.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import asyncio
from datasette.utils.asyncdi import AsyncBase
from datasette.utils.asyncdi import AsyncBase, inject
import pytest
from random import random

Expand All @@ -8,15 +8,22 @@ class Simple(AsyncBase):
def __init__(self):
self.log = []

@inject
async def two(self):
self.log.append("two")

@inject
async def one(self, two):
self.log.append("one")
return self.log

async def not_inject(self, one, two):
return one + two


class Complex(AsyncBase):
inject_all = True

def __init__(self):
self.log = []

Expand All @@ -40,6 +47,8 @@ async def go(self, a):


class WithParameters(AsyncBase):
inject_all = True

async def go(self, calc1, calc2, param1):
return param1 + calc1 + calc2

Expand All @@ -53,6 +62,7 @@ async def calc2(self):
@pytest.mark.asyncio
async def test_simple():
assert await Simple().one() == ["two", "one"]
assert await Simple().not_inject(6, 7) == 13


@pytest.mark.asyncio
Expand Down

0 comments on commit 22f41f7

Please sign in to comment.