-
Notifications
You must be signed in to change notification settings - Fork 336
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat(traces): server-side sort of spans by evaluation result (score or label) #1812
Changes from 4 commits
1c9a90f
0d1bd6d
4d47425
fc78f87
6b0eadb
77ebf68
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,50 +1,96 @@ | ||
from enum import Enum | ||
from functools import partial | ||
from typing import Any, Iterable, Iterator | ||
from typing import Any, Iterable, Iterator, Optional, Protocol | ||
|
||
import pandas as pd | ||
import strawberry | ||
from typing_extensions import assert_never | ||
|
||
import phoenix.trace.v1 as pb | ||
from phoenix.core.traces import ( | ||
END_TIME, | ||
LLM_TOKEN_COUNT_COMPLETION, | ||
LLM_TOKEN_COUNT_PROMPT, | ||
LLM_TOKEN_COUNT_TOTAL, | ||
START_TIME, | ||
ComputedAttributes, | ||
) | ||
from phoenix.server.api.types.SortDir import SortDir | ||
from phoenix.trace.schemas import Span | ||
from phoenix.trace import semantic_conventions | ||
from phoenix.trace.schemas import Span, SpanID | ||
|
||
|
||
@strawberry.enum | ||
class SpanColumn(Enum): | ||
startTime = START_TIME | ||
endTime = END_TIME | ||
latencyMs = ComputedAttributes.LATENCY_MS.value | ||
tokenCountTotal = LLM_TOKEN_COUNT_TOTAL | ||
tokenCountPrompt = LLM_TOKEN_COUNT_PROMPT | ||
tokenCountCompletion = LLM_TOKEN_COUNT_COMPLETION | ||
tokenCountTotal = semantic_conventions.LLM_TOKEN_COUNT_TOTAL | ||
tokenCountPrompt = semantic_conventions.LLM_TOKEN_COUNT_PROMPT | ||
tokenCountCompletion = semantic_conventions.LLM_TOKEN_COUNT_COMPLETION | ||
cumulativeTokenCountTotal = ComputedAttributes.CUMULATIVE_LLM_TOKEN_COUNT_TOTAL.value | ||
cumulativeTokenCountPrompt = ComputedAttributes.CUMULATIVE_LLM_TOKEN_COUNT_PROMPT.value | ||
cumulativeTokenCountCompletion = ComputedAttributes.CUMULATIVE_LLM_TOKEN_COUNT_COMPLETION.value | ||
|
||
|
||
@strawberry.enum | ||
class EvalAttr(Enum): | ||
score = "score" | ||
label = "label" | ||
|
||
|
||
class SupportsGetSpanEvaluation(Protocol): | ||
def get_span_evaluation(self, span_id: SpanID, name: str) -> Optional[pb.Evaluation]: | ||
... | ||
|
||
|
||
@strawberry.input | ||
class SpanSort: | ||
""" | ||
The sort column and direction for span connections | ||
""" | ||
class EvalResultKey: | ||
name: str | ||
attr: EvalAttr | ||
|
||
col: SpanColumn | ||
def __call__( | ||
self, | ||
span: Span, | ||
evals: Optional[SupportsGetSpanEvaluation] = None, | ||
) -> Any: | ||
""" | ||
Returns the evaluation result for the given span | ||
""" | ||
if evals is None: | ||
return None | ||
span_id = span.context.span_id | ||
evaluation = evals.get_span_evaluation(span_id, self.name) | ||
if evaluation is None: | ||
return None | ||
result = evaluation.result | ||
if self.attr is EvalAttr.score: | ||
return result.score.value if result.HasField("score") else None | ||
if self.attr is EvalAttr.label: | ||
return result.label.value if result.HasField("label") else None | ||
assert_never(self.attr) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. minor nit on semantics: making the key a callable rather than having an explicit method that has a well-formed name is confusing for me. For example the code below doesn't reed well to me - the key appears to be a simple param but then it's getting partially called with evals. It forces me to have to go to definition to understand what calling the key does, which is a level of indirection that I think you can avoid by just having a good name for a method or function. _key = partial(self.eval_result_key, evals=evals) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I can change it. Not a problem. But will just note the reason why it was done.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I see, yeah I think Key being a callable makes sense but in this case it requires a partial application to work. But this makes sense. Your call. Thanks for the explanation. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No problem. I already changed it. |
||
|
||
|
||
@strawberry.input( | ||
description="The sort key and direction for span connections. Must " | ||
"specify one and only one of either `col` or `eval_result_key`." | ||
) | ||
class SpanSort: | ||
col: Optional[SpanColumn] = None | ||
eval_result_key: Optional[EvalResultKey] = None | ||
dir: SortDir | ||
|
||
def __call__(self, spans: Iterable[Span]) -> Iterator[Span]: | ||
def __call__( | ||
self, | ||
spans: Iterable[Span], | ||
evals: Optional[SupportsGetSpanEvaluation] = None, | ||
) -> Iterator[Span]: | ||
""" | ||
Sorts the spans by the given column and direction | ||
Sorts the spans by the given key (column or eval) and direction | ||
""" | ||
if self.eval_result_key is not None: | ||
_key = partial(self.eval_result_key, evals=evals) | ||
else: | ||
_key = partial(_get_column, span_column=self.col or SpanColumn.startTime) | ||
yield from pd.Series(spans, dtype=object).sort_values( | ||
key=lambda series: series.apply(partial(_get_column, span_column=self.col)), | ||
key=lambda series: series.apply(_key), | ||
ascending=self.dir.value == SortDir.asc.value, | ||
) | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
from collections import namedtuple | ||
from itertools import count, islice | ||
from random import random | ||
|
||
import phoenix.trace.v1 as pb | ||
import pytest | ||
from google.protobuf.wrappers_pb2 import DoubleValue, StringValue | ||
from phoenix.server.api.input_types.SpanSort import EvalAttr, EvalResultKey, SpanColumn, SpanSort | ||
from phoenix.server.api.types.SortDir import SortDir | ||
|
||
|
||
@pytest.mark.parametrize("col", [SpanColumn.endTime, SpanColumn.latencyMs]) | ||
def test_sort_by_col(spans, col): | ||
span0, span1, span2 = islice(spans, 3) | ||
sort = SpanSort(col=col, dir=SortDir.desc) | ||
assert list(sort([span0, span1, span2])) == [span2, span0, span1] | ||
|
||
|
||
@pytest.mark.parametrize("eval_attr", list(EvalAttr)) | ||
def test_sort_by_eval(spans, evals, eval_name, eval_attr): | ||
span0, span1, span2 = islice(spans, 3) | ||
|
||
eval_result_key = EvalResultKey(name=eval_name, attr=eval_attr) | ||
sort = SpanSort(eval_result_key=eval_result_key, dir=SortDir.desc) | ||
assert list(sort([span0, span2, span1], evals)) == [span1, span0, span2] | ||
|
||
# non-existent evaluation name | ||
no_op_key = EvalResultKey(name=random(), attr=eval_attr) | ||
no_op_sort = SpanSort(eval_result_key=no_op_key, dir=SortDir.desc) | ||
assert list(no_op_sort([span2, span0, span1], evals)) == [span2, span0, span1] | ||
|
||
|
||
Span = namedtuple("Span", "context end_time attributes") | ||
Context = namedtuple("Context", "span_id") | ||
Evals = namedtuple("Evals", "get_span_evaluation") | ||
|
||
|
||
@pytest.fixture | ||
def evals(eval_name): | ||
result0 = pb.Evaluation.Result(score=DoubleValue(value=0)) | ||
result1 = pb.Evaluation.Result(score=DoubleValue(value=1), label=StringValue(value="1")) | ||
evaluations = {eval_name: {0: pb.Evaluation(result=result0), 1: pb.Evaluation(result=result1)}} | ||
return Evals(lambda span_id, name: evaluations.get(name, {}).get(span_id)) | ||
|
||
|
||
@pytest.fixture | ||
def eval_name(): | ||
return "correctness" | ||
|
||
|
||
@pytest.fixture | ||
def spans(): | ||
return ( | ||
Span( | ||
context=Context(i), | ||
end_time=None if i % 2 else i, | ||
attributes={} if i % 2 else {SpanColumn.latencyMs.value: i}, | ||
) | ||
for i in count() | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oneOf
for input types is not available in strawberry