Skip to content

Commit

Permalink
add resolver and refactor tests (#12)
Browse files Browse the repository at this point in the history
* add resolver and refactor tests

* python 3.10 tests fix
  • Loading branch information
lesnik512 authored May 11, 2024
1 parent 58393e8 commit 51155ed
Show file tree
Hide file tree
Showing 9 changed files with 170 additions and 71 deletions.
6 changes: 2 additions & 4 deletions .github/workflows/python-package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,7 @@ jobs:
with:
repo-token: ${{ github.token }}
- run: pip install poetry
- run: task install
- run: task lint-ci
- run: task install lint-ci

pytest:
runs-on: ubuntu-latest
Expand All @@ -43,5 +42,4 @@ jobs:
with:
repo-token: ${{ github.token }}
- run: pip install poetry
- run: task install
- run: task tests
- run: task install tests
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,15 @@ async def main():
some_dependency = await DIContainer.independent_factory()
```
2. No wiring for injections in function arguments -> achieved by decision that only one instance of container is supported

```python
from tests import container
from that_depends import Provide, inject


@inject
async def some_function(
independent_factory: container.IndependentFactory = Provide[container.DIContainer.independent_factory],
independent_factory: container.SimpleFactory = Provide[container.DIContainer.independent_factory],
) -> None:
assert independent_factory.dep1
```
Expand Down
44 changes: 22 additions & 22 deletions tests/container.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,20 +9,23 @@
logger = logging.getLogger(__name__)


def create_sync_resource() -> typing.Iterator[str]:
def create_sync_resource() -> typing.Iterator[datetime.datetime]:
logger.debug("Resource initiated")
yield "sync resource"
logger.debug("Resource destructed")
try:
yield datetime.datetime.now(tz=datetime.timezone.utc)
finally:
logger.debug("Resource destructed")


async def create_async_resource() -> typing.AsyncIterator[str]:
logger.debug("Async resource initiated")
yield "async resource"
logger.debug("Async resource destructed")
async def create_async_resource() -> typing.AsyncIterator[datetime.datetime]:
try:
yield datetime.datetime.now(tz=datetime.timezone.utc)
finally:
logger.debug("Async resource destructed")


@dataclasses.dataclass(kw_only=True, slots=True)
class IndependentFactory:
class SimpleFactory:
dep1: str
dep2: int

Expand All @@ -32,15 +35,16 @@ async def async_factory() -> datetime.datetime:


@dataclasses.dataclass(kw_only=True, slots=True)
class SyncDependentFactory:
independent_factory: IndependentFactory
sync_resource: str
class DependentFactory:
simple_factory: SimpleFactory
sync_resource: datetime.datetime
async_resource: datetime.datetime


@dataclasses.dataclass(kw_only=True, slots=True)
class AsyncDependentFactory:
independent_factory: IndependentFactory
async_resource: str
class FreeFactory:
dependent_factory: DependentFactory
sync_resource: str


@dataclasses.dataclass(kw_only=True, slots=True)
Expand All @@ -53,16 +57,12 @@ class DIContainer(BaseContainer):
async_resource = providers.AsyncResource(create_async_resource)
sequence = providers.List(sync_resource, async_resource)

independent_factory = providers.Factory(IndependentFactory, dep1="text", dep2=123)
simple_factory = providers.Factory(SimpleFactory, dep1="text", dep2=123)
async_factory = providers.AsyncFactory(async_factory)
sync_dependent_factory = providers.Factory(
SyncDependentFactory,
independent_factory=independent_factory,
dependent_factory = providers.Factory(
DependentFactory,
simple_factory=simple_factory,
sync_resource=sync_resource,
)
async_dependent_factory = providers.Factory(
AsyncDependentFactory,
independent_factory=independent_factory,
async_resource=async_resource,
)
singleton = providers.Singleton(SingletonFactory, dep1=True)
24 changes: 17 additions & 7 deletions tests/test_fastapi_di.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import datetime
import typing

import fastapi
Expand All @@ -14,18 +15,27 @@

@app.get("/")
async def read_root(
sync_dependency: typing.Annotated[
container.AsyncDependentFactory,
fastapi.Depends(container.DIContainer.async_dependent_factory),
dependency: typing.Annotated[
container.DependentFactory,
fastapi.Depends(container.DIContainer.dependent_factory),
],
) -> str:
return sync_dependency.async_resource
free_dependency: typing.Annotated[
container.FreeFactory,
fastapi.Depends(container.DIContainer.resolver(container.FreeFactory)),
],
) -> datetime.datetime:
assert dependency.sync_resource == free_dependency.dependent_factory.sync_resource
assert dependency.async_resource == free_dependency.dependent_factory.async_resource
return dependency.async_resource


client = TestClient(app)


def test_read_main() -> None:
async def test_read_main() -> None:
response = client.get("/")
assert response.status_code == status.HTTP_200_OK
assert response.json() == "async resource"
assert (
datetime.datetime.fromisoformat(response.json().replace("Z", "+00:00"))
== await container.DIContainer.async_resource()
)
14 changes: 8 additions & 6 deletions tests/test_injection.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import datetime

import pytest

from tests import container
Expand All @@ -12,12 +14,12 @@ def create_fixture_one() -> int:
@inject
async def test_injection(
fixture_one: int,
independent_factory: container.IndependentFactory = Provide[container.DIContainer.independent_factory],
async_dependent_factory: container.AsyncDependentFactory = Provide[container.DIContainer.async_dependent_factory],
simple_factory: container.SimpleFactory = Provide[container.DIContainer.simple_factory],
dependent_factory: container.DependentFactory = Provide[container.DIContainer.dependent_factory],
default_zero: int = 0,
) -> None:
assert independent_factory.dep1
assert async_dependent_factory.async_resource == "async resource"
assert simple_factory.dep1
assert isinstance(dependent_factory.async_resource, datetime.datetime)
assert True
assert default_zero == 0
assert fixture_one == 1
Expand All @@ -26,9 +28,9 @@ async def test_injection(
async def test_wrong_injection() -> None:
@inject
async def inner(
_: container.IndependentFactory = Provide[container.DIContainer.independent_factory],
_: container.SimpleFactory = Provide[container.DIContainer.simple_factory],
) -> None:
"""Do nothing."""

with pytest.raises(RuntimeError, match="Injected arguments must not be redefined"):
await inner(_=container.IndependentFactory(dep1="1", dep2=2))
await inner(_=container.SimpleFactory(dep1="1", dep2=2))
11 changes: 8 additions & 3 deletions tests/test_litestar_di_simple.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import datetime

from litestar import Litestar, get
from litestar.di import Provide
from litestar.status_codes import HTTP_200_OK
Expand All @@ -7,15 +9,18 @@


@get("/")
async def index(injected: str) -> str:
async def index(injected: datetime.datetime) -> datetime.datetime:
return injected


app = Litestar([index], dependencies={"injected": Provide(container.DIContainer.async_resource)})


def test_litestar_di() -> None:
async def test_litestar_di() -> None:
with TestClient(app=app) as client:
response = client.get("/")
assert response.status_code == HTTP_200_OK, response.text
assert response.text == "async resource"
assert (
datetime.datetime.fromisoformat(response.json().replace("Z", "+00:00"))
== await container.DIContainer.async_resource()
)
66 changes: 42 additions & 24 deletions tests/test_main_providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,51 +7,61 @@
from that_depends import inject, providers


async def test_main_providers() -> None:
independent_factory = await DIContainer.independent_factory()
sync_dependent_factory = await DIContainer.sync_dependent_factory()
async_dependent_factory = await DIContainer.async_dependent_factory()
async def test_factory_providers() -> None:
simple_factory = await DIContainer.simple_factory()
dependent_factory = await DIContainer.dependent_factory()
async_factory = await DIContainer.async_factory()
sync_resource = await DIContainer.sync_resource()
async_resource = await DIContainer.async_resource()

assert dependent_factory.simple_factory is not simple_factory
assert dependent_factory.sync_resource == sync_resource
assert dependent_factory.async_resource == async_resource
assert isinstance(async_factory, datetime.datetime)


async def test_list_provider() -> None:
sequence = await DIContainer.sequence()
sync_resource = await DIContainer.sync_resource()
async_resource = await DIContainer.async_resource()

assert sequence == [sync_resource, async_resource]


async def test_singleton_provider() -> None:
singleton1 = await DIContainer.singleton()
singleton2 = await DIContainer.singleton()
async_factory = await DIContainer.async_factory()

assert sync_dependent_factory.independent_factory is not independent_factory
assert sync_dependent_factory.sync_resource == "sync resource"
assert async_dependent_factory.async_resource == "async resource"
assert sequence == ["sync resource", "async resource"]
assert singleton1 is singleton2
assert isinstance(async_factory, datetime.datetime)


@inject
async def test_main_providers_overriding() -> None:
async_resource_mock = "async overriding"
sync_resource_mock = "sync overriding"
async_factory_mock = datetime.datetime.now(tz=datetime.timezone.utc)
independent_factory_mock = container.IndependentFactory(dep1="override", dep2=999)
async def test_providers_overriding() -> None:
async_resource_mock = datetime.datetime.fromisoformat("2023-01-01")
sync_resource_mock = datetime.datetime.fromisoformat("2024-01-01")
async_factory_mock = datetime.datetime.fromisoformat("2025-01-01")
simple_factory_mock = container.SimpleFactory(dep1="override", dep2=999)
singleton_mock = container.SingletonFactory(dep1=False)
container.DIContainer.async_resource.override(async_resource_mock)
container.DIContainer.sync_resource.override(sync_resource_mock)
container.DIContainer.independent_factory.override(independent_factory_mock)
container.DIContainer.simple_factory.override(simple_factory_mock)
container.DIContainer.singleton.override(singleton_mock)
container.DIContainer.async_factory.override(async_factory_mock)

await container.DIContainer.independent_factory()
sync_dependent_factory = await container.DIContainer.sync_dependent_factory()
async_dependent_factory = await container.DIContainer.async_dependent_factory()
await container.DIContainer.simple_factory()
dependent_factory = await container.DIContainer.dependent_factory()
singleton = await container.DIContainer.singleton()
async_factory = await container.DIContainer.async_factory()

assert sync_dependent_factory.independent_factory.dep1 == independent_factory_mock.dep1
assert sync_dependent_factory.independent_factory.dep2 == independent_factory_mock.dep2
assert sync_dependent_factory.sync_resource == sync_resource_mock
assert async_dependent_factory.async_resource == async_resource_mock
assert dependent_factory.simple_factory.dep1 == simple_factory_mock.dep1
assert dependent_factory.simple_factory.dep2 == simple_factory_mock.dep2
assert dependent_factory.sync_resource == sync_resource_mock
assert dependent_factory.async_resource == async_resource_mock
assert singleton is singleton_mock
assert async_factory is async_factory_mock

container.DIContainer.reset_override()
assert (await container.DIContainer.async_resource()) == "async resource"
assert (await container.DIContainer.async_resource()) != async_resource_mock


def test_wrong_providers_init() -> None:
Expand All @@ -65,3 +75,11 @@ def test_wrong_providers_init() -> None:
def test_container_init_error() -> None:
with pytest.raises(RuntimeError, match="DIContainer should not be instantiated"):
DIContainer()


async def test_free_dependency() -> None:
resolver = DIContainer.resolver(container.FreeFactory)
dep1 = await resolver()
dep2 = await DIContainer.resolve(container.FreeFactory)
assert dep1
assert dep2
29 changes: 29 additions & 0 deletions tests/test_resolver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import dataclasses

import pytest

from tests import container
from tests.container import DIContainer


@dataclasses.dataclass(kw_only=True, slots=True)
class WrongFactory:
some_dep: int = 1
not_existing_name: container.DependentFactory
sync_resource: str


async def test_dependency_resolver() -> None:
resolver = DIContainer.resolver(container.FreeFactory)
dep1 = await resolver()
dep2 = await DIContainer.resolve(container.FreeFactory)

assert dep1
assert dep2
assert dep1 is not dep2


async def test_dependency_resolver_failed() -> None:
resolver = DIContainer.resolver(WrongFactory)
with pytest.raises(RuntimeError, match="Provider is not found, field_name='not_existing_name'"):
await resolver()
Loading

0 comments on commit 51155ed

Please sign in to comment.