Skip to content

Commit

Permalink
More RAG and Summarization Metrics (#701)
Browse files Browse the repository at this point in the history
Co-authored-by: b.nativi <[email protected]>
  • Loading branch information
bnativi and b.nativi authored Aug 29, 2024
1 parent cbd1051 commit 8de55e2
Show file tree
Hide file tree
Showing 19 changed files with 2,963 additions and 561 deletions.
632 changes: 519 additions & 113 deletions api/tests/functional-tests/backend/core/test_llm_clients.py

Large diffs are not rendered by default.

424 changes: 368 additions & 56 deletions api/tests/functional-tests/backend/metrics/test_text_generation.py

Large diffs are not rendered by default.

20 changes: 9 additions & 11 deletions api/tests/unit-tests/schemas/test_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,12 @@ def test_EvaluationParameters(llm_api_params):
schemas.EvaluationParameters(
task_type=enums.TaskType.TEXT_GENERATION,
metrics_to_return=[
MetricType.AnswerCorrectness,
MetricType.AnswerRelevance,
MetricType.Bias,
MetricType.BLEU,
MetricType.Coherence,
MetricType.ContextPrecision,
MetricType.ContextRecall,
MetricType.ContextRelevance,
MetricType.Faithfulness,
MetricType.Hallucination,
Expand All @@ -76,10 +78,12 @@ def test_EvaluationParameters(llm_api_params):
schemas.EvaluationParameters(
task_type=enums.TaskType.TEXT_GENERATION,
metrics_to_return=[
MetricType.AnswerCorrectness,
MetricType.AnswerRelevance,
MetricType.Bias,
MetricType.BLEU,
MetricType.Coherence,
MetricType.ContextPrecision,
MetricType.ContextRecall,
MetricType.ContextRelevance,
MetricType.Faithfulness,
MetricType.Hallucination,
Expand Down Expand Up @@ -167,15 +171,13 @@ def test_EvaluationParameters(llm_api_params):
)

# If any llm-guided metrics are requested, then llm_api_params must be provided.
# Purposely did a subset of metrics_to_return, to increase test variation.
with pytest.raises(ValidationError):
schemas.EvaluationParameters(
task_type=enums.TaskType.TEXT_GENERATION,
metrics_to_return=[
MetricType.AnswerRelevance,
MetricType.Bias,
MetricType.BLEU,
MetricType.Coherence,
MetricType.ContextRelevance,
MetricType.Faithfulness,
MetricType.Hallucination,
MetricType.ROUGE,
Expand All @@ -195,19 +197,15 @@ def test_EvaluationParameters(llm_api_params):
bleu_weights=[1.1, 0.3, -0.5, 0.1],
)

# BLEU weights must sum to 1.
# BLEU weights must sum to 1. metrics_to_return here are all metrics applicable to summarization.
with pytest.raises(ValidationError):
schemas.EvaluationParameters(
task_type=enums.TaskType.TEXT_GENERATION,
metrics_to_return=[
MetricType.AnswerRelevance,
MetricType.Bias,
MetricType.BLEU,
MetricType.Coherence,
MetricType.ContextRelevance,
MetricType.Faithfulness,
MetricType.Hallucination,
MetricType.ROUGE,
MetricType.SummaryCoherence,
MetricType.Toxicity,
],
llm_api_params=llm_api_params,
Expand Down
160 changes: 146 additions & 14 deletions api/tests/unit-tests/schemas/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,6 +434,50 @@ def test_DetailedPrecisionRecallCurve():
}


def test_AnswerCorrectnessMetric():
metric = schemas.AnswerCorrectnessMetric(
value=0.52,
parameters={
"dataset_uid": "01",
"dataset_name": "test_dataset",
"prediction": "some prediction",
},
)

with pytest.raises(ValidationError):
schemas.AnswerCorrectnessMetric(
value=None, # type: ignore
parameters={
"dataset_uid": "01",
"dataset_name": "test_dataset",
"prediction": "some prediction",
},
)

with pytest.raises(ValidationError):
schemas.AnswerCorrectnessMetric(
value={"key": 0.3}, # type: ignore
parameters={
"dataset_uid": "01",
"dataset_name": "test_dataset",
"prediction": "some prediction",
},
)

with pytest.raises(ValidationError):
schemas.AnswerCorrectnessMetric(
value=0.0, # type: ignore
parameters="not a valid parameter", # type: ignore
)

assert all(
[
key in ["value", "type", "evaluation_id", "parameters"]
for key in metric.db_mapping(evaluation_id=1)
]
)


def test_AnswerRelevanceMetric():
metric = schemas.AnswerRelevanceMetric(
value=0.421,
Expand Down Expand Up @@ -581,49 +625,83 @@ def test_BLEUMetric():
)


def test_CoherenceMetric():
metric = schemas.CoherenceMetric(
value=3,
def test_ContextPrecisionMetric():
metric = schemas.ContextPrecisionMetric(
value=0.873,
parameters={
"dataset_uid": "01",
"dataset_name": "test_dataset",
"prediction": "some prediction",
"context_list": ["context1", "context2"],
},
)

with pytest.raises(ValidationError):
schemas.CoherenceMetric(
schemas.ContextPrecisionMetric(
value=None, # type: ignore
parameters={
"dataset_uid": "01",
"dataset_name": "test_dataset",
"prediction": "some prediction",
"context_list": ["context1", "context2"],
},
)

with pytest.raises(ValidationError):
schemas.CoherenceMetric(
value=2.5, # type: ignore
schemas.ContextPrecisionMetric(
value={"key": 0.222}, # type: ignore
parameters={
"dataset_uid": "01",
"dataset_name": "test_dataset",
"prediction": "some prediction",
"context_list": ["context1", "context2"],
},
)

with pytest.raises(ValidationError):
schemas.CoherenceMetric(
value={"key": 4}, # type: ignore
schemas.ContextPrecisionMetric(
value=0.501, # type: ignore
parameters="not a valid parameter", # type: ignore
)

assert all(
[
key in ["value", "type", "evaluation_id", "parameters"]
for key in metric.db_mapping(evaluation_id=1)
]
)


def test_ContextRecallMetric():
metric = schemas.ContextRecallMetric(
value=0.8,
parameters={
"dataset_uid": "01",
"dataset_name": "test_dataset",
"context_list": ["context1", "context2"],
},
)

with pytest.raises(ValidationError):
schemas.ContextRecallMetric(
value="value", # type: ignore
parameters={
"dataset_uid": "01",
"dataset_name": "test_dataset",
"prediction": "some prediction",
"context_list": ["context1", "context2"],
},
)

with pytest.raises(ValidationError):
schemas.CoherenceMetric(
value=5, # type: ignore
schemas.ContextRecallMetric(
value={"key": 0.5}, # type: ignore
parameters={
"dataset_uid": "01",
"dataset_name": "test_dataset",
"context_list": ["context1", "context2"],
},
)

with pytest.raises(ValidationError):
schemas.ContextRecallMetric(
value=0.6, # type: ignore
parameters="not a valid parameter", # type: ignore
)

Expand Down Expand Up @@ -838,6 +916,60 @@ def test_ROUGEMetric():
)


def test_SummaryCoherenceMetric():
metric = schemas.SummaryCoherenceMetric(
value=3,
parameters={
"dataset_uid": "01",
"dataset_name": "test_dataset",
"prediction": "some summary",
},
)

with pytest.raises(ValidationError):
schemas.SummaryCoherenceMetric(
value=None, # type: ignore
parameters={
"dataset_uid": "01",
"dataset_name": "test_dataset",
"prediction": "some summary",
},
)

with pytest.raises(ValidationError):
schemas.SummaryCoherenceMetric(
value=2.5, # type: ignore
parameters={
"dataset_uid": "01",
"dataset_name": "test_dataset",
"prediction": "some summary",
},
)

with pytest.raises(ValidationError):
schemas.SummaryCoherenceMetric(
value={"key": 4}, # type: ignore
parameters={
"dataset_uid": "01",
"dataset_name": "test_dataset",
"prediction": "some summary",
},
)

with pytest.raises(ValidationError):
schemas.SummaryCoherenceMetric(
value=5, # type: ignore
parameters="not a valid parameter", # type: ignore
)

assert all(
[
key in ["value", "type", "evaluation_id", "parameters"]
for key in metric.db_mapping(evaluation_id=1)
]
)


def test_ToxicityMetric():
metric = schemas.ToxicityMetric(
value=0.4,
Expand Down
Loading

0 comments on commit 8de55e2

Please sign in to comment.