From 38bc78e7444c2db5a6888324184f8f4b71c9f357 Mon Sep 17 00:00:00 2001 From: Shukri Date: Wed, 13 Mar 2024 11:01:14 +0100 Subject: [PATCH 1/2] weaviate: migrate from weaviate python client v3 to v4 (#463) * upgrade to latest weaviate server * upgrade to latest weaviate client * reformat code * create client using v4 api * use v4 api to create collection * store collection obj for convenience * upgrade filters to use v4 api * upgrade batch write to use v4 api * use v4 api cursor to retrieve all docs * upgrade query with filters to use v4 api * upgrade filter documents to use v4 API * update weaviate fixture to align with v4 API * update v4 to v3 conversion logic * fix typo * fix date v4 to v3 conversion logic * hardcode limit in query filter * fix typo * upgrade weaviate server * update v4 to v3 object date conversion the property name will still appear in the object's propertities even though it is not set. So, we need to check if it is not None too * fix invert logic bug * upgrade delete function to v4 API * update bm25 search to v4 API * update count docs to v4 API * update _write to use v4 API * support optional filters in bm25 * update embedding retrieval to use v4 API * update from_dict for v4 API * fix write invalid input test * update other test_from_dict for V4 * update test_to_dict for v4 * update test_init for v4 API * try to pas test_init * pass test_init * add exception handling in _query_paginated * remove commented out code * remove dead code * remove commented out code * return weaviate traceback too when query error occurs * make _query_paginated return an iterator * refactor _to_document * remove v4 to v3 object conv fn * update to_dict serialization * update test case * update weaviate server * updates due to latest client changes * update test case due to latest client changes * Fix filter converters return types * Rework query methods * Fix batch writing errors * Handle different vector types in _to_document * Add pagination tests * Fix pagination test --------- Co-authored-by: Silvano Cerza --- integrations/weaviate/docker-compose.yml | 2 +- integrations/weaviate/pyproject.toml | 2 +- .../document_stores/weaviate/_filters.py | 101 ++--- .../weaviate/document_store.py | 370 ++++++++---------- .../weaviate/tests/test_bm25_retriever.py | 8 - .../weaviate/tests/test_document_store.py | 126 +++--- .../tests/test_embedding_retriever.py | 8 - integrations/weaviate/tests/test_filters.py | 2 +- 8 files changed, 273 insertions(+), 346 deletions(-) diff --git a/integrations/weaviate/docker-compose.yml b/integrations/weaviate/docker-compose.yml index c61b0ed57..f7f033eee 100644 --- a/integrations/weaviate/docker-compose.yml +++ b/integrations/weaviate/docker-compose.yml @@ -8,7 +8,7 @@ services: - '8080' - --scheme - http - image: semitechnologies/weaviate:1.23.2 + image: semitechnologies/weaviate:1.24.1 ports: - 8080:8080 - 50051:50051 diff --git a/integrations/weaviate/pyproject.toml b/integrations/weaviate/pyproject.toml index 421c2ce18..54d9ec21b 100644 --- a/integrations/weaviate/pyproject.toml +++ b/integrations/weaviate/pyproject.toml @@ -24,7 +24,7 @@ classifiers = [ ] dependencies = [ "haystack-ai", - "weaviate-client==3.*", + "weaviate-client", "haystack-pydoc-tools", "python-dateutil", ] diff --git a/integrations/weaviate/src/haystack_integrations/document_stores/weaviate/_filters.py b/integrations/weaviate/src/haystack_integrations/document_stores/weaviate/_filters.py index a192c6947..a2201f0a5 100644 --- a/integrations/weaviate/src/haystack_integrations/document_stores/weaviate/_filters.py +++ b/integrations/weaviate/src/haystack_integrations/document_stores/weaviate/_filters.py @@ -4,8 +4,11 @@ from haystack.errors import FilterError from pandas import DataFrame +import weaviate +from weaviate.collections.classes.filters import Filter, FilterReturn -def convert_filters(filters: Dict[str, Any]) -> Dict[str, Any]: + +def convert_filters(filters: Dict[str, Any]) -> FilterReturn: """ Convert filters from Haystack format to Weaviate format. """ @@ -14,7 +17,7 @@ def convert_filters(filters: Dict[str, Any]) -> Dict[str, Any]: raise FilterError(msg) if "field" in filters: - return {"operator": "And", "operands": [_parse_comparison_condition(filters)]} + return Filter.all_of([_parse_comparison_condition(filters)]) return _parse_logical_condition(filters) @@ -29,7 +32,7 @@ def convert_filters(filters: Dict[str, Any]) -> Dict[str, Any]: "not in": "in", "AND": "OR", "OR": "AND", - "NOT": "AND", + "NOT": "OR", } @@ -51,7 +54,13 @@ def _invert_condition(filters: Dict[str, Any]) -> Dict[str, Any]: return inverted_condition -def _parse_logical_condition(condition: Dict[str, Any]) -> Dict[str, Any]: +LOGICAL_OPERATORS = { + "AND": Filter.all_of, + "OR": Filter.any_of, +} + + +def _parse_logical_condition(condition: Dict[str, Any]) -> FilterReturn: if "operator" not in condition: msg = f"'operator' key missing in {condition}" raise FilterError(msg) @@ -67,7 +76,7 @@ def _parse_logical_condition(condition: Dict[str, Any]) -> Dict[str, Any]: operands.append(_parse_logical_condition(c)) else: operands.append(_parse_comparison_condition(c)) - return {"operator": operator.lower().capitalize(), "operands": operands} + return LOGICAL_OPERATORS[operator](operands) elif operator == "NOT": inverted_conditions = _invert_condition(condition) return _parse_logical_condition(inverted_conditions) @@ -76,28 +85,6 @@ def _parse_logical_condition(condition: Dict[str, Any]) -> Dict[str, Any]: raise FilterError(msg) -def _infer_value_type(value: Any) -> str: - if value is None: - return "valueNull" - - if isinstance(value, bool): - return "valueBoolean" - if isinstance(value, int): - return "valueInt" - if isinstance(value, float): - return "valueNumber" - - if isinstance(value, str): - try: - parser.isoparse(value) - return "valueDate" - except ValueError: - return "valueText" - - msg = f"Unknown value type {type(value)}" - raise FilterError(msg) - - def _handle_date(value: Any) -> str: if isinstance(value, str): try: @@ -107,25 +94,22 @@ def _handle_date(value: Any) -> str: return value -def _equal(field: str, value: Any) -> Dict[str, Any]: +def _equal(field: str, value: Any) -> FilterReturn: if value is None: - return {"path": field, "operator": "IsNull", "valueBoolean": True} - return {"path": field, "operator": "Equal", _infer_value_type(value): _handle_date(value)} + return weaviate.classes.query.Filter.by_property(field).is_none(True) + return weaviate.classes.query.Filter.by_property(field).equal(_handle_date(value)) -def _not_equal(field: str, value: Any) -> Dict[str, Any]: +def _not_equal(field: str, value: Any) -> FilterReturn: if value is None: - return {"path": field, "operator": "IsNull", "valueBoolean": False} - return { - "operator": "Or", - "operands": [ - {"path": field, "operator": "NotEqual", _infer_value_type(value): _handle_date(value)}, - {"path": field, "operator": "IsNull", "valueBoolean": True}, - ], - } + return weaviate.classes.query.Filter.by_property(field).is_none(False) + return weaviate.classes.query.Filter.by_property(field).not_equal( + _handle_date(value) + ) | weaviate.classes.query.Filter.by_property(field).is_none(True) -def _greater_than(field: str, value: Any) -> Dict[str, Any]: + +def _greater_than(field: str, value: Any) -> FilterReturn: if value is None: # When the value is None and '>' is used we create a filter that would return a Document # if it has a field set and not set at the same time. @@ -144,10 +128,10 @@ def _greater_than(field: str, value: Any) -> Dict[str, Any]: if type(value) in [list, DataFrame]: msg = f"Filter value can't be of type {type(value)} using operators '>', '>=', '<', '<='" raise FilterError(msg) - return {"path": field, "operator": "GreaterThan", _infer_value_type(value): _handle_date(value)} + return weaviate.classes.query.Filter.by_property(field).greater_than(_handle_date(value)) -def _greater_than_equal(field: str, value: Any) -> Dict[str, Any]: +def _greater_than_equal(field: str, value: Any) -> FilterReturn: if value is None: # When the value is None and '>=' is used we create a filter that would return a Document # if it has a field set and not set at the same time. @@ -166,10 +150,10 @@ def _greater_than_equal(field: str, value: Any) -> Dict[str, Any]: if type(value) in [list, DataFrame]: msg = f"Filter value can't be of type {type(value)} using operators '>', '>=', '<', '<='" raise FilterError(msg) - return {"path": field, "operator": "GreaterThanEqual", _infer_value_type(value): _handle_date(value)} + return weaviate.classes.query.Filter.by_property(field).greater_or_equal(_handle_date(value)) -def _less_than(field: str, value: Any) -> Dict[str, Any]: +def _less_than(field: str, value: Any) -> FilterReturn: if value is None: # When the value is None and '<' is used we create a filter that would return a Document # if it has a field set and not set at the same time. @@ -188,10 +172,10 @@ def _less_than(field: str, value: Any) -> Dict[str, Any]: if type(value) in [list, DataFrame]: msg = f"Filter value can't be of type {type(value)} using operators '>', '>=', '<', '<='" raise FilterError(msg) - return {"path": field, "operator": "LessThan", _infer_value_type(value): _handle_date(value)} + return weaviate.classes.query.Filter.by_property(field).less_than(_handle_date(value)) -def _less_than_equal(field: str, value: Any) -> Dict[str, Any]: +def _less_than_equal(field: str, value: Any) -> FilterReturn: if value is None: # When the value is None and '<=' is used we create a filter that would return a Document # if it has a field set and not set at the same time. @@ -210,22 +194,23 @@ def _less_than_equal(field: str, value: Any) -> Dict[str, Any]: if type(value) in [list, DataFrame]: msg = f"Filter value can't be of type {type(value)} using operators '>', '>=', '<', '<='" raise FilterError(msg) - return {"path": field, "operator": "LessThanEqual", _infer_value_type(value): _handle_date(value)} + return weaviate.classes.query.Filter.by_property(field).less_or_equal(_handle_date(value)) -def _in(field: str, value: Any) -> Dict[str, Any]: +def _in(field: str, value: Any) -> FilterReturn: if not isinstance(value, list): msg = f"{field}'s value must be a list when using 'in' or 'not in' comparators" raise FilterError(msg) - return {"operator": "And", "operands": [_equal(field, v) for v in value]} + return weaviate.classes.query.Filter.by_property(field).contains_any(value) -def _not_in(field: str, value: Any) -> Dict[str, Any]: +def _not_in(field: str, value: Any) -> FilterReturn: if not isinstance(value, list): msg = f"{field}'s value must be a list when using 'in' or 'not in' comparators" raise FilterError(msg) - return {"operator": "And", "operands": [_not_equal(field, v) for v in value]} + operands = [weaviate.classes.query.Filter.by_property(field).not_equal(v) for v in value] + return Filter.all_of(operands) COMPARISON_OPERATORS = { @@ -240,7 +225,7 @@ def _not_in(field: str, value: Any) -> Dict[str, Any]: } -def _parse_comparison_condition(condition: Dict[str, Any]) -> Dict[str, Any]: +def _parse_comparison_condition(condition: Dict[str, Any]) -> FilterReturn: field: str = condition["field"] if field.startswith("meta."): @@ -265,15 +250,11 @@ def _parse_comparison_condition(condition: Dict[str, Any]) -> Dict[str, Any]: return COMPARISON_OPERATORS[operator](field, value) -def _match_no_document(field: str) -> Dict[str, Any]: +def _match_no_document(field: str) -> FilterReturn: """ Returns a filters that will match no Document, this is used to keep the behavior consistent between different Document Stores. """ - return { - "operator": "And", - "operands": [ - {"path": field, "operator": "IsNull", "valueBoolean": False}, - {"path": field, "operator": "IsNull", "valueBoolean": True}, - ], - } + + operands = [weaviate.classes.query.Filter.by_property(field).is_none(val) for val in [False, True]] + return Filter.all_of(operands) diff --git a/integrations/weaviate/src/haystack_integrations/document_stores/weaviate/document_store.py b/integrations/weaviate/src/haystack_integrations/document_stores/weaviate/document_store.py index 071fe336b..34fefa0a5 100644 --- a/integrations/weaviate/src/haystack_integrations/document_stores/weaviate/document_store.py +++ b/integrations/weaviate/src/haystack_integrations/document_stores/weaviate/document_store.py @@ -2,6 +2,8 @@ # # SPDX-License-Identifier: Apache-2.0 import base64 +import datetime +import json from dataclasses import asdict from typing import Any, Dict, List, Optional, Tuple, Union @@ -11,7 +13,8 @@ from haystack.document_stores.types.policy import DuplicatePolicy import weaviate -from weaviate.config import Config, ConnectionConfig +from weaviate.collections.classes.data import DataObject +from weaviate.config import AdditionalConfig from weaviate.embedded import EmbeddedOptions from weaviate.util import generate_uuid5 @@ -42,6 +45,16 @@ {"name": "score", "dataType": ["number"]}, ] +# This is the default limit used when querying documents with WeaviateDocumentStore. +# +# We picked this as QUERY_MAXIMUM_RESULTS defaults to 10000, trying to get that many +# documents at once will fail, even if the query is paginated. +# This value will ensure we get the most documents possible without hitting that limit, it would +# still fail if the user lowers the QUERY_MAXIMUM_RESULTS environment variable for their Weaviate instance. +# +# See WeaviateDocumentStore._query_with_filters() for more information. +DEFAULT_QUERY_LIMIT = 9999 + class WeaviateDocumentStore: """ @@ -54,13 +67,11 @@ def __init__( url: Optional[str] = None, collection_settings: Optional[Dict[str, Any]] = None, auth_client_secret: Optional[AuthCredentials] = None, - timeout_config: TimeoutType = (10, 60), - proxies: Optional[Union[Dict, str]] = None, - trust_env: bool = False, additional_headers: Optional[Dict] = None, - startup_period: Optional[int] = 5, embedded_options: Optional[EmbeddedOptions] = None, - additional_config: Optional[Config] = None, + additional_config: Optional[AdditionalConfig] = None, + grpc_port: int = 50051, + grpc_secure: bool = False, ): """ Create a new instance of WeaviateDocumentStore and connects to the Weaviate instance. @@ -88,46 +99,35 @@ def __init__( - `AuthClientPassword` to use username and password for oidc Resource Owner Password flow - `AuthClientCredentials` to use a client secret for oidc client credential flow - `AuthApiKey` to use an API key - :param timeout_config: Timeout configuration for all requests to the Weaviate server, defaults to (10, 60). - It can be a real number or, a tuple of two real numbers: (connect timeout, read timeout). - If only one real number is passed then both connect and read timeout will be set to - that value, by default (2, 20). - :param proxies: Proxy configuration, defaults to None. - Can be passed as a dict using the - ``requests` format`_, - or a string. If a string is passed it will be used for both HTTP and HTTPS requests. - :param trust_env: Whether to read proxies from the ENV variables, defaults to False. - Proxies will be read from the following ENV variables: - * `HTTP_PROXY` - * `http_proxy` - * `HTTPS_PROXY` - * `https_proxy` - If `proxies` is not None, `trust_env` is ignored. :param additional_headers: Additional headers to include in the requests, defaults to None. Can be used to set OpenAI/HuggingFace keys. OpenAI/HuggingFace key looks like this: ``` {"X-OpenAI-Api-Key": ""}, {"X-HuggingFace-Api-Key": ""} ``` - :param startup_period: How many seconds the client will wait for Weaviate to start before - raising a RequestsConnectionError, defaults to 5. :param embedded_options: If set create an embedded Weaviate cluster inside the client, defaults to None. For a full list of options see `weaviate.embedded.EmbeddedOptions`. :param additional_config: Additional and advanced configuration options for weaviate, defaults to None. + :param grpc_port: The port to use for the gRPC connection, defaults to 50051. + :param grpc_secure: Whether to use a secure channel for the underlying gRPC API. """ - self._client = weaviate.Client( - url=url, + # proxies, timeout_config, trust_env are part of additional_config now + # startup_period has been removed + self._client = weaviate.WeaviateClient( + connection_params=( + weaviate.connect.base.ConnectionParams.from_url(url=url, grpc_port=grpc_port, grpc_secure=grpc_secure) + if url + else None + ), auth_client_secret=auth_client_secret.resolve_value() if auth_client_secret else None, - timeout_config=timeout_config, - proxies=proxies, - trust_env=trust_env, + additional_config=additional_config, additional_headers=additional_headers, - startup_period=startup_period, embedded_options=embedded_options, - additional_config=additional_config, + skip_init_checks=False, ) + self._client.connect() # Test connection, it will raise an exception if it fails. - self._client.schema.get() + self._client.collections._get_all(simple=True) if collection_settings is None: collection_settings = { @@ -141,64 +141,53 @@ def __init__( # Set the properties if they're not set collection_settings["properties"] = collection_settings.get("properties", DOCUMENT_COLLECTION_PROPERTIES) - if not self._client.schema.exists(collection_settings["class"]): - self._client.schema.create_class(collection_settings) + if not self._client.collections.exists(collection_settings["class"]): + self._client.collections.create_from_dict(collection_settings) self._url = url self._collection_settings = collection_settings self._auth_client_secret = auth_client_secret - self._timeout_config = timeout_config - self._proxies = proxies - self._trust_env = trust_env self._additional_headers = additional_headers - self._startup_period = startup_period self._embedded_options = embedded_options self._additional_config = additional_config + self._collection = self._client.collections.get(collection_settings["class"]) def to_dict(self) -> Dict[str, Any]: embedded_options = asdict(self._embedded_options) if self._embedded_options else None - additional_config = asdict(self._additional_config) if self._additional_config else None + additional_config = ( + json.loads(self._additional_config.model_dump_json(by_alias=True)) if self._additional_config else None + ) return default_to_dict( self, url=self._url, collection_settings=self._collection_settings, auth_client_secret=self._auth_client_secret.to_dict() if self._auth_client_secret else None, - timeout_config=self._timeout_config, - proxies=self._proxies, - trust_env=self._trust_env, additional_headers=self._additional_headers, - startup_period=self._startup_period, embedded_options=embedded_options, additional_config=additional_config, ) @classmethod def from_dict(cls, data: Dict[str, Any]) -> "WeaviateDocumentStore": - if (timeout_config := data["init_parameters"].get("timeout_config")) is not None: - data["init_parameters"]["timeout_config"] = ( - tuple(timeout_config) if isinstance(timeout_config, list) else timeout_config - ) if (auth_client_secret := data["init_parameters"].get("auth_client_secret")) is not None: data["init_parameters"]["auth_client_secret"] = AuthCredentials.from_dict(auth_client_secret) if (embedded_options := data["init_parameters"].get("embedded_options")) is not None: data["init_parameters"]["embedded_options"] = EmbeddedOptions(**embedded_options) if (additional_config := data["init_parameters"].get("additional_config")) is not None: - additional_config["connection_config"] = ConnectionConfig(**additional_config["connection_config"]) - data["init_parameters"]["additional_config"] = Config(**additional_config) + data["init_parameters"]["additional_config"] = AdditionalConfig(**additional_config) return default_from_dict( cls, data, ) def count_documents(self) -> int: - collection_name = self._collection_settings["class"] - res = self._client.query.aggregate(collection_name).with_meta_count().do() - return res.get("data", {}).get("Aggregate", {}).get(collection_name, [{}])[0].get("meta", {}).get("count", 0) + total = self._collection.aggregate.over_all(total_count=True).total_count + return total if total else 0 def _to_data_object(self, document: Document) -> Dict[str, Any]: """ - Convert a Document to a Weviate data object ready to be saved. + Convert a Document to a Weaviate data object ready to be saved. """ data = document.to_dict() # Weaviate forces a UUID as an id. @@ -216,95 +205,82 @@ def _to_data_object(self, document: Document) -> Dict[str, Any]: return data - def _to_document(self, data: Dict[str, Any]) -> Document: + def _to_document(self, data: DataObject[Dict[str, Any], None]) -> Document: """ Convert a data object read from Weaviate into a Document. """ - data["id"] = data.pop("_original_id") - data["embedding"] = data["_additional"].pop("vector") if data["_additional"].get("vector") else None + document_data = data.properties + document_data["id"] = document_data.pop("_original_id") + if isinstance(data.vector, List): + document_data["embedding"] = data.vector + elif isinstance(data.vector, Dict): + document_data["embedding"] = data.vector.get("default") + else: + document_data["embedding"] = None - if (blob_data := data.get("blob_data")) is not None: - data["blob"] = { + if (blob_data := document_data.get("blob_data")) is not None: + document_data["blob"] = { "data": base64.b64decode(blob_data), - "mime_type": data.get("blob_mime_type"), + "mime_type": document_data.get("blob_mime_type"), } - # We always delete these fields as they're not part of the Document dataclass - data.pop("blob_data") - data.pop("blob_mime_type") - - # We don't need these fields anymore, this usually only contains the uuid - # used by Weaviate to identify the object and the embedding vector that we already extracted. - del data["_additional"] - - return Document.from_dict(data) - - def _query_paginated(self, properties: List[str], cursor=None): - collection_name = self._collection_settings["class"] - query = ( - self._client.query.get( - collection_name, - properties, - ) - .with_additional(["id vector"]) - .with_limit(100) - ) - - if cursor: - # Fetch the next set of results - result = query.with_after(cursor).do() - else: - # Fetch the first set of results - result = query.do() - - if "errors" in result: - errors = [e["message"] for e in result.get("errors", {})] - msg = "\n".join(errors) - msg = f"Failed to query documents in Weaviate. Errors:\n{msg}" - raise DocumentStoreError(msg) - - return result["data"]["Get"][collection_name] - - def _query_with_filters(self, properties: List[str], filters: Dict[str, Any]) -> List[Dict[str, Any]]: - collection_name = self._collection_settings["class"] - query = ( - self._client.query.get( - collection_name, - properties, - ) - .with_additional(["id vector"]) - .with_where(convert_filters(filters)) - ) - - result = query.do() - if "errors" in result: - errors = [e["message"] for e in result.get("errors", {})] - msg = "\n".join(errors) - msg = f"Failed to query documents in Weaviate. Errors:\n{msg}" - raise DocumentStoreError(msg) + # We always delete these fields as they're not part of the Document dataclass + document_data.pop("blob_data", None) + document_data.pop("blob_mime_type", None) + + for key, value in document_data.items(): + if isinstance(value, datetime.datetime): + document_data[key] = value.strftime("%Y-%m-%dT%H:%M:%SZ") + + return Document.from_dict(document_data) + + def _query(self) -> List[Dict[str, Any]]: + properties = [p.name for p in self._collection.config.get().properties] + try: + result = self._collection.iterator(include_vector=True, return_properties=properties) + except weaviate.exceptions.WeaviateQueryError as e: + msg = f"Failed to query documents in Weaviate. Error: {e.message}" + raise DocumentStoreError(msg) from e + return result - return result["data"]["Get"][collection_name] + def _query_with_filters(self, filters: Dict[str, Any]) -> List[Dict[str, Any]]: + properties = [p.name for p in self._collection.config.get().properties] + # When querying with filters we need to paginate using limit and offset as using + # a cursor with after is not possible. See the official docs: + # https://weaviate.io/developers/weaviate/api/graphql/additional-operators#cursor-with-after + # + # Nonetheless there's also another issue, paginating with limit and offset is not efficient + # and it's still restricted by the QUERY_MAXIMUM_RESULTS environment variable. + # If the sum of limit and offest is greater than QUERY_MAXIMUM_RESULTS an error is raised. + # See the official docs for more: + # https://weaviate.io/developers/weaviate/api/graphql/additional-operators#performance-considerations + offset = 0 + partial_result = None + result = [] + # Keep querying until we get all documents matching the filters + while partial_result is None or len(partial_result.objects) == DEFAULT_QUERY_LIMIT: + try: + partial_result = self._collection.query.fetch_objects( + filters=convert_filters(filters), + include_vector=True, + limit=DEFAULT_QUERY_LIMIT, + offset=offset, + return_properties=properties, + ) + except weaviate.exceptions.WeaviateQueryError as e: + msg = f"Failed to query documents in Weaviate. Error: {e.message}" + raise DocumentStoreError(msg) from e + result.extend(partial_result.objects) + offset += DEFAULT_QUERY_LIMIT + return result def filter_documents(self, filters: Optional[Dict[str, Any]] = None) -> List[Document]: - properties = self._client.schema.get(self._collection_settings["class"]).get("properties", []) - properties = [prop["name"] for prop in properties] - - if filters: - result = self._query_with_filters(properties, filters) - return [self._to_document(doc) for doc in result] - result = [] - - cursor = None - while batch := self._query_paginated(properties, cursor): - # Take the cursor before we convert the batch to Documents as we manipulate - # the batch dictionary and might lose that information. - cursor = batch[-1]["_additional"]["id"] - - for doc in batch: - result.append(self._to_document(doc)) - # Move the cursor to the last returned uuid - return result + if filters: + result = self._query_with_filters(filters) + else: + result = self._query() + return [self._to_document(doc) for doc in result] def _batch_write(self, documents: List[Document]) -> int: """ @@ -312,33 +288,35 @@ def _batch_write(self, documents: List[Document]) -> int: Documents with the same id will be overwritten. Raises in case of errors. """ - statuses = [] - for doc in documents: - if not isinstance(doc, Document): - msg = f"Expected a Document, got '{type(doc)}' instead." - raise ValueError(msg) - if self._client.batch.num_objects() == self._client.batch.recommended_num_objects: - # Batch is full, let's create the objects - statuses.extend(self._client.batch.create_objects()) - self._client.batch.add_data_object( - uuid=generate_uuid5(doc.id), - data_object=self._to_data_object(doc), - class_name=self._collection_settings["class"], - vector=doc.embedding, + + with self._client.batch.dynamic() as batch: + for doc in documents: + if not isinstance(doc, Document): + msg = f"Expected a Document, got '{type(doc)}' instead." + raise ValueError(msg) + + batch.add_object( + properties=self._to_data_object(doc), + collection=self._collection.name, + uuid=generate_uuid5(doc.id), + vector=doc.embedding, + ) + if failed_objects := self._client.batch.failed_objects: + # We fallback to use the UUID if the _original_id is not present, this is just to be + mapped_objects = {} + for obj in failed_objects: + properties = obj.object_.properties or {} + # We get the object uuid just in case the _original_id is not present. + # That's extremely unlikely to happen but let's stay on the safe side. + id_ = properties.get("_original_id", obj.object_.uuid) + mapped_objects[id_] = obj.message + + msg = "\n".join( + [ + f"Failed to write object with id '{id_}'. Error: '{message}'" + for id_, message in mapped_objects.items() + ] ) - # Write remaining documents - statuses.extend(self._client.batch.create_objects()) - - errors = [] - # Gather errors and number of written documents - for status in statuses: - result_status = status.get("result", {}).get("status") - if result_status == "FAILED": - errors.extend([e["message"] for e in status["result"]["errors"]["error"]]) - - if errors: - msg = "\n".join(errors) - msg = f"Failed to write documents in Weaviate. Errors:\n{msg}" raise DocumentStoreError(msg) # If the document already exists we get no status message back from Weaviate. @@ -359,22 +337,19 @@ def _write(self, documents: List[Document], policy: DuplicatePolicy) -> int: msg = f"Expected a Document, got '{type(doc)}' instead." raise ValueError(msg) - if policy == DuplicatePolicy.SKIP and self._client.data_object.exists( - uuid=generate_uuid5(doc.id), - class_name=self._collection_settings["class"], - ): + if policy == DuplicatePolicy.SKIP and self._collection.data.exists(uuid=generate_uuid5(doc.id)): # This Document already exists, we skip it continue try: - self._client.data_object.create( + self._collection.data.insert( uuid=generate_uuid5(doc.id), - data_object=self._to_data_object(doc), - class_name=self._collection_settings["class"], + properties=self._to_data_object(doc), vector=doc.embedding, ) + written += 1 - except weaviate.exceptions.ObjectAlreadyExistsException: + except weaviate.exceptions.UnexpectedStatusCodeError: if policy == DuplicatePolicy.FAIL: duplicate_errors_ids.append(doc.id) if duplicate_errors_ids: @@ -397,37 +372,21 @@ def write_documents(self, documents: List[Document], policy: DuplicatePolicy = D return self._write(documents, policy) def delete_documents(self, document_ids: List[str]) -> None: - self._client.batch.delete_objects( - class_name=self._collection_settings["class"], - where={ - "path": ["id"], - "operator": "ContainsAny", - "valueTextArray": [generate_uuid5(doc_id) for doc_id in document_ids], - }, - ) + weaviate_ids = [generate_uuid5(doc_id) for doc_id in document_ids] + self._collection.data.delete_many(where=weaviate.classes.query.Filter.by_id().contains_any(weaviate_ids)) def _bm25_retrieval( self, query: str, filters: Optional[Dict[str, Any]] = None, top_k: Optional[int] = None ) -> List[Document]: - collection_name = self._collection_settings["class"] - properties = self._client.schema.get(self._collection_settings["class"]).get("properties", []) - properties = [prop["name"] for prop in properties] - - query_builder = ( - self._client.query.get(collection_name, properties=properties) - .with_bm25(query=query, properties=["content"]) - .with_additional(["vector"]) + result = self._collection.query.bm25( + query=query, + filters=convert_filters(filters) if filters else None, + limit=top_k, + include_vector=True, + query_properties=["content"], ) - if filters: - query_builder = query_builder.with_where(convert_filters(filters)) - - if top_k: - query_builder = query_builder.with_limit(top_k) - - result = query_builder.do() - - return [self._to_document(doc) for doc in result["data"]["Get"][collection_name]] + return [self._to_document(doc) for doc in result.objects] def _embedding_retrieval( self, @@ -441,30 +400,13 @@ def _embedding_retrieval( msg = "Can't use 'distance' and 'certainty' parameters together" raise ValueError(msg) - collection_name = self._collection_settings["class"] - properties = self._client.schema.get(self._collection_settings["class"]).get("properties", []) - properties = [prop["name"] for prop in properties] - - near_vector: Dict[str, Union[float, List[float]]] = { - "vector": query_embedding, - } - if distance is not None: - near_vector["distance"] = distance - - if certainty is not None: - near_vector["certainty"] = certainty - - query_builder = ( - self._client.query.get(collection_name, properties=properties) - .with_near_vector(near_vector) - .with_additional(["vector"]) + result = self._collection.query.near_vector( + near_vector=query_embedding, + distance=distance, + certainty=certainty, + include_vector=True, + filters=convert_filters(filters) if filters else None, + limit=top_k, ) - if filters: - query_builder = query_builder.with_where(convert_filters(filters)) - - if top_k: - query_builder = query_builder.with_limit(top_k) - - result = query_builder.do() - return [self._to_document(doc) for doc in result["data"]["Get"][collection_name]] + return [self._to_document(doc) for doc in result.objects] diff --git a/integrations/weaviate/tests/test_bm25_retriever.py b/integrations/weaviate/tests/test_bm25_retriever.py index 83f90735b..23b7c8f92 100644 --- a/integrations/weaviate/tests/test_bm25_retriever.py +++ b/integrations/weaviate/tests/test_bm25_retriever.py @@ -38,11 +38,7 @@ def test_to_dict(_mock_weaviate): ], }, "auth_client_secret": None, - "timeout_config": (10, 60), - "proxies": None, - "trust_env": False, "additional_headers": None, - "startup_period": 5, "embedded_options": None, "additional_config": None, }, @@ -76,11 +72,7 @@ def test_from_dict(_mock_weaviate): ], }, "auth_client_secret": None, - "timeout_config": (10, 60), - "proxies": None, - "trust_env": False, "additional_headers": None, - "startup_period": 5, "embedded_options": None, "additional_config": None, }, diff --git a/integrations/weaviate/tests/test_document_store.py b/integrations/weaviate/tests/test_document_store.py index a2b32d578..4c1659a86 100644 --- a/integrations/weaviate/tests/test_document_store.py +++ b/integrations/weaviate/tests/test_document_store.py @@ -7,6 +7,7 @@ from dateutil import parser from haystack.dataclasses.byte_stream import ByteStream from haystack.dataclasses.document import Document +from haystack.document_stores.errors import DocumentStoreError from haystack.testing.document_store import ( TEST_EMBEDDING_1, TEST_EMBEDDING_2, @@ -24,8 +25,10 @@ from numpy import array_equal as np_array_equal from numpy import float32 as np_float32 from pandas import DataFrame -from weaviate.auth import AuthApiKey as WeaviateAuthApiKey -from weaviate.config import Config +from weaviate.collections.classes.data import DataObject + +# from weaviate.auth import AuthApiKey as WeaviateAuthApiKey +from weaviate.config import AdditionalConfig, ConnectionConfig, Proxies, Timeout from weaviate.embedded import ( DEFAULT_BINARY_PATH, DEFAULT_GRPC_PORT, @@ -53,7 +56,7 @@ def document_store(self, request) -> WeaviateDocumentStore: collection_settings=collection_settings, ) yield store - store._client.schema.delete_class(collection_settings["class"]) + store._client.collections.delete(collection_settings["class"]) @pytest.fixture def filterable_docs(self) -> List[Document]: @@ -145,49 +148,48 @@ def assert_documents_are_equal(self, received: List[Document], expected: List[Do for key in meta_keys: assert received_meta.get(key) == expected_meta.get(key) - @patch("haystack_integrations.document_stores.weaviate.document_store.weaviate.Client") + @patch("haystack_integrations.document_stores.weaviate.document_store.weaviate.WeaviateClient") def test_init(self, mock_weaviate_client_class, monkeypatch): mock_client = MagicMock() - mock_client.schema.exists.return_value = False + mock_client.collections.exists.return_value = False mock_weaviate_client_class.return_value = mock_client monkeypatch.setenv("WEAVIATE_API_KEY", "my_api_key") WeaviateDocumentStore( - url="http://localhost:8080", collection_settings={"class": "My_collection"}, auth_client_secret=AuthApiKey(), - proxies={"http": "http://proxy:1234"}, additional_headers={"X-HuggingFace-Api-Key": "MY_HUGGINGFACE_KEY"}, embedded_options=EmbeddedOptions( persistence_data_path=DEFAULT_PERSISTENCE_DATA_PATH, binary_path=DEFAULT_BINARY_PATH, - version="1.23.0", + version="1.23.7", hostname="127.0.0.1", ), - additional_config=Config(grpc_port_experimental=12345), + additional_config=AdditionalConfig( + proxies={"http": "http://proxy:1234"}, trust_env=False, timeout=(10, 60) + ), ) # Verify client is created with correct parameters + mock_weaviate_client_class.assert_called_once_with( - url="http://localhost:8080", - auth_client_secret=WeaviateAuthApiKey("my_api_key"), - timeout_config=(10, 60), - proxies={"http": "http://proxy:1234"}, - trust_env=False, + auth_client_secret=AuthApiKey().resolve_value(), + connection_params=None, additional_headers={"X-HuggingFace-Api-Key": "MY_HUGGINGFACE_KEY"}, - startup_period=5, embedded_options=EmbeddedOptions( persistence_data_path=DEFAULT_PERSISTENCE_DATA_PATH, binary_path=DEFAULT_BINARY_PATH, - version="1.23.0", + version="1.23.7", hostname="127.0.0.1", ), - additional_config=Config(grpc_port_experimental=12345), + skip_init_checks=False, + additional_config=AdditionalConfig( + proxies={"http": "http://proxy:1234"}, trust_env=False, timeout=(10, 60) + ), ) # Verify collection is created - mock_client.schema.get.assert_called_once() - mock_client.schema.exists.assert_called_once_with("My_collection") - mock_client.schema.create_class.assert_called_once_with( + mock_client.collections.exists.assert_called_once_with("My_collection") + mock_client.collections.create_from_dict.assert_called_once_with( {"class": "My_collection", "properties": DOCUMENT_COLLECTION_PROPERTIES} ) @@ -197,7 +199,6 @@ def test_to_dict(self, _mock_weaviate, monkeypatch): document_store = WeaviateDocumentStore( url="http://localhost:8080", auth_client_secret=AuthApiKey(), - proxies={"http": "http://proxy:1234"}, additional_headers={"X-HuggingFace-Api-Key": "MY_HUGGINGFACE_KEY"}, embedded_options=EmbeddedOptions( persistence_data_path=DEFAULT_PERSISTENCE_DATA_PATH, @@ -205,7 +206,12 @@ def test_to_dict(self, _mock_weaviate, monkeypatch): version="1.23.0", hostname="127.0.0.1", ), - additional_config=Config(grpc_port_experimental=12345), + additional_config=AdditionalConfig( + connection=ConnectionConfig(), + timeout=(30, 90), + trust_env=False, + proxies={"http": "http://proxy:1234"}, + ), ) assert document_store.to_dict() == { "type": "haystack_integrations.document_stores.weaviate.document_store.WeaviateDocumentStore", @@ -229,11 +235,7 @@ def test_to_dict(self, _mock_weaviate, monkeypatch): "api_key": {"env_vars": ["WEAVIATE_API_KEY"], "strict": True, "type": "env_var"} }, }, - "timeout_config": (10, 60), - "proxies": {"http": "http://proxy:1234"}, - "trust_env": False, "additional_headers": {"X-HuggingFace-Api-Key": "MY_HUGGINGFACE_KEY"}, - "startup_period": 5, "embedded_options": { "persistence_data_path": DEFAULT_PERSISTENCE_DATA_PATH, "binary_path": DEFAULT_BINARY_PATH, @@ -244,11 +246,14 @@ def test_to_dict(self, _mock_weaviate, monkeypatch): "grpc_port": DEFAULT_GRPC_PORT, }, "additional_config": { - "grpc_port_experimental": 12345, - "connection_config": { + "connection": { "session_pool_connections": 20, - "session_pool_maxsize": 20, + "session_pool_maxsize": 100, + "session_pool_max_retries": 3, }, + "proxies": {"http": "http://proxy:1234", "https": None, "grpc": None}, + "timeout": [30, 90], + "trust_env": False, }, }, } @@ -268,11 +273,7 @@ def test_from_dict(self, _mock_weaviate, monkeypatch): "api_key": {"env_vars": ["WEAVIATE_API_KEY"], "strict": True, "type": "env_var"} }, }, - "timeout_config": [10, 60], - "proxies": {"http": "http://proxy:1234"}, - "trust_env": False, "additional_headers": {"X-HuggingFace-Api-Key": "MY_HUGGINGFACE_KEY"}, - "startup_period": 5, "embedded_options": { "persistence_data_path": DEFAULT_PERSISTENCE_DATA_PATH, "binary_path": DEFAULT_BINARY_PATH, @@ -283,11 +284,13 @@ def test_from_dict(self, _mock_weaviate, monkeypatch): "grpc_port": DEFAULT_GRPC_PORT, }, "additional_config": { - "grpc_port_experimental": 12345, - "connection_config": { + "connection": { "session_pool_connections": 20, "session_pool_maxsize": 20, }, + "proxies": {"http": "http://proxy:1234"}, + "timeout": [10, 60], + "trust_env": False, }, }, } @@ -307,11 +310,10 @@ def test_from_dict(self, _mock_weaviate, monkeypatch): ], } assert document_store._auth_client_secret == AuthApiKey() - assert document_store._timeout_config == (10, 60) - assert document_store._proxies == {"http": "http://proxy:1234"} - assert not document_store._trust_env + assert document_store._additional_config.timeout == Timeout(query=10, insert=60) + assert document_store._additional_config.proxies == Proxies(http="http://proxy:1234", https=None, grpc=None) + assert not document_store._additional_config.trust_env assert document_store._additional_headers == {"X-HuggingFace-Api-Key": "MY_HUGGINGFACE_KEY"} - assert document_store._startup_period == 5 assert document_store._embedded_options.persistence_data_path == DEFAULT_PERSISTENCE_DATA_PATH assert document_store._embedded_options.binary_path == DEFAULT_BINARY_PATH assert document_store._embedded_options.version == "1.23.0" @@ -319,9 +321,8 @@ def test_from_dict(self, _mock_weaviate, monkeypatch): assert document_store._embedded_options.hostname == "127.0.0.1" assert document_store._embedded_options.additional_env_vars is None assert document_store._embedded_options.grpc_port == DEFAULT_GRPC_PORT - assert document_store._additional_config.grpc_port_experimental == 12345 - assert document_store._additional_config.connection_config.session_pool_connections == 20 - assert document_store._additional_config.connection_config.session_pool_maxsize == 20 + assert document_store._additional_config.connection.session_pool_connections == 20 + assert document_store._additional_config.connection.session_pool_maxsize == 20 def test_to_data_object(self, document_store, test_files_path): doc = Document(content="test doc") @@ -353,18 +354,18 @@ def test_to_data_object(self, document_store, test_files_path): def test_to_document(self, document_store, test_files_path): image = ByteStream.from_file_path(test_files_path / "robot1.jpg", mime_type="image/jpeg") - data = { - "_additional": { - "vector": [1, 2, 3], + data = DataObject( + properties={ + "_original_id": "123", + "content": "some content", + "blob_data": base64.b64encode(image.data).decode(), + "blob_mime_type": "image/jpeg", + "dataframe": None, + "score": None, + "key": "value", }, - "_original_id": "123", - "content": "some content", - "blob_data": base64.b64encode(image.data).decode(), - "blob_mime_type": "image/jpeg", - "dataframe": None, - "score": None, - "meta": {"key": "value"}, - } + vector={"default": [1, 2, 3]}, + ) doc = document_store._to_document(data) assert doc.id == "123" @@ -626,3 +627,22 @@ def test_embedding_retrieval_with_certainty(self, document_store): def test_embedding_retrieval_with_distance_and_certainty(self, document_store): with pytest.raises(ValueError): document_store._embedding_retrieval(query_embedding=[], distance=0.1, certainty=0.1) + + def test_filter_documents_below_default_limit(self, document_store): + docs = [] + for index in range(9998): + docs.append(Document(content="This is some content", meta={"index": index})) + document_store.write_documents(docs) + result = document_store.filter_documents( + {"field": "content", "operator": "==", "value": "This is some content"} + ) + + assert len(result) == 9998 + + def test_filter_documents_over_default_limit(self, document_store): + docs = [] + for index in range(10000): + docs.append(Document(content="This is some content", meta={"index": index})) + document_store.write_documents(docs) + with pytest.raises(DocumentStoreError): + document_store.filter_documents({"field": "content", "operator": "==", "value": "This is some content"}) diff --git a/integrations/weaviate/tests/test_embedding_retriever.py b/integrations/weaviate/tests/test_embedding_retriever.py index 7f07d8a24..a406c40db 100644 --- a/integrations/weaviate/tests/test_embedding_retriever.py +++ b/integrations/weaviate/tests/test_embedding_retriever.py @@ -49,11 +49,7 @@ def test_to_dict(_mock_weaviate): ], }, "auth_client_secret": None, - "timeout_config": (10, 60), - "proxies": None, - "trust_env": False, "additional_headers": None, - "startup_period": 5, "embedded_options": None, "additional_config": None, }, @@ -89,11 +85,7 @@ def test_from_dict(_mock_weaviate): ], }, "auth_client_secret": None, - "timeout_config": (10, 60), - "proxies": None, - "trust_env": False, "additional_headers": None, - "startup_period": 5, "embedded_options": None, "additional_config": None, }, diff --git a/integrations/weaviate/tests/test_filters.py b/integrations/weaviate/tests/test_filters.py index cf38d84be..c32d69e2f 100644 --- a/integrations/weaviate/tests/test_filters.py +++ b/integrations/weaviate/tests/test_filters.py @@ -19,7 +19,7 @@ def test_invert_conditions(): inverted = _invert_condition(filters) assert inverted == { - "operator": "AND", + "operator": "OR", "conditions": [ {"field": "meta.number", "operator": "!=", "value": 100}, {"field": "meta.name", "operator": "!=", "value": "name_0"}, From b0b71e4f618849889f32131c01e0aaa648caf162 Mon Sep 17 00:00:00 2001 From: Stefano Fiorucci Date: Wed, 13 Mar 2024 13:35:01 +0100 Subject: [PATCH 2/2] Refactor tests (#574) * first refactorings * separate unit tests in pgvector * small change to weaviate * fix format * usefixtures when possible --- .../chroma/tests/test_document_store.py | 30 -- .../tests/test_cohere_chat_generator.py | 12 - integrations/deepeval/tests/test_evaluator.py | 1 + .../tests/test_document_store.py | 58 +-- .../mongodb_atlas/tests/test_retriever.py | 52 ++- .../opensearch/tests/test_document_store.py | 53 +-- integrations/pgvector/tests/conftest.py | 36 ++ .../pgvector/tests/test_document_store.py | 369 +++++++++--------- .../tests/test_embedding_retrieval.py | 1 + integrations/pgvector/tests/test_filters.py | 226 ++++++----- integrations/pgvector/tests/test_retriever.py | 25 +- .../weaviate/tests/test_document_store.py | 1 + 12 files changed, 453 insertions(+), 411 deletions(-) diff --git a/integrations/chroma/tests/test_document_store.py b/integrations/chroma/tests/test_document_store.py index 8d61e63ed..5b827a984 100644 --- a/integrations/chroma/tests/test_document_store.py +++ b/integrations/chroma/tests/test_document_store.py @@ -60,7 +60,6 @@ def assert_documents_are_equal(self, received: List[Document], expected: List[Do assert doc_received.content == doc_expected.content assert doc_received.meta == doc_expected.meta - @pytest.mark.unit def test_ne_filter(self, document_store: ChromaDocumentStore, filterable_docs: List[Document]): """ We customize this test because Chroma consider "not equal" true when @@ -72,14 +71,12 @@ def test_ne_filter(self, document_store: ChromaDocumentStore, filterable_docs: L result, [doc for doc in filterable_docs if doc.meta.get("page", "100") != "100"] ) - @pytest.mark.unit def test_delete_empty(self, document_store: ChromaDocumentStore): """ Deleting a non-existing document should not raise with Chroma """ document_store.delete_documents(["test"]) - @pytest.mark.unit def test_delete_not_empty_nonexisting(self, document_store: ChromaDocumentStore): """ Deleting a non-existing document should not raise with Chroma @@ -131,144 +128,117 @@ def test_same_collection_name_reinitialization(self): ChromaDocumentStore("test_name") @pytest.mark.skip(reason="Filter on array contents is not supported.") - @pytest.mark.unit def test_filter_document_array(self, document_store: ChromaDocumentStore, filterable_docs: List[Document]): pass @pytest.mark.skip(reason="Filter on dataframe contents is not supported.") - @pytest.mark.unit def test_filter_document_dataframe(self, document_store: ChromaDocumentStore, filterable_docs: List[Document]): pass @pytest.mark.skip(reason="Filter on table contents is not supported.") - @pytest.mark.unit def test_eq_filter_table(self, document_store: ChromaDocumentStore, filterable_docs: List[Document]): pass @pytest.mark.skip(reason="Filter on embedding value is not supported.") - @pytest.mark.unit def test_eq_filter_embedding(self, document_store: ChromaDocumentStore, filterable_docs: List[Document]): pass @pytest.mark.skip(reason="$in operator is not supported.") - @pytest.mark.unit def test_in_filter_explicit(self, document_store: ChromaDocumentStore, filterable_docs: List[Document]): pass @pytest.mark.skip(reason="$in operator is not supported. Filter on table contents is not supported.") - @pytest.mark.unit def test_in_filter_table(self, document_store: ChromaDocumentStore, filterable_docs: List[Document]): pass @pytest.mark.skip(reason="$in operator is not supported.") - @pytest.mark.unit def test_in_filter_embedding(self, document_store: ChromaDocumentStore, filterable_docs: List[Document]): pass @pytest.mark.skip(reason="Filter on table contents is not supported.") - @pytest.mark.unit def test_ne_filter_table(self, document_store: ChromaDocumentStore, filterable_docs: List[Document]): pass @pytest.mark.skip(reason="Filter on embedding value is not supported.") - @pytest.mark.unit def test_ne_filter_embedding(self, document_store: ChromaDocumentStore, filterable_docs: List[Document]): pass @pytest.mark.skip(reason="$nin operator is not supported. Filter on table contents is not supported.") - @pytest.mark.unit def test_nin_filter_table(self, document_store: ChromaDocumentStore, filterable_docs: List[Document]): pass @pytest.mark.skip(reason="$nin operator is not supported. Filter on embedding value is not supported.") - @pytest.mark.unit def test_nin_filter_embedding(self, document_store: ChromaDocumentStore, filterable_docs: List[Document]): pass @pytest.mark.skip(reason="$nin operator is not supported.") - @pytest.mark.unit def test_nin_filter(self, document_store: ChromaDocumentStore, filterable_docs: List[Document]): pass @pytest.mark.skip(reason="Filter syntax not supported.") - @pytest.mark.unit def test_filter_simple_implicit_and_with_multi_key_dict( self, document_store: ChromaDocumentStore, filterable_docs: List[Document] ): pass @pytest.mark.skip(reason="Filter syntax not supported.") - @pytest.mark.unit def test_filter_simple_explicit_and_with_multikey_dict( self, document_store: ChromaDocumentStore, filterable_docs: List[Document] ): pass @pytest.mark.skip(reason="Filter syntax not supported.") - @pytest.mark.unit def test_filter_simple_explicit_and_with_list( self, document_store: ChromaDocumentStore, filterable_docs: List[Document] ): pass @pytest.mark.skip(reason="Filter syntax not supported.") - @pytest.mark.unit def test_filter_simple_implicit_and(self, document_store: ChromaDocumentStore, filterable_docs: List[Document]): pass @pytest.mark.skip(reason="Filter syntax not supported.") - @pytest.mark.unit def test_filter_nested_explicit_and(self, document_store: ChromaDocumentStore, filterable_docs: List[Document]): pass @pytest.mark.skip(reason="Filter syntax not supported.") - @pytest.mark.unit def test_filter_nested_implicit_and(self, document_store: ChromaDocumentStore, filterable_docs: List[Document]): pass @pytest.mark.skip(reason="Filter syntax not supported.") - @pytest.mark.unit def test_filter_simple_or(self, document_store: ChromaDocumentStore, filterable_docs: List[Document]): pass @pytest.mark.skip(reason="Filter syntax not supported.") - @pytest.mark.unit def test_filter_nested_or(self, document_store: ChromaDocumentStore, filterable_docs: List[Document]): pass @pytest.mark.skip(reason="Filter on table contents is not supported.") - @pytest.mark.unit def test_filter_nested_and_or_explicit(self, document_store: ChromaDocumentStore, filterable_docs: List[Document]): pass @pytest.mark.skip(reason="Filter syntax not supported.") - @pytest.mark.unit def test_filter_nested_and_or_implicit(self, document_store: ChromaDocumentStore, filterable_docs: List[Document]): pass @pytest.mark.skip(reason="Filter syntax not supported.") - @pytest.mark.unit def test_filter_nested_or_and(self, document_store: ChromaDocumentStore, filterable_docs: List[Document]): pass @pytest.mark.skip(reason="Filter syntax not supported.") - @pytest.mark.unit def test_filter_nested_multiple_identical_operators_same_level( self, document_store: ChromaDocumentStore, filterable_docs: List[Document] ): pass @pytest.mark.skip(reason="Duplicate policy not supported.") - @pytest.mark.unit def test_write_duplicate_fail(self, document_store: ChromaDocumentStore): pass @pytest.mark.skip(reason="Duplicate policy not supported.") - @pytest.mark.unit def test_write_duplicate_skip(self, document_store: ChromaDocumentStore): pass @pytest.mark.skip(reason="Duplicate policy not supported.") - @pytest.mark.unit def test_write_duplicate_overwrite(self, document_store: ChromaDocumentStore): pass diff --git a/integrations/cohere/tests/test_cohere_chat_generator.py b/integrations/cohere/tests/test_cohere_chat_generator.py index 7fd588fec..9a822856e 100644 --- a/integrations/cohere/tests/test_cohere_chat_generator.py +++ b/integrations/cohere/tests/test_cohere_chat_generator.py @@ -53,7 +53,6 @@ def chat_messages(): class TestCohereChatGenerator: - @pytest.mark.unit def test_init_default(self, monkeypatch): monkeypatch.setenv("COHERE_API_KEY", "test-api-key") @@ -64,14 +63,12 @@ def test_init_default(self, monkeypatch): assert component.api_base_url == cohere.COHERE_API_URL assert not component.generation_kwargs - @pytest.mark.unit def test_init_fail_wo_api_key(self, monkeypatch): monkeypatch.delenv("COHERE_API_KEY", raising=False) monkeypatch.delenv("CO_API_KEY", raising=False) with pytest.raises(ValueError): CohereChatGenerator() - @pytest.mark.unit def test_init_with_parameters(self): component = CohereChatGenerator( api_key=Secret.from_token("test-api-key"), @@ -86,7 +83,6 @@ def test_init_with_parameters(self): assert component.api_base_url == "test-base-url" assert component.generation_kwargs == {"max_tokens": 10, "some_test_param": "test-params"} - @pytest.mark.unit def test_to_dict_default(self, monkeypatch): monkeypatch.setenv("COHERE_API_KEY", "test-api-key") component = CohereChatGenerator() @@ -102,7 +98,6 @@ def test_to_dict_default(self, monkeypatch): }, } - @pytest.mark.unit def test_to_dict_with_parameters(self, monkeypatch): monkeypatch.setenv("COHERE_API_KEY", "test-api-key") monkeypatch.setenv("CO_API_KEY", "fake-api-key") @@ -125,7 +120,6 @@ def test_to_dict_with_parameters(self, monkeypatch): }, } - @pytest.mark.unit def test_to_dict_with_lambda_streaming_callback(self, monkeypatch): monkeypatch.setenv("COHERE_API_KEY", "test-api-key") component = CohereChatGenerator( @@ -146,7 +140,6 @@ def test_to_dict_with_lambda_streaming_callback(self, monkeypatch): }, } - @pytest.mark.unit def test_from_dict(self, monkeypatch): monkeypatch.setenv("COHERE_API_KEY", "fake-api-key") monkeypatch.setenv("CO_API_KEY", "fake-api-key") @@ -166,7 +159,6 @@ def test_from_dict(self, monkeypatch): assert component.api_base_url == "test-base-url" assert component.generation_kwargs == {"max_tokens": 10, "some_test_param": "test-params"} - @pytest.mark.unit def test_from_dict_fail_wo_env_var(self, monkeypatch): monkeypatch.delenv("COHERE_API_KEY", raising=False) monkeypatch.delenv("CO_API_KEY", raising=False) @@ -183,7 +175,6 @@ def test_from_dict_fail_wo_env_var(self, monkeypatch): with pytest.raises(ValueError): CohereChatGenerator.from_dict(data) - @pytest.mark.unit def test_run(self, chat_messages, mock_chat_response): # noqa: ARG002 component = CohereChatGenerator(api_key=Secret.from_token("test-api-key")) response = component.run(chat_messages) @@ -195,13 +186,11 @@ def test_run(self, chat_messages, mock_chat_response): # noqa: ARG002 assert len(response["replies"]) == 1 assert [isinstance(reply, ChatMessage) for reply in response["replies"]] - @pytest.mark.unit def test_message_to_dict(self, chat_messages): obj = CohereChatGenerator(api_key=Secret.from_token("test-api-key")) dictionary = [obj._message_to_dict(message) for message in chat_messages] assert dictionary == [{"user_name": "Chatbot", "text": "What's the capital of France"}] - @pytest.mark.unit def test_run_with_params(self, chat_messages, mock_chat_response): component = CohereChatGenerator( api_key=Secret.from_token("test-api-key"), generation_kwargs={"max_tokens": 10, "temperature": 0.5} @@ -220,7 +209,6 @@ def test_run_with_params(self, chat_messages, mock_chat_response): assert len(response["replies"]) == 1 assert [isinstance(reply, ChatMessage) for reply in response["replies"]] - @pytest.mark.unit def test_run_streaming(self, chat_messages, mock_chat_response): streaming_call_count = 0 diff --git a/integrations/deepeval/tests/test_evaluator.py b/integrations/deepeval/tests/test_evaluator.py index 8534ef687..7d1946185 100644 --- a/integrations/deepeval/tests/test_evaluator.py +++ b/integrations/deepeval/tests/test_evaluator.py @@ -270,6 +270,7 @@ def test_evaluator_outputs(metric, inputs, expected_outputs, metric_params, monk # OpenAI API. It is parameterized by the metric, the inputs to the evalutor # and the metric parameters. @pytest.mark.skipif("OPENAI_API_KEY" not in os.environ, reason="OPENAI_API_KEY not set") +@pytest.mark.integration @pytest.mark.parametrize( "metric, inputs, metric_params", [ diff --git a/integrations/elasticsearch/tests/test_document_store.py b/integrations/elasticsearch/tests/test_document_store.py index e46e76ed2..308486a78 100644 --- a/integrations/elasticsearch/tests/test_document_store.py +++ b/integrations/elasticsearch/tests/test_document_store.py @@ -15,6 +15,36 @@ from haystack_integrations.document_stores.elasticsearch import ElasticsearchDocumentStore +@patch("haystack_integrations.document_stores.elasticsearch.document_store.Elasticsearch") +def test_to_dict(_mock_elasticsearch_client): + document_store = ElasticsearchDocumentStore(hosts="some hosts") + res = document_store.to_dict() + assert res == { + "type": "haystack_integrations.document_stores.elasticsearch.document_store.ElasticsearchDocumentStore", + "init_parameters": { + "hosts": "some hosts", + "index": "default", + "embedding_similarity_function": "cosine", + }, + } + + +@patch("haystack_integrations.document_stores.elasticsearch.document_store.Elasticsearch") +def test_from_dict(_mock_elasticsearch_client): + data = { + "type": "haystack_integrations.document_stores.elasticsearch.document_store.ElasticsearchDocumentStore", + "init_parameters": { + "hosts": "some hosts", + "index": "default", + "embedding_similarity_function": "cosine", + }, + } + document_store = ElasticsearchDocumentStore.from_dict(data) + assert document_store._hosts == "some hosts" + assert document_store._index == "default" + assert document_store._embedding_similarity_function == "cosine" + + @pytest.mark.integration class TestDocumentStore(DocumentStoreBaseTests): """ @@ -67,34 +97,6 @@ def assert_documents_are_equal(self, received: List[Document], expected: List[Do super().assert_documents_are_equal(received, expected) - @patch("haystack_integrations.document_stores.elasticsearch.document_store.Elasticsearch") - def test_to_dict(self, _mock_elasticsearch_client): - document_store = ElasticsearchDocumentStore(hosts="some hosts") - res = document_store.to_dict() - assert res == { - "type": "haystack_integrations.document_stores.elasticsearch.document_store.ElasticsearchDocumentStore", - "init_parameters": { - "hosts": "some hosts", - "index": "default", - "embedding_similarity_function": "cosine", - }, - } - - @patch("haystack_integrations.document_stores.elasticsearch.document_store.Elasticsearch") - def test_from_dict(self, _mock_elasticsearch_client): - data = { - "type": "haystack_integrations.document_stores.elasticsearch.document_store.ElasticsearchDocumentStore", - "init_parameters": { - "hosts": "some hosts", - "index": "default", - "embedding_similarity_function": "cosine", - }, - } - document_store = ElasticsearchDocumentStore.from_dict(data) - assert document_store._hosts == "some hosts" - assert document_store._index == "default" - assert document_store._embedding_similarity_function == "cosine" - def test_user_agent_header(self, document_store: ElasticsearchDocumentStore): assert document_store._client._headers["user-agent"].startswith("haystack-py-ds/") diff --git a/integrations/mongodb_atlas/tests/test_retriever.py b/integrations/mongodb_atlas/tests/test_retriever.py index ec44513e2..4ef5222ce 100644 --- a/integrations/mongodb_atlas/tests/test_retriever.py +++ b/integrations/mongodb_atlas/tests/test_retriever.py @@ -1,7 +1,7 @@ # SPDX-FileCopyrightText: 2023-present deepset GmbH # # SPDX-License-Identifier: Apache-2.0 -from unittest.mock import Mock +from unittest.mock import MagicMock, Mock, patch import pytest from haystack.dataclasses import Document @@ -10,34 +10,48 @@ from haystack_integrations.document_stores.mongodb_atlas import MongoDBAtlasDocumentStore -@pytest.fixture -def document_store(): - store = MongoDBAtlasDocumentStore( - database_name="haystack_integration_test", - collection_name="test_embeddings_collection", - vector_search_index="cosine_index", - ) - return store +class TestRetriever: + @pytest.fixture + def mock_client(self): + with patch( + "haystack_integrations.document_stores.mongodb_atlas.document_store.MongoClient" + ) as mock_mongo_client: + mock_connection = MagicMock() + mock_database = MagicMock() + mock_collection_names = MagicMock(return_value=["test_embeddings_collection"]) + mock_database.list_collection_names = mock_collection_names + mock_connection.__getitem__.return_value = mock_database + mock_mongo_client.return_value = mock_connection + yield mock_mongo_client -class TestRetriever: - def test_init_default(self, document_store: MongoDBAtlasDocumentStore): - retriever = MongoDBAtlasEmbeddingRetriever(document_store=document_store) - assert retriever.document_store == document_store + def test_init_default(self): + mock_store = Mock(spec=MongoDBAtlasDocumentStore) + retriever = MongoDBAtlasEmbeddingRetriever(document_store=mock_store) + assert retriever.document_store == mock_store assert retriever.filters == {} assert retriever.top_k == 10 - def test_init(self, document_store: MongoDBAtlasDocumentStore): + def test_init(self): + mock_store = Mock(spec=MongoDBAtlasDocumentStore) retriever = MongoDBAtlasEmbeddingRetriever( - document_store=document_store, + document_store=mock_store, filters={"field": "value"}, top_k=5, ) - assert retriever.document_store == document_store + assert retriever.document_store == mock_store assert retriever.filters == {"field": "value"} assert retriever.top_k == 5 - def test_to_dict(self, document_store: MongoDBAtlasDocumentStore): + def test_to_dict(self, mock_client, monkeypatch): # noqa: ARG002 mock_client appears unused but is required + monkeypatch.setenv("MONGO_CONNECTION_STRING", "test_conn_str") + + document_store = MongoDBAtlasDocumentStore( + database_name="haystack_integration_test", + collection_name="test_embeddings_collection", + vector_search_index="cosine_index", + ) + retriever = MongoDBAtlasEmbeddingRetriever(document_store=document_store, filters={"field": "value"}, top_k=5) res = retriever.to_dict() assert res == { @@ -61,7 +75,9 @@ def test_to_dict(self, document_store: MongoDBAtlasDocumentStore): }, } - def test_from_dict(self): + def test_from_dict(self, mock_client, monkeypatch): # noqa: ARG002 mock_client appears unused but is required + monkeypatch.setenv("MONGO_CONNECTION_STRING", "test_conn_str") + data = { "type": "haystack_integrations.components.retrievers.mongodb_atlas.embedding_retriever.MongoDBAtlasEmbeddingRetriever", # noqa: E501 "init_parameters": { diff --git a/integrations/opensearch/tests/test_document_store.py b/integrations/opensearch/tests/test_document_store.py index e3a314141..bc0d1c434 100644 --- a/integrations/opensearch/tests/test_document_store.py +++ b/integrations/opensearch/tests/test_document_store.py @@ -14,6 +14,34 @@ from opensearchpy.exceptions import RequestError +@patch("haystack_integrations.document_stores.opensearch.document_store.OpenSearch") +def test_to_dict(_mock_opensearch_client): + document_store = OpenSearchDocumentStore(hosts="some hosts") + res = document_store.to_dict() + assert res == { + "type": "haystack_integrations.document_stores.opensearch.document_store.OpenSearchDocumentStore", + "init_parameters": { + "hosts": "some hosts", + "index": "default", + }, + } + + +@patch("haystack_integrations.document_stores.opensearch.document_store.OpenSearch") +def test_from_dict(_mock_opensearch_client): + data = { + "type": "haystack_integrations.document_stores.opensearch.document_store.OpenSearchDocumentStore", + "init_parameters": { + "hosts": "some hosts", + "index": "default", + }, + } + document_store = OpenSearchDocumentStore.from_dict(data) + assert document_store._hosts == "some hosts" + assert document_store._index == "default" + + +@pytest.mark.integration class TestDocumentStore(DocumentStoreBaseTests): """ Common test cases will be provided by `DocumentStoreBaseTests` but @@ -87,31 +115,6 @@ def assert_documents_are_equal(self, received: List[Document], expected: List[Do super().assert_documents_are_equal(received, expected) - @patch("haystack_integrations.document_stores.opensearch.document_store.OpenSearch") - def test_to_dict(self, _mock_opensearch_client): - document_store = OpenSearchDocumentStore(hosts="some hosts") - res = document_store.to_dict() - assert res == { - "type": "haystack_integrations.document_stores.opensearch.document_store.OpenSearchDocumentStore", - "init_parameters": { - "hosts": "some hosts", - "index": "default", - }, - } - - @patch("haystack_integrations.document_stores.opensearch.document_store.OpenSearch") - def test_from_dict(self, _mock_opensearch_client): - data = { - "type": "haystack_integrations.document_stores.opensearch.document_store.OpenSearchDocumentStore", - "init_parameters": { - "hosts": "some hosts", - "index": "default", - }, - } - document_store = OpenSearchDocumentStore.from_dict(data) - assert document_store._hosts == "some hosts" - assert document_store._index == "default" - def test_write_documents(self, document_store: OpenSearchDocumentStore): docs = [Document(id="1")] assert document_store.write_documents(docs) == 1 diff --git a/integrations/pgvector/tests/conftest.py b/integrations/pgvector/tests/conftest.py index 068f2ac54..94b35a04d 100644 --- a/integrations/pgvector/tests/conftest.py +++ b/integrations/pgvector/tests/conftest.py @@ -1,4 +1,5 @@ import os +from unittest.mock import patch import pytest from haystack_integrations.document_stores.pgvector import PgvectorDocumentStore @@ -24,3 +25,38 @@ def document_store(request): yield store store.delete_table() + + +@pytest.fixture +def patches_for_unit_tests(): + with patch("haystack_integrations.document_stores.pgvector.document_store.connect") as mock_connect, patch( + "haystack_integrations.document_stores.pgvector.document_store.register_vector" + ) as mock_register, patch( + "haystack_integrations.document_stores.pgvector.document_store.PgvectorDocumentStore.delete_table" + ) as mock_delete, patch( + "haystack_integrations.document_stores.pgvector.document_store.PgvectorDocumentStore._create_table_if_not_exists" + ) as mock_create, patch( + "haystack_integrations.document_stores.pgvector.document_store.PgvectorDocumentStore._handle_hnsw" + ) as mock_hnsw: + + yield mock_connect, mock_register, mock_delete, mock_create, mock_hnsw + + +@pytest.fixture +def mock_store(patches_for_unit_tests, monkeypatch): # noqa: ARG001 patches are not explicitly called but necessary + monkeypatch.setenv("PG_CONN_STR", "some-connection-string") + table_name = "haystack" + embedding_dimension = 768 + vector_function = "cosine_similarity" + recreate_table = True + search_strategy = "exact_nearest_neighbor" + + store = PgvectorDocumentStore( + table_name=table_name, + embedding_dimension=embedding_dimension, + vector_function=vector_function, + recreate_table=recreate_table, + search_strategy=search_strategy, + ) + + yield store diff --git a/integrations/pgvector/tests/test_document_store.py b/integrations/pgvector/tests/test_document_store.py index 1e158f134..bf5ccd5d4 100644 --- a/integrations/pgvector/tests/test_document_store.py +++ b/integrations/pgvector/tests/test_document_store.py @@ -13,6 +13,7 @@ from pandas import DataFrame +@pytest.mark.integration class TestDocumentStore(CountDocumentsTest, WriteDocumentsTest, DeleteDocumentsTest): def test_write_documents(self, document_store: PgvectorDocumentStore): docs = [Document(id="1")] @@ -25,7 +26,6 @@ def test_write_blob(self, document_store: PgvectorDocumentStore): docs = [Document(id="1", blob=bytestream)] document_store.write_documents(docs) - # TODO: update when filters are implemented retrieved_docs = document_store.filter_documents() assert retrieved_docs == docs @@ -35,185 +35,194 @@ def test_write_dataframe(self, document_store: PgvectorDocumentStore): document_store.write_documents(docs) - # TODO: update when filters are implemented retrieved_docs = document_store.filter_documents() assert retrieved_docs == docs - def test_init(self): - document_store = PgvectorDocumentStore( - table_name="my_table", - embedding_dimension=512, - vector_function="l2_distance", - recreate_table=True, - search_strategy="hnsw", - hnsw_recreate_index_if_exists=True, - hnsw_index_creation_kwargs={"m": 32, "ef_construction": 128}, - hnsw_ef_search=50, - ) - - assert document_store.table_name == "my_table" - assert document_store.embedding_dimension == 512 - assert document_store.vector_function == "l2_distance" - assert document_store.recreate_table - assert document_store.search_strategy == "hnsw" - assert document_store.hnsw_recreate_index_if_exists - assert document_store.hnsw_index_creation_kwargs == {"m": 32, "ef_construction": 128} - assert document_store.hnsw_ef_search == 50 - - def test_to_dict(self): - document_store = PgvectorDocumentStore( - table_name="my_table", - embedding_dimension=512, - vector_function="l2_distance", - recreate_table=True, - search_strategy="hnsw", - hnsw_recreate_index_if_exists=True, - hnsw_index_creation_kwargs={"m": 32, "ef_construction": 128}, - hnsw_ef_search=50, - ) - - assert document_store.to_dict() == { - "type": "haystack_integrations.document_stores.pgvector.document_store.PgvectorDocumentStore", - "init_parameters": { - "connection_string": {"env_vars": ["PG_CONN_STR"], "strict": True, "type": "env_var"}, - "table_name": "my_table", - "embedding_dimension": 512, - "vector_function": "l2_distance", - "recreate_table": True, - "search_strategy": "hnsw", - "hnsw_recreate_index_if_exists": True, - "hnsw_index_creation_kwargs": {"m": 32, "ef_construction": 128}, - "hnsw_ef_search": 50, - }, - } - - def test_from_haystack_to_pg_documents(self): - haystack_docs = [ - Document( - id="1", - content="This is a text", - meta={"meta_key": "meta_value"}, - embedding=[0.1, 0.2, 0.3], - score=0.5, - ), - Document( - id="2", - dataframe=DataFrame({"col1": [1, 2], "col2": [3, 4]}), - meta={"meta_key": "meta_value"}, - embedding=[0.4, 0.5, 0.6], - score=0.6, - ), - Document( - id="3", - blob=ByteStream(b"test", meta={"blob_meta_key": "blob_meta_value"}, mime_type="mime_type"), - meta={"meta_key": "meta_value"}, - embedding=[0.7, 0.8, 0.9], - score=0.7, - ), - ] - - with patch( - "haystack_integrations.document_stores.pgvector.document_store.PgvectorDocumentStore.__init__" - ) as mock_init: - mock_init.return_value = None - ds = PgvectorDocumentStore(connection_string="test") - - pg_docs = ds._from_haystack_to_pg_documents(haystack_docs) - - assert pg_docs[0]["id"] == "1" - assert pg_docs[0]["content"] == "This is a text" - assert pg_docs[0]["dataframe"] is None - assert pg_docs[0]["blob_data"] is None - assert pg_docs[0]["blob_meta"] is None - assert pg_docs[0]["blob_mime_type"] is None - assert pg_docs[0]["meta"].obj == {"meta_key": "meta_value"} - assert pg_docs[0]["embedding"] == [0.1, 0.2, 0.3] - assert "score" not in pg_docs[0] - - assert pg_docs[1]["id"] == "2" - assert pg_docs[1]["content"] is None - assert pg_docs[1]["dataframe"].obj == DataFrame({"col1": [1, 2], "col2": [3, 4]}).to_json() - assert pg_docs[1]["blob_data"] is None - assert pg_docs[1]["blob_meta"] is None - assert pg_docs[1]["blob_mime_type"] is None - assert pg_docs[1]["meta"].obj == {"meta_key": "meta_value"} - assert pg_docs[1]["embedding"] == [0.4, 0.5, 0.6] - assert "score" not in pg_docs[1] - - assert pg_docs[2]["id"] == "3" - assert pg_docs[2]["content"] is None - assert pg_docs[2]["dataframe"] is None - assert pg_docs[2]["blob_data"] == b"test" - assert pg_docs[2]["blob_meta"].obj == {"blob_meta_key": "blob_meta_value"} - assert pg_docs[2]["blob_mime_type"] == "mime_type" - assert pg_docs[2]["meta"].obj == {"meta_key": "meta_value"} - assert pg_docs[2]["embedding"] == [0.7, 0.8, 0.9] - assert "score" not in pg_docs[2] - - def test_from_pg_to_haystack_documents(self): - pg_docs = [ - { - "id": "1", - "content": "This is a text", - "dataframe": None, - "blob_data": None, - "blob_meta": None, - "blob_mime_type": None, - "meta": {"meta_key": "meta_value"}, - "embedding": "[0.1, 0.2, 0.3]", - }, - { - "id": "2", - "content": None, - "dataframe": DataFrame({"col1": [1, 2], "col2": [3, 4]}).to_json(), - "blob_data": None, - "blob_meta": None, - "blob_mime_type": None, - "meta": {"meta_key": "meta_value"}, - "embedding": "[0.4, 0.5, 0.6]", - }, - { - "id": "3", - "content": None, - "dataframe": None, - "blob_data": b"test", - "blob_meta": {"blob_meta_key": "blob_meta_value"}, - "blob_mime_type": "mime_type", - "meta": {"meta_key": "meta_value"}, - "embedding": "[0.7, 0.8, 0.9]", - }, - ] - - with patch( - "haystack_integrations.document_stores.pgvector.document_store.PgvectorDocumentStore.__init__" - ) as mock_init: - mock_init.return_value = None - ds = PgvectorDocumentStore(connection_string="test") - - haystack_docs = ds._from_pg_to_haystack_documents(pg_docs) - - assert haystack_docs[0].id == "1" - assert haystack_docs[0].content == "This is a text" - assert haystack_docs[0].dataframe is None - assert haystack_docs[0].blob is None - assert haystack_docs[0].meta == {"meta_key": "meta_value"} - assert haystack_docs[0].embedding == [0.1, 0.2, 0.3] - assert haystack_docs[0].score is None - - assert haystack_docs[1].id == "2" - assert haystack_docs[1].content is None - assert haystack_docs[1].dataframe.equals(DataFrame({"col1": [1, 2], "col2": [3, 4]})) - assert haystack_docs[1].blob is None - assert haystack_docs[1].meta == {"meta_key": "meta_value"} - assert haystack_docs[1].embedding == [0.4, 0.5, 0.6] - assert haystack_docs[1].score is None - - assert haystack_docs[2].id == "3" - assert haystack_docs[2].content is None - assert haystack_docs[2].dataframe is None - assert haystack_docs[2].blob.data == b"test" - assert haystack_docs[2].blob.meta == {"blob_meta_key": "blob_meta_value"} - assert haystack_docs[2].blob.mime_type == "mime_type" - assert haystack_docs[2].meta == {"meta_key": "meta_value"} - assert haystack_docs[2].embedding == [0.7, 0.8, 0.9] - assert haystack_docs[2].score is None + +@pytest.mark.usefixtures("patches_for_unit_tests") +def test_init(monkeypatch): + monkeypatch.setenv("PG_CONN_STR", "some_connection_string") + + document_store = PgvectorDocumentStore( + table_name="my_table", + embedding_dimension=512, + vector_function="l2_distance", + recreate_table=True, + search_strategy="hnsw", + hnsw_recreate_index_if_exists=True, + hnsw_index_creation_kwargs={"m": 32, "ef_construction": 128}, + hnsw_ef_search=50, + ) + + assert document_store.table_name == "my_table" + assert document_store.embedding_dimension == 512 + assert document_store.vector_function == "l2_distance" + assert document_store.recreate_table + assert document_store.search_strategy == "hnsw" + assert document_store.hnsw_recreate_index_if_exists + assert document_store.hnsw_index_creation_kwargs == {"m": 32, "ef_construction": 128} + assert document_store.hnsw_ef_search == 50 + + +@pytest.mark.usefixtures("patches_for_unit_tests") +def test_to_dict(monkeypatch): + monkeypatch.setenv("PG_CONN_STR", "some_connection_string") + + document_store = PgvectorDocumentStore( + table_name="my_table", + embedding_dimension=512, + vector_function="l2_distance", + recreate_table=True, + search_strategy="hnsw", + hnsw_recreate_index_if_exists=True, + hnsw_index_creation_kwargs={"m": 32, "ef_construction": 128}, + hnsw_ef_search=50, + ) + + assert document_store.to_dict() == { + "type": "haystack_integrations.document_stores.pgvector.document_store.PgvectorDocumentStore", + "init_parameters": { + "connection_string": {"env_vars": ["PG_CONN_STR"], "strict": True, "type": "env_var"}, + "table_name": "my_table", + "embedding_dimension": 512, + "vector_function": "l2_distance", + "recreate_table": True, + "search_strategy": "hnsw", + "hnsw_recreate_index_if_exists": True, + "hnsw_index_creation_kwargs": {"m": 32, "ef_construction": 128}, + "hnsw_ef_search": 50, + }, + } + + +def test_from_haystack_to_pg_documents(): + haystack_docs = [ + Document( + id="1", + content="This is a text", + meta={"meta_key": "meta_value"}, + embedding=[0.1, 0.2, 0.3], + score=0.5, + ), + Document( + id="2", + dataframe=DataFrame({"col1": [1, 2], "col2": [3, 4]}), + meta={"meta_key": "meta_value"}, + embedding=[0.4, 0.5, 0.6], + score=0.6, + ), + Document( + id="3", + blob=ByteStream(b"test", meta={"blob_meta_key": "blob_meta_value"}, mime_type="mime_type"), + meta={"meta_key": "meta_value"}, + embedding=[0.7, 0.8, 0.9], + score=0.7, + ), + ] + + with patch( + "haystack_integrations.document_stores.pgvector.document_store.PgvectorDocumentStore.__init__" + ) as mock_init: + mock_init.return_value = None + ds = PgvectorDocumentStore(connection_string="test") + + pg_docs = ds._from_haystack_to_pg_documents(haystack_docs) + + assert pg_docs[0]["id"] == "1" + assert pg_docs[0]["content"] == "This is a text" + assert pg_docs[0]["dataframe"] is None + assert pg_docs[0]["blob_data"] is None + assert pg_docs[0]["blob_meta"] is None + assert pg_docs[0]["blob_mime_type"] is None + assert pg_docs[0]["meta"].obj == {"meta_key": "meta_value"} + assert pg_docs[0]["embedding"] == [0.1, 0.2, 0.3] + assert "score" not in pg_docs[0] + + assert pg_docs[1]["id"] == "2" + assert pg_docs[1]["content"] is None + assert pg_docs[1]["dataframe"].obj == DataFrame({"col1": [1, 2], "col2": [3, 4]}).to_json() + assert pg_docs[1]["blob_data"] is None + assert pg_docs[1]["blob_meta"] is None + assert pg_docs[1]["blob_mime_type"] is None + assert pg_docs[1]["meta"].obj == {"meta_key": "meta_value"} + assert pg_docs[1]["embedding"] == [0.4, 0.5, 0.6] + assert "score" not in pg_docs[1] + + assert pg_docs[2]["id"] == "3" + assert pg_docs[2]["content"] is None + assert pg_docs[2]["dataframe"] is None + assert pg_docs[2]["blob_data"] == b"test" + assert pg_docs[2]["blob_meta"].obj == {"blob_meta_key": "blob_meta_value"} + assert pg_docs[2]["blob_mime_type"] == "mime_type" + assert pg_docs[2]["meta"].obj == {"meta_key": "meta_value"} + assert pg_docs[2]["embedding"] == [0.7, 0.8, 0.9] + assert "score" not in pg_docs[2] + + +def test_from_pg_to_haystack_documents(): + pg_docs = [ + { + "id": "1", + "content": "This is a text", + "dataframe": None, + "blob_data": None, + "blob_meta": None, + "blob_mime_type": None, + "meta": {"meta_key": "meta_value"}, + "embedding": "[0.1, 0.2, 0.3]", + }, + { + "id": "2", + "content": None, + "dataframe": DataFrame({"col1": [1, 2], "col2": [3, 4]}).to_json(), + "blob_data": None, + "blob_meta": None, + "blob_mime_type": None, + "meta": {"meta_key": "meta_value"}, + "embedding": "[0.4, 0.5, 0.6]", + }, + { + "id": "3", + "content": None, + "dataframe": None, + "blob_data": b"test", + "blob_meta": {"blob_meta_key": "blob_meta_value"}, + "blob_mime_type": "mime_type", + "meta": {"meta_key": "meta_value"}, + "embedding": "[0.7, 0.8, 0.9]", + }, + ] + + with patch( + "haystack_integrations.document_stores.pgvector.document_store.PgvectorDocumentStore.__init__" + ) as mock_init: + mock_init.return_value = None + ds = PgvectorDocumentStore(connection_string="test") + + haystack_docs = ds._from_pg_to_haystack_documents(pg_docs) + + assert haystack_docs[0].id == "1" + assert haystack_docs[0].content == "This is a text" + assert haystack_docs[0].dataframe is None + assert haystack_docs[0].blob is None + assert haystack_docs[0].meta == {"meta_key": "meta_value"} + assert haystack_docs[0].embedding == [0.1, 0.2, 0.3] + assert haystack_docs[0].score is None + + assert haystack_docs[1].id == "2" + assert haystack_docs[1].content is None + assert haystack_docs[1].dataframe.equals(DataFrame({"col1": [1, 2], "col2": [3, 4]})) + assert haystack_docs[1].blob is None + assert haystack_docs[1].meta == {"meta_key": "meta_value"} + assert haystack_docs[1].embedding == [0.4, 0.5, 0.6] + assert haystack_docs[1].score is None + + assert haystack_docs[2].id == "3" + assert haystack_docs[2].content is None + assert haystack_docs[2].dataframe is None + assert haystack_docs[2].blob.data == b"test" + assert haystack_docs[2].blob.meta == {"blob_meta_key": "blob_meta_value"} + assert haystack_docs[2].blob.mime_type == "mime_type" + assert haystack_docs[2].meta == {"meta_key": "meta_value"} + assert haystack_docs[2].embedding == [0.7, 0.8, 0.9] + assert haystack_docs[2].score is None diff --git a/integrations/pgvector/tests/test_embedding_retrieval.py b/integrations/pgvector/tests/test_embedding_retrieval.py index 1d5e8e297..2c384f57c 100644 --- a/integrations/pgvector/tests/test_embedding_retrieval.py +++ b/integrations/pgvector/tests/test_embedding_retrieval.py @@ -10,6 +10,7 @@ from numpy.random import rand +@pytest.mark.integration class TestEmbeddingRetrieval: @pytest.fixture def document_store_w_hnsw_index(self, request): diff --git a/integrations/pgvector/tests/test_filters.py b/integrations/pgvector/tests/test_filters.py index 8b2dc8ec9..bda10e3c0 100644 --- a/integrations/pgvector/tests/test_filters.py +++ b/integrations/pgvector/tests/test_filters.py @@ -15,6 +15,7 @@ from psycopg.types.json import Jsonb +@pytest.mark.integration class TestFilters(FilterDocumentsTest): def assert_documents_are_equal(self, received: List[Document], expected: List[Document]): """ @@ -35,6 +36,9 @@ def assert_documents_are_equal(self, received: List[Document], expected: List[Do received_doc.embedding, expected_doc.embedding = None, None assert received_doc == expected_doc + @pytest.mark.skip(reason="NOT operator is not supported in PgvectorDocumentStore") + def test_not_operator(self, document_store, filterable_docs): ... + def test_complex_filter(self, document_store, filterable_docs): document_store.write_documents(filterable_docs) filters = { @@ -69,111 +73,119 @@ def test_complex_filter(self, document_store, filterable_docs): ], ) - @pytest.mark.skip(reason="NOT operator is not supported in PgvectorDocumentStore") - def test_not_operator(self, document_store, filterable_docs): ... - def test_treat_meta_field(self): - assert _treat_meta_field(field="meta.number", value=9) == "(meta->>'number')::integer" - assert _treat_meta_field(field="meta.number", value=[1, 2, 3]) == "(meta->>'number')::integer" - assert _treat_meta_field(field="meta.name", value="my_name") == "meta->>'name'" - assert _treat_meta_field(field="meta.name", value=["my_name"]) == "meta->>'name'" - assert _treat_meta_field(field="meta.number", value=1.1) == "(meta->>'number')::real" - assert _treat_meta_field(field="meta.number", value=[1.1, 2.2, 3.3]) == "(meta->>'number')::real" - assert _treat_meta_field(field="meta.bool", value=True) == "(meta->>'bool')::boolean" - assert _treat_meta_field(field="meta.bool", value=[True, False, True]) == "(meta->>'bool')::boolean" - - # do not cast the field if its value is not one of the known types, an empty list or None - assert _treat_meta_field(field="meta.other", value={"a": 3, "b": "example"}) == "meta->>'other'" - assert _treat_meta_field(field="meta.empty_list", value=[]) == "meta->>'empty_list'" - assert _treat_meta_field(field="meta.name", value=None) == "meta->>'name'" - - def test_comparison_condition_dataframe_jsonb_conversion(self): - dataframe = DataFrame({"a": [1, 2, 3], "b": ["a", "b", "c"]}) - condition = {"field": "meta.df", "operator": "==", "value": dataframe} - field, values = _parse_comparison_condition(condition) - assert field == "(meta.df)::jsonb = %s" - - # we check each slot of the Jsonb object because it does not implement __eq__ - assert values[0].obj == Jsonb(dataframe.to_json()).obj - assert values[0].dumps == Jsonb(dataframe.to_json()).dumps - - def test_comparison_condition_missing_operator(self): - condition = {"field": "meta.type", "value": "article"} - with pytest.raises(FilterError): - _parse_comparison_condition(condition) - - def test_comparison_condition_missing_value(self): - condition = {"field": "meta.type", "operator": "=="} - with pytest.raises(FilterError): - _parse_comparison_condition(condition) - - def test_comparison_condition_unknown_operator(self): - condition = {"field": "meta.type", "operator": "unknown", "value": "article"} - with pytest.raises(FilterError): - _parse_comparison_condition(condition) - - def test_logical_condition_missing_operator(self): - condition = {"conditions": []} - with pytest.raises(FilterError): - _parse_logical_condition(condition) - - def test_logical_condition_missing_conditions(self): - condition = {"operator": "AND"} - with pytest.raises(FilterError): - _parse_logical_condition(condition) - - def test_logical_condition_unknown_operator(self): - condition = {"operator": "unknown", "conditions": []} - with pytest.raises(FilterError): - _parse_logical_condition(condition) - - def test_logical_condition_nested(self): - condition = { - "operator": "AND", - "conditions": [ - { - "operator": "OR", - "conditions": [ - {"field": "meta.domain", "operator": "!=", "value": "science"}, - {"field": "meta.chapter", "operator": "in", "value": ["intro", "conclusion"]}, - ], - }, - { - "operator": "OR", - "conditions": [ - {"field": "meta.number", "operator": ">=", "value": 90}, - {"field": "meta.author", "operator": "not in", "value": ["John", "Jane"]}, - ], - }, - ], - } - query, values = _parse_logical_condition(condition) - assert query == ( - "((meta->>'domain' IS DISTINCT FROM %s OR meta->>'chapter' = ANY(%s)) " - "AND ((meta->>'number')::integer >= %s OR meta->>'author' IS NULL OR meta->>'author' != ALL(%s)))" - ) - assert values == ["science", [["intro", "conclusion"]], 90, [["John", "Jane"]]] - - def test_convert_filters_to_where_clause_and_params(self): - filters = { - "operator": "AND", - "conditions": [ - {"field": "meta.number", "operator": "==", "value": 100}, - {"field": "meta.chapter", "operator": "==", "value": "intro"}, - ], - } - where_clause, params = _convert_filters_to_where_clause_and_params(filters) - assert where_clause == SQL(" WHERE ") + SQL("((meta->>'number')::integer = %s AND meta->>'chapter' = %s)") - assert params == (100, "intro") - - def test_convert_filters_to_where_clause_and_params_handle_null(self): - filters = { - "operator": "AND", - "conditions": [ - {"field": "meta.number", "operator": "==", "value": None}, - {"field": "meta.chapter", "operator": "==", "value": "intro"}, - ], - } - where_clause, params = _convert_filters_to_where_clause_and_params(filters) - assert where_clause == SQL(" WHERE ") + SQL("(meta->>'number' IS NULL AND meta->>'chapter' = %s)") - assert params == ("intro",) +def test_treat_meta_field(): + assert _treat_meta_field(field="meta.number", value=9) == "(meta->>'number')::integer" + assert _treat_meta_field(field="meta.number", value=[1, 2, 3]) == "(meta->>'number')::integer" + assert _treat_meta_field(field="meta.name", value="my_name") == "meta->>'name'" + assert _treat_meta_field(field="meta.name", value=["my_name"]) == "meta->>'name'" + assert _treat_meta_field(field="meta.number", value=1.1) == "(meta->>'number')::real" + assert _treat_meta_field(field="meta.number", value=[1.1, 2.2, 3.3]) == "(meta->>'number')::real" + assert _treat_meta_field(field="meta.bool", value=True) == "(meta->>'bool')::boolean" + assert _treat_meta_field(field="meta.bool", value=[True, False, True]) == "(meta->>'bool')::boolean" + + # do not cast the field if its value is not one of the known types, an empty list or None + assert _treat_meta_field(field="meta.other", value={"a": 3, "b": "example"}) == "meta->>'other'" + assert _treat_meta_field(field="meta.empty_list", value=[]) == "meta->>'empty_list'" + assert _treat_meta_field(field="meta.name", value=None) == "meta->>'name'" + + +def test_comparison_condition_dataframe_jsonb_conversion(): + dataframe = DataFrame({"a": [1, 2, 3], "b": ["a", "b", "c"]}) + condition = {"field": "meta.df", "operator": "==", "value": dataframe} + field, values = _parse_comparison_condition(condition) + assert field == "(meta.df)::jsonb = %s" + + # we check each slot of the Jsonb object because it does not implement __eq__ + assert values[0].obj == Jsonb(dataframe.to_json()).obj + assert values[0].dumps == Jsonb(dataframe.to_json()).dumps + + +def test_comparison_condition_missing_operator(): + condition = {"field": "meta.type", "value": "article"} + with pytest.raises(FilterError): + _parse_comparison_condition(condition) + + +def test_comparison_condition_missing_value(): + condition = {"field": "meta.type", "operator": "=="} + with pytest.raises(FilterError): + _parse_comparison_condition(condition) + + +def test_comparison_condition_unknown_operator(): + condition = {"field": "meta.type", "operator": "unknown", "value": "article"} + with pytest.raises(FilterError): + _parse_comparison_condition(condition) + + +def test_logical_condition_missing_operator(): + condition = {"conditions": []} + with pytest.raises(FilterError): + _parse_logical_condition(condition) + + +def test_logical_condition_missing_conditions(): + condition = {"operator": "AND"} + with pytest.raises(FilterError): + _parse_logical_condition(condition) + + +def test_logical_condition_unknown_operator(): + condition = {"operator": "unknown", "conditions": []} + with pytest.raises(FilterError): + _parse_logical_condition(condition) + + +def test_logical_condition_nested(): + condition = { + "operator": "AND", + "conditions": [ + { + "operator": "OR", + "conditions": [ + {"field": "meta.domain", "operator": "!=", "value": "science"}, + {"field": "meta.chapter", "operator": "in", "value": ["intro", "conclusion"]}, + ], + }, + { + "operator": "OR", + "conditions": [ + {"field": "meta.number", "operator": ">=", "value": 90}, + {"field": "meta.author", "operator": "not in", "value": ["John", "Jane"]}, + ], + }, + ], + } + query, values = _parse_logical_condition(condition) + assert query == ( + "((meta->>'domain' IS DISTINCT FROM %s OR meta->>'chapter' = ANY(%s)) " + "AND ((meta->>'number')::integer >= %s OR meta->>'author' IS NULL OR meta->>'author' != ALL(%s)))" + ) + assert values == ["science", [["intro", "conclusion"]], 90, [["John", "Jane"]]] + + +def test_convert_filters_to_where_clause_and_params(): + filters = { + "operator": "AND", + "conditions": [ + {"field": "meta.number", "operator": "==", "value": 100}, + {"field": "meta.chapter", "operator": "==", "value": "intro"}, + ], + } + where_clause, params = _convert_filters_to_where_clause_and_params(filters) + assert where_clause == SQL(" WHERE ") + SQL("((meta->>'number')::integer = %s AND meta->>'chapter' = %s)") + assert params == (100, "intro") + + +def test_convert_filters_to_where_clause_and_params_handle_null(): + filters = { + "operator": "AND", + "conditions": [ + {"field": "meta.number", "operator": "==", "value": None}, + {"field": "meta.chapter", "operator": "==", "value": "intro"}, + ], + } + where_clause, params = _convert_filters_to_where_clause_and_params(filters) + assert where_clause == SQL(" WHERE ") + SQL("(meta->>'number' IS NULL AND meta->>'chapter' = %s)") + assert params == ("intro",) diff --git a/integrations/pgvector/tests/test_retriever.py b/integrations/pgvector/tests/test_retriever.py index 8eab10de5..61381c24e 100644 --- a/integrations/pgvector/tests/test_retriever.py +++ b/integrations/pgvector/tests/test_retriever.py @@ -3,6 +3,7 @@ # SPDX-License-Identifier: Apache-2.0 from unittest.mock import Mock +import pytest from haystack.dataclasses import Document from haystack.utils.auth import EnvVarSecret from haystack_integrations.components.retrievers.pgvector import PgvectorEmbeddingRetriever @@ -10,25 +11,25 @@ class TestRetriever: - def test_init_default(self, document_store: PgvectorDocumentStore): - retriever = PgvectorEmbeddingRetriever(document_store=document_store) - assert retriever.document_store == document_store + def test_init_default(self, mock_store): + retriever = PgvectorEmbeddingRetriever(document_store=mock_store) + assert retriever.document_store == mock_store assert retriever.filters == {} assert retriever.top_k == 10 - assert retriever.vector_function == document_store.vector_function + assert retriever.vector_function == mock_store.vector_function - def test_init(self, document_store: PgvectorDocumentStore): + def test_init(self, mock_store): retriever = PgvectorEmbeddingRetriever( - document_store=document_store, filters={"field": "value"}, top_k=5, vector_function="l2_distance" + document_store=mock_store, filters={"field": "value"}, top_k=5, vector_function="l2_distance" ) - assert retriever.document_store == document_store + assert retriever.document_store == mock_store assert retriever.filters == {"field": "value"} assert retriever.top_k == 5 assert retriever.vector_function == "l2_distance" - def test_to_dict(self, document_store: PgvectorDocumentStore): + def test_to_dict(self, mock_store): retriever = PgvectorEmbeddingRetriever( - document_store=document_store, filters={"field": "value"}, top_k=5, vector_function="l2_distance" + document_store=mock_store, filters={"field": "value"}, top_k=5, vector_function="l2_distance" ) res = retriever.to_dict() t = "haystack_integrations.components.retrievers.pgvector.embedding_retriever.PgvectorEmbeddingRetriever" @@ -39,7 +40,7 @@ def test_to_dict(self, document_store: PgvectorDocumentStore): "type": "haystack_integrations.document_stores.pgvector.document_store.PgvectorDocumentStore", "init_parameters": { "connection_string": {"env_vars": ["PG_CONN_STR"], "strict": True, "type": "env_var"}, - "table_name": "haystack_test_to_dict", + "table_name": "haystack", "embedding_dimension": 768, "vector_function": "cosine_similarity", "recreate_table": True, @@ -55,7 +56,9 @@ def test_to_dict(self, document_store: PgvectorDocumentStore): }, } - def test_from_dict(self): + @pytest.mark.usefixtures("patches_for_unit_tests") + def test_from_dict(self, monkeypatch): + monkeypatch.setenv("PG_CONN_STR", "some-connection-string") t = "haystack_integrations.components.retrievers.pgvector.embedding_retriever.PgvectorEmbeddingRetriever" data = { "type": t, diff --git a/integrations/weaviate/tests/test_document_store.py b/integrations/weaviate/tests/test_document_store.py index 4c1659a86..a2ae9cb70 100644 --- a/integrations/weaviate/tests/test_document_store.py +++ b/integrations/weaviate/tests/test_document_store.py @@ -38,6 +38,7 @@ ) +@pytest.mark.integration class TestWeaviateDocumentStore(CountDocumentsTest, WriteDocumentsTest, DeleteDocumentsTest, FilterDocumentsTest): @pytest.fixture def document_store(self, request) -> WeaviateDocumentStore: