Skip to content

Commit

Permalink
Improve Text Generation Instructions and Tests (#690)
Browse files Browse the repository at this point in the history
Co-authored-by: b.nativi <[email protected]>
  • Loading branch information
bnativi and b.nativi authored Aug 6, 2024
1 parent 93b4bb0 commit f28acba
Show file tree
Hide file tree
Showing 24 changed files with 1,931 additions and 1,127 deletions.
201 changes: 123 additions & 78 deletions api/tests/functional-tests/backend/core/test_llm_clients.py

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def rag_data(
annotations=[
schemas.Annotation(
text=RAG_PREDICTIONS[i],
context=RAG_CONTEXT[i],
context_list=RAG_CONTEXT[i],
)
],
)
Expand Down Expand Up @@ -376,7 +376,7 @@ def two_text_generation_datasets(
annotations=[
schemas.Annotation(
text=RAG_PREDICTIONS[i],
context=RAG_CONTEXT[i],
context_list=RAG_CONTEXT[i],
)
],
)
Expand Down Expand Up @@ -540,40 +540,40 @@ def mocked_coherence(
def mocked_context_relevance(
self,
query: str,
context: list[str],
context_list: list[str],
):
ret_dict = {
(RAG_QUERIES[0], tuple(RAG_CONTEXT[0])): 0.75,
(RAG_QUERIES[1], tuple(RAG_CONTEXT[1])): 1.0,
(RAG_QUERIES[2], tuple(RAG_CONTEXT[2])): 0.25,
}
return ret_dict[(query, tuple(context))]
return ret_dict[(query, tuple(context_list))]


def mocked_faithfulness(
self,
text: str,
context: list[str],
context_list: list[str],
):
ret_dict = {
(RAG_PREDICTIONS[0], tuple(RAG_CONTEXT[0])): 0.4,
(RAG_PREDICTIONS[1], tuple(RAG_CONTEXT[1])): 0.55,
(RAG_PREDICTIONS[2], tuple(RAG_CONTEXT[2])): 0.6666666666666666,
}
return ret_dict[(text, tuple(context))]
return ret_dict[(text, tuple(context_list))]


def mocked_hallucination(
self,
text: str,
context: list[str],
context_list: list[str],
):
ret_dict = {
(RAG_PREDICTIONS[0], tuple(RAG_CONTEXT[0])): 0.0,
(RAG_PREDICTIONS[1], tuple(RAG_CONTEXT[1])): 0.0,
(RAG_PREDICTIONS[2], tuple(RAG_CONTEXT[2])): 0.25,
}
return ret_dict[(text, tuple(context))]
return ret_dict[(text, tuple(context_list))]


def mocked_toxicity(
Expand Down
18 changes: 9 additions & 9 deletions api/tests/unit-tests/schemas/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -641,7 +641,7 @@ def test_ContextRelevanceMetric():
parameters={
"dataset_uid": "01",
"dataset_name": "test_dataset",
"context": ["context1", "context2"],
"context_list": ["context1", "context2"],
},
)

Expand All @@ -651,7 +651,7 @@ def test_ContextRelevanceMetric():
parameters={
"dataset_uid": "01",
"dataset_name": "test_dataset",
"context": ["context1", "context2"],
"context_list": ["context1", "context2"],
},
)

Expand All @@ -661,7 +661,7 @@ def test_ContextRelevanceMetric():
parameters={
"dataset_uid": "01",
"dataset_name": "test_dataset",
"context": ["context1", "context2"],
"context_list": ["context1", "context2"],
},
)

Expand All @@ -686,7 +686,7 @@ def test_FaithfulnessMetric():
"dataset_uid": "01",
"dataset_name": "test_dataset",
"prediction": "some prediction",
"context": ["context1", "context2"],
"context_list": ["context1", "context2"],
},
)

Expand All @@ -697,7 +697,7 @@ def test_FaithfulnessMetric():
"dataset_uid": "01",
"dataset_name": "test_dataset",
"prediction": "some prediction",
"context": ["context1", "context2"],
"context_list": ["context1", "context2"],
},
)

Expand All @@ -708,7 +708,7 @@ def test_FaithfulnessMetric():
"dataset_uid": "01",
"dataset_name": "test_dataset",
"prediction": "some prediction",
"context": ["context1", "context2"],
"context_list": ["context1", "context2"],
},
)

Expand All @@ -733,7 +733,7 @@ def test_HallucinationMetric():
"dataset_uid": "01",
"dataset_name": "test_dataset",
"prediction": "some prediction",
"context": ["context1", "context2"],
"context_list": ["context1", "context2"],
},
)

Expand All @@ -744,7 +744,7 @@ def test_HallucinationMetric():
"dataset_uid": "01",
"dataset_name": "test_dataset",
"prediction": "some prediction",
"context": ["context1", "context2"],
"context_list": ["context1", "context2"],
},
)

Expand All @@ -755,7 +755,7 @@ def test_HallucinationMetric():
"dataset_uid": "01",
"dataset_name": "test_dataset",
"prediction": "some prediction",
"context": ["context1", "context2"],
"context_list": ["context1", "context2"],
},
)

Expand Down
14 changes: 3 additions & 11 deletions api/valor_api/backend/core/annotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,14 +55,6 @@ def _create_embedding(
return row.id


def _format_context(
context: str | list[str] | None,
) -> list[str] | None:
if isinstance(context, str):
context = [context]
return context


def create_annotations(
db: Session,
annotations: list[list[schemas.Annotation]],
Expand Down Expand Up @@ -116,7 +108,7 @@ def create_annotations(
db=db, value=annotation.embedding
),
"text": annotation.text,
"context": _format_context(annotation.context),
"context_list": annotation.context_list,
"is_instance": annotation.is_instance,
"implied_task_types": annotation.implied_task_types,
}
Expand Down Expand Up @@ -176,7 +168,7 @@ def create_skipped_annotations(
raster=None,
embedding_id=None,
text=None,
context=None,
context_list=None,
is_instance=False,
implied_task_types=[TaskType.EMPTY],
)
Expand Down Expand Up @@ -283,7 +275,7 @@ def get_annotation(
raster=raster,
embedding=embedding,
text=annotation.text,
context=annotation.context,
context_list=annotation.context_list,
is_instance=annotation.is_instance,
implied_task_types=annotation.implied_task_types,
)
Expand Down
Loading

0 comments on commit f28acba

Please sign in to comment.