diff --git a/integrations/mongodb_atlas/src/haystack_integrations/document_stores/mongodb_atlas/document_store.py b/integrations/mongodb_atlas/src/haystack_integrations/document_stores/mongodb_atlas/document_store.py index 4cb5b8659..83e93e269 100644 --- a/integrations/mongodb_atlas/src/haystack_integrations/document_stores/mongodb_atlas/document_store.py +++ b/integrations/mongodb_atlas/src/haystack_integrations/document_stores/mongodb_atlas/document_store.py @@ -12,6 +12,7 @@ from haystack.utils import Secret, deserialize_secrets_inplace from haystack_integrations.document_stores.mongodb_atlas.filters import _normalize_filters from pymongo import InsertOne, MongoClient, ReplaceOne, UpdateOne +from pymongo.collection import Collection from pymongo.driver_info import DriverInfo from pymongo.errors import BulkWriteError @@ -81,22 +82,34 @@ def __init__( msg = f'Invalid collection name: "{collection_name}". It can only contain letters, numbers, -, or _.' raise ValueError(msg) - resolved_connection_string = mongo_connection_string.resolve_value() + self.resolved_connection_string = mongo_connection_string.resolve_value() self.mongo_connection_string = mongo_connection_string self.database_name = database_name self.collection_name = collection_name self.vector_search_index = vector_search_index + self._connection: Optional[MongoClient] = None + self._collection: Optional[Collection] = None - self.connection: MongoClient = MongoClient( - resolved_connection_string, driver=DriverInfo(name="MongoDBAtlasHaystackIntegration") - ) - database = self.connection[self.database_name] + @property + def connection(self) -> MongoClient: + if self._connection is None: + self._connection = MongoClient( + self.resolved_connection_string, driver=DriverInfo(name="MongoDBAtlasHaystackIntegration") + ) - if collection_name not in database.list_collection_names(): - msg = f"Collection '{collection_name}' does not exist in database '{database_name}'." - raise ValueError(msg) - self.collection = database[self.collection_name] + return self._connection + + @property + def collection(self) -> Collection: + if self._collection is None: + database = self.connection[self.database_name] + + if self.collection_name not in database.list_collection_names(): + msg = f"Collection '{self.collection_name}' does not exist in database '{self.database_name}'." + raise ValueError(msg) + self._collection = database[self.collection_name] + return self._collection def to_dict(self) -> Dict[str, Any]: """ diff --git a/integrations/mongodb_atlas/tests/test_document_store.py b/integrations/mongodb_atlas/tests/test_document_store.py index 89810ec8b..453d9d16c 100644 --- a/integrations/mongodb_atlas/tests/test_document_store.py +++ b/integrations/mongodb_atlas/tests/test_document_store.py @@ -2,6 +2,7 @@ # # SPDX-License-Identifier: Apache-2.0 import os +from unittest.mock import patch from uuid import uuid4 import pytest @@ -16,13 +17,23 @@ from pymongo.driver_info import DriverInfo +@patch("haystack_integrations.document_stores.mongodb_atlas.document_store.MongoClient") +def test_init_is_lazy(_mock_client): + MongoDBAtlasDocumentStore( + mongo_connection_string=Secret.from_token("test"), + database_name="database_name", + collection_name="collection_name", + vector_search_index="cosine_index", + ) + _mock_client.assert_not_called() + + @pytest.mark.skipif( "MONGO_CONNECTION_STRING" not in os.environ, reason="No MongoDB Atlas connection string provided", ) @pytest.mark.integration class TestDocumentStore(DocumentStoreBaseTests): - @pytest.fixture def document_store(self): database_name = "haystack_integration_test"