Skip to content

Commit

Permalink
feat: Add SimilarityRanker to Haystack 2.0 (#5923)
Browse files Browse the repository at this point in the history
* Initial SimilarityRanker
  • Loading branch information
vblagoje authored Oct 6, 2023
1 parent ccc9f01 commit 1cdff64
Show file tree
Hide file tree
Showing 4 changed files with 189 additions and 0 deletions.
3 changes: 3 additions & 0 deletions haystack/preview/components/rankers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from haystack.preview.components.rankers.similarity import SimilarityRanker

__all__ = ["SimilarityRanker"]
108 changes: 108 additions & 0 deletions haystack/preview/components/rankers/similarity.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
import logging
from pathlib import Path
from typing import List, Union, Dict, Any

from haystack.preview import ComponentError, Document, component, default_from_dict, default_to_dict
from haystack.preview.lazy_imports import LazyImport

logger = logging.getLogger(__name__)


with LazyImport(message="Run 'pip install transformers[torch,sentencepiece]==4.32.1'") as torch_and_transformers_import:
import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer


@component
class SimilarityRanker:
"""
Ranks documents based on query similarity.
Usage example:
```
from haystack.preview import Document
from haystack.preview.components.rankers import SimilarityRanker
sampler = SimilarityRanker()
docs = [Document(text="Paris"), Document(text="Berlin")]
query = "City in Germany"
output = sampler.run(query=query, documents=docs)
docs = output["documents"]
assert len(docs) == 2
assert docs[0].text == "Berlin"
```
"""

def __init__(
self, model_name_or_path: Union[str, Path] = "cross-encoder/ms-marco-MiniLM-L-6-v2", device: str = "cpu"
):
"""
Creates an instance of SimilarityRanker.
:param model_name_or_path: Path to a pre-trained sentence-transformers model.
:param device: torch device (for example, cuda:0, cpu, mps) to limit model inference to a specific device.
"""
torch_and_transformers_import.check()

self.model_name_or_path = model_name_or_path
self.device = device
self.model = None
self.tokenizer = None

def warm_up(self):
"""
Warm up the model and tokenizer used in scoring the documents.
"""
if self.model_name_or_path and not self.model:
self.model = AutoModelForSequenceClassification.from_pretrained(self.model_name_or_path)
self.model = self.model.to(self.device)
self.model.eval()
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name_or_path)

def to_dict(self) -> Dict[str, Any]:
"""
Serialize this component to a dictionary.
"""
return default_to_dict(self, device=self.device, model_name_or_path=self.model_name_or_path)

@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "SimilarityRanker":
"""
Deserialize this component from a dictionary.
"""
return default_from_dict(cls, data)

@component.output_types(documents=List[Document])
def run(self, query: str, documents: List[Document]):
"""
Returns a list of documents ranked by their similarity to the given query
:param query: Query string.
:param documents: List of Documents.
:return: List of Documents sorted by (desc.) similarity with the query.
"""
if not documents:
return {"documents": []}

# If a model path is provided but the model isn't loaded
if self.model_name_or_path and not self.model:
raise ComponentError(
f"The component {self.__class__.__name__} not warmed up. Run 'warm_up()' before calling 'run()'."
)

query_doc_pairs = [[query, doc.text] for doc in documents]
features = self.tokenizer(
query_doc_pairs, padding=True, truncation=True, return_tensors="pt"
).to( # type: ignore
self.device
)
with torch.inference_mode():
similarity_scores = self.model(**features).logits.squeeze() # type: ignore

_, sorted_indices = torch.sort(similarity_scores, descending=True)
ranked_docs = []
for sorted_index_tensor in sorted_indices:
i = sorted_index_tensor.item()
documents[i].score = similarity_scores[i].item()
ranked_docs.append(documents[i])
return {"documents": ranked_docs}
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
---
preview:
- |
Adds SimilarityRanker, a component that ranks a list of Documents based on their similarity to the query.
74 changes: 74 additions & 0 deletions test/preview/components/rankers/test_similarity.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import pytest

from haystack.preview import Document, ComponentError
from haystack.preview.components.rankers.similarity import SimilarityRanker


class TestSimilarityRanker:
@pytest.mark.unit
def test_to_dict(self):
component = SimilarityRanker(model_name_or_path="cross-encoder/ms-marco-MiniLM-L-6-v2")
data = component.to_dict()
assert data == {
"type": "SimilarityRanker",
"init_parameters": {"device": "cpu", "model_name_or_path": "cross-encoder/ms-marco-MiniLM-L-6-v2"},
}

@pytest.mark.unit
def test_to_dict_with_custom_init_parameters(self):
component = SimilarityRanker()
data = component.to_dict()
assert data == {
"type": "SimilarityRanker",
"init_parameters": {"device": "cpu", "model_name_or_path": "cross-encoder/ms-marco-MiniLM-L-6-v2"},
}

@pytest.mark.integration
def test_from_dict(self):
data = {
"type": "SimilarityRanker",
"init_parameters": {"device": "cpu", "model_name_or_path": "cross-encoder/ms-marco-MiniLM-L-6-v2"},
}
component = SimilarityRanker.from_dict(data)
assert component.model_name_or_path == "cross-encoder/ms-marco-MiniLM-L-6-v2"

@pytest.mark.integration
@pytest.mark.parametrize(
"query,docs_before_texts,expected_first_text",
[
("City in Bosnia and Herzegovina", ["Berlin", "Belgrade", "Sarajevo"], "Sarajevo"),
("Machine learning", ["Python", "Bakery in Paris", "Tesla Giga Berlin"], "Python"),
("Cubist movement", ["Nirvana", "Pablo Picasso", "Coffee"], "Pablo Picasso"),
],
)
def test_run(self, query, docs_before_texts, expected_first_text):
"""
Test if the component ranks documents correctly.
"""
ranker = SimilarityRanker(model_name_or_path="cross-encoder/ms-marco-MiniLM-L-6-v2")
ranker.warm_up()
docs_before = [Document(text=text) for text in docs_before_texts]
output = ranker.run(query=query, documents=docs_before)
docs_after = output["documents"]

assert len(docs_after) == 3
assert docs_after[0].text == expected_first_text

sorted_scores = sorted([doc.score for doc in docs_after], reverse=True)
assert [doc.score for doc in docs_after] == sorted_scores

# Returns an empty list if no documents are provided
@pytest.mark.integration
def test_returns_empty_list_if_no_documents_are_provided(self):
sampler = SimilarityRanker()
sampler.warm_up()
output = sampler.run(query="City in Germany", documents=[])
assert output["documents"] == []

# Raises ComponentError if model is not warmed up
@pytest.mark.integration
def test_raises_component_error_if_model_not_warmed_up(self):
sampler = SimilarityRanker()

with pytest.raises(ComponentError):
sampler.run(query="query", documents=[Document(text="document")])

0 comments on commit 1cdff64

Please sign in to comment.