Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(traces): server-side sort of spans by evaluation result (score or label) #1812

Merged
merged 6 commits into from
Nov 28, 2023
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion app/schema.graphql
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,16 @@ type EmbeddingMetadata {
linkToData: String
}

enum EvalAttr {
score
label
}

input EvalResultKey {
name: String!
attr: EvalAttr!
}

interface Evaluation {
"""Name of the evaluation, e.g. 'helpfulness' or 'relevance'."""
name: String!
Expand Down Expand Up @@ -654,8 +664,12 @@ enum SpanKind {
unknown
}

"""
The sort key and direction for span connections. Must specify one and only one of either `col` or `eval_result_key`.
"""
input SpanSort {
col: SpanColumn!
col: SpanColumn = null
evalResultKey: EvalResultKey = null
Comment on lines +671 to +672
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oneOf for input types is not available in strawberry

dir: SortDir!
}

Expand Down

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

10 changes: 8 additions & 2 deletions app/src/pages/tracing/__generated__/TracesTableQuery.graphql.ts

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 4 additions & 0 deletions src/phoenix/core/evals.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,10 @@ def _process_evaluation(self, evaluation: pb.Evaluation) -> None:
else:
assert_never(subject_id_kind)

def get_span_evaluation(self, span_id: SpanID, name: str) -> Optional[pb.Evaluation]:
with self._lock:
return self._evaluations_by_span_id[span_id].get(name)

def get_span_evaluation_names(self) -> List[EvaluationName]:
with self._lock:
return list(self._span_evaluations_by_name.keys())
Expand Down
78 changes: 62 additions & 16 deletions src/phoenix/server/api/input_types/SpanSort.py
Original file line number Diff line number Diff line change
@@ -1,50 +1,96 @@
from enum import Enum
from functools import partial
from typing import Any, Iterable, Iterator
from typing import Any, Iterable, Iterator, Optional, Protocol

import pandas as pd
import strawberry
from typing_extensions import assert_never

import phoenix.trace.v1 as pb
from phoenix.core.traces import (
END_TIME,
LLM_TOKEN_COUNT_COMPLETION,
LLM_TOKEN_COUNT_PROMPT,
LLM_TOKEN_COUNT_TOTAL,
START_TIME,
ComputedAttributes,
)
from phoenix.server.api.types.SortDir import SortDir
from phoenix.trace.schemas import Span
from phoenix.trace import semantic_conventions
from phoenix.trace.schemas import Span, SpanID


@strawberry.enum
class SpanColumn(Enum):
startTime = START_TIME
endTime = END_TIME
latencyMs = ComputedAttributes.LATENCY_MS.value
tokenCountTotal = LLM_TOKEN_COUNT_TOTAL
tokenCountPrompt = LLM_TOKEN_COUNT_PROMPT
tokenCountCompletion = LLM_TOKEN_COUNT_COMPLETION
tokenCountTotal = semantic_conventions.LLM_TOKEN_COUNT_TOTAL
tokenCountPrompt = semantic_conventions.LLM_TOKEN_COUNT_PROMPT
tokenCountCompletion = semantic_conventions.LLM_TOKEN_COUNT_COMPLETION
cumulativeTokenCountTotal = ComputedAttributes.CUMULATIVE_LLM_TOKEN_COUNT_TOTAL.value
cumulativeTokenCountPrompt = ComputedAttributes.CUMULATIVE_LLM_TOKEN_COUNT_PROMPT.value
cumulativeTokenCountCompletion = ComputedAttributes.CUMULATIVE_LLM_TOKEN_COUNT_COMPLETION.value


@strawberry.enum
class EvalAttr(Enum):
score = "score"
label = "label"


class SupportsGetSpanEvaluation(Protocol):
def get_span_evaluation(self, span_id: SpanID, name: str) -> Optional[pb.Evaluation]:
...


@strawberry.input
class SpanSort:
"""
The sort column and direction for span connections
"""
class EvalResultKey:
name: str
attr: EvalAttr

col: SpanColumn
def __call__(
self,
span: Span,
evals: Optional[SupportsGetSpanEvaluation] = None,
) -> Any:
"""
Returns the evaluation result for the given span
"""
if evals is None:
return None
span_id = span.context.span_id
evaluation = evals.get_span_evaluation(span_id, self.name)
if evaluation is None:
return None
result = evaluation.result
if self.attr is EvalAttr.score:
return result.score.value if result.HasField("score") else None
if self.attr is EvalAttr.label:
return result.label.value if result.HasField("label") else None
assert_never(self.attr)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

minor nit on semantics: making the key a callable rather than having an explicit method that has a well-formed name is confusing for me. For example the code below doesn't reed well to me - the key appears to be a simple param but then it's getting partially called with evals. It forces me to have to go to definition to understand what calling the key does, which is a level of indirection that I think you can avoid by just having a good name for a method or function.

_key = partial(self.eval_result_key, evals=evals)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can change it. Not a problem. But will just note the reason why it was done.

  1. key is a callable in both python and pandas, so it's part of the python tradition.
  2. python functions are objects, so initializing it with parameters alludes to python's functional paradigm.
  3. it's fun since it's a programmer's version of a double entendre (given 1 and 2 above)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, yeah I think Key being a callable makes sense but in this case it requires a partial application to work. But this makes sense. Your call. Thanks for the explanation.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No problem. I already changed it.



@strawberry.input(
description="The sort key and direction for span connections. Must "
"specify one and only one of either `col` or `eval_result_key`."
)
class SpanSort:
col: Optional[SpanColumn] = None
eval_result_key: Optional[EvalResultKey] = None
dir: SortDir

def __call__(self, spans: Iterable[Span]) -> Iterator[Span]:
def __call__(
self,
spans: Iterable[Span],
evals: Optional[SupportsGetSpanEvaluation] = None,
) -> Iterator[Span]:
"""
Sorts the spans by the given column and direction
Sorts the spans by the given key (column or eval) and direction
"""
if self.eval_result_key is not None:
_key = partial(self.eval_result_key, evals=evals)
else:
_key = partial(_get_column, span_column=self.col or SpanColumn.startTime)
yield from pd.Series(spans, dtype=object).sort_values(
key=lambda series: series.apply(partial(_get_column, span_column=self.col)),
key=lambda series: series.apply(_key),
ascending=self.dir.value == SortDir.asc.value,
)

Expand Down
2 changes: 1 addition & 1 deletion src/phoenix/server/api/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ def spans(
if predicate:
spans = filter(predicate, spans)
if sort:
spans = sort(spans)
spans = sort(spans, evals=info.context.evals)
data = list(map(to_gql_span, spans))
return connection_from_list(data=data, args=args)

Expand Down
60 changes: 60 additions & 0 deletions tests/server/api/input_types/test_SpanSort.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
from collections import namedtuple
from itertools import count, islice
from random import random

import phoenix.trace.v1 as pb
import pytest
from google.protobuf.wrappers_pb2 import DoubleValue, StringValue
from phoenix.server.api.input_types.SpanSort import EvalAttr, EvalResultKey, SpanColumn, SpanSort
from phoenix.server.api.types.SortDir import SortDir


@pytest.mark.parametrize("col", [SpanColumn.endTime, SpanColumn.latencyMs])
def test_sort_by_col(spans, col):
span0, span1, span2 = islice(spans, 3)
sort = SpanSort(col=col, dir=SortDir.desc)
assert list(sort([span0, span1, span2])) == [span2, span0, span1]


@pytest.mark.parametrize("eval_attr", list(EvalAttr))
def test_sort_by_eval(spans, evals, eval_name, eval_attr):
span0, span1, span2 = islice(spans, 3)

eval_result_key = EvalResultKey(name=eval_name, attr=eval_attr)
sort = SpanSort(eval_result_key=eval_result_key, dir=SortDir.desc)
assert list(sort([span0, span2, span1], evals)) == [span1, span0, span2]

# non-existent evaluation name
no_op_key = EvalResultKey(name=random(), attr=eval_attr)
no_op_sort = SpanSort(eval_result_key=no_op_key, dir=SortDir.desc)
assert list(no_op_sort([span2, span0, span1], evals)) == [span2, span0, span1]


Span = namedtuple("Span", "context end_time attributes")
Context = namedtuple("Context", "span_id")
Evals = namedtuple("Evals", "get_span_evaluation")


@pytest.fixture
def evals(eval_name):
result0 = pb.Evaluation.Result(score=DoubleValue(value=0))
result1 = pb.Evaluation.Result(score=DoubleValue(value=1), label=StringValue(value="1"))
evaluations = {eval_name: {0: pb.Evaluation(result=result0), 1: pb.Evaluation(result=result1)}}
return Evals(lambda span_id, name: evaluations.get(name, {}).get(span_id))


@pytest.fixture
def eval_name():
return "correctness"


@pytest.fixture
def spans():
return (
Span(
context=Context(i),
end_time=None if i % 2 else i,
attributes={} if i % 2 else {SpanColumn.latencyMs.value: i},
)
for i in count()
)
Loading