Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

List fields tags for filtering #1004

Merged
merged 3 commits into from
Feb 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,9 @@ ignore_missing_imports = True
[mypy-dynaconf.*]
ignore_missing_imports = True

[mypy-litestar.*]
ignore_missing_imports = True

[tool:pytest]
testpaths=tests
python_classes=*Test
Expand Down
76 changes: 73 additions & 3 deletions src/evidently/pydantic_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from typing import List
from typing import Optional
from typing import Set
from typing import Tuple
from typing import Type
from typing import TypeVar
from typing import Union
Expand All @@ -24,6 +25,7 @@
if TYPE_CHECKING:
from evidently._pydantic_compat import DictStrAny
from evidently._pydantic_compat import Model
from evidently.core import IncludeTags
T = TypeVar("T")


Expand Down Expand Up @@ -229,25 +231,80 @@ def child(self, item: str) -> "FieldPath":
return FieldPath(self._path + [item], field_value, is_mapping=True)
return FieldPath(self._path + [item], field_value, is_mapping=is_mapping)

def list_nested_fields(self) -> List[str]:
@staticmethod
def _get_field_tags_rec(mro, name):
from evidently.base_metric import BaseResult

cls = mro[0]
if not issubclass(cls, BaseResult):
return None
if name in cls.__config__.field_tags:
return cls.__config__.field_tags[name]
return FieldPath._get_field_tags_rec(mro[1:], name)

@staticmethod
def _get_field_tags(cls, name, type_) -> Optional[Set["IncludeTags"]]:
from evidently.base_metric import BaseResult

if not issubclass(cls, BaseResult):
return None
field_tags = FieldPath._get_field_tags_rec(cls.__mro__, name)
if field_tags is not None:
return field_tags
if isinstance(type_, type) and issubclass(type_, BaseResult):
return type_.__config__.tags
return set()

def list_nested_fields(self, exclude: Set["IncludeTags"] = None) -> List[str]:
if not isinstance(self._cls, type) or not issubclass(self._cls, BaseModel):
return [repr(self)]
res = []
for name, field in self._cls.__fields__.items():
field_value = field.type_
field_tags = self._get_field_tags(self._cls, name, field_value)
if field_tags is not None and (exclude is not None and any(t in exclude for t in field_tags)):
continue
is_mapping = field.shape == SHAPE_DICT
if self.has_instance:
field_value = getattr(self._instance, name)
if is_mapping and isinstance(field_value, dict):
for key, value in field_value.items():
res.extend(FieldPath(self._path + [name, str(key)], value).list_nested_fields())
res.extend(FieldPath(self._path + [name, str(key)], value).list_nested_fields(exclude=exclude))
continue
else:
if is_mapping:
name = f"{name}.*"
res.extend(FieldPath(self._path + [name], field_value).list_nested_fields())
res.extend(FieldPath(self._path + [name], field_value).list_nested_fields(exclude=exclude))
return res

def _list_with_tags(self, current_tags: Set["IncludeTags"]) -> List[Tuple[str, Set["IncludeTags"]]]:
if not isinstance(self._cls, type) or not issubclass(self._cls, BaseModel):
return [(repr(self), current_tags)]
res = []
for name, field in self._cls.__fields__.items():
field_value = field.type_
field_tags = self._get_field_tags(self._cls, name, field_value) or set()

is_mapping = field.shape == SHAPE_DICT
if self.has_instance:
field_value = getattr(self._instance, name)
if is_mapping and isinstance(field_value, dict):
for key, value in field_value.items():
res.extend(
FieldPath(self._path + [name, str(key)], value)._list_with_tags(
current_tags.union(field_tags)
)
)
continue
else:
if is_mapping:
name = f"{name}.*"
res.extend(FieldPath(self._path + [name], field_value)._list_with_tags(current_tags.union(field_tags)))
return res

def list_nested_fields_with_tags(self) -> List[Tuple[str, Set["IncludeTags"]]]:
return self._list_with_tags(set())

def __repr__(self):
return self.get_path()

Expand All @@ -260,6 +317,19 @@ def __dir__(self) -> Iterable[str]:
res.extend(self.list_fields())
return res

def get_field_tags(self, path: List[str]) -> Optional[Set["IncludeTags"]]:
from evidently.base_metric import BaseResult

if not isinstance(self._cls, type) or not issubclass(self._cls, BaseResult):
return None
self_tags = self._cls.__config__.tags
if len(path) == 0:
return self_tags
field_name, *path = path

field_tags = self._get_field_tags(self._cls, field_name, self._cls.__fields__[field_name].type_) or set()
return self_tags.union(field_tags).union(self.child(field_name).get_field_tags(path) or tuple())


@pydantic_type_validator(FieldPath)
def series_validator(value):
Expand Down
99 changes: 99 additions & 0 deletions tests/utils/test_pydantic_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from evidently.base_metric import Metric
from evidently.base_metric import MetricResult
from evidently.core import IncludeTags
from evidently.pydantic_utils import PolymorphicModel


Expand Down Expand Up @@ -130,3 +131,101 @@ class Config:

obj = parse_obj_as(SomeModel, {"type": "othersubclass"})
assert obj.__class__ == SomeOtherSubclass


def test_include_exclude():
class SomeModel(MetricResult):
class Config:
field_tags = {"f1": {IncludeTags.Render}}

f1: str
f2: str

assert SomeModel.fields.list_nested_fields(exclude={IncludeTags.Render, IncludeTags.TypeField}) == ["f2"]

# assert SomeModel.fields.list_nested_fields(include={IncludeTags.Render}) == ["f1"]

class SomeNestedModel(MetricResult):
class Config:
tags = {IncludeTags.Render}

f1: str

class SomeOtherModel(MetricResult):
f1: str
f2: SomeNestedModel
f3: SomeModel

assert SomeOtherModel.fields.list_nested_fields(exclude={IncludeTags.Render, IncludeTags.TypeField}) == [
"f1",
"f3.f2",
]
# assert SomeOtherModel.fields.list_nested_fields(include={IncludeTags.Render}) == ["f2.f1", "f3.f1"]


def test_get_field_tags():
class SomeModel(MetricResult):
class Config:
field_tags = {"f1": {IncludeTags.Render}}

f1: str
f2: str

assert SomeModel.fields.get_field_tags(["type"]) == {IncludeTags.TypeField}
assert SomeModel.fields.get_field_tags(["f1"]) == {IncludeTags.Render}
assert SomeModel.fields.get_field_tags(["f2"]) == set()

class SomeNestedModel(MetricResult):
class Config:
tags = {IncludeTags.Render}

f1: str

class SomeOtherModel(MetricResult):
f1: str
f2: SomeNestedModel
f3: SomeModel

assert SomeOtherModel.fields.get_field_tags(["type"]) == {IncludeTags.TypeField}
assert SomeOtherModel.fields.get_field_tags(["f1"]) == set()
assert SomeOtherModel.fields.get_field_tags(["f2"]) == {IncludeTags.Render}
assert SomeOtherModel.fields.get_field_tags(["f2", "f1"]) == {IncludeTags.Render}
assert SomeOtherModel.fields.get_field_tags(["f3"]) == set()
assert SomeOtherModel.fields.get_field_tags(["f3", "f1"]) == {IncludeTags.Render}
assert SomeOtherModel.fields.get_field_tags(["f3", "f2"]) == set()


def test_list_with_tags():
class SomeModel(MetricResult):
class Config:
field_tags = {"f1": {IncludeTags.Render}}

f1: str
f2: str

assert SomeModel.fields.list_nested_fields_with_tags() == [
("type", {IncludeTags.TypeField}),
("f1", {IncludeTags.Render}),
("f2", set()),
]

class SomeNestedModel(MetricResult):
class Config:
tags = {IncludeTags.Render}

f1: str

class SomeOtherModel(MetricResult):
f1: str
f2: SomeNestedModel
f3: SomeModel

assert SomeOtherModel.fields.list_nested_fields_with_tags() == [
("type", {IncludeTags.TypeField}),
("f1", set()),
("f2.type", {IncludeTags.Render, IncludeTags.TypeField}),
("f2.f1", {IncludeTags.Render}),
("f3.type", {IncludeTags.TypeField}),
("f3.f1", {IncludeTags.Render}),
("f3.f2", set()),
]
Loading