From 38b3472bb2dd2f74b128d25bff26849b156bae42 Mon Sep 17 00:00:00 2001 From: Ashwin Mathur <97467100+awinml@users.noreply.github.com> Date: Mon, 11 Mar 2024 17:44:59 +0530 Subject: [PATCH] feat: Add `SentenceTransformersDiversityRanker` (#7095) * Add Diversity Ranker * Update tests * Add separate suffix, prefix params for query and documents; allow empty query * Update docstrings * Make changes based on review * Add additional tests * Add test for warm up * Update release notes --------- Co-authored-by: Sebastian Husch Lee --- haystack/components/rankers/__init__.py | 8 +- .../sentence_transformers_diversity.py | 246 ++++++++ ...add-diversity-ranker-6ecee21134eda673.yaml | 6 + .../test_sentence_transformers_diversity.py | 554 ++++++++++++++++++ 4 files changed, 813 insertions(+), 1 deletion(-) create mode 100644 haystack/components/rankers/sentence_transformers_diversity.py create mode 100644 releasenotes/notes/add-diversity-ranker-6ecee21134eda673.yaml create mode 100644 test/components/rankers/test_sentence_transformers_diversity.py diff --git a/haystack/components/rankers/__init__.py b/haystack/components/rankers/__init__.py index bb8c7dd999..282cf5cf2f 100644 --- a/haystack/components/rankers/__init__.py +++ b/haystack/components/rankers/__init__.py @@ -1,5 +1,11 @@ from haystack.components.rankers.lost_in_the_middle import LostInTheMiddleRanker from haystack.components.rankers.meta_field import MetaFieldRanker +from haystack.components.rankers.sentence_transformers_diversity import SentenceTransformersDiversityRanker from haystack.components.rankers.transformers_similarity import TransformersSimilarityRanker -__all__ = ["LostInTheMiddleRanker", "MetaFieldRanker", "TransformersSimilarityRanker"] +__all__ = [ + "LostInTheMiddleRanker", + "MetaFieldRanker", + "SentenceTransformersDiversityRanker", + "TransformersSimilarityRanker", +] diff --git a/haystack/components/rankers/sentence_transformers_diversity.py b/haystack/components/rankers/sentence_transformers_diversity.py new file mode 100644 index 0000000000..86915eb657 --- /dev/null +++ b/haystack/components/rankers/sentence_transformers_diversity.py @@ -0,0 +1,246 @@ +from typing import Any, Dict, List, Literal, Optional + +from haystack import ComponentError, Document, component, default_from_dict, default_to_dict, logging +from haystack.lazy_imports import LazyImport +from haystack.utils import ComponentDevice, Secret, deserialize_secrets_inplace + +logger = logging.getLogger(__name__) + + +with LazyImport(message="Run 'pip install \"sentence-transformers>=2.2.0\"'") as torch_and_sentence_transformers_import: + import torch + from sentence_transformers import SentenceTransformer + + +@component +class SentenceTransformersDiversityRanker: + """ + Implements a document ranking algorithm that orders documents in such a way as to maximize the overall diversity + of the documents. + + This component provides functionality to rank a list of documents based on their similarity with respect to the + query to maximize the overall diversity. It uses a pre-trained Sentence Transformers model to embed the query and + the Documents. + + Usage example: + ```python + from haystack import Document + from haystack.components.rankers import SentenceTransformersDiversityRanker + + ranker = SentenceTransformersDiversityRanker(model="sentence-transformers/all-MiniLM-L6-v2", similarity="cosine") + ranker.warm_up() + + docs = [Document(content="Paris"), Document(content="Berlin")] + query = "What is the capital of germany?" + output = ranker.run(query=query, documents=docs) + docs = output["documents"] + ``` + """ + + def __init__( + self, + model: str = "sentence-transformers/all-MiniLM-L6-v2", + top_k: int = 10, + device: Optional[ComponentDevice] = None, + token: Optional[Secret] = Secret.from_env_var("HF_API_TOKEN", strict=False), + similarity: Literal["dot_product", "cosine"] = "cosine", + query_prefix: str = "", + query_suffix: str = "", + document_prefix: str = "", + document_suffix: str = "", + meta_fields_to_embed: Optional[List[str]] = None, + embedding_separator: str = "\n", + ): + """ + Initialize a SentenceTransformersDiversityRanker. + + :param model: Local path or name of the model in Hugging Face's model hub, + such as `'sentence-transformers/all-MiniLM-L6-v2'`. + :param top_k: The maximum number of Documents to return per query. + :param device: The device on which the model is loaded. If `None`, the default device is automatically + selected. + :param token: The API token used to download private models from Hugging Face. + :param similarity: Similarity metric for comparing embeddings. Can be set to "dot_product" (default) or + "cosine". + :param query_prefix: A string to add to the beginning of the query text before ranking. + Can be used to prepend the text with an instruction, as required by some embedding models, + such as E5 and BGE. + :param query_suffix: A string to add to the end of the query text before ranking. + :param document_prefix: A string to add to the beginning of each Document text before ranking. + Can be used to prepend the text with an instruction, as required by some embedding models, + such as E5 and BGE. + :param document_suffix: A string to add to the end of each Document text before ranking. + :param meta_fields_to_embed: List of meta fields that should be embedded along with the Document content. + :param embedding_separator: Separator used to concatenate the meta fields to the Document content. + """ + torch_and_sentence_transformers_import.check() + + self.model_name_or_path = model + if top_k is None or top_k <= 0: + raise ValueError(f"top_k must be > 0, but got {top_k}") + self.top_k = top_k + self.device = ComponentDevice.resolve_device(device) + self.token = token + self.model = None + if similarity not in ["dot_product", "cosine"]: + raise ValueError(f"Similarity must be one of 'dot_product' or 'cosine', but got {similarity}.") + self.similarity = similarity + self.query_prefix = query_prefix + self.document_prefix = document_prefix + self.query_suffix = query_suffix + self.document_suffix = document_suffix + self.meta_fields_to_embed = meta_fields_to_embed or [] + self.embedding_separator = embedding_separator + + def warm_up(self): + """ + Initializes the component. + """ + if self.model is None: + self.model = SentenceTransformer( + model_name_or_path=self.model_name_or_path, + device=self.device.to_torch_str(), + use_auth_token=self.token.resolve_value() if self.token else None, + ) + + def to_dict(self) -> Dict[str, Any]: + """ + Serializes the component to a dictionary. + + :returns: + Dictionary with serialized data. + """ + return default_to_dict( + self, + model=self.model_name_or_path, + device=self.device.to_dict(), + token=self.token.to_dict() if self.token else None, + top_k=self.top_k, + similarity=self.similarity, + query_prefix=self.query_prefix, + document_prefix=self.document_prefix, + query_suffix=self.query_suffix, + document_suffix=self.document_suffix, + meta_fields_to_embed=self.meta_fields_to_embed, + embedding_separator=self.embedding_separator, + ) + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "SentenceTransformersDiversityRanker": + """ + Deserializes the component from a dictionary. + + :param data: + The dictionary to deserialize from. + :returns: + The deserialized component. + """ + serialized_device = data["init_parameters"]["device"] + data["init_parameters"]["device"] = ComponentDevice.from_dict(serialized_device) + + deserialize_secrets_inplace(data["init_parameters"], keys=["token"]) + return default_from_dict(cls, data) + + def _prepare_texts_to_embed(self, documents: List[Document]) -> List[str]: + """ + Prepare the texts to embed by concatenating the Document text with the metadata fields to embed. + """ + texts_to_embed = [] + for doc in documents: + meta_values_to_embed = [ + str(doc.meta[key]) for key in self.meta_fields_to_embed if key in doc.meta and doc.meta[key] + ] + text_to_embed = ( + self.document_prefix + + self.embedding_separator.join(meta_values_to_embed + [doc.content or ""]) + + self.document_suffix + ) + texts_to_embed.append(text_to_embed) + + return texts_to_embed + + def _greedy_diversity_order(self, query: str, documents: List[Document]) -> List[Document]: + """ + Orders the given list of documents to maximize diversity. + + The algorithm first calculates embeddings for each document and the query. It starts by selecting the document + that is semantically closest to the query. Then, for each remaining document, it selects the one that, on + average, is least similar to the already selected documents. This process continues until all documents are + selected, resulting in a list where each subsequent document contributes the most to the overall diversity of + the selected set. + + :param query: The search query. + :param documents: The list of Document objects to be ranked. + + :return: A list of documents ordered to maximize diversity. + """ + texts_to_embed = self._prepare_texts_to_embed(documents) + + # Calculate embeddings + doc_embeddings = self.model.encode(texts_to_embed, convert_to_tensor=True) # type: ignore[attr-defined] + query_embedding = self.model.encode([self.query_prefix + query + self.query_suffix], convert_to_tensor=True) # type: ignore[attr-defined] + + # Normalize embeddings to unit length for computing cosine similarity + if self.similarity == "cosine": + doc_embeddings /= torch.norm(doc_embeddings, p=2, dim=-1).unsqueeze(-1) + query_embedding /= torch.norm(query_embedding, p=2, dim=-1).unsqueeze(-1) + + n = len(documents) + selected: List[int] = [] + + # Compute the similarity vector between the query and documents + query_doc_sim = query_embedding @ doc_embeddings.T + + # Start with the document with the highest similarity to the query + selected.append(int(torch.argmax(query_doc_sim).item())) + + selected_sum = doc_embeddings[selected[0]] / n + + while len(selected) < n: + # Compute mean of dot products of all selected documents and all other documents + similarities = selected_sum @ doc_embeddings.T + # Mask documents that are already selected + similarities[selected] = torch.inf + # Select the document with the lowest total similarity score + index_unselected = int(torch.argmin(similarities).item()) + selected.append(index_unselected) + # It's enough just to add to the selected vectors because dot product is distributive + # It's divided by n for numerical stability + selected_sum += doc_embeddings[index_unselected] / n + + ranked_docs: List[Document] = [documents[i] for i in selected] + + return ranked_docs + + @component.output_types(documents=List[Document]) + def run(self, query: str, documents: List[Document], top_k: Optional[int] = None): + """ + Rank the documents based on their diversity. + + :param query: The search query. + :param documents: List of Document objects to be ranker. + :param top_k: Optional. An integer to override the top_k set during initialization. + + :returns: A dictionary with the following key: + - `documents`: List of Document objects that have been selected based on the diversity ranking. + + :raises ValueError: If the top_k value is less than or equal to 0. + """ + if not documents: + return {"documents": []} + + if top_k is None: + top_k = self.top_k + elif top_k <= 0: + raise ValueError(f"top_k must be > 0, but got {top_k}") + + if self.model is None: + error_msg = ( + "The component SentenceTransformersDiversityRanker wasn't warmed up. " + "Run 'warm_up()' before calling 'run()'." + ) + raise ComponentError(error_msg) + + diversity_sorted = self._greedy_diversity_order(query=query, documents=documents) + + return {"documents": diversity_sorted[:top_k]} diff --git a/releasenotes/notes/add-diversity-ranker-6ecee21134eda673.yaml b/releasenotes/notes/add-diversity-ranker-6ecee21134eda673.yaml new file mode 100644 index 0000000000..0a9dfa9058 --- /dev/null +++ b/releasenotes/notes/add-diversity-ranker-6ecee21134eda673.yaml @@ -0,0 +1,6 @@ +--- +features: + - | + Add `SentenceTransformersDiversityRanker`. + The Diversity Ranker orders documents in such a way as to maximize the overall diversity of the given documents. + The ranker leverages sentence-transformer models to calculate semantic embeddings for each document and the query. diff --git a/test/components/rankers/test_sentence_transformers_diversity.py b/test/components/rankers/test_sentence_transformers_diversity.py new file mode 100644 index 0000000000..0ed6a579bb --- /dev/null +++ b/test/components/rankers/test_sentence_transformers_diversity.py @@ -0,0 +1,554 @@ +from unittest.mock import MagicMock, call, patch + +import pytest +import torch + +from haystack import ComponentError, Document +from haystack.components.rankers import SentenceTransformersDiversityRanker +from haystack.utils import ComponentDevice +from haystack.utils.auth import Secret + + +def mock_encode_response(texts, **kwargs): + if texts == ["city"]: + return torch.tensor([[1.0, 1.0]]) + elif texts == ["Eiffel Tower", "Berlin", "Bananas"]: + return torch.tensor([[1.0, 0.0], [0.8, 0.8], [0.0, 1.0]]) + else: + return torch.tensor([[0.0, 1.0]] * len(texts)) + + +class TestSentenceTransformersDiversityRanker: + def test_init(self): + component = SentenceTransformersDiversityRanker() + assert component.model_name_or_path == "sentence-transformers/all-MiniLM-L6-v2" + assert component.top_k == 10 + assert component.device == ComponentDevice.resolve_device(None) + assert component.similarity == "cosine" + assert component.token == Secret.from_env_var("HF_API_TOKEN", strict=False) + assert component.query_prefix == "" + assert component.document_prefix == "" + assert component.query_suffix == "" + assert component.document_suffix == "" + assert component.meta_fields_to_embed == [] + assert component.embedding_separator == "\n" + + def test_init_with_custom_init_parameters(self): + component = SentenceTransformersDiversityRanker( + model="sentence-transformers/msmarco-distilbert-base-v4", + top_k=5, + device=ComponentDevice.from_str("cuda:0"), + token=Secret.from_token("fake-api-token"), + similarity="dot_product", + query_prefix="query:", + document_prefix="document:", + query_suffix="query suffix", + document_suffix="document suffix", + meta_fields_to_embed=["meta_field"], + embedding_separator="--", + ) + assert component.model_name_or_path == "sentence-transformers/msmarco-distilbert-base-v4" + assert component.top_k == 5 + assert component.device == ComponentDevice.from_str("cuda:0") + assert component.similarity == "dot_product" + assert component.token == Secret.from_token("fake-api-token") + assert component.query_prefix == "query:" + assert component.document_prefix == "document:" + assert component.query_suffix == "query suffix" + assert component.document_suffix == "document suffix" + assert component.meta_fields_to_embed == ["meta_field"] + assert component.embedding_separator == "--" + + def test_to_dict(self): + component = SentenceTransformersDiversityRanker() + data = component.to_dict() + assert data == { + "type": "haystack.components.rankers.sentence_transformers_diversity.SentenceTransformersDiversityRanker", + "init_parameters": { + "model": "sentence-transformers/all-MiniLM-L6-v2", + "top_k": 10, + "device": ComponentDevice.resolve_device(None).to_dict(), + "similarity": "cosine", + "token": {"env_vars": ["HF_API_TOKEN"], "strict": False, "type": "env_var"}, + "query_prefix": "", + "document_prefix": "", + "query_suffix": "", + "document_suffix": "", + "meta_fields_to_embed": [], + "embedding_separator": "\n", + }, + } + + def test_from_dict(self): + data = { + "type": "haystack.components.rankers.sentence_transformers_diversity.SentenceTransformersDiversityRanker", + "init_parameters": { + "model": "sentence-transformers/all-MiniLM-L6-v2", + "top_k": 10, + "device": ComponentDevice.resolve_device(None).to_dict(), + "similarity": "cosine", + "token": {"env_vars": ["HF_API_TOKEN"], "strict": False, "type": "env_var"}, + "query_prefix": "", + "document_prefix": "", + "query_suffix": "", + "document_suffix": "", + "meta_fields_to_embed": [], + "embedding_separator": "\n", + }, + } + ranker = SentenceTransformersDiversityRanker.from_dict(data) + + assert ranker.model_name_or_path == "sentence-transformers/all-MiniLM-L6-v2" + assert ranker.top_k == 10 + assert ranker.device == ComponentDevice.resolve_device(None) + assert ranker.similarity == "cosine" + assert ranker.token == Secret.from_env_var("HF_API_TOKEN", strict=False) + assert ranker.query_prefix == "" + assert ranker.document_prefix == "" + assert ranker.query_suffix == "" + assert ranker.document_suffix == "" + assert ranker.meta_fields_to_embed == [] + assert ranker.embedding_separator == "\n" + + def test_to_dict_with_custom_init_parameters(self): + component = SentenceTransformersDiversityRanker( + model="sentence-transformers/msmarco-distilbert-base-v4", + top_k=5, + device=ComponentDevice.from_str("cuda:0"), + token=Secret.from_env_var("ENV_VAR", strict=False), + similarity="dot_product", + query_prefix="query:", + document_prefix="document:", + query_suffix="query suffix", + document_suffix="document suffix", + meta_fields_to_embed=["meta_field"], + embedding_separator="--", + ) + data = component.to_dict() + assert data == { + "type": "haystack.components.rankers.sentence_transformers_diversity.SentenceTransformersDiversityRanker", + "init_parameters": { + "model": "sentence-transformers/msmarco-distilbert-base-v4", + "top_k": 5, + "device": ComponentDevice.from_str("cuda:0").to_dict(), + "token": {"env_vars": ["ENV_VAR"], "strict": False, "type": "env_var"}, + "similarity": "dot_product", + "query_prefix": "query:", + "document_prefix": "document:", + "query_suffix": "query suffix", + "document_suffix": "document suffix", + "meta_fields_to_embed": ["meta_field"], + "embedding_separator": "--", + }, + } + + def test_from_dict_with_custom_init_parameters(self): + data = { + "type": "haystack.components.rankers.sentence_transformers_diversity.SentenceTransformersDiversityRanker", + "init_parameters": { + "model": "sentence-transformers/msmarco-distilbert-base-v4", + "top_k": 5, + "device": ComponentDevice.from_str("cuda:0").to_dict(), + "token": {"env_vars": ["ENV_VAR"], "strict": False, "type": "env_var"}, + "similarity": "dot_product", + "query_prefix": "query:", + "document_prefix": "document:", + "query_suffix": "query suffix", + "document_suffix": "document suffix", + "meta_fields_to_embed": ["meta_field"], + "embedding_separator": "--", + }, + } + ranker = SentenceTransformersDiversityRanker.from_dict(data) + + assert ranker.model_name_or_path == "sentence-transformers/msmarco-distilbert-base-v4" + assert ranker.top_k == 5 + assert ranker.device == ComponentDevice.from_str("cuda:0") + assert ranker.similarity == "dot_product" + assert ranker.token == Secret.from_env_var("ENV_VAR", strict=False) + assert ranker.query_prefix == "query:" + assert ranker.document_prefix == "document:" + assert ranker.query_suffix == "query suffix" + assert ranker.document_suffix == "document suffix" + assert ranker.meta_fields_to_embed == ["meta_field"] + assert ranker.embedding_separator == "--" + + def test_run_incorrect_similarity(self): + """ + Tests that run method raises ValueError if similarity is incorrect + """ + similarity = "incorrect" + with pytest.raises( + ValueError, match=f"Similarity must be one of 'dot_product' or 'cosine', but got {similarity}." + ): + SentenceTransformersDiversityRanker(model="sentence-transformers/all-MiniLM-L6-v2", similarity=similarity) + + @pytest.mark.parametrize("similarity", ["dot_product", "cosine"]) + def test_run_without_warm_up(self, similarity): + """ + Tests that run method raises ComponentError if model is not warmed up + """ + ranker = SentenceTransformersDiversityRanker( + model="sentence-transformers/all-MiniLM-L6-v2", top_k=1, similarity=similarity + ) + documents = [Document(content="doc1"), Document(content="doc2")] + + error_msg = "The component SentenceTransformersDiversityRanker wasn't warmed up." + with pytest.raises(ComponentError, match=error_msg): + ranker.run(query="test query", documents=documents) + + @pytest.mark.parametrize("similarity", ["dot_product", "cosine"]) + def test_warm_up(self, similarity): + """ + Test that ranker loads the SentenceTransformer model correctly during warm up. + """ + mock_model_class = MagicMock() + mock_model_instance = MagicMock() + mock_model_class.return_value = mock_model_instance + + with patch( + "haystack.components.rankers.sentence_transformers_diversity.SentenceTransformer", new=mock_model_class + ): + ranker = SentenceTransformersDiversityRanker(model="mock_model_name", similarity=similarity) + + assert ranker.model is None + + ranker.warm_up() + + mock_model_class.assert_called_once_with( + model_name_or_path="mock_model_name", + device=ComponentDevice.resolve_device(None).to_torch_str(), + use_auth_token=None, + ) + assert ranker.model == mock_model_instance + + @pytest.mark.parametrize("similarity", ["dot_product", "cosine"]) + def test_run_empty_query(self, similarity): + """ + Test that ranker can be run with an empty query. + """ + ranker = SentenceTransformersDiversityRanker( + model="sentence-transformers/all-MiniLM-L6-v2", top_k=3, similarity=similarity + ) + ranker.model = MagicMock() + ranker.model.encode = MagicMock(side_effect=mock_encode_response) + documents = [Document(content="doc1"), Document(content="doc2")] + + result = ranker.run(query="", documents=documents) + ranked_docs = result["documents"] + + assert isinstance(ranked_docs, list) + assert len(ranked_docs) == 2 + assert all(isinstance(doc, Document) for doc in ranked_docs) + + @pytest.mark.parametrize("similarity", ["dot_product", "cosine"]) + def test_run_top_k(self, similarity): + """ + Test that run method returns the correct number of documents for different top_k values passed at + initialization and runtime. + """ + ranker = SentenceTransformersDiversityRanker( + model="sentence-transformers/all-MiniLM-L6-v2", similarity=similarity, top_k=3 + ) + ranker.model = MagicMock() + ranker.model.encode = MagicMock(side_effect=mock_encode_response) + query = "test query" + documents = [ + Document(content="doc1"), + Document(content="doc2"), + Document(content="doc3"), + Document(content="doc4"), + ] + + result = ranker.run(query=query, documents=documents) + ranked_docs = result["documents"] + + assert isinstance(ranked_docs, list) + assert len(ranked_docs) == 3 + assert all(isinstance(doc, Document) for doc in ranked_docs) + + # Passing a different top_k at runtime + result = ranker.run(query=query, documents=documents, top_k=2) + ranked_docs = result["documents"] + + assert isinstance(ranked_docs, list) + assert len(ranked_docs) == 2 + assert all(isinstance(doc, Document) for doc in ranked_docs) + + @pytest.mark.parametrize("similarity", ["dot_product", "cosine"]) + def test_run_negative_top_k_at_init(self, similarity): + """ + Tests that run method raises an error for negative top-k set at init. + """ + with pytest.raises(ValueError, match="top_k must be > 0, but got"): + SentenceTransformersDiversityRanker( + model="sentence-transformers/all-MiniLM-L6-v2", similarity=similarity, top_k=-5 + ) + + @pytest.mark.parametrize("similarity", ["dot_product", "cosine"]) + def test_run_top_k_is_none_at_init(self, similarity): + """ + Tests that run method raises an error for top-k set to None at init. + """ + with pytest.raises(ValueError, match="top_k must be > 0, but got"): + SentenceTransformersDiversityRanker( + model="sentence-transformers/all-MiniLM-L6-v2", similarity=similarity, top_k=None + ) + + @pytest.mark.parametrize("similarity", ["dot_product", "cosine"]) + def test_run_negative_top_k(self, similarity): + """ + Tests that run method raises an error for negative top-k set at runtime. + """ + ranker = SentenceTransformersDiversityRanker( + model="sentence-transformers/all-MiniLM-L6-v2", similarity=similarity, top_k=10 + ) + ranker.model = MagicMock() + query = "test" + documents = [Document(content="doc1"), Document(content="doc2"), Document(content="doc3")] + + with pytest.raises(ValueError, match="top_k must be > 0, but got"): + ranker.run(query=query, documents=documents, top_k=-5) + + @pytest.mark.parametrize("similarity", ["dot_product", "cosine"]) + def test_run_top_k_is_none(self, similarity): + """ + Tests that run method returns the correct order of documents for top-k set to None. + """ + # Setting top_k to None is ignored during runtime, it should use top_k set at init. + ranker = SentenceTransformersDiversityRanker( + model="sentence-transformers/all-MiniLM-L6-v2", similarity=similarity, top_k=2 + ) + ranker.model = MagicMock() + ranker.model.encode = MagicMock(side_effect=mock_encode_response) + query = "test" + documents = [Document(content="doc1"), Document(content="doc2"), Document(content="doc3")] + result = ranker.run(query=query, documents=documents, top_k=None) + + assert len(result["documents"]) == 2 + + @pytest.mark.parametrize("similarity", ["dot_product", "cosine"]) + def test_run_no_documents_provided(self, similarity): + """ + Test that run method returns an empty list if no documents are supplied. + """ + ranker = SentenceTransformersDiversityRanker( + model="sentence-transformers/all-MiniLM-L6-v2", similarity=similarity + ) + ranker.model = MagicMock() + query = "test query" + documents = [] + results = ranker.run(query=query, documents=documents) + + assert len(results["documents"]) == 0 + + @pytest.mark.parametrize("similarity", ["dot_product", "cosine"]) + def test_run_with_less_documents_than_top_k(self, similarity): + """ + Tests that run method returns the correct number of documents for top_k values greater than number of documents. + """ + ranker = SentenceTransformersDiversityRanker( + model="sentence-transformers/all-MiniLM-L6-v2", similarity=similarity, top_k=5 + ) + ranker.model = MagicMock() + ranker.model.encode = MagicMock(side_effect=mock_encode_response) + query = "test" + documents = [Document(content="doc1"), Document(content="doc2"), Document(content="doc3")] + result = ranker.run(query=query, documents=documents) + + assert len(result["documents"]) == 3 + + @pytest.mark.parametrize("similarity", ["dot_product", "cosine"]) + def test_run_single_document_corner_case(self, similarity): + """ + Tests that run method returns the correct number of documents for a single document + """ + ranker = SentenceTransformersDiversityRanker( + model="sentence-transformers/all-MiniLM-L6-v2", similarity=similarity + ) + ranker.model = MagicMock() + ranker.model.encode = MagicMock(side_effect=mock_encode_response) + query = "test" + documents = [Document(content="doc1")] + result = ranker.run(query=query, documents=documents) + + assert len(result["documents"]) == 1 + + @pytest.mark.parametrize("similarity", ["dot_product", "cosine"]) + def test_prepare_texts_to_embed(self, similarity): + """ + Test creation of texts to embed from documents with meta fields, document prefix and suffix. + """ + ranker = SentenceTransformersDiversityRanker( + model="sentence-transformers/all-MiniLM-L6-v2", + similarity=similarity, + document_prefix="test doc: ", + document_suffix=" end doc.", + meta_fields_to_embed=["meta_field"], + embedding_separator="\n", + ) + documents = [Document(content=f"document number {i}", meta={"meta_field": f"meta_value {i}"}) for i in range(5)] + texts = ranker._prepare_texts_to_embed(documents=documents) + + assert texts == [ + "test doc: meta_value 0\ndocument number 0 end doc.", + "test doc: meta_value 1\ndocument number 1 end doc.", + "test doc: meta_value 2\ndocument number 2 end doc.", + "test doc: meta_value 3\ndocument number 3 end doc.", + "test doc: meta_value 4\ndocument number 4 end doc.", + ] + + @pytest.mark.parametrize("similarity", ["dot_product", "cosine"]) + def test_encode_text(self, similarity): + """ + Test addition of suffix and prefix to the query and documents when creating embeddings. + """ + ranker = SentenceTransformersDiversityRanker( + model="sentence-transformers/all-MiniLM-L6-v2", + similarity=similarity, + query_prefix="test query: ", + query_suffix=" end query.", + document_prefix="test doc: ", + document_suffix=" end doc.", + meta_fields_to_embed=["meta_field"], + embedding_separator="\n", + ) + query = "query" + documents = [Document(content=f"document number {i}", meta={"meta_field": f"meta_value {i}"}) for i in range(5)] + ranker.model = MagicMock() + ranker.model.encode = MagicMock(side_effect=mock_encode_response) + ranker.run(query=query, documents=documents) + + assert ranker.model.encode.call_count == 2 + ranker.model.assert_has_calls( + [ + call.encode( + [ + "test doc: meta_value 0\ndocument number 0 end doc.", + "test doc: meta_value 1\ndocument number 1 end doc.", + "test doc: meta_value 2\ndocument number 2 end doc.", + "test doc: meta_value 3\ndocument number 3 end doc.", + "test doc: meta_value 4\ndocument number 4 end doc.", + ], + convert_to_tensor=True, + ), + call.encode(["test query: query end query."], convert_to_tensor=True), + ] + ) + + @pytest.mark.parametrize("similarity", ["dot_product", "cosine"]) + def test_run_greedy_diversity_order(self, similarity): + """ + Tests that the given list of documents is ordered to maximize diversity. + """ + ranker = SentenceTransformersDiversityRanker( + model="sentence-transformers/all-MiniLM-L6-v2", similarity=similarity + ) + query = "city" + documents = [Document(content="Eiffel Tower"), Document(content="Berlin"), Document(content="Bananas")] + ranker.model = MagicMock() + ranker.model.encode = MagicMock(side_effect=mock_encode_response) + + ranked_docs = ranker._greedy_diversity_order(query=query, documents=documents) + ranked_text = " ".join([doc.content for doc in ranked_docs]) + + assert ranked_text == "Berlin Eiffel Tower Bananas" + + @pytest.mark.integration + @pytest.mark.parametrize("similarity", ["dot_product", "cosine"]) + def test_run(self, similarity): + """ + Tests that run method returns documents in the correct order + """ + ranker = SentenceTransformersDiversityRanker( + model="sentence-transformers/all-MiniLM-L6-v2", similarity=similarity + ) + ranker.warm_up() + query = "city" + documents = [ + Document(content="France"), + Document(content="Germany"), + Document(content="Eiffel Tower"), + Document(content="Berlin"), + Document(content="Bananas"), + Document(content="Silicon Valley"), + Document(content="Brandenburg Gate"), + ] + result = ranker.run(query=query, documents=documents) + ranked_docs = result["documents"] + ranked_order = ", ".join([doc.content for doc in ranked_docs]) + expected_order = "Berlin, Bananas, Eiffel Tower, Silicon Valley, France, Brandenburg Gate, Germany" + + assert ranked_order == expected_order + + @pytest.mark.integration + @pytest.mark.parametrize("similarity", ["dot_product", "cosine"]) + def test_run_real_world_use_case(self, similarity): + ranker = SentenceTransformersDiversityRanker( + model="sentence-transformers/all-MiniLM-L6-v2", similarity=similarity + ) + ranker.warm_up() + query = "What are the reasons for long-standing animosities between Russia and Poland?" + + doc1 = Document( + "One of the earliest known events in Russian-Polish history dates back to 981, when the Grand Prince of Kiev , " + "Vladimir Svyatoslavich , seized the Cherven Cities from the Duchy of Poland . The relationship between two by " + "that time was mostly close and cordial, as there had been no serious wars between both. In 966, Poland " + "accepted Christianity from Rome while Kievan Rus' —the ancestor of Russia, Ukraine and Belarus—was " + "Christianized by Constantinople. In 1054, the internal Christian divide formally split the Church into " + "the Catholic and Orthodox branches separating the Poles from the Eastern Slavs." + ) + doc2 = Document( + "Since the fall of the Soviet Union , with Lithuania , Ukraine and Belarus regaining independence, the " + "Polish Russian border has mostly been replaced by borders with the respective countries, but there still " + "is a 210 km long border between Poland and the Kaliningrad Oblast" + ) + doc3 = Document( + "As part of Poland's plans to become fully energy independent from Russia within the next years, Piotr " + "Wozniak, president of state-controlled oil and gas company PGNiG , stated in February 2019: 'The strategy of " + "the company is just to forget about Eastern suppliers and especially about Gazprom .'[53] In 2020, the " + "Stockholm Arbitrary Tribunal ruled that PGNiG's long-term contract gas price with Gazprom linked to oil prices " + "should be changed to approximate the Western European gas market price, backdated to 1 November 2014 when " + "PGNiG requested a price review under the contract. Gazprom had to refund about $1.5 billion to PGNiG." + ) + doc4 = Document( + "Both Poland and Russia had accused each other for their historical revisionism . Russia has repeatedly " + "accused Poland for not honoring Soviet Red Army soldiers fallen in World War II for Poland, notably in " + "2017, in which Poland was thought on 'attempting to impose its own version of history' after Moscow was " + "not allowed to join an international effort to renovate a World War II museum at Sobibór , site of a " + "notorious Sobibor extermination camp." + ) + doc5 = Document( + "President of Russia Vladimir Putin and Prime Minister of Poland Leszek Miller in 2002 Modern Polish Russian " + "relations begin with the fall of communism in1989 in Poland ( Solidarity and the Polish Round Table " + "Agreement ) and 1991 in Russia ( dissolution of the Soviet Union ). With a new democratic government after " + "the 1989 elections , Poland regained full sovereignty, [2] and what was the Soviet Union, became 15 newly " + "independent states , including the Russian Federation . Relations between modern Poland and Russia suffer " + "from constant ups and downs." + ) + doc6 = Document( + "Soviet influence in Poland finally ended with the Round Table Agreement of 1989 guaranteeing free elections " + "in Poland, the Revolutions of 1989 against Soviet-sponsored Communist governments in the Eastern Block , and " + "finally the formal dissolution of the Warsaw Pact." + ) + doc7 = Document( + "Dmitry Medvedev and then Polish Prime Minister Donald Tusk , 6 December 2010 BBC News reported that one of " + "the main effects of the 2010 Polish Air Force Tu-154 crash would be the impact it has on Russian-Polish " + "relations. [38] It was thought if the inquiry into the crash were not transparent, it would increase " + "suspicions toward Russia in Poland." + ) + doc8 = Document( + "Soviet control over the Polish People's Republic lessened after Stalin's death and Gomulka's Thaw , and " + "ceased completely after the fall of the communist government in Poland in late 1989, although the " + "Soviet-Russian Northern Group of Forces did not leave Polish soil until 1993. The continuing Soviet military " + "presence allowed the Soviet Union to heavily influence Polish politics." + ) + + documents = [doc1, doc2, doc3, doc4, doc5, doc6, doc7, doc8] + result = ranker.run(query=query, documents=documents) + expected_order = [doc5, doc7, doc3, doc1, doc4, doc2, doc6, doc8] + expected_content = " ".join([doc.content or "" for doc in expected_order]) + result_content = " ".join([doc.content or "" for doc in result["documents"]]) + + # Check the order of ranked documents by comparing the content of the ranked documents + assert result_content == expected_content