diff --git a/chromadb/api/types.py b/chromadb/api/types.py index 7781c422572..b8f520e4e15 100644 --- a/chromadb/api/types.py +++ b/chromadb/api/types.py @@ -169,6 +169,30 @@ class QueryResult(TypedDict): distances: Optional[List[List[float]]] +Rank = Union[int, float] + + +class RankerScore(TypedDict): + ranker_id: str + rank: Rank + + +class RankerQueryResult(QueryResult): + ranks: Optional[List[List[RankerScore]]] + + +Rankable = Union[str, int, QueryResult] +R = TypeVar("R", bound=Rankable, contravariant=True) + + +class RankingFunction(Protocol[R]): + def get_id(self) -> str: + ... + + def __call__(self, results: R) -> RankerQueryResult: + ... + + class IndexMetadata(TypedDict): dimensionality: int # The current number of elements in the index (total = additions - deletes) diff --git a/chromadb/utils/ranking_functions.py b/chromadb/utils/ranking_functions.py new file mode 100644 index 00000000000..8cac417e2cc --- /dev/null +++ b/chromadb/utils/ranking_functions.py @@ -0,0 +1,11 @@ +from chromadb.api.types import RankingFunction + + +class BM25ServerSideRankingFunction(RankingFunction): + def __init__(self, k1=1.2, b=0.75): + self.k1 = k1 + self.b = b + + def rank(self, query, documents): + # ... (implementation of the BM25 ranking function) + return ranked_documents