Skip to content

Commit

Permalink
feat(traces): evaluation annotations on traces for associating spans …
Browse files Browse the repository at this point in the history
…with eval metrics (#1693)

* feat: initial associations of evaluations to traces

* add some documentaiton

* wip: add dataframe utils

* Switch to a single evaluation per dataframe

* make copy the default

* fix doc string

* fix name

* fix notebook

* Add immutability

* remove value from being required

* fix tutorials formatting

* make type a string to see if it fixes tests

* fix test to handle un-parsable

* Update src/phoenix/trace/trace_eval_dataset.py

Co-authored-by: Xander Song <[email protected]>

* Update src/phoenix/trace/trace_eval_dataset.py

Co-authored-by: Xander Song <[email protected]>

* change to trace_evaluations

* cleanup

* Fix formatting

* pr comments

* cleanup notebook

* make sure columns are dropped

* remove unused test

---------

Co-authored-by: Xander Song <[email protected]>
  • Loading branch information
mikeldking and axiomofjoy authored Dec 1, 2023
1 parent 13d019f commit a218a65
Show file tree
Hide file tree
Showing 7 changed files with 356 additions and 102 deletions.
4 changes: 4 additions & 0 deletions src/phoenix/trace/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from .span_evaluations import SpanEvaluations
from .trace_dataset import TraceDataset

__all__ = ["TraceDataset", "SpanEvaluations"]
72 changes: 72 additions & 0 deletions src/phoenix/trace/span_evaluations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import pandas as pd

EVALUATIONS_INDEX_NAME = "context.span_id"
RESULTS_COLUMN_NAMES = ["score", "label", "explanation"]

EVAL_NAME_COLUMN_PREFIX = "eval."


class SpanEvaluations:
"""
SpanEvaluations is a set of evaluation annotations for a set of spans.
SpanEvaluations encompasses the evaluation annotations for a single evaluation task
such as toxicity or hallucinations.
SpanEvaluations can be appended to TraceDatasets so that the spans and
evaluations can be joined and analyzed together.
Parameters
__________
eval_name: str
the name of the evaluation, e.x. 'toxicity'
dataframe: pandas.DataFrame
the pandas dataframe containing the evaluation annotations Each row
represents the evaluations on a span.
Example
_______
DataFrame of evaluations for toxicity may look like:
| span_id | score | label | explanation |
|---------|--------------------|--------------------|--------------------|
| span_1 | 1 | toxic | bad language |
| span_2 | 0 | non-toxic | violence |
| span_3 | 1 | toxic | discrimination |
"""

dataframe: pd.DataFrame

eval_name: str # The name for the evaluation, e.x. 'toxicity'

def __init__(self, eval_name: str, dataframe: pd.DataFrame):
self.eval_name = eval_name

# If the dataframe contains the index column, set the index to that column
if EVALUATIONS_INDEX_NAME in dataframe.columns:
dataframe = dataframe.set_index(EVALUATIONS_INDEX_NAME)

# validate that the dataframe is indexed by context.span_id
if dataframe.index.name != EVALUATIONS_INDEX_NAME:
raise ValueError(
f"The dataframe index must be '{EVALUATIONS_INDEX_NAME}' but was "
f"'{dataframe.index.name}'"
)

# Drop the unnecessary columns
extra_column_names = dataframe.columns.difference(RESULTS_COLUMN_NAMES)
self.dataframe = dataframe.drop(extra_column_names, axis=1)

def get_dataframe(self, prefix_columns_with_name: bool = True) -> pd.DataFrame:
"""
Returns a copy of the dataframe with the evaluation annotations
Parameters
__________
prefix_columns_with_name: bool
if True, the columns will be prefixed with the eval_name, e.x. 'eval.toxicity.value'
"""
if prefix_columns_with_name:
prefix = f"{EVAL_NAME_COLUMN_PREFIX}{self.eval_name}."
return self.dataframe.add_prefix(prefix)
return self.dataframe.copy()
62 changes: 55 additions & 7 deletions src/phoenix/trace/trace_dataset.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import json
import uuid
from datetime import datetime
from typing import Iterator, List, Optional, cast
from typing import Iterable, Iterator, List, Optional, cast

import pandas as pd
from pandas import DataFrame, read_parquet
Expand All @@ -10,6 +10,7 @@

from ..config import DATASET_DIR, GENERATED_DATASET_NAME_PREFIX
from .schemas import ATTRIBUTE_PREFIX, CONTEXT_PREFIX, Span
from .span_evaluations import EVALUATIONS_INDEX_NAME, SpanEvaluations
from .span_json_decoder import json_to_span
from .span_json_encoder import span_to_json

Expand Down Expand Up @@ -40,25 +41,41 @@ class TraceDataset:
"""
A TraceDataset is a wrapper around a dataframe which is a flattened representation
of Spans. The collection of spans trace the LLM application's execution.
Parameters
__________
dataframe: pandas.DataFrame
the pandas dataframe containing the tracing data. Each row represents a span.
"""

name: str
dataframe: pd.DataFrame
evaluations: List[SpanEvaluations] = []
_data_file_name: str = "data.parquet"

def __init__(self, dataframe: DataFrame, name: Optional[str] = None):
def __init__(
self,
dataframe: DataFrame,
name: Optional[str] = None,
evaluations: Iterable[SpanEvaluations] = (),
):
"""
Constructs a TraceDataset from a dataframe of spans. Optionally takes in
evaluations for the spans in the dataset.
Parameters
__________
dataframe: pandas.DataFrame
the pandas dataframe containing the tracing data. Each row
represents a span.
evaluations: Optional[Iterable[SpanEvaluations]]
an optional list of evaluations for the spans in the dataset. If
provided, the evaluations can be materialized into a unified
dataframe as annotations.
"""
# Validate the the dataframe has required fields
if missing_columns := set(REQUIRED_COLUMNS) - set(dataframe.columns):
raise ValueError(
f"The dataframe is missing some required columns: {', '.join(missing_columns)}"
)
self.dataframe = normalize_dataframe(dataframe)
self.name = name or f"{GENERATED_DATASET_NAME_PREFIX}{str(uuid.uuid4())}"
self.evaluations = list(evaluations)

@classmethod
def from_spans(cls, spans: List[Span]) -> "TraceDataset":
Expand Down Expand Up @@ -133,3 +150,34 @@ def to_disc(self) -> None:
allow_truncated_timestamps=True,
coerce_timestamps="ms",
)

def append_evaluations(self, evaluations: SpanEvaluations) -> None:
"""adds an evaluation to the traces"""
# Append the evaluations to the list of evaluations
self.evaluations.append(evaluations)

def get_evals_dataframe(self) -> DataFrame:
"""
Creates a flat dataframe of all the evaluations for the dataset.
"""
return pd.concat(
[evals.get_dataframe(prefix_columns_with_name=True) for evals in self.evaluations],
axis=1,
)

def get_spans_dataframe(self, include_evaluations: bool = True) -> DataFrame:
"""
converts the dataset to a dataframe of spans. If evaluations are included,
the evaluations are merged into the dataframe.
Parameters
__________
include_evaluations: bool
if True, the evaluations are merged into the dataframe
"""
if not include_evaluations:
return self.dataframe.copy()
evals_df = self.get_evals_dataframe()
# Make sure the index is set to the span_id
df = self.dataframe.set_index(EVALUATIONS_INDEX_NAME, drop=False)
return pd.concat([df, evals_df], axis=1)
24 changes: 24 additions & 0 deletions tests/trace/test_span_evaluations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import pandas as pd
from phoenix.trace.span_evaluations import SpanEvaluations


def test_span_evaluations_construction():
num_records = 5
span_ids = [f"span_{index}" for index in range(num_records)]

eval_ds = SpanEvaluations(
eval_name="my_eval",
dataframe=pd.DataFrame(
{
"context.span_id": span_ids,
"label": [index for index in range(num_records)],
"score": [index for index in range(num_records)],
"random_column": [index for index in range(num_records)],
}
).set_index("context.span_id"),
)

# make sure the dataframe only has the needed values
assert "context.span_id" not in eval_ds.dataframe.columns
assert "random_column" not in eval_ds.dataframe.columns
assert "score" in eval_ds.dataframe.columns
54 changes: 54 additions & 0 deletions tests/trace/test_trace_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
SpanKind,
SpanStatusCode,
)
from phoenix.trace.span_evaluations import SpanEvaluations
from phoenix.trace.trace_dataset import TraceDataset


Expand Down Expand Up @@ -142,3 +143,56 @@ def test_dataset_construction_from_spans():
)
dataset = TraceDataset.from_spans(spans)
assert_frame_equal(expected_dataframe, dataset.dataframe[expected_dataframe.columns])


def test_dataset_construction_with_evaluations():
num_records = 5
span_ids = [f"span_{index}" for index in range(num_records)]
traces_df = pd.DataFrame(
{
"name": [f"name_{index}" for index in range(num_records)],
"span_kind": ["LLM" for index in range(num_records)],
"parent_id": [None for index in range(num_records)],
"start_time": [datetime.now() for index in range(num_records)],
"end_time": [datetime.now() for index in range(num_records)],
"message": [f"message_{index}" for index in range(num_records)],
"status_code": ["OK" for index in range(num_records)],
"status_message": ["" for index in range(num_records)],
"context.trace_id": [f"trace_{index}" for index in range(num_records)],
"context.span_id": span_ids,
}
).set_index("context.span_id", drop=False)
eval_ds_1 = SpanEvaluations(
eval_name="fake_eval_1",
dataframe=pd.DataFrame(
{
"context.span_id": span_ids,
"score": [index for index in range(num_records)],
}
).set_index("context.span_id"),
)
eval_ds_2 = SpanEvaluations(
eval_name="fake_eval_2",
dataframe=pd.DataFrame(
{
"context.span_id": span_ids,
"score": [index for index in range(num_records)],
}
).set_index("context.span_id"),
)
ds = TraceDataset(traces_df, evaluations=[eval_ds_1, eval_ds_2])
evals_df = ds.get_evals_dataframe()
assert "eval.fake_eval_1.score" in evals_df.columns
assert "eval.fake_eval_2.score" in evals_df.columns
assert len(evals_df) is num_records
df_with_evals = ds.get_spans_dataframe(include_evaluations=True)
# Validate that the length of the dataframe is the same
assert len(df_with_evals) == len(traces_df)
# Validate that the evaluation columns are present
assert "eval.fake_eval_1.score" in df_with_evals.columns
assert "eval.fake_eval_2.score" in df_with_evals.columns
# Validate that the evaluation column contains the correct values
assert list(df_with_evals["eval.fake_eval_1.score"]) == list(eval_ds_1.dataframe["score"])
assert list(df_with_evals["eval.fake_eval_2.score"]) == list(eval_ds_2.dataframe["score"])
# Validate that the output contains a span_id column
assert "context.span_id" in df_with_evals.columns
Loading

0 comments on commit a218a65

Please sign in to comment.