Skip to content

Commit

Permalink
Enable attribute resolution for all providers. (#97)
Browse files Browse the repository at this point in the history
* Remove redundant resets of override in __init__ methods.

* Made attribute resolution possible for all providers.

* Disable attribute resolution for collections.

* Add tests for Selector provider.

---------

Co-authored-by: Alexander <[email protected]>
  • Loading branch information
alexanderlazarev0 and Alexander authored Oct 3, 2024
1 parent 78a6202 commit 630733a
Show file tree
Hide file tree
Showing 10 changed files with 141 additions and 71 deletions.
2 changes: 1 addition & 1 deletion docs/providers/singleton.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Singleton Provider
# Singleton

Singleton providers resolve the dependency only once and cache the resolved instance for future injections.

Expand Down
96 changes: 84 additions & 12 deletions tests/providers/test_attr_getter.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import random
import typing
from dataclasses import dataclass, field

import pytest

from that_depends import providers
from that_depends.providers.attr_getter import _get_value_from_object_by_dotted_path
from that_depends.providers.base import _get_value_from_object_by_dotted_path
from that_depends.providers.context_resources import container_context


@dataclass
Expand All @@ -24,25 +26,80 @@ class Settings:
nested1_attr: Nested1 = field(default_factory=Nested1)


async def return_settings_async() -> Settings:
return Settings()


async def yield_settings_async() -> typing.AsyncIterator[Settings]:
yield Settings()


def yield_settings_sync() -> typing.Iterator[Settings]:
yield Settings()


@dataclass
class NestingTestDTO: ...


@pytest.fixture
def some_settings_provider() -> providers.Singleton[Settings]:
return providers.Singleton(Settings)
@pytest.fixture(
params=[
providers.Resource(yield_settings_sync),
providers.Singleton(Settings),
providers.ContextResource(yield_settings_sync),
providers.Object(Settings()),
providers.Factory(Settings),
providers.Selector(lambda: "sync", sync=providers.Factory(Settings)),
]
)
def some_sync_settings_provider(request: pytest.FixtureRequest) -> providers.AbstractProvider[Settings]:
return typing.cast(providers.AbstractProvider[Settings], request.param)


@pytest.fixture(
params=[
providers.AsyncFactory(return_settings_async),
providers.Resource(yield_settings_async),
providers.ContextResource(yield_settings_async),
providers.Selector(lambda: "asynchronous", asynchronous=providers.AsyncFactory(return_settings_async)),
]
)
def some_async_settings_provider(request: pytest.FixtureRequest) -> providers.AbstractProvider[Settings]:
return typing.cast(providers.AbstractProvider[Settings], request.param)


def test_attr_getter_with_zero_attribute_depth(some_settings_provider: providers.Singleton[Settings]) -> None:
attr_getter = some_settings_provider.some_str_value
@container_context()
def test_attr_getter_with_zero_attribute_depth_sync(
some_sync_settings_provider: providers.AbstractProvider[Settings],
) -> None:
attr_getter = some_sync_settings_provider.some_str_value
assert attr_getter.sync_resolve() == Settings().some_str_value


def test_attr_getter_with_more_than_zero_attribute_depth(some_settings_provider: providers.Singleton[Settings]) -> None:
attr_getter = some_settings_provider.nested1_attr.nested2_attr.some_const
@container_context()
async def test_attr_getter_with_zero_attribute_depth_async(
some_async_settings_provider: providers.AbstractProvider[Settings],
) -> None:
attr_getter = some_async_settings_provider.some_str_value
assert await attr_getter.async_resolve() == Settings().some_str_value


@container_context()
def test_attr_getter_with_more_than_zero_attribute_depth_sync(
some_sync_settings_provider: providers.AbstractProvider[Settings],
) -> None:
attr_getter = some_sync_settings_provider.nested1_attr.nested2_attr.some_const
assert attr_getter.sync_resolve() == Nested2().some_const


@container_context()
async def test_attr_getter_with_more_than_zero_attribute_depth_async(
some_async_settings_provider: providers.AbstractProvider[Settings],
) -> None:
attr_getter = some_async_settings_provider.nested1_attr.nested2_attr.some_const
assert await attr_getter.async_resolve() == Nested2().some_const


@pytest.mark.parametrize(
("field_count", "test_field_name", "test_value"),
[(1, "test_field", "sdf6fF^SF(FF*4ffsf"), (5, "nested_field", -252625), (50, "50_lvl_field", 909234235)],
Expand All @@ -66,10 +123,25 @@ def test_nesting_levels(field_count: int, test_field_name: str, test_value: str
assert attr_value == test_value


def test_attr_getter_with_invalid_attribute(some_settings_provider: providers.Singleton[Settings]) -> None:
@container_context()
def test_attr_getter_with_invalid_attribute_sync(
some_sync_settings_provider: providers.AbstractProvider[Settings],
) -> None:
with pytest.raises(AttributeError):
some_sync_settings_provider.nested1_attr.nested2_attr.__some_private__ # noqa: B018
with pytest.raises(AttributeError):
some_sync_settings_provider.nested1_attr.__another_private__ # noqa: B018
with pytest.raises(AttributeError):
some_sync_settings_provider.nested1_attr._final_private_ # noqa: B018


@container_context()
async def test_attr_getter_with_invalid_attribute_async(
some_async_settings_provider: providers.AbstractProvider[Settings],
) -> None:
with pytest.raises(AttributeError):
some_settings_provider.nested1_attr.nested2_attr.__some_private__ # noqa: B018
some_async_settings_provider.nested1_attr.nested2_attr.__some_private__ # noqa: B018
with pytest.raises(AttributeError):
some_settings_provider.nested1_attr.__another_private__ # noqa: B018
some_async_settings_provider.nested1_attr.__another_private__ # noqa: B018
with pytest.raises(AttributeError):
some_settings_provider.nested1_attr._final_private_ # noqa: B018
some_async_settings_provider.nested1_attr._final_private_ # noqa: B018
2 changes: 1 addition & 1 deletion that_depends/injection.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ async def inner(*args: P.args, **kwargs: P.kwargs) -> T:
if field_name in kwargs:
continue

kwargs[field_name] = await field_value.default()
kwargs[field_name] = await field_value.default.async_resolve()
injected = True
if not injected:
warnings.warn(
Expand Down
3 changes: 1 addition & 2 deletions that_depends/providers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from that_depends.providers.attr_getter import AttrGetter
from that_depends.providers.base import AbstractProvider
from that_depends.providers.base import AbstractProvider, AttrGetter
from that_depends.providers.collections import Dict, List
from that_depends.providers.context_resources import (
AsyncContextResource,
Expand Down
41 changes: 0 additions & 41 deletions that_depends/providers/attr_getter.py

This file was deleted.

48 changes: 45 additions & 3 deletions that_depends/providers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import inspect
import typing
from contextlib import contextmanager
from operator import attrgetter


T_co = typing.TypeVar("T_co", covariant=True)
Expand All @@ -18,6 +19,12 @@ def __init__(self) -> None:
super().__init__()
self._override: typing.Any = None

def __getattr__(self, attr_name: str) -> typing.Any: # noqa: ANN401
if attr_name.startswith("_"):
msg = f"'{type(self)}' object has no attribute '{attr_name}'"
raise AttributeError(msg)
return AttrGetter(provider=self, attr_name=attr_name)

@abc.abstractmethod
async def async_resolve(self) -> T_co:
"""Resolve dependency asynchronously."""
Expand Down Expand Up @@ -137,7 +144,6 @@ def __init__(
self._creator: typing.Final = creator
self._args: typing.Final = args
self._kwargs: typing.Final = kwargs
self._override = None

def _is_creator_async(
self, _: typing.Callable[P, typing.Iterator[T_co] | typing.AsyncIterator[T_co]]
Expand Down Expand Up @@ -174,9 +180,12 @@ async def async_resolve(self) -> T_co:
T_co,
await context.context_stack.enter_async_context(
contextlib.asynccontextmanager(self._creator)(
*[await x() if isinstance(x, AbstractProvider) else x for x in self._args],
*[
await x.async_resolve() if isinstance(x, AbstractProvider) else x
for x in self._args
],
**{
k: await v() if isinstance(v, AbstractProvider) else v
k: await v.async_resolve() if isinstance(v, AbstractProvider) else v
for k, v in self._kwargs.items()
},
),
Expand Down Expand Up @@ -228,3 +237,36 @@ def provider(self) -> typing.Callable[[], typing.Coroutine[typing.Any, typing.An
@property
def sync_provider(self) -> typing.Callable[[], T_co]:
return self.sync_resolve


def _get_value_from_object_by_dotted_path(obj: typing.Any, path: str) -> typing.Any: # noqa: ANN401
attribute_getter = attrgetter(path)
return attribute_getter(obj)


class AttrGetter(
AbstractProvider[T_co],
):
__slots__ = "_provider", "_attrs"

def __init__(self, provider: AbstractProvider[T_co], attr_name: str) -> None:
super().__init__()
self._provider = provider
self._attrs = [attr_name]

def __getattr__(self, attr: str) -> "AttrGetter[T_co]":
if attr.startswith("_"):
msg = f"'{type(self)}' object has no attribute '{attr}'"
raise AttributeError(msg)
self._attrs.append(attr)
return self

async def async_resolve(self) -> typing.Any: # noqa: ANN401
resolved_provider_object = await self._provider.async_resolve()
attribute_path = ".".join(self._attrs)
return _get_value_from_object_by_dotted_path(resolved_provider_object, attribute_path)

def sync_resolve(self) -> typing.Any: # noqa: ANN401
resolved_provider_object = self._provider.sync_resolve()
attribute_path = ".".join(self._attrs)
return _get_value_from_object_by_dotted_path(resolved_provider_object, attribute_path)
8 changes: 8 additions & 0 deletions that_depends/providers/collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@ def __init__(self, *providers: AbstractProvider[T_co]) -> None:
super().__init__()
self._providers: typing.Final = providers

def __getattr__(self, attr_name: str) -> typing.Any: # noqa: ANN401
msg = f"'{type(self)}' object has no attribute '{attr_name}'"
raise AttributeError(msg)

async def async_resolve(self) -> list[T_co]:
return [await x.async_resolve() for x in self._providers]

Expand All @@ -30,6 +34,10 @@ def __init__(self, **providers: AbstractProvider[T_co]) -> None:
super().__init__()
self._providers: typing.Final = providers

def __getattr__(self, attr_name: str) -> typing.Any: # noqa: ANN401
msg = f"'{type(self)}' object has no attribute '{attr_name}'"
raise AttributeError(msg)

async def async_resolve(self) -> dict[str, T_co]:
return {key: await provider.async_resolve() for key, provider in self._providers.items()}

Expand Down
3 changes: 1 addition & 2 deletions that_depends/providers/factories.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ def __init__(self, factory: type[T_co] | typing.Callable[P, T_co], *args: P.args
self._factory: typing.Final = factory
self._args: typing.Final = args
self._kwargs: typing.Final = kwargs
self._override = None

async def async_resolve(self) -> T_co:
if self._override:
Expand All @@ -40,10 +39,10 @@ class AsyncFactory(AbstractFactory[T_co]):
__slots__ = "_factory", "_args", "_kwargs", "_override"

def __init__(self, factory: typing.Callable[P, typing.Awaitable[T_co]], *args: P.args, **kwargs: P.kwargs) -> None:
super().__init__()
self._factory: typing.Final = factory
self._args: typing.Final = args
self._kwargs: typing.Final = kwargs
self._override = None

async def async_resolve(self) -> T_co:
if self._override:
Expand Down
1 change: 0 additions & 1 deletion that_depends/providers/selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ def __init__(self, selector: typing.Callable[[], str], **providers: AbstractProv
super().__init__()
self._selector: typing.Final = selector
self._providers: typing.Final = providers
self._override = None

async def async_resolve(self) -> T_co:
if self._override:
Expand Down
8 changes: 0 additions & 8 deletions that_depends/providers/singleton.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import asyncio
import typing

from that_depends.providers import AttrGetter
from that_depends.providers.base import AbstractProvider


Expand All @@ -17,16 +16,9 @@ def __init__(self, factory: type[T_co] | typing.Callable[P, T_co], *args: P.args
self._factory: typing.Final = factory
self._args: typing.Final = args
self._kwargs: typing.Final = kwargs
self._override = None
self._instance: T_co | None = None
self._resolving_lock: typing.Final = asyncio.Lock()

def __getattr__(self, attr_name: str) -> typing.Any: # noqa: ANN401
if attr_name.startswith("_"):
msg = f"'{type(self)}' object has no attribute '{attr_name}'"
raise AttributeError(msg)
return AttrGetter(provider=self, attr_name=attr_name)

async def async_resolve(self) -> T_co:
if self._override is not None:
return typing.cast(T_co, self._override)
Expand Down

0 comments on commit 630733a

Please sign in to comment.