Skip to content

Commit

Permalink
fix semantic similarity descriptor (#1260)
Browse files Browse the repository at this point in the history
  • Loading branch information
mike0sv authored Aug 23, 2024
1 parent db68276 commit 9c909aa
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 10 deletions.
9 changes: 9 additions & 0 deletions src/evidently/descriptors/semantic_similarity.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,19 @@
from typing import List

from evidently.features.generated_features import FeatureDescriptor
from evidently.features.generated_features import GeneratedFeature
from evidently.features.generated_features import GeneratedFeatures
from evidently.features.generated_features import MultiColumnFeatureDescriptor
from evidently.features.semantic_similarity_feature import SemanticSimilarityFeature


class SemanticSimilarity(MultiColumnFeatureDescriptor):
def feature(self, columns: List[str]) -> GeneratedFeature:
return SemanticSimilarityFeature(columns=columns, display_name=self.display_name)


class SemanticSimilatiryDescriptor(FeatureDescriptor):
with_column: str

def feature(self, column_name: str) -> GeneratedFeatures:
return SemanticSimilarityFeature(columns=[column_name, self.with_column], display_name=self.display_name)
21 changes: 11 additions & 10 deletions src/evidently/features/generated_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,9 +183,14 @@ def _as_column(self) -> "ColumnName":
return self._create_column(self.name)


class GeneralDescriptor(EvidentlyBaseModel):
class BaseDescriptor(EvidentlyBaseModel):
class Config:
is_base_type = True

display_name: Optional[str] = None


class GeneralDescriptor(BaseDescriptor):
@abc.abstractmethod
def feature(self) -> GeneratedFeatures:
raise NotImplementedError()
Expand All @@ -194,9 +199,7 @@ def as_column(self) -> "ColumnName":
return self.feature().as_column()


class MultiColumnFeatureDescriptor(EvidentlyBaseModel):
display_name: Optional[str] = None

class MultiColumnFeatureDescriptor(BaseDescriptor):
def feature(self, columns: List[str]) -> GeneratedFeature:
raise NotImplementedError()

Expand All @@ -207,11 +210,13 @@ def on(self, columns: List[str]) -> "ColumnName":
return self.feature(columns).as_column()


class FeatureDescriptor(EvidentlyBaseModel):
class FeatureDescriptor(BaseDescriptor):
class Config:
is_base_type = True

display_name: Optional[str] = None
@abc.abstractmethod
def feature(self, column_name: str) -> GeneratedFeatures:
raise NotImplementedError

def for_column(self, column_name: str) -> "ColumnName":
feature = self.feature(column_name)
Expand All @@ -221,7 +226,3 @@ def for_column(self, column_name: str) -> "ColumnName":

def on(self, column_name: str) -> "ColumnName":
return self.for_column(column_name)

@abc.abstractmethod
def feature(self, column_name: str) -> GeneratedFeatures:
raise NotImplementedError

0 comments on commit 9c909aa

Please sign in to comment.