-
Notifications
You must be signed in to change notification settings - Fork 336
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(traces): document retrieval metrics based on document evaluation…
… scores (#1826) * feat: span document retrieval metrics
- Loading branch information
1 parent
f066e10
commit 3dfb7bd
Showing
7 changed files
with
348 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,105 @@ | ||
from dataclasses import dataclass, field | ||
from typing import Iterable, Optional, cast | ||
|
||
import numpy as np | ||
import pandas as pd | ||
from sklearn.metrics import ndcg_score | ||
|
||
|
||
@dataclass(frozen=True) | ||
class RetrievalMetrics: | ||
""" | ||
Ranking metrics computed on a list of evaluation scores sorted from high to | ||
low by their ranking scores (prior to evaluation). For example, if the items | ||
are search results and the evaluation scores are their relevance scores (e.g. | ||
1 if relevant and 0 if not relevant), then the evaluation scores should be | ||
sorted by the original order of the displayed results, i.e. the first search | ||
result should go first. For more info on these metrics, | ||
see https://cran.r-project.org/web/packages/recometrics/vignettes/Evaluating_recommender_systems.html | ||
""" # noqa: E501 | ||
|
||
eval_scores: "pd.Series[float]" | ||
length: int = field(init=False) | ||
has_nan: bool = field(init=False) | ||
|
||
def __init__(self, eval_scores: Iterable[float]) -> None: | ||
_eval_scores = np.fromiter(eval_scores, dtype=float) | ||
object.__setattr__(self, "length", len(_eval_scores)) | ||
object.__setattr__(self, "has_nan", not np.all(np.isfinite(_eval_scores))) | ||
if self.length < 2: | ||
# len < 2 won't work for sklearn.metrics.ndcg_score, so we pad it | ||
# with zeros (but still keep track of the original length) | ||
_scores = _eval_scores | ||
_eval_scores = np.zeros(2) | ||
_eval_scores[: len(_scores)] = _scores | ||
# For ranking metrics, the actual scores used for ranking are only | ||
# needed for sorting the items. Since we assume the items are already | ||
# sorted from high to low by their ranking scores, we can assign ranking | ||
# scores to be the reverse of the indices of eval_scores, just so that | ||
# it goes from high to low. | ||
ranking_scores = reversed(range(len(_eval_scores))) | ||
object.__setattr__( | ||
self, | ||
"eval_scores", | ||
pd.Series(_eval_scores, dtype=float, index=ranking_scores), # type: ignore | ||
) | ||
|
||
def ndcg(self, k: Optional[int] = None) -> float: | ||
""" | ||
Normalized Discounted Cumulative Gain (NDCG) at `k` with log base 2 | ||
discounting. If `k` is None, it's set to the length of the scores. If | ||
`k` < 1, return 0.0. | ||
""" | ||
if self.has_nan: | ||
return np.nan | ||
if k is None: | ||
k = self.length | ||
if k < 1: | ||
return 0.0 | ||
y_true = [self.eval_scores] | ||
y_score = [self.eval_scores.index] | ||
# Note that ndcg_score calculates differently depending on whether ties | ||
# are involved, but this is not an issue for us because our setup has no | ||
# ties in y_score, so we can set ignore_ties=True. | ||
return cast(float, ndcg_score(y_true=y_true, y_score=y_score, k=k, ignore_ties=True)) | ||
|
||
def precision(self, k: Optional[int] = None) -> float: | ||
""" | ||
Precision at `k`, defined as the fraction of truthy scores among first | ||
`k` positions (1-based index). If `k` is None, then it's set to the | ||
length of the scores. If `k` < 1, return 0.0. | ||
""" | ||
if self.has_nan: | ||
return np.nan | ||
if k is None: | ||
k = self.length | ||
if k < 1: | ||
return 0.0 | ||
return self.eval_scores[:k].astype(bool).sum() / k | ||
|
||
def reciprocal_rank(self) -> float: | ||
""" | ||
Return `1/R` where `R` is the rank of the first hit, i.e. the 1-based | ||
index position of first truthy score, e.g. score=1. If a non-finite | ||
value (e.g. `NaN`) is encountered before the first (finite) truthy | ||
score, then return `NaN`, otherwise if no truthy score is found (or if | ||
the count of scores is zero), return 0.0. | ||
""" | ||
for i, score in enumerate(self.eval_scores): | ||
if not np.isfinite(score): | ||
return np.nan | ||
if score: | ||
return 1 / (i + 1) | ||
return 0.0 | ||
|
||
def hit(self) -> float: | ||
""" | ||
Return 1.0 if any score is truthy (i.e. is a hit), e.g. score=1. | ||
Otherwise, return `NaN` if any score is non-finite (e.g. `NaN`), or | ||
return 0.0 if all scores are falsy, e.g. all scores are 0. | ||
""" | ||
if self.eval_scores.any(): | ||
return 1.0 | ||
if self.has_nan: | ||
return np.nan | ||
return 0.0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
import math | ||
import re | ||
from typing import Optional | ||
|
||
import strawberry | ||
from strawberry import UNSET, Private | ||
|
||
from phoenix.metrics.retrieval_metrics import RetrievalMetrics | ||
|
||
|
||
def _clean_docstring(docstring: Optional[str]) -> Optional[str]: | ||
return re.sub(r"\s*\n+\s*", " ", docstring).strip() if docstring else None | ||
|
||
|
||
_ndcg_docstring = _clean_docstring(RetrievalMetrics.ndcg.__doc__) | ||
_precision_docstring = _clean_docstring(RetrievalMetrics.precision.__doc__) | ||
_reciprocal_rank_docstring = _clean_docstring(RetrievalMetrics.reciprocal_rank.__doc__) | ||
_hit_docstring = _clean_docstring(RetrievalMetrics.hit.__doc__) | ||
|
||
|
||
@strawberry.type( | ||
description="A collection of retrieval metrics computed on a list of document " | ||
"evaluation scores: NDCG@K, Precision@K, Reciprocal Rank, etc." | ||
) | ||
class DocumentRetrievalMetrics: | ||
evaluation_name: str | ||
metrics: Private[RetrievalMetrics] | ||
|
||
@strawberry.field(description=_ndcg_docstring) # type: ignore | ||
def ndcg(self, k: Optional[int] = UNSET) -> Optional[float]: | ||
value = self.metrics.ndcg(None if k is UNSET else k) | ||
return value if math.isfinite(value) else None | ||
|
||
@strawberry.field(description=_precision_docstring) # type: ignore | ||
def precision(self, k: Optional[int] = UNSET) -> Optional[float]: | ||
value = self.metrics.precision(None if k is UNSET else k) | ||
return value if math.isfinite(value) else None | ||
|
||
@strawberry.field(description=_reciprocal_rank_docstring) # type: ignore | ||
def reciprocal_rank(self) -> Optional[float]: | ||
value = self.metrics.reciprocal_rank() | ||
return value if math.isfinite(value) else None | ||
|
||
@strawberry.field(description=_hit_docstring) # type: ignore | ||
def hit(self) -> Optional[float]: | ||
value = self.metrics.hit() | ||
return value if math.isfinite(value) else None |
Oops, something went wrong.