diff --git a/docs/providers/singleton.md b/docs/providers/singleton.md index 3c42964..ecf8876 100644 --- a/docs/providers/singleton.md +++ b/docs/providers/singleton.md @@ -1,4 +1,4 @@ -# Singleton Provider +# Singleton Singleton providers resolve the dependency only once and cache the resolved instance for future injections. diff --git a/tests/providers/test_attr_getter.py b/tests/providers/test_attr_getter.py index 6876340..b09deda 100644 --- a/tests/providers/test_attr_getter.py +++ b/tests/providers/test_attr_getter.py @@ -1,10 +1,12 @@ import random +import typing from dataclasses import dataclass, field import pytest from that_depends import providers -from that_depends.providers.attr_getter import _get_value_from_object_by_dotted_path +from that_depends.providers.base import _get_value_from_object_by_dotted_path +from that_depends.providers.context_resources import container_context @dataclass @@ -24,25 +26,80 @@ class Settings: nested1_attr: Nested1 = field(default_factory=Nested1) +async def return_settings_async() -> Settings: + return Settings() + + +async def yield_settings_async() -> typing.AsyncIterator[Settings]: + yield Settings() + + +def yield_settings_sync() -> typing.Iterator[Settings]: + yield Settings() + + @dataclass class NestingTestDTO: ... -@pytest.fixture -def some_settings_provider() -> providers.Singleton[Settings]: - return providers.Singleton(Settings) +@pytest.fixture( + params=[ + providers.Resource(yield_settings_sync), + providers.Singleton(Settings), + providers.ContextResource(yield_settings_sync), + providers.Object(Settings()), + providers.Factory(Settings), + providers.Selector(lambda: "sync", sync=providers.Factory(Settings)), + ] +) +def some_sync_settings_provider(request: pytest.FixtureRequest) -> providers.AbstractProvider[Settings]: + return typing.cast(providers.AbstractProvider[Settings], request.param) + + +@pytest.fixture( + params=[ + providers.AsyncFactory(return_settings_async), + providers.Resource(yield_settings_async), + providers.ContextResource(yield_settings_async), + providers.Selector(lambda: "asynchronous", asynchronous=providers.AsyncFactory(return_settings_async)), + ] +) +def some_async_settings_provider(request: pytest.FixtureRequest) -> providers.AbstractProvider[Settings]: + return typing.cast(providers.AbstractProvider[Settings], request.param) -def test_attr_getter_with_zero_attribute_depth(some_settings_provider: providers.Singleton[Settings]) -> None: - attr_getter = some_settings_provider.some_str_value +@container_context() +def test_attr_getter_with_zero_attribute_depth_sync( + some_sync_settings_provider: providers.AbstractProvider[Settings], +) -> None: + attr_getter = some_sync_settings_provider.some_str_value assert attr_getter.sync_resolve() == Settings().some_str_value -def test_attr_getter_with_more_than_zero_attribute_depth(some_settings_provider: providers.Singleton[Settings]) -> None: - attr_getter = some_settings_provider.nested1_attr.nested2_attr.some_const +@container_context() +async def test_attr_getter_with_zero_attribute_depth_async( + some_async_settings_provider: providers.AbstractProvider[Settings], +) -> None: + attr_getter = some_async_settings_provider.some_str_value + assert await attr_getter.async_resolve() == Settings().some_str_value + + +@container_context() +def test_attr_getter_with_more_than_zero_attribute_depth_sync( + some_sync_settings_provider: providers.AbstractProvider[Settings], +) -> None: + attr_getter = some_sync_settings_provider.nested1_attr.nested2_attr.some_const assert attr_getter.sync_resolve() == Nested2().some_const +@container_context() +async def test_attr_getter_with_more_than_zero_attribute_depth_async( + some_async_settings_provider: providers.AbstractProvider[Settings], +) -> None: + attr_getter = some_async_settings_provider.nested1_attr.nested2_attr.some_const + assert await attr_getter.async_resolve() == Nested2().some_const + + @pytest.mark.parametrize( ("field_count", "test_field_name", "test_value"), [(1, "test_field", "sdf6fF^SF(FF*4ffsf"), (5, "nested_field", -252625), (50, "50_lvl_field", 909234235)], @@ -66,10 +123,25 @@ def test_nesting_levels(field_count: int, test_field_name: str, test_value: str assert attr_value == test_value -def test_attr_getter_with_invalid_attribute(some_settings_provider: providers.Singleton[Settings]) -> None: +@container_context() +def test_attr_getter_with_invalid_attribute_sync( + some_sync_settings_provider: providers.AbstractProvider[Settings], +) -> None: + with pytest.raises(AttributeError): + some_sync_settings_provider.nested1_attr.nested2_attr.__some_private__ # noqa: B018 + with pytest.raises(AttributeError): + some_sync_settings_provider.nested1_attr.__another_private__ # noqa: B018 + with pytest.raises(AttributeError): + some_sync_settings_provider.nested1_attr._final_private_ # noqa: B018 + + +@container_context() +async def test_attr_getter_with_invalid_attribute_async( + some_async_settings_provider: providers.AbstractProvider[Settings], +) -> None: with pytest.raises(AttributeError): - some_settings_provider.nested1_attr.nested2_attr.__some_private__ # noqa: B018 + some_async_settings_provider.nested1_attr.nested2_attr.__some_private__ # noqa: B018 with pytest.raises(AttributeError): - some_settings_provider.nested1_attr.__another_private__ # noqa: B018 + some_async_settings_provider.nested1_attr.__another_private__ # noqa: B018 with pytest.raises(AttributeError): - some_settings_provider.nested1_attr._final_private_ # noqa: B018 + some_async_settings_provider.nested1_attr._final_private_ # noqa: B018 diff --git a/that_depends/injection.py b/that_depends/injection.py index 8668d73..6a440a7 100644 --- a/that_depends/injection.py +++ b/that_depends/injection.py @@ -37,7 +37,7 @@ async def inner(*args: P.args, **kwargs: P.kwargs) -> T: if field_name in kwargs: continue - kwargs[field_name] = await field_value.default() + kwargs[field_name] = await field_value.default.async_resolve() injected = True if not injected: warnings.warn( diff --git a/that_depends/providers/__init__.py b/that_depends/providers/__init__.py index 86f9416..c9f7b86 100644 --- a/that_depends/providers/__init__.py +++ b/that_depends/providers/__init__.py @@ -1,5 +1,4 @@ -from that_depends.providers.attr_getter import AttrGetter -from that_depends.providers.base import AbstractProvider +from that_depends.providers.base import AbstractProvider, AttrGetter from that_depends.providers.collections import Dict, List from that_depends.providers.context_resources import ( AsyncContextResource, diff --git a/that_depends/providers/attr_getter.py b/that_depends/providers/attr_getter.py deleted file mode 100644 index 3529651..0000000 --- a/that_depends/providers/attr_getter.py +++ /dev/null @@ -1,41 +0,0 @@ -import typing -from operator import attrgetter - -from that_depends.providers.base import AbstractProvider - - -T_co = typing.TypeVar("T_co", covariant=True) -P = typing.ParamSpec("P") - - -def _get_value_from_object_by_dotted_path(obj: typing.Any, path: str) -> typing.Any: # noqa: ANN401 - attribute_getter = attrgetter(path) - return attribute_getter(obj) - - -class AttrGetter( - AbstractProvider[T_co], -): - __slots__ = "_provider", "_attrs" - - def __init__(self, provider: AbstractProvider[T_co], attr_name: str) -> None: - super().__init__() - self._provider = provider - self._attrs = [attr_name] - - def __getattr__(self, attr: str) -> "AttrGetter[T_co]": - if attr.startswith("_"): - msg = f"'{type(self)}' object has no attribute '{attr}'" - raise AttributeError(msg) - self._attrs.append(attr) - return self - - async def async_resolve(self) -> typing.Any: # noqa: ANN401 - resolved_provider_object = await self._provider.async_resolve() - attribute_path = ".".join(self._attrs) - return _get_value_from_object_by_dotted_path(resolved_provider_object, attribute_path) - - def sync_resolve(self) -> typing.Any: # noqa: ANN401 - resolved_provider_object = self._provider.sync_resolve() - attribute_path = ".".join(self._attrs) - return _get_value_from_object_by_dotted_path(resolved_provider_object, attribute_path) diff --git a/that_depends/providers/base.py b/that_depends/providers/base.py index 04c7d32..9f0a625 100644 --- a/that_depends/providers/base.py +++ b/that_depends/providers/base.py @@ -4,6 +4,7 @@ import inspect import typing from contextlib import contextmanager +from operator import attrgetter T_co = typing.TypeVar("T_co", covariant=True) @@ -18,6 +19,12 @@ def __init__(self) -> None: super().__init__() self._override: typing.Any = None + def __getattr__(self, attr_name: str) -> typing.Any: # noqa: ANN401 + if attr_name.startswith("_"): + msg = f"'{type(self)}' object has no attribute '{attr_name}'" + raise AttributeError(msg) + return AttrGetter(provider=self, attr_name=attr_name) + @abc.abstractmethod async def async_resolve(self) -> T_co: """Resolve dependency asynchronously.""" @@ -137,7 +144,6 @@ def __init__( self._creator: typing.Final = creator self._args: typing.Final = args self._kwargs: typing.Final = kwargs - self._override = None def _is_creator_async( self, _: typing.Callable[P, typing.Iterator[T_co] | typing.AsyncIterator[T_co]] @@ -174,9 +180,12 @@ async def async_resolve(self) -> T_co: T_co, await context.context_stack.enter_async_context( contextlib.asynccontextmanager(self._creator)( - *[await x() if isinstance(x, AbstractProvider) else x for x in self._args], + *[ + await x.async_resolve() if isinstance(x, AbstractProvider) else x + for x in self._args + ], **{ - k: await v() if isinstance(v, AbstractProvider) else v + k: await v.async_resolve() if isinstance(v, AbstractProvider) else v for k, v in self._kwargs.items() }, ), @@ -228,3 +237,36 @@ def provider(self) -> typing.Callable[[], typing.Coroutine[typing.Any, typing.An @property def sync_provider(self) -> typing.Callable[[], T_co]: return self.sync_resolve + + +def _get_value_from_object_by_dotted_path(obj: typing.Any, path: str) -> typing.Any: # noqa: ANN401 + attribute_getter = attrgetter(path) + return attribute_getter(obj) + + +class AttrGetter( + AbstractProvider[T_co], +): + __slots__ = "_provider", "_attrs" + + def __init__(self, provider: AbstractProvider[T_co], attr_name: str) -> None: + super().__init__() + self._provider = provider + self._attrs = [attr_name] + + def __getattr__(self, attr: str) -> "AttrGetter[T_co]": + if attr.startswith("_"): + msg = f"'{type(self)}' object has no attribute '{attr}'" + raise AttributeError(msg) + self._attrs.append(attr) + return self + + async def async_resolve(self) -> typing.Any: # noqa: ANN401 + resolved_provider_object = await self._provider.async_resolve() + attribute_path = ".".join(self._attrs) + return _get_value_from_object_by_dotted_path(resolved_provider_object, attribute_path) + + def sync_resolve(self) -> typing.Any: # noqa: ANN401 + resolved_provider_object = self._provider.sync_resolve() + attribute_path = ".".join(self._attrs) + return _get_value_from_object_by_dotted_path(resolved_provider_object, attribute_path) diff --git a/that_depends/providers/collections.py b/that_depends/providers/collections.py index 0e3c7f3..c976a56 100644 --- a/that_depends/providers/collections.py +++ b/that_depends/providers/collections.py @@ -13,6 +13,10 @@ def __init__(self, *providers: AbstractProvider[T_co]) -> None: super().__init__() self._providers: typing.Final = providers + def __getattr__(self, attr_name: str) -> typing.Any: # noqa: ANN401 + msg = f"'{type(self)}' object has no attribute '{attr_name}'" + raise AttributeError(msg) + async def async_resolve(self) -> list[T_co]: return [await x.async_resolve() for x in self._providers] @@ -30,6 +34,10 @@ def __init__(self, **providers: AbstractProvider[T_co]) -> None: super().__init__() self._providers: typing.Final = providers + def __getattr__(self, attr_name: str) -> typing.Any: # noqa: ANN401 + msg = f"'{type(self)}' object has no attribute '{attr_name}'" + raise AttributeError(msg) + async def async_resolve(self) -> dict[str, T_co]: return {key: await provider.async_resolve() for key, provider in self._providers.items()} diff --git a/that_depends/providers/factories.py b/that_depends/providers/factories.py index 28774c0..a4370c4 100644 --- a/that_depends/providers/factories.py +++ b/that_depends/providers/factories.py @@ -15,7 +15,6 @@ def __init__(self, factory: type[T_co] | typing.Callable[P, T_co], *args: P.args self._factory: typing.Final = factory self._args: typing.Final = args self._kwargs: typing.Final = kwargs - self._override = None async def async_resolve(self) -> T_co: if self._override: @@ -40,10 +39,10 @@ class AsyncFactory(AbstractFactory[T_co]): __slots__ = "_factory", "_args", "_kwargs", "_override" def __init__(self, factory: typing.Callable[P, typing.Awaitable[T_co]], *args: P.args, **kwargs: P.kwargs) -> None: + super().__init__() self._factory: typing.Final = factory self._args: typing.Final = args self._kwargs: typing.Final = kwargs - self._override = None async def async_resolve(self) -> T_co: if self._override: diff --git a/that_depends/providers/selector.py b/that_depends/providers/selector.py index e2a6aeb..34a0997 100644 --- a/that_depends/providers/selector.py +++ b/that_depends/providers/selector.py @@ -13,7 +13,6 @@ def __init__(self, selector: typing.Callable[[], str], **providers: AbstractProv super().__init__() self._selector: typing.Final = selector self._providers: typing.Final = providers - self._override = None async def async_resolve(self) -> T_co: if self._override: diff --git a/that_depends/providers/singleton.py b/that_depends/providers/singleton.py index 11c86c7..d1d2a34 100644 --- a/that_depends/providers/singleton.py +++ b/that_depends/providers/singleton.py @@ -1,7 +1,6 @@ import asyncio import typing -from that_depends.providers import AttrGetter from that_depends.providers.base import AbstractProvider @@ -17,16 +16,9 @@ def __init__(self, factory: type[T_co] | typing.Callable[P, T_co], *args: P.args self._factory: typing.Final = factory self._args: typing.Final = args self._kwargs: typing.Final = kwargs - self._override = None self._instance: T_co | None = None self._resolving_lock: typing.Final = asyncio.Lock() - def __getattr__(self, attr_name: str) -> typing.Any: # noqa: ANN401 - if attr_name.startswith("_"): - msg = f"'{type(self)}' object has no attribute '{attr_name}'" - raise AttributeError(msg) - return AttrGetter(provider=self, attr_name=attr_name) - async def async_resolve(self) -> T_co: if self._override is not None: return typing.cast(T_co, self._override)