Skip to content

Commit

Permalink
Add defensive check for filter_policy deserialization
Browse files Browse the repository at this point in the history
  • Loading branch information
vblagoje committed Jul 15, 2024
1 parent 05a21f6 commit 50a1ad5
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,10 @@ def from_dict(cls, data: Dict[str, Any]) -> "QdrantEmbeddingRetriever":
"""
document_store = QdrantDocumentStore.from_dict(data["init_parameters"]["document_store"])
data["init_parameters"]["document_store"] = document_store
data["init_parameters"]["filter_policy"] = FilterPolicy.from_str(data["init_parameters"]["filter_policy"])
# Pipelines serialized with old versions of the component might not
# have the filter_policy field.
if filter_policy := data["init_parameters"].get("filter_policy"):
data["init_parameters"]["filter_policy"] = FilterPolicy.from_str(filter_policy)
return default_from_dict(cls, data)

@component.output_types(documents=List[Document])
Expand Down Expand Up @@ -249,7 +252,10 @@ def from_dict(cls, data: Dict[str, Any]) -> "QdrantSparseEmbeddingRetriever":
"""
document_store = QdrantDocumentStore.from_dict(data["init_parameters"]["document_store"])
data["init_parameters"]["document_store"] = document_store
data["init_parameters"]["filter_policy"] = FilterPolicy.from_str(data["init_parameters"]["filter_policy"])
# Pipelines serialized with old versions of the component might not
# have the filter_policy field.
if filter_policy := data["init_parameters"].get("filter_policy"):
data["init_parameters"]["filter_policy"] = FilterPolicy.from_str(filter_policy)
return default_from_dict(cls, data)

@component.output_types(documents=List[Document])
Expand Down Expand Up @@ -394,7 +400,10 @@ def from_dict(cls, data: Dict[str, Any]) -> "QdrantHybridRetriever":
"""
document_store = QdrantDocumentStore.from_dict(data["init_parameters"]["document_store"])
data["init_parameters"]["document_store"] = document_store
data["init_parameters"]["filter_policy"] = FilterPolicy.from_str(data["init_parameters"]["filter_policy"])
# Pipelines serialized with old versions of the component might not
# have the filter_policy field.
if filter_policy := data["init_parameters"].get("filter_policy"):
data["init_parameters"]["filter_policy"] = FilterPolicy.from_str(filter_policy)
return default_from_dict(cls, data)

@component.output_types(documents=List[Document])
Expand Down
25 changes: 25 additions & 0 deletions integrations/qdrant/tests/test_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,31 @@ def test_from_dict(self):
assert retriever._return_embedding is True
assert retriever._score_threshold is None

def test_from_dict_no_filter_policy(self):
data = {
"type": "haystack_integrations.components.retrievers.qdrant.retriever.QdrantSparseEmbeddingRetriever",
"init_parameters": {
"document_store": {
"init_parameters": {"location": ":memory:", "index": "test"},
"type": "haystack_integrations.document_stores.qdrant.document_store.QdrantDocumentStore",
},
"filters": None,
"top_k": 5,
"scale_score": False,
"return_embedding": True,
"score_threshold": None,
},
}
retriever = QdrantSparseEmbeddingRetriever.from_dict(data)
assert isinstance(retriever._document_store, QdrantDocumentStore)
assert retriever._document_store.index == "test"
assert retriever._filters is None
assert retriever._top_k == 5
assert retriever._filter_policy == FilterPolicy.REPLACE # defaults to REPLACE
assert retriever._scale_score is False
assert retriever._return_embedding is True
assert retriever._score_threshold is None

def test_run(self, filterable_docs: List[Document], generate_sparse_embedding):
document_store = QdrantDocumentStore(location=":memory:", index="Boi", use_sparse_embeddings=True)

Expand Down

0 comments on commit 50a1ad5

Please sign in to comment.