Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add SentenceTransformersDiversityRanker #7095

Merged
merged 11 commits into from
Mar 11, 2024
3 changes: 2 additions & 1 deletion haystack/components/rankers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from haystack.components.rankers.diversity import DiversityRanker
from haystack.components.rankers.lost_in_the_middle import LostInTheMiddleRanker
from haystack.components.rankers.meta_field import MetaFieldRanker
from haystack.components.rankers.transformers_similarity import TransformersSimilarityRanker

__all__ = ["LostInTheMiddleRanker", "MetaFieldRanker", "TransformersSimilarityRanker"]
__all__ = ["DiversityRanker", "LostInTheMiddleRanker", "MetaFieldRanker", "TransformersSimilarityRanker"]
215 changes: 215 additions & 0 deletions haystack/components/rankers/diversity.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,215 @@
import logging
awinml marked this conversation as resolved.
Show resolved Hide resolved
from typing import Any, Dict, List, Literal, Optional

from haystack import ComponentError, Document, component, default_from_dict, default_to_dict
from haystack.lazy_imports import LazyImport
from haystack.utils import ComponentDevice, Secret, deserialize_secrets_inplace

logger = logging.getLogger(__name__)


with LazyImport(message="Run 'pip install \"sentence-transformers>=2.2.0\"'") as torch_and_transformers_import:
awinml marked this conversation as resolved.
Show resolved Hide resolved
import torch
from sentence_transformers import SentenceTransformer


@component
class DiversityRanker:
awinml marked this conversation as resolved.
Show resolved Hide resolved
"""
Implements a document ranking algorithm that orders documents in such a way as to maximize the overall diversity
of the documents.
It uses a pre-trained Sentence Transformers model to embed the query and the Documents.

Usage example:
```python
from haystack import Document
from haystack.components.rankers import DiversityRanker

ranker = DiversityRanker(model="sentence-transformers/all-MiniLM-L6-v2", similarity="cosine")
awinml marked this conversation as resolved.
Show resolved Hide resolved
docs = [Document(content="Paris"), Document(content="Berlin")]
query = "What is the capital of germany?"
output = ranker.run(query=query, documents=docs)
docs = output["documents"]
assert len(docs) == 2
assert docs[0].content == "Paris"
```
"""

def __init__(
self,
model: str = "sentence-transformers/all-MiniLM-L6-v2",
top_k: int = 10,
device: Optional[ComponentDevice] = None,
token: Optional[Secret] = Secret.from_env_var("HF_API_TOKEN", strict=False),
similarity: Literal["dot_product", "cosine"] = "dot_product",
prefix: str = "",
suffix: str = "",
awinml marked this conversation as resolved.
Show resolved Hide resolved
meta_fields_to_embed: Optional[List[str]] = None,
embedding_separator: str = "\n",
):
"""
Initialize a DiversityRanker.

:param model: Local path or name of the model in Hugging Face's model hub,
such as `'sentence-transformers/all-MiniLM-L6-v2'`.
:param top_k: The maximum number of Documents to return per query.
:param device: The device on which the model is loaded. If `None`, the default device is automatically
selected.
:param token: The API token used to download private models from Hugging Face.
:param similarity: Similarity metric for comparing embeddings. Can be set to "dot_product" (default) or
"cosine".
:param prefix: A string to add to the beginning of each Document text before embedding.
Can be used to prepend the text with an instruction, as required by some embedding models,
such as E5 and bge.
:param suffix: A string to add to the end of each Document text before embedding.
:param meta_fields_to_embed: List of meta fields that should be embedded along with the Document content.
:param embedding_separator: Separator used to concatenate the meta fields to the Document content.
"""
torch_and_transformers_import.check()

self.model_name_or_path = model
if top_k is None or top_k <= 0:
raise ValueError(f"top_k must be > 0, but got {top_k}")
self.top_k = top_k
self.device = ComponentDevice.resolve_device(device)
self.token = token
self.model = None
if similarity not in ["dot_product", "cosine"]:
raise ValueError(f"Similarity must be one of 'dot_product' or 'cosine', but got {similarity}.")
self.similarity = similarity
self.prefix = prefix
self.suffix = suffix
self.meta_fields_to_embed = meta_fields_to_embed or []
self.embedding_separator = embedding_separator

def warm_up(self):
"""
Warm up the model used for scoring the Documents.
"""
if self.model is None:
self.model = SentenceTransformer(
model_name_or_path=self.model_name_or_path,
device=self.device.to_torch_str(),
use_auth_token=self.token.resolve_value() if self.token else None,
)

def to_dict(self) -> Dict[str, Any]:
"""
Serialize this component to a dictionary.
"""
return default_to_dict(
self,
model=self.model_name_or_path,
device=self.device.to_dict(),
token=self.token.to_dict() if self.token else None,
top_k=self.top_k,
similarity=self.similarity,
prefix=self.prefix,
suffix=self.suffix,
meta_fields_to_embed=self.meta_fields_to_embed,
embedding_separator=self.embedding_separator,
)

@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "DiversityRanker":
"""
Deserialize this component from a dictionary.
"""
serialized_device = data["init_parameters"]["device"]
data["init_parameters"]["device"] = ComponentDevice.from_dict(serialized_device)

deserialize_secrets_inplace(data["init_parameters"], keys=["token"])
return default_from_dict(cls, data)

def _greedy_diversity_order(self, query: str, documents: List[Document]) -> List[Document]:
"""
Orders the given list of documents to maximize diversity.

The algorithm first calculates embeddings for each document and the query. It starts by selecting the document
that is semantically closest to the query. Then, for each remaining document, it selects the one that, on
average, is least similar to the already selected documents. This process continues until all documents are
selected, resulting in a list where each subsequent document contributes the most to the overall diversity of
the selected set.

:param query: The search query.
:param documents: The list of Document objects to be ranked.

:return: A list of documents ordered to maximize diversity.
"""

texts_to_embed = []
for doc in documents:
meta_values_to_embed = [
str(doc.meta[key]) for key in self.meta_fields_to_embed if key in doc.meta and doc.meta[key]
]
text_to_embed = (
self.prefix + self.embedding_separator.join(meta_values_to_embed + [doc.content or ""]) + self.suffix
)
texts_to_embed.append(text_to_embed)

# Calculate embeddings
doc_embeddings = self.model.encode(texts_to_embed, convert_to_tensor=True) # type: ignore[attr-defined]
query_embedding = self.model.encode([query], convert_to_tensor=True) # type: ignore[attr-defined]

# Normalize embeddings to unit length for computing cosine similarity
if self.similarity == "cosine":
doc_embeddings /= torch.norm(doc_embeddings, p=2, dim=-1).unsqueeze(-1)
query_embedding /= torch.norm(query_embedding, p=2, dim=-1).unsqueeze(-1)

n = len(documents)
selected: List[int] = []

# Compute the similarity vector between the query and documents
query_doc_sim = query_embedding @ doc_embeddings.T

# Start with the document with the highest similarity to the query
selected.append(int(torch.argmax(query_doc_sim).item()))

selected_sum = doc_embeddings[selected[0]] / n

while len(selected) < n:
# Compute mean of dot products of all selected documents and all other documents
similarities = selected_sum @ doc_embeddings.T
# Mask documents that are already selected
similarities[selected] = torch.inf
# Select the document with the lowest total similarity score
index_unselected = int(torch.argmin(similarities).item())
selected.append(index_unselected)
# It's enough just to add to the selected vectors because dot product is distributive
# It's divided by n for numerical stability
selected_sum += doc_embeddings[index_unselected] / n

ranked_docs: List[Document] = [documents[i] for i in selected]

return ranked_docs

@component.output_types(documents=List[Document])
def run(self, query: str, documents: List[Document], top_k: Optional[int] = None):
"""
Rank the documents based on their diversity and return the top_k documents.

:param query: The query.
:param documents: A list of Document objects that should be ranked.
:param top_k: The maximum number of documents to return.

:return: A list of top_k documents ranked based on diversity.
"""
if query is None or len(query) == 0:
raise ValueError("Query is empty")
awinml marked this conversation as resolved.
Show resolved Hide resolved

if not documents:
return {"documents": []}

if top_k is None:
top_k = self.top_k
elif top_k <= 0:
raise ValueError(f"top_k must be > 0, but got {top_k}")

if self.model is None:
raise ComponentError(
f"The component {self.__class__.__name__} wasn't warmed up. Run 'warm_up()' before calling 'run()'."
)

diversity_sorted = self._greedy_diversity_order(query=query, documents=documents)

return {"documents": diversity_sorted[:top_k]}
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
---
features:
- |
Add `DiversityRanker`. Diversity Ranker orders documents in such a way as to maximize the overall diversity of the given documents. The ranker leverages sentence-transformer models to calculate semantic embeddings for each document and the query.
Loading