diff --git a/integrations/chroma/src/haystack_integrations/components/retrievers/chroma/retriever.py b/integrations/chroma/src/haystack_integrations/components/retrievers/chroma/retriever.py index 7138eff88..1dbf3a61d 100644 --- a/integrations/chroma/src/haystack_integrations/components/retrievers/chroma/retriever.py +++ b/integrations/chroma/src/haystack_integrations/components/retrievers/chroma/retriever.py @@ -55,7 +55,7 @@ def __init__(self, document_store: ChromaDocumentStore, filters: Optional[Dict[s def run( self, query: str, - _: Optional[Dict[str, Any]] = None, # filters not yet supported + filters: Optional[Dict[str, Any]] = None, top_k: Optional[int] = None, ): """ @@ -64,6 +64,7 @@ def run( :param query: The input data for the retriever. In this case, a plain-text query. :param top_k: The maximum number of documents to retrieve. If not specified, the default value from the constructor is used. + :param filters: a dictionary of filters to apply to the search. Accepts filters in haystack format. :returns: A dictionary with the following keys: - `documents`: List of documents returned by the search engine. @@ -71,7 +72,7 @@ def run( """ top_k = top_k or self.top_k - return {"documents": self.document_store.search([query], top_k)[0]} + return {"documents": self.document_store.search([query], top_k, filters)[0]} @classmethod def from_dict(cls, data: Dict[str, Any]) -> "ChromaQueryTextRetriever": diff --git a/integrations/chroma/src/haystack_integrations/document_stores/chroma/document_store.py b/integrations/chroma/src/haystack_integrations/document_stores/chroma/document_store.py index 6d795f8ca..02acbe8dc 100644 --- a/integrations/chroma/src/haystack_integrations/document_stores/chroma/document_store.py +++ b/integrations/chroma/src/haystack_integrations/document_stores/chroma/document_store.py @@ -209,16 +209,30 @@ def delete_documents(self, document_ids: List[str]) -> None: """ self._collection.delete(ids=document_ids) - def search(self, queries: List[str], top_k: int) -> List[List[Document]]: + def search(self, queries: List[str], top_k: int, filters: Optional[Dict[str, Any]] = None) -> List[List[Document]]: """Search the documents in the store using the provided text queries. :param queries: the list of queries to search for. :param top_k: top_k documents to return for each query. + :param filters: a dictionary of filters to apply to the search. Accepts filters in haystack format. :returns: matching documents for each query. """ - results = self._collection.query( - query_texts=queries, n_results=top_k, include=["embeddings", "documents", "metadatas", "distances"] - ) + if filters is None: + results = self._collection.query( + query_texts=queries, + n_results=top_k, + include=["embeddings", "documents", "metadatas", "distances"], + ) + else: + chroma_filters = self._normalize_filters(filters=filters) + results = self._collection.query( + query_texts=queries, + n_results=top_k, + where=chroma_filters[1], + where_document=chroma_filters[2], + include=["embeddings", "documents", "metadatas", "distances"], + ) + return self._query_result_to_documents(results) def search_embeddings( diff --git a/integrations/chroma/tests/test_document_store.py b/integrations/chroma/tests/test_document_store.py index cddc66e3f..742f3305e 100644 --- a/integrations/chroma/tests/test_document_store.py +++ b/integrations/chroma/tests/test_document_store.py @@ -131,10 +131,6 @@ def test_same_collection_name_reinitialization(self): ChromaDocumentStore("test_name") ChromaDocumentStore("test_name") - @pytest.mark.skip(reason="Filter on array contents is not supported.") - def test_filter_document_array(self, document_store: ChromaDocumentStore, filterable_docs: List[Document]): - pass - @pytest.mark.skip(reason="Filter on dataframe contents is not supported.") def test_filter_document_dataframe(self, document_store: ChromaDocumentStore, filterable_docs: List[Document]): pass @@ -147,10 +143,6 @@ def test_eq_filter_table(self, document_store: ChromaDocumentStore, filterable_d def test_eq_filter_embedding(self, document_store: ChromaDocumentStore, filterable_docs: List[Document]): pass - @pytest.mark.skip(reason="$in operator is not supported.") - def test_in_filter_explicit(self, document_store: ChromaDocumentStore, filterable_docs: List[Document]): - pass - @pytest.mark.skip(reason="$in operator is not supported. Filter on table contents is not supported.") def test_in_filter_table(self, document_store: ChromaDocumentStore, filterable_docs: List[Document]): pass @@ -185,12 +177,6 @@ def test_filter_simple_implicit_and_with_multi_key_dict( ): pass - @pytest.mark.skip(reason="Filter syntax not supported.") - def test_filter_simple_explicit_and_with_multikey_dict( - self, document_store: ChromaDocumentStore, filterable_docs: List[Document] - ): - pass - @pytest.mark.skip(reason="Filter syntax not supported.") def test_filter_simple_explicit_and_with_list( self, document_store: ChromaDocumentStore, filterable_docs: List[Document] @@ -201,10 +187,6 @@ def test_filter_simple_explicit_and_with_list( def test_filter_simple_implicit_and(self, document_store: ChromaDocumentStore, filterable_docs: List[Document]): pass - @pytest.mark.skip(reason="Filter syntax not supported.") - def test_filter_nested_explicit_and(self, document_store: ChromaDocumentStore, filterable_docs: List[Document]): - pass - @pytest.mark.skip(reason="Filter syntax not supported.") def test_filter_nested_implicit_and(self, document_store: ChromaDocumentStore, filterable_docs: List[Document]): pass @@ -234,15 +216,3 @@ def test_filter_nested_multiple_identical_operators_same_level( self, document_store: ChromaDocumentStore, filterable_docs: List[Document] ): pass - - @pytest.mark.skip(reason="Duplicate policy not supported.") - def test_write_duplicate_fail(self, document_store: ChromaDocumentStore): - pass - - @pytest.mark.skip(reason="Duplicate policy not supported.") - def test_write_duplicate_skip(self, document_store: ChromaDocumentStore): - pass - - @pytest.mark.skip(reason="Duplicate policy not supported.") - def test_write_duplicate_overwrite(self, document_store: ChromaDocumentStore): - pass