Skip to content

Commit

Permalink
feat(traces): graphql query for document evaluation summary (#1874)
Browse files Browse the repository at this point in the history
* feat(traces): graphql query for document evaluation summary
  • Loading branch information
RogerHYang authored Dec 7, 2023
1 parent b6e8c73 commit 8a6a063
Show file tree
Hide file tree
Showing 5 changed files with 160 additions and 2 deletions.
16 changes: 16 additions & 0 deletions app/schema.graphql
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,21 @@ type DocumentEvaluation implements Evaluation {
documentPosition: Int!
}

"""
Summarization of retrieval metrics: Average NDCG@K, Average Precision@K, Mean Reciprocal Rank, Hit Rate, etc.
"""
type DocumentEvaluationSummary {
evaluationName: String!
averageNdcg(k: Int): Float
countNdcg(k: Int): Int!
averagePrecision(k: Int): Float
countPrecision(k: Int): Int!
meanReciprocalRank: Float
countReciprocalRank: Int!
hitRate: Float
countHit: Int!
}

"""
A collection of retrieval metrics computed on a list of document evaluation scores: NDCG@K, Precision@K, Reciprocal Rank, etc.
"""
Expand Down Expand Up @@ -561,6 +576,7 @@ type Query {
"""Names of available document evaluations."""
documentEvaluationNames(spanId: ID): [String!]!
spanEvaluationSummary(evaluationName: String!): EvaluationSummary
documentEvaluationSummary(evaluationName: String!): DocumentEvaluationSummary
traceDatasetInfo: TraceDatasetInfo
validateSpanFilterCondition(condition: String!): ValidationResult!
}
Expand Down
4 changes: 4 additions & 0 deletions src/phoenix/core/evals.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,10 @@ def get_evaluations_by_span_id(self, span_id: SpanID) -> List[pb.Evaluation]:
with self._lock:
return list(self._evaluations_by_span_id[span_id].values())

def get_document_evaluation_span_ids(self, name: EvaluationName) -> Tuple[SpanID, ...]:
with self._lock:
return tuple(self._document_evaluations_by_name[name].keys())

def get_document_evaluations_by_span_id(self, span_id: SpanID) -> List[pb.Evaluation]:
all_evaluations: List[pb.Evaluation] = []
with self._lock:
Expand Down
11 changes: 11 additions & 0 deletions src/phoenix/core/traces.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ def __init__(self) -> None:
self._traces: Dict[TraceID, List[SpanID]] = defaultdict(list)
self._child_span_ids: DefaultDict[SpanID, List[ChildSpanID]] = defaultdict(list)
self._orphan_spans: DefaultDict[ParentSpanID, List[pb.Span]] = defaultdict(list)
self._num_documents: DefaultDict[SpanID, int] = defaultdict(int)
self._start_time_sorted_span_ids: SortedKeyList[SpanID] = SortedKeyList(
key=lambda span_id: self._spans[span_id].start_time.ToDatetime(timezone.utc),
)
Expand Down Expand Up @@ -180,6 +181,10 @@ def get_spans(
if span := self[span_id]:
yield span

def get_num_documents(self, span_id: SpanID) -> int:
with self._lock:
return self._num_documents[span_id]

def latency_rank_percent(self, latency_ms: float) -> Optional[float]:
"""
Returns a value between 0 and 100 approximating the rank of the
Expand Down Expand Up @@ -319,6 +324,12 @@ def _process_span(self, span: pb.Span) -> None:
if existing_span:
self._token_count_total -= existing_span[LLM_TOKEN_COUNT_TOTAL] or 0
self._token_count_total += new_span[LLM_TOKEN_COUNT_TOTAL] or 0
# Update number of documents
num_documents_update = len(span.retrieval.documents)
if existing_span:
num_documents_update -= len(existing_span.retrieval.documents)
if num_documents_update:
self._num_documents[span_id] += num_documents_update
# Process previously orphaned spans, if any.
for orphan_span in self._orphan_spans.pop(span_id, ()):
self._process_span(orphan_span)
Expand Down
37 changes: 35 additions & 2 deletions src/phoenix/server/api/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from strawberry.types import Info
from typing_extensions import Annotated

from phoenix.metrics.retrieval_metrics import RetrievalMetrics
from phoenix.pointcloud.clustering import Hdbscan
from phoenix.server.api.helpers import ensure_list
from phoenix.server.api.input_types.ClusterInput import ClusterInput
Expand All @@ -20,14 +21,15 @@
)
from phoenix.server.api.input_types.SpanSort import SpanSort
from phoenix.server.api.types.Cluster import Cluster, to_gql_clusters
from phoenix.trace.filter import SpanFilter
from phoenix.trace.schemas import SpanID

from ...trace.filter import SpanFilter
from ...trace.schemas import SpanID
from .context import Context
from .input_types.TimeRange import TimeRange
from .types.DatasetInfo import TraceDatasetInfo
from .types.DatasetRole import AncillaryDatasetRole, DatasetRole
from .types.Dimension import to_gql_dimension
from .types.DocumentEvaluationSummary import DocumentEvaluationSummary
from .types.EmbeddingDimension import (
DEFAULT_CLUSTER_SELECTION_EPSILON,
DEFAULT_MIN_CLUSTER_SIZE,
Expand Down Expand Up @@ -289,6 +291,37 @@ def span_evaluation_summary(
labels = evals.get_span_evaluation_labels(evaluation_name)
return EvaluationSummary(evaluations, labels)

@strawberry.field
def document_evaluation_summary(
self,
info: Info[Context, None],
evaluation_name: str,
) -> Optional[DocumentEvaluationSummary]:
if (evals := info.context.evals) is None:
return None
if (traces := info.context.traces) is None:
return None
span_ids = evals.get_document_evaluation_span_ids(evaluation_name)
if not span_ids:
return None
metrics_collection = []
for span_id in span_ids:
num_documents = traces.get_num_documents(span_id)
if not num_documents:
continue
evaluation_scores = evals.get_document_evaluation_scores(
span_id=span_id,
evaluation_name=evaluation_name,
num_documents=num_documents,
)
metrics_collection.append(RetrievalMetrics(evaluation_scores))
if not metrics_collection:
return None
return DocumentEvaluationSummary(
evaluation_name=evaluation_name,
metrics_collection=metrics_collection,
)

@strawberry.field
def trace_dataset_info(
self,
Expand Down
94 changes: 94 additions & 0 deletions src/phoenix/server/api/types/DocumentEvaluationSummary.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
import math
from functools import cached_property
from typing import Any, Dict, Iterable, Optional, Tuple

import pandas as pd
import strawberry
from strawberry import UNSET, Private

from phoenix.metrics.retrieval_metrics import RetrievalMetrics


@strawberry.type(
description="Summarization of retrieval metrics: Average NDCG@K, Average "
"Precision@K, Mean Reciprocal Rank, Hit Rate, etc."
)
class DocumentEvaluationSummary:
evaluation_name: str
metrics_collection: Private["pd.Series[Any]"]

def __init__(
self,
evaluation_name: str,
metrics_collection: Iterable[RetrievalMetrics],
) -> None:
self.evaluation_name = evaluation_name
self.metrics_collection = pd.Series(metrics_collection, dtype=object)
self._cached_average_ndcg_results: Dict[Optional[int], Tuple[float, int]] = {}
self._cached_average_precision_results: Dict[Optional[int], Tuple[float, int]] = {}

@strawberry.field
def average_ndcg(self, k: Optional[int] = UNSET) -> Optional[float]:
value, _ = self._average_ndcg(None if k is UNSET else k)
return value if math.isfinite(value) else None

@strawberry.field
def count_ndcg(self, k: Optional[int] = UNSET) -> int:
_, count = self._average_ndcg(None if k is UNSET else k)
return count

@strawberry.field
def average_precision(self, k: Optional[int] = UNSET) -> Optional[float]:
value, _ = self._average_precision(None if k is UNSET else k)
return value if math.isfinite(value) else None

@strawberry.field
def count_precision(self, k: Optional[int] = UNSET) -> int:
_, count = self._average_precision(None if k is UNSET else k)
return count

@strawberry.field
def mean_reciprocal_rank(self) -> Optional[float]:
value, _ = self._average_reciprocal_rank
return value if math.isfinite(value) else None

@strawberry.field
def count_reciprocal_rank(self) -> int:
_, count = self._average_reciprocal_rank
return count

@strawberry.field
def hit_rate(self) -> Optional[float]:
value, _ = self._average_hit
return value if math.isfinite(value) else None

@strawberry.field
def count_hit(self) -> int:
_, count = self._average_hit
return count

def _average_ndcg(self, k: Optional[int] = None) -> Tuple[float, int]:
if (result := self._cached_average_ndcg_results.get(k)) is not None:
return result
values = self.metrics_collection.apply(lambda m: m.ndcg(k))
result = (values.mean(), values.count())
self._cached_average_ndcg_results[k] = result
return result

def _average_precision(self, k: Optional[int] = None) -> Tuple[float, int]:
if (result := self._cached_average_precision_results.get(k)) is not None:
return result
values = self.metrics_collection.apply(lambda m: m.precision(k))
result = (values.mean(), values.count())
self._cached_average_ndcg_results[k] = result
return result

@cached_property
def _average_reciprocal_rank(self) -> Tuple[float, int]:
values = self.metrics_collection.apply(lambda m: m.reciprocal_rank())
return values.mean(), values.count()

@cached_property
def _average_hit(self) -> Tuple[float, int]:
values = self.metrics_collection.apply(lambda m: m.hit())
return values.mean(), values.count()

0 comments on commit 8a6a063

Please sign in to comment.