Skip to content

Commit

Permalink
Merge pull request #14 from modern-python/feature/injecting-factory
Browse files Browse the repository at this point in the history
implement factories injecting
  • Loading branch information
lesnik512 authored Nov 10, 2024
2 parents 4ff483c + 48423de commit 7d74cab
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 0 deletions.
5 changes: 5 additions & 0 deletions packages/modern-di/modern_di/providers/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from modern_di import Container
from modern_di.providers.abstract import AbstractCreatorProvider
from modern_di.providers.injected_factory import InjectedFactory


T_co = typing.TypeVar("T_co", covariant=True)
Expand All @@ -21,6 +22,10 @@ def __init__(
) -> None:
super().__init__(scope, creator, *args, **kwargs)

@property
def factory_provider(self) -> InjectedFactory[T_co]:
return InjectedFactory(self)

async def async_resolve(self, container: Container) -> T_co:
container = container.find_container(self.scope)
if (override := container.fetch_override(self.provider_id)) is not None:
Expand Down
24 changes: 24 additions & 0 deletions packages/modern-di/modern_di/providers/injected_factory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import functools
import typing

from modern_di import Container
from modern_di.providers.abstract import AbstractProvider


T_co = typing.TypeVar("T_co", covariant=True)
P = typing.ParamSpec("P")


class InjectedFactory(typing.Generic[T_co]):
__slots__ = ("_factory_provider",)

def __init__(self, factory_provider: AbstractProvider[T_co]) -> None:
self._factory_provider = factory_provider

async def async_resolve(self, container: Container) -> typing.Callable[[], T_co]:
await self._factory_provider.async_resolve(container)
return functools.partial(self._factory_provider.sync_resolve, container)

def sync_resolve(self, container: Container) -> typing.Callable[[], T_co]:
self._factory_provider.sync_resolve(container)
return functools.partial(self._factory_provider.sync_resolve, container)
54 changes: 54 additions & 0 deletions packages/modern-di/tests_core/providers/test_injected_factory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import dataclasses
import datetime

import pytest
from modern_di import Container, Scope, providers

from tests_core.creators import create_async_resource, create_sync_resource


@dataclasses.dataclass(kw_only=True, slots=True)
class DependentCreator:
dep1: datetime.datetime


async_resource = providers.Resource(Scope.APP, create_async_resource)
sync_resource = providers.Resource(Scope.APP, create_sync_resource)
request_sync_factory = providers.Factory(Scope.REQUEST, DependentCreator, dep1=sync_resource.cast)
request_async_factory = providers.Factory(Scope.REQUEST, DependentCreator, dep1=async_resource.cast)


async def test_injected_async_factory() -> None:
async with (
Container(scope=Scope.APP) as app_container,
app_container.build_child_container(scope=Scope.REQUEST) as request_container,
):
factory = await request_async_factory.factory_provider.async_resolve(request_container)
instance1, instance2 = factory(), factory()
assert instance1 is not instance2
assert isinstance(instance1, DependentCreator)
assert isinstance(instance2, DependentCreator)


async def test_injected_sync_factory() -> None:
with (
Container(scope=Scope.APP) as app_container,
app_container.build_child_container(scope=Scope.REQUEST) as request_container,
):
factory = request_sync_factory.factory_provider.sync_resolve(request_container)
instance1, instance2 = factory(), factory()
assert instance1 is not instance2
assert isinstance(instance1, DependentCreator)
assert isinstance(instance2, DependentCreator)


async def test_injected_async_factory_in_sync_mode() -> None:
with (
Container(scope=Scope.APP) as app_container,
app_container.build_child_container(scope=Scope.REQUEST) as request_container,
):
with pytest.raises(RuntimeError, match="Resolving async resource in sync container is not allowed"):
await request_async_factory.factory_provider.async_resolve(request_container)

with pytest.raises(RuntimeError, match="Async resource cannot be resolved synchronously"):
request_async_factory.factory_provider.sync_resolve(request_container)

0 comments on commit 7d74cab

Please sign in to comment.