Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: Add unwrapping of metadata to aggregation_overviews_to_pandas
Browse files Browse the repository at this point in the history
…-function.

TASK: IL-547
MerlinKallenbornAA committed Jun 13, 2024
1 parent e9cda34 commit 7d06bdc
Showing 3 changed files with 59 additions and 9 deletions.
18 changes: 9 additions & 9 deletions src/documentation/parameter_optimization.ipynb
Original file line number Diff line number Diff line change
@@ -12,7 +12,6 @@
"import string\n",
"from typing import Iterable\n",
"\n",
"import pandas as pd\n",
"from pydantic import BaseModel\n",
"\n",
"from intelligence_layer.core import Input, Task, TaskSpan\n",
@@ -194,7 +193,7 @@
"\n",
" # Model and prompt are stored in the metadata to specify the configuration of the current experiment\n",
" metadata = dict({\"model\": model, \"prompt\": prompt})\n",
" description = f\"|{model}|{prompt}|\"\n",
" description = \"Evaluate dummy task\"\n",
" runner = Runner(dummy_task, dataset_repository, run_repository, EXPERIMENT_NAME)\n",
" run_overview = runner.run_dataset(\n",
" dataset.id, metadata=metadata, description=description\n",
@@ -229,7 +228,7 @@
"metadata": {},
"outputs": [],
"source": [
"# Retrieve all aggregations and filter them by desired criteria, i.e., the `EXPERIMENT_NAME`.\n",
"# Retrieve all aggregations and filter them by desired criteria, i.e., the `EXPERIMENT_NAME`. Filtering can also be done on labels and/or metadata.\n",
"aggregations_of_interest = [\n",
" overview\n",
" for overview in aggregation_repository.aggregation_overviews(\n",
@@ -239,7 +238,10 @@
"]\n",
"\n",
"# Convert the desired aggregation into a pandas dataframe\n",
"formated_aggregations = aggregation_overviews_to_pandas(aggregations_of_interest)"
"formated_aggregations = aggregation_overviews_to_pandas(aggregations_of_interest)\n",
"\n",
"# Print all columns to check for columns of interest\n",
"formated_aggregations.columns"
]
},
{
@@ -258,11 +260,9 @@
"outputs": [],
"source": [
"aggregation_fields = list(DummyAggregatedEvaluation.model_fields.keys())\n",
"formated_aggregations = formated_aggregations[[\"metadata\"] + aggregation_fields]\n",
"# Flatten the metadata dict into columns\n",
"flattened_metadata = pd.json_normalize(formated_aggregations[\"metadata\"])\n",
"formated_aggregations = pd.concat([formated_aggregations, flattened_metadata], axis=1)\n",
"formated_aggregations.drop(columns=[\"metadata\"], inplace=True)\n",
"# Filter for columns of interest\n",
"formated_aggregations = formated_aggregations[[\"model\", \"prompt\", *aggregation_fields]]\n",
"\n",
"display(\n",
" formated_aggregations.sort_values(\n",
" by=\"avg_normalized_capital_count\", ascending=False\n",
Original file line number Diff line number Diff line change
@@ -141,6 +141,7 @@ def aggregation_overviews_to_pandas(
aggregation_overviews: Sequence[AggregationOverview[AggregatedEvaluation]],
unwrap_statistics: bool = True,
strict: bool = True,
unwrap_metadata: bool = True,
) -> pd.DataFrame:
"""Converts aggregation overviews to a pandas table for easier comparison.
@@ -149,6 +150,8 @@ def aggregation_overviews_to_pandas(
unwrap_statistics: Unwrap the `statistics` field in the overviews into separate columns.
Defaults to True.
strict: Allow only overviews with exactly equal `statistics` types. Defaults to True.
unwrap_metadata: Unwrap the `metadata` field in the overviews into separate columns.
Defaults to True.
Returns:
A pandas :class:`DataFrame` containing an overview per row with fields as columns.
@@ -170,6 +173,11 @@ def aggregation_overviews_to_pandas(
df = df.join(pd.DataFrame(df["statistics"].to_list())).drop(
columns=["statistics"]
)
if unwrap_metadata and "metadata" in df.columns:
df = pd.concat([df, pd.json_normalize(df["metadata"])], axis=1).drop( # type: ignore
columns=["metadata"]
)

return df


42 changes: 42 additions & 0 deletions tests/evaluation/infrastructure/test_repository_navigator.py
Original file line number Diff line number Diff line change
@@ -465,6 +465,48 @@ class AggregationDummy2(BaseModel):
assert "statistics" not in df.columns


def test_aggregation_overviews_to_pandas_unwrap_metadata() -> None:
# given

overview = AggregationOverview(
evaluation_overviews=frozenset([]),
id="aggregation-id",
start=utc_now(),
end=utc_now(),
successful_evaluation_count=5,
crashed_during_evaluation_count=3,
description="dummy-evaluator",
statistics=AggregationDummy(),
labels=set(),
metadata=dict({"model": "model_a", "prompt": "prompt_a"}),
)
overview2 = AggregationOverview(
evaluation_overviews=frozenset([]),
id="aggregation-id2",
start=utc_now(),
end=utc_now(),
successful_evaluation_count=5,
crashed_during_evaluation_count=3,
description="dummy-evaluator",
statistics=AggregationDummy(),
labels=set(),
metadata=dict(
{"model": "model_a", "prompt": "prompt_a", "different_column": "value"}
),
)

df = aggregation_overviews_to_pandas(
[overview, overview2], unwrap_metadata=True, strict=False
)

assert "model" in df.columns
assert "prompt" in df.columns
assert "different_column" in df.columns
assert "metadata" not in df.columns
assert all(df["model"] == "model_a")
assert all(df["prompt"] == "prompt_a")


def test_aggregation_overviews_to_pandas_works_with_eval_overviews() -> None:
# given
eval_overview = EvaluationOverview(

0 comments on commit 7d06bdc

Please sign in to comment.