Skip to content

Commit

Permalink
Pgvector - Embedding Retriever (#320)
Browse files Browse the repository at this point in the history
* squash

* squash

* Update integrations/pgvector/src/haystack_integrations/document_stores/pgvector/document_store.py

Co-authored-by: Massimiliano Pippi <[email protected]>

* Update integrations/pgvector/src/haystack_integrations/document_stores/pgvector/document_store.py

Co-authored-by: Massimiliano Pippi <[email protected]>

* Update integrations/pgvector/src/haystack_integrations/document_stores/pgvector/document_store.py

Co-authored-by: Massimiliano Pippi <[email protected]>

* Update integrations/pgvector/src/haystack_integrations/document_stores/pgvector/document_store.py

Co-authored-by: Massimiliano Pippi <[email protected]>

* fix fmt

* adjust docstrings

* Update integrations/pgvector/src/haystack_integrations/components/retrievers/pgvector/embedding_retriever.py

Co-authored-by: Massimiliano Pippi <[email protected]>

* Update integrations/pgvector/src/haystack_integrations/components/retrievers/pgvector/embedding_retriever.py

Co-authored-by: Massimiliano Pippi <[email protected]>

* improve docstrings

* fmt

---------

Co-authored-by: Massimiliano Pippi <[email protected]>
  • Loading branch information
anakin87 and masci authored Feb 1, 2024
1 parent 61daacb commit 3454815
Show file tree
Hide file tree
Showing 3 changed files with 222 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# SPDX-FileCopyrightText: 2023-present deepset GmbH <[email protected]>
#
# SPDX-License-Identifier: Apache-2.0
from .embedding_retriever import PgvectorEmbeddingRetriever

__all__ = ["PgvectorEmbeddingRetriever"]
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
# SPDX-FileCopyrightText: 2023-present deepset GmbH <[email protected]>
#
# SPDX-License-Identifier: Apache-2.0
from typing import Any, Dict, List, Literal, Optional

from haystack import component, default_from_dict, default_to_dict
from haystack.dataclasses import Document
from haystack_integrations.document_stores.pgvector import PgvectorDocumentStore
from haystack_integrations.document_stores.pgvector.document_store import VALID_VECTOR_FUNCTIONS


@component
class PgvectorEmbeddingRetriever:
"""
Retrieves documents from the PgvectorDocumentStore, based on their dense embeddings.
Needs to be connected to the PgvectorDocumentStore.
"""

def __init__(
self,
*,
document_store: PgvectorDocumentStore,
filters: Optional[Dict[str, Any]] = None,
top_k: int = 10,
vector_function: Optional[Literal["cosine_similarity", "inner_product", "l2_distance"]] = None,
):
"""
Create the PgvectorEmbeddingRetriever component.
:param document_store: An instance of PgvectorDocumentStore.
:param filters: Filters applied to the retrieved Documents. Defaults to None.
:param top_k: Maximum number of Documents to return, defaults to 10.
:param vector_function: The similarity function to use when searching for similar embeddings.
Defaults to the one set in the `document_store` instance.
"cosine_similarity" and "inner_product" are similarity functions and
higher scores indicate greater similarity between the documents.
"l2_distance" returns the straight-line distance between vectors,
and the most similar documents are the ones with the smallest score.
Important: if the document store is using the "hnsw" search strategy, the vector function
should match the one utilized during index creation to take advantage of the index.
:type vector_function: Literal["cosine_similarity", "inner_product", "l2_distance"]
:raises ValueError: If `document_store` is not an instance of PgvectorDocumentStore.
"""
if not isinstance(document_store, PgvectorDocumentStore):
msg = "document_store must be an instance of PgvectorDocumentStore"
raise ValueError(msg)

if vector_function and vector_function not in VALID_VECTOR_FUNCTIONS:
msg = f"vector_function must be one of {VALID_VECTOR_FUNCTIONS}"
raise ValueError(msg)

self.document_store = document_store
self.filters = filters or {}
self.top_k = top_k
self.vector_function = vector_function or document_store.vector_function

def to_dict(self) -> Dict[str, Any]:
return default_to_dict(
self,
filters=self.filters,
top_k=self.top_k,
vector_function=self.vector_function,
document_store=self.document_store.to_dict(),
)

@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "PgvectorEmbeddingRetriever":
data["init_parameters"]["document_store"] = default_from_dict(
PgvectorDocumentStore, data["init_parameters"]["document_store"]
)
return default_from_dict(cls, data)

@component.output_types(documents=List[Document])
def run(
self,
query_embedding: List[float],
filters: Optional[Dict[str, Any]] = None,
top_k: Optional[int] = None,
vector_function: Optional[Literal["cosine_similarity", "inner_product", "l2_distance"]] = None,
):
"""
Retrieve documents from the PgvectorDocumentStore, based on their embeddings.
:param query_embedding: Embedding of the query.
:param filters: Filters applied to the retrieved Documents.
:param top_k: Maximum number of Documents to return.
:param vector_function: The similarity function to use when searching for similar embeddings.
:type vector_function: Literal["cosine_similarity", "inner_product", "l2_distance"]
:return: List of Documents similar to `query_embedding`.
"""
filters = filters or self.filters
top_k = top_k or self.top_k
vector_function = vector_function or self.vector_function

docs = self.document_store._embedding_retrieval(
query_embedding=query_embedding,
filters=filters,
top_k=top_k,
vector_function=vector_function,
)
return {"documents": docs}
112 changes: 112 additions & 0 deletions integrations/pgvector/tests/test_retriever.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
# SPDX-FileCopyrightText: 2023-present deepset GmbH <[email protected]>
#
# SPDX-License-Identifier: Apache-2.0
from unittest.mock import Mock

from haystack.dataclasses import Document
from haystack_integrations.components.retrievers.pgvector import PgvectorEmbeddingRetriever
from haystack_integrations.document_stores.pgvector import PgvectorDocumentStore


class TestRetriever:
def test_init_default(self, document_store: PgvectorDocumentStore):
retriever = PgvectorEmbeddingRetriever(document_store=document_store)
assert retriever.document_store == document_store
assert retriever.filters == {}
assert retriever.top_k == 10
assert retriever.vector_function == document_store.vector_function

def test_init(self, document_store: PgvectorDocumentStore):
retriever = PgvectorEmbeddingRetriever(
document_store=document_store, filters={"field": "value"}, top_k=5, vector_function="l2_distance"
)
assert retriever.document_store == document_store
assert retriever.filters == {"field": "value"}
assert retriever.top_k == 5
assert retriever.vector_function == "l2_distance"

def test_to_dict(self, document_store: PgvectorDocumentStore):
retriever = PgvectorEmbeddingRetriever(
document_store=document_store, filters={"field": "value"}, top_k=5, vector_function="l2_distance"
)
res = retriever.to_dict()
t = "haystack_integrations.components.retrievers.pgvector.embedding_retriever.PgvectorEmbeddingRetriever"
assert res == {
"type": t,
"init_parameters": {
"document_store": {
"type": "haystack_integrations.document_stores.pgvector.document_store.PgvectorDocumentStore",
"init_parameters": {
"connection_string": "postgresql://postgres:postgres@localhost:5432/postgres",
"table_name": "haystack_test_to_dict",
"embedding_dimension": 768,
"vector_function": "cosine_similarity",
"recreate_table": True,
"search_strategy": "exact_nearest_neighbor",
"hnsw_recreate_index_if_exists": False,
"hnsw_index_creation_kwargs": {},
"hnsw_ef_search": None,
},
},
"filters": {"field": "value"},
"top_k": 5,
"vector_function": "l2_distance",
},
}

def test_from_dict(self):
t = "haystack_integrations.components.retrievers.pgvector.embedding_retriever.PgvectorEmbeddingRetriever"
data = {
"type": t,
"init_parameters": {
"document_store": {
"type": "haystack_integrations.document_stores.pgvector.document_store.PgvectorDocumentStore",
"init_parameters": {
"connection_string": "postgresql://postgres:postgres@localhost:5432/postgres",
"table_name": "haystack_test_to_dict",
"embedding_dimension": 768,
"vector_function": "cosine_similarity",
"recreate_table": True,
"search_strategy": "exact_nearest_neighbor",
"hnsw_recreate_index_if_exists": False,
"hnsw_index_creation_kwargs": {},
"hnsw_ef_search": None,
},
},
"filters": {"field": "value"},
"top_k": 5,
"vector_function": "l2_distance",
},
}

retriever = PgvectorEmbeddingRetriever.from_dict(data)
document_store = retriever.document_store

assert isinstance(document_store, PgvectorDocumentStore)
assert document_store.connection_string == "postgresql://postgres:postgres@localhost:5432/postgres"
assert document_store.table_name == "haystack_test_to_dict"
assert document_store.embedding_dimension == 768
assert document_store.vector_function == "cosine_similarity"
assert document_store.recreate_table
assert document_store.search_strategy == "exact_nearest_neighbor"
assert not document_store.hnsw_recreate_index_if_exists
assert document_store.hnsw_index_creation_kwargs == {}
assert document_store.hnsw_ef_search is None

assert retriever.filters == {"field": "value"}
assert retriever.top_k == 5
assert retriever.vector_function == "l2_distance"

def test_run(self):
mock_store = Mock(spec=PgvectorDocumentStore)
doc = Document(content="Test doc", embedding=[0.1, 0.2])
mock_store._embedding_retrieval.return_value = [doc]

retriever = PgvectorEmbeddingRetriever(document_store=mock_store, vector_function="l2_distance")
res = retriever.run(query_embedding=[0.3, 0.5])

mock_store._embedding_retrieval.assert_called_once_with(
query_embedding=[0.3, 0.5], filters={}, top_k=10, vector_function="l2_distance"
)

assert res == {"documents": [doc]}

0 comments on commit 3454815

Please sign in to comment.