Skip to content

Commit

Permalink
Implement filters for chromaQueryTextRetriever via existing haystack …
Browse files Browse the repository at this point in the history
…filters logic (#705)

* Implement filters for chromaQueryTextRetriever via existing haystack filters logic

Run linter

* un-skip tests

---------

Co-authored-by: Massimiliano Pippi <[email protected]>
  • Loading branch information
jongirard and masci authored May 10, 2024
1 parent c29db9c commit d4a598b
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 36 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def __init__(self, document_store: ChromaDocumentStore, filters: Optional[Dict[s
def run(
self,
query: str,
_: Optional[Dict[str, Any]] = None, # filters not yet supported
filters: Optional[Dict[str, Any]] = None,
top_k: Optional[int] = None,
):
"""
Expand All @@ -64,14 +64,15 @@ def run(
:param query: The input data for the retriever. In this case, a plain-text query.
:param top_k: The maximum number of documents to retrieve.
If not specified, the default value from the constructor is used.
:param filters: a dictionary of filters to apply to the search. Accepts filters in haystack format.
:returns: A dictionary with the following keys:
- `documents`: List of documents returned by the search engine.
:raises ValueError: If the specified document store is not found or is not a MemoryDocumentStore instance.
"""
top_k = top_k or self.top_k

return {"documents": self.document_store.search([query], top_k)[0]}
return {"documents": self.document_store.search([query], top_k, filters)[0]}

@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "ChromaQueryTextRetriever":
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -209,16 +209,30 @@ def delete_documents(self, document_ids: List[str]) -> None:
"""
self._collection.delete(ids=document_ids)

def search(self, queries: List[str], top_k: int) -> List[List[Document]]:
def search(self, queries: List[str], top_k: int, filters: Optional[Dict[str, Any]] = None) -> List[List[Document]]:
"""Search the documents in the store using the provided text queries.
:param queries: the list of queries to search for.
:param top_k: top_k documents to return for each query.
:param filters: a dictionary of filters to apply to the search. Accepts filters in haystack format.
:returns: matching documents for each query.
"""
results = self._collection.query(
query_texts=queries, n_results=top_k, include=["embeddings", "documents", "metadatas", "distances"]
)
if filters is None:
results = self._collection.query(
query_texts=queries,
n_results=top_k,
include=["embeddings", "documents", "metadatas", "distances"],
)
else:
chroma_filters = self._normalize_filters(filters=filters)
results = self._collection.query(
query_texts=queries,
n_results=top_k,
where=chroma_filters[1],
where_document=chroma_filters[2],
include=["embeddings", "documents", "metadatas", "distances"],
)

return self._query_result_to_documents(results)

def search_embeddings(
Expand Down
30 changes: 0 additions & 30 deletions integrations/chroma/tests/test_document_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,10 +131,6 @@ def test_same_collection_name_reinitialization(self):
ChromaDocumentStore("test_name")
ChromaDocumentStore("test_name")

@pytest.mark.skip(reason="Filter on array contents is not supported.")
def test_filter_document_array(self, document_store: ChromaDocumentStore, filterable_docs: List[Document]):
pass

@pytest.mark.skip(reason="Filter on dataframe contents is not supported.")
def test_filter_document_dataframe(self, document_store: ChromaDocumentStore, filterable_docs: List[Document]):
pass
Expand All @@ -147,10 +143,6 @@ def test_eq_filter_table(self, document_store: ChromaDocumentStore, filterable_d
def test_eq_filter_embedding(self, document_store: ChromaDocumentStore, filterable_docs: List[Document]):
pass

@pytest.mark.skip(reason="$in operator is not supported.")
def test_in_filter_explicit(self, document_store: ChromaDocumentStore, filterable_docs: List[Document]):
pass

@pytest.mark.skip(reason="$in operator is not supported. Filter on table contents is not supported.")
def test_in_filter_table(self, document_store: ChromaDocumentStore, filterable_docs: List[Document]):
pass
Expand Down Expand Up @@ -185,12 +177,6 @@ def test_filter_simple_implicit_and_with_multi_key_dict(
):
pass

@pytest.mark.skip(reason="Filter syntax not supported.")
def test_filter_simple_explicit_and_with_multikey_dict(
self, document_store: ChromaDocumentStore, filterable_docs: List[Document]
):
pass

@pytest.mark.skip(reason="Filter syntax not supported.")
def test_filter_simple_explicit_and_with_list(
self, document_store: ChromaDocumentStore, filterable_docs: List[Document]
Expand All @@ -201,10 +187,6 @@ def test_filter_simple_explicit_and_with_list(
def test_filter_simple_implicit_and(self, document_store: ChromaDocumentStore, filterable_docs: List[Document]):
pass

@pytest.mark.skip(reason="Filter syntax not supported.")
def test_filter_nested_explicit_and(self, document_store: ChromaDocumentStore, filterable_docs: List[Document]):
pass

@pytest.mark.skip(reason="Filter syntax not supported.")
def test_filter_nested_implicit_and(self, document_store: ChromaDocumentStore, filterable_docs: List[Document]):
pass
Expand Down Expand Up @@ -234,15 +216,3 @@ def test_filter_nested_multiple_identical_operators_same_level(
self, document_store: ChromaDocumentStore, filterable_docs: List[Document]
):
pass

@pytest.mark.skip(reason="Duplicate policy not supported.")
def test_write_duplicate_fail(self, document_store: ChromaDocumentStore):
pass

@pytest.mark.skip(reason="Duplicate policy not supported.")
def test_write_duplicate_skip(self, document_store: ChromaDocumentStore):
pass

@pytest.mark.skip(reason="Duplicate policy not supported.")
def test_write_duplicate_overwrite(self, document_store: ChromaDocumentStore):
pass

0 comments on commit d4a598b

Please sign in to comment.