From 7524022524ca38889a39744b2661913ef2552206 Mon Sep 17 00:00:00 2001 From: Massimiliano Pippi Date: Tue, 11 Jun 2024 11:03:27 +0200 Subject: [PATCH] feat: defer the database connection to when it's needed (#804) * feat: defer the database connection to when it's needed * linting * fix tests --- .../pinecone/document_store.py | 54 +++++++++++-------- integrations/pinecone/tests/conftest.py | 2 +- .../pinecone/tests/test_document_store.py | 30 ++++++++--- .../tests/test_embedding_retriever.py | 2 +- 4 files changed, 56 insertions(+), 32 deletions(-) diff --git a/integrations/pinecone/src/haystack_integrations/document_stores/pinecone/document_store.py b/integrations/pinecone/src/haystack_integrations/document_stores/pinecone/document_store.py index 0e87f97fc..1fd3adf40 100644 --- a/integrations/pinecone/src/haystack_integrations/document_stores/pinecone/document_store.py +++ b/integrations/pinecone/src/haystack_integrations/document_stores/pinecone/document_store.py @@ -66,35 +66,45 @@ def __init__( """ self.api_key = api_key spec = spec or DEFAULT_STARTER_PLAN_SPEC + self.namespace = namespace + self.batch_size = batch_size + self.metric = metric + self.spec = spec + self.dimension = dimension + self.index_name = index + + self._index = None + self._dummy_vector = [-10.0] * self.dimension + + @property + def index(self): + if self._index is not None: + return self._index - client = Pinecone(api_key=api_key.resolve_value(), source_tag="haystack") + client = Pinecone(api_key=self.api_key.resolve_value(), source_tag="haystack") - if index not in client.list_indexes().names(): - logger.info(f"Index {index} does not exist. Creating a new index.") - pinecone_spec = self._convert_dict_spec_to_pinecone_object(spec) - client.create_index(name=index, dimension=dimension, spec=pinecone_spec, metric=metric) + if self.index_name not in client.list_indexes().names(): + logger.info(f"Index {self.index_name} does not exist. Creating a new index.") + pinecone_spec = self._convert_dict_spec_to_pinecone_object(self.spec) + client.create_index(name=self.index_name, dimension=self.dimension, spec=pinecone_spec, metric=self.metric) else: logger.info( - f"Index {index} already exists. Connecting to it. `dimension`, `spec`, and `metric` will be ignored." + f"Connecting to existing index {self.index_name}. `dimension`, `spec`, and `metric` will be ignored." ) - self._index = client.Index(name=index) + self._index = client.Index(name=self.index_name) actual_dimension = self._index.describe_index_stats().get("dimension") - if actual_dimension and actual_dimension != dimension: + if actual_dimension and actual_dimension != self.dimension: logger.warning( - f"Dimension of index {index} is {actual_dimension}, but {dimension} was specified. " + f"Dimension of index {self.index_name} is {actual_dimension}, but {self.dimension} was specified. " "The specified dimension will be ignored." "If you need an index with a different dimension, please create a new one." ) - self.dimension = actual_dimension or dimension - + self.dimension = actual_dimension or self.dimension self._dummy_vector = [-10.0] * self.dimension - self.index = index - self.namespace = namespace - self.batch_size = batch_size - self.metric = metric - self.spec = spec + + return self._index @staticmethod def _convert_dict_spec_to_pinecone_object(spec: Dict[str, Any]): @@ -135,7 +145,7 @@ def to_dict(self) -> Dict[str, Any]: self, api_key=self.api_key.to_dict(), spec=self.spec, - index=self.index, + index=self.index_name, dimension=self.dimension, namespace=self.namespace, batch_size=self.batch_size, @@ -147,7 +157,7 @@ def count_documents(self) -> int: Returns how many documents are present in the document store. """ try: - count = self._index.describe_index_stats()["namespaces"][self.namespace]["vector_count"] + count = self.index.describe_index_stats()["namespaces"][self.namespace]["vector_count"] except KeyError: count = 0 return count @@ -174,9 +184,7 @@ def write_documents(self, documents: List[Document], policy: DuplicatePolicy = D documents_for_pinecone = self._convert_documents_to_pinecone_format(documents) - result = self._index.upsert( - vectors=documents_for_pinecone, namespace=self.namespace, batch_size=self.batch_size - ) + result = self.index.upsert(vectors=documents_for_pinecone, namespace=self.namespace, batch_size=self.batch_size) written_docs = result["upserted_count"] return written_docs @@ -214,7 +222,7 @@ def delete_documents(self, document_ids: List[str]) -> None: :param document_ids: the document ids to delete """ - self._index.delete(ids=document_ids, namespace=self.namespace) + self.index.delete(ids=document_ids, namespace=self.namespace) def _embedding_retrieval( self, @@ -247,7 +255,7 @@ def _embedding_retrieval( filters = convert(filters) filters = _normalize_filters(filters) if filters else None - result = self._index.query( + result = self.index.query( vector=query_embedding, top_k=top_k, namespace=namespace or self.namespace, diff --git a/integrations/pinecone/tests/conftest.py b/integrations/pinecone/tests/conftest.py index d6f58b6aa..6c3f7f39b 100644 --- a/integrations/pinecone/tests/conftest.py +++ b/integrations/pinecone/tests/conftest.py @@ -51,6 +51,6 @@ def delete_documents_and_wait(filters): yield store try: - store._index.delete(delete_all=True, namespace=namespace) + store.index.delete(delete_all=True, namespace=namespace) except NotFoundException: pass diff --git a/integrations/pinecone/tests/test_document_store.py b/integrations/pinecone/tests/test_document_store.py index f89208f48..459401800 100644 --- a/integrations/pinecone/tests/test_document_store.py +++ b/integrations/pinecone/tests/test_document_store.py @@ -12,6 +12,12 @@ from haystack_integrations.document_stores.pinecone import PineconeDocumentStore +@patch("haystack_integrations.document_stores.pinecone.document_store.Pinecone") +def test_init_is_lazy(_mock_client): + _ = PineconeDocumentStore(api_key=Secret.from_token("fake-api-key")) + _mock_client.assert_not_called() + + @patch("haystack_integrations.document_stores.pinecone.document_store.Pinecone") def test_init(mock_pinecone): mock_pinecone.return_value.Index.return_value.describe_index_stats.return_value = {"dimension": 60} @@ -25,9 +31,12 @@ def test_init(mock_pinecone): metric="euclidean", ) + # Trigger an actual connection + _ = document_store.index + mock_pinecone.assert_called_with(api_key="fake-api-key", source_tag="haystack") - assert document_store.index == "my_index" + assert document_store.index_name == "my_index" assert document_store.namespace == "test" assert document_store.batch_size == 50 assert document_store.dimension == 60 @@ -38,7 +47,7 @@ def test_init(mock_pinecone): def test_init_api_key_in_environment_variable(mock_pinecone, monkeypatch): monkeypatch.setenv("PINECONE_API_KEY", "env-api-key") - PineconeDocumentStore( + ds = PineconeDocumentStore( index="my_index", namespace="test", batch_size=50, @@ -46,6 +55,9 @@ def test_init_api_key_in_environment_variable(mock_pinecone, monkeypatch): metric="euclidean", ) + # Trigger an actual connection + _ = ds.index + mock_pinecone.assert_called_with(api_key="env-api-key", source_tag="haystack") @@ -61,6 +73,9 @@ def test_to_from_dict(mock_pinecone, monkeypatch): metric="euclidean", ) + # Trigger an actual connection + _ = document_store.index + dict_output = { "type": "haystack_integrations.document_stores.pinecone.document_store.PineconeDocumentStore", "init_parameters": { @@ -83,7 +98,7 @@ def test_to_from_dict(mock_pinecone, monkeypatch): document_store = PineconeDocumentStore.from_dict(dict_output) assert document_store.api_key == Secret.from_env_var("PINECONE_API_KEY", strict=True) - assert document_store.index == "my_index" + assert document_store.index_name == "my_index" assert document_store.namespace == "test" assert document_store.batch_size == 50 assert document_store.dimension == 60 @@ -94,9 +109,9 @@ def test_to_from_dict(mock_pinecone, monkeypatch): def test_init_fails_wo_api_key(monkeypatch): monkeypatch.delenv("PINECONE_API_KEY", raising=False) with pytest.raises(ValueError): - PineconeDocumentStore( + _ = PineconeDocumentStore( index="my_index", - ) + ).index def test_convert_dict_spec_to_pinecone_object_serverless(): @@ -108,7 +123,6 @@ def test_convert_dict_spec_to_pinecone_object_serverless(): def test_convert_dict_spec_to_pinecone_object_pod(): - dict_spec = {"pod": {"replicas": 1, "shards": 1, "pods": 1, "pod_type": "p1.x1", "environment": "us-west1-gcp"}} pinecone_object = PineconeDocumentStore._convert_dict_spec_to_pinecone_object(dict_spec) @@ -141,7 +155,7 @@ def test_serverless_index_creation_from_scratch(sleep_time): time.sleep(sleep_time) - PineconeDocumentStore( + ds = PineconeDocumentStore( index=index_name, namespace="test", batch_size=50, @@ -149,6 +163,8 @@ def test_serverless_index_creation_from_scratch(sleep_time): metric="euclidean", spec={"serverless": {"region": "us-east-1", "cloud": "aws"}}, ) + # Trigger the connection + _ = ds.index index_description = client.describe_index(name=index_name) assert index_description["name"] == index_name diff --git a/integrations/pinecone/tests/test_embedding_retriever.py b/integrations/pinecone/tests/test_embedding_retriever.py index 76e930737..537b4c933 100644 --- a/integrations/pinecone/tests/test_embedding_retriever.py +++ b/integrations/pinecone/tests/test_embedding_retriever.py @@ -91,7 +91,7 @@ def test_from_dict(mock_pinecone, monkeypatch): document_store = retriever.document_store assert document_store.api_key == Secret.from_env_var("PINECONE_API_KEY", strict=True) - assert document_store.index == "default" + assert document_store.index_name == "default" assert document_store.namespace == "test-namespace" assert document_store.batch_size == 50 assert document_store.dimension == 512