Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make attrgetter extract value from any nesting level of objects #57

Merged
merged 2 commits into from
Jul 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 66 additions & 0 deletions tests/test_attr_getter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import random
from dataclasses import dataclass, field

import pytest

from that_depends import providers
from that_depends.providers.attr_getter import _get_value_from_object_by_dotted_path


@dataclass
class Nested2:
some_const = 144


@dataclass
class Nested1:
nested2_attr: Nested2 = field(default_factory=Nested2)


@dataclass
class Settings:
some_str_value: str = "some_string_value"
some_int_value: int = 3453621
nested1_attr: Nested1 = field(default_factory=Nested1)


@dataclass
class NestingTestDTO: ...


@pytest.fixture()
def some_settings_provider() -> providers.Singleton[Settings]:
return providers.Singleton(Settings)


def test_attr_getter_with_zero_attribute_depth(some_settings_provider: providers.Singleton[Settings]) -> None:
attr_getter = some_settings_provider.some_str_value
assert attr_getter.sync_resolve() == Settings().some_str_value


def test_attr_getter_with_more_than_zero_attribute_depth(some_settings_provider: providers.Singleton[Settings]) -> None:
attr_getter = some_settings_provider.nested1_attr.nested2_attr.some_const
assert attr_getter.sync_resolve() == Nested2().some_const


@pytest.mark.parametrize(
("field_count", "test_field_name", "test_value"),
[(1, "test_field", "sdf6fF^SF(FF*4ffsf"), (5, "nested_field", -252625), (50, "50_lvl_field", 909234235)],
)
def test_nesting_levels(field_count: int, test_field_name: str, test_value: str | int) -> None:
obj = NestingTestDTO()
fields = [f"field_{i}" for i in range(1, field_count + 1)]
random.shuffle(fields)

attr_path = ".".join(fields) + f".{test_field_name}"
obj_copy = obj

while fields:
field_name = fields.pop(0)
setattr(obj_copy, field_name, NestingTestDTO())
obj_copy = obj_copy.__getattribute__(field_name)

setattr(obj_copy, test_field_name, test_value)

attr_value = _get_value_from_object_by_dotted_path(obj, attr_path)
assert attr_value == test_value
22 changes: 18 additions & 4 deletions that_depends/providers/attr_getter.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import typing
from operator import attrgetter

from that_depends.providers.base import AbstractProvider

Expand All @@ -7,15 +8,28 @@
P = typing.ParamSpec("P")


def _get_value_from_object_by_dotted_path(obj: typing.Any, path: str) -> typing.Any: # noqa: ANN401
attribute_getter = attrgetter(path)
return attribute_getter(obj)


class AttrGetter(AbstractProvider[T]):
__slots__ = "_provider", "_attr_name"
__slots__ = "_provider", "_attrs"

def __init__(self, provider: AbstractProvider[T], attr_name: str) -> None:
self._provider = provider
self._attr_name = attr_name
self._attrs = [attr_name]

def __getattr__(self, attr: str) -> "AttrGetter[T]":
self._attrs.append(attr)
return self

async def async_resolve(self) -> typing.Any: # noqa: ANN401
return getattr(await self._provider.async_resolve(), self._attr_name)
resolved_provider_object = await self._provider.async_resolve()
attribute_path = ".".join(self._attrs)
return _get_value_from_object_by_dotted_path(resolved_provider_object, attribute_path)

def sync_resolve(self) -> typing.Any: # noqa: ANN401
return getattr(self._provider.sync_resolve(), self._attr_name)
resolved_provider_object = self._provider.sync_resolve()
attribute_path = ".".join(self._attrs)
return _get_value_from_object_by_dotted_path(resolved_provider_object, attribute_path)
Loading