Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RAGAS] fix: Metric parameter validation and metric descriptors #555

Merged
merged 4 commits into from
Mar 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from .metrics import (
METRIC_DESCRIPTORS,
InputConverters,
MetricParamsValidator,
OutputConverters,
RagasMetric,
)
Expand Down Expand Up @@ -66,7 +65,7 @@ def __init__(
on required parameters.
"""
self.metric = metric if isinstance(metric, RagasMetric) else RagasMetric.from_str(metric)
self.metric_params = metric_params or {}
self.metric_params = metric_params
self.descriptor = METRIC_DESCRIPTORS[self.metric]

self._init_backend()
Expand All @@ -79,10 +78,24 @@ def _init_backend(self):
self._backend_callable = RagasEvaluator._invoke_evaluate

def _init_metric(self):
MetricParamsValidator.validate_metric_parameters(
self.metric, self.descriptor.init_parameters, self.metric_params
)
self._backend_metric = self.descriptor.backend(**self.metric_params)
if self.descriptor.init_parameters is not None:
if self.metric_params is None:
msg = f"Ragas metric '{self.metric}' expected init parameters but got none"
raise ValueError(msg)
elif not all(k in self.descriptor.init_parameters for k in self.metric_params.keys()):
msg = (
f"Invalid init parameters for Ragas metric '{self.metric}'. "
f"Expected: {self.descriptor.init_parameters}"
)
raise ValueError(msg)
elif self.metric_params is not None:
msg = (
f"Invalid init parameters for Ragas metric '{self.metric}'. "
f"None expected but {self.metric_params} given"
)
raise ValueError(msg)
metric_params = self.metric_params or {}
self._backend_metric = self.descriptor.backend(**metric_params)

@staticmethod
def _invoke_evaluate(dataset: Dataset, metric: Metric) -> Result:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ class MetricDescriptor:
backend: Type[Metric]
input_parameters: Dict[str, Type]
input_converter: Callable[[Any], Iterable[Dict[str, str]]]
output_converter: Callable[[Result, RagasMetric, Dict[str, Any]], List[MetricResult]]
output_converter: Callable[[Result, RagasMetric, Optional[Dict[str, Any]]], List[MetricResult]]
init_parameters: Optional[List[str]] = None

@classmethod
Expand All @@ -143,7 +143,9 @@ def new(
metric: RagasMetric,
backend: Type[Metric],
input_converter: Callable[[Any], Iterable[Dict[str, str]]],
output_converter: Optional[Callable[[Result, RagasMetric, Dict[str, Any]], List[MetricResult]]] = None,
output_converter: Optional[
Callable[[Result, RagasMetric, Optional[Dict[str, Any]]], List[MetricResult]]
] = None,
*,
init_parameters: Optional[List[str]] = None,
) -> "MetricDescriptor":
Expand All @@ -166,24 +168,6 @@ def new(
)


class MetricParamsValidator:
"""
Validates metric parameters.

Depending on the metric type, different metric parameters are allowed.
The validator functions are responsible for validating the parameters and raising an error if they are invalid.
"""

@staticmethod
def validate_metric_parameters(metric: RagasMetric, allowed: List[str], received: Dict[str, Any]) -> None:
if not set(received).issubset(allowed):
msg = (
f"Invalid init parameters for Ragas metric '{metric}'. "
f"Allowed metric parameters {allowed} but got '{received}'"
)
raise ValueError(msg)


class InputConverters:
"""
Converters for input parameters.
Expand Down Expand Up @@ -292,12 +276,15 @@ def _extract_default_results(output: Result, metric_name: str) -> List[MetricRes
raise ValueError(msg) from e

@staticmethod
def default(output: Result, metric: RagasMetric, _: Dict) -> List[MetricResult]:
def default(output: Result, metric: RagasMetric, _: Optional[Dict]) -> List[MetricResult]:
metric_name = metric.value
return OutputConverters._extract_default_results(output, metric_name)

@staticmethod
def aspect_critique(output: Result, _: RagasMetric, metric_params: Dict[str, Any]) -> List[MetricResult]:
def aspect_critique(output: Result, _: RagasMetric, metric_params: Optional[Dict[str, Any]]) -> List[MetricResult]:
if metric_params is None:
msg = "Aspect critique metric requires metric parameters"
raise ValueError(msg)
metric_name = metric_params["name"]
return OutputConverters._extract_default_results(output, metric_name)

Expand All @@ -307,55 +294,50 @@ def aspect_critique(output: Result, _: RagasMetric, metric_params: Dict[str, Any
RagasMetric.ANSWER_CORRECTNESS,
AnswerCorrectness,
InputConverters.question_response_ground_truth, # type: ignore
init_parameters=["name", "weights", "answer_similarity"],
init_parameters=["weights"],
),
RagasMetric.FAITHFULNESS: MetricDescriptor.new(
RagasMetric.FAITHFULNESS,
Faithfulness,
InputConverters.question_context_response, # type: ignore
init_parameters=["name"],
),
RagasMetric.ANSWER_SIMILARITY: MetricDescriptor.new(
RagasMetric.ANSWER_SIMILARITY,
AnswerSimilarity,
InputConverters.response_ground_truth, # type: ignore
init_parameters=["name", "model_name", "threshold"],
init_parameters=["threshold"],
),
RagasMetric.CONTEXT_PRECISION: MetricDescriptor.new(
RagasMetric.CONTEXT_PRECISION,
ContextPrecision,
InputConverters.question_context_ground_truth, # type: ignore
init_parameters=["name"],
),
RagasMetric.CONTEXT_UTILIZATION: MetricDescriptor.new(
RagasMetric.CONTEXT_UTILIZATION,
ContextUtilization,
InputConverters.question_context_response, # type: ignore
init_parameters=["name"],
),
RagasMetric.CONTEXT_RECALL: MetricDescriptor.new(
RagasMetric.CONTEXT_RECALL,
ContextRecall,
InputConverters.question_context_ground_truth, # type: ignore
init_parameters=["name"],
),
RagasMetric.ASPECT_CRITIQUE: MetricDescriptor.new(
RagasMetric.ASPECT_CRITIQUE,
AspectCritique,
InputConverters.question_context_response, # type: ignore
OutputConverters.aspect_critique,
init_parameters=["name", "definition", "strictness", "llm"],
init_parameters=["name", "definition", "strictness"],
),
RagasMetric.CONTEXT_RELEVANCY: MetricDescriptor.new(
RagasMetric.CONTEXT_RELEVANCY,
ContextRelevancy,
InputConverters.question_context, # type: ignore
init_parameters=["name"],
),
RagasMetric.ANSWER_RELEVANCY: MetricDescriptor.new(
RagasMetric.ANSWER_RELEVANCY,
AnswerRelevancy,
InputConverters.question_context_response, # type: ignore
init_parameters=["name", "strictness", "embeddings"],
init_parameters=["strictness"],
),
}
Loading