Skip to content

Commit

Permalink
feat: defer the database connection to when it's needed (#770)
Browse files Browse the repository at this point in the history
* feat: defer the database connection to when it's needed

* lazy collection too

* add test

* linting
  • Loading branch information
masci authored May 29, 2024
1 parent 5eebd84 commit 588d654
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from haystack.utils import Secret, deserialize_secrets_inplace
from haystack_integrations.document_stores.mongodb_atlas.filters import _normalize_filters
from pymongo import InsertOne, MongoClient, ReplaceOne, UpdateOne
from pymongo.collection import Collection
from pymongo.driver_info import DriverInfo
from pymongo.errors import BulkWriteError

Expand Down Expand Up @@ -81,22 +82,34 @@ def __init__(
msg = f'Invalid collection name: "{collection_name}". It can only contain letters, numbers, -, or _.'
raise ValueError(msg)

resolved_connection_string = mongo_connection_string.resolve_value()
self.resolved_connection_string = mongo_connection_string.resolve_value()
self.mongo_connection_string = mongo_connection_string

self.database_name = database_name
self.collection_name = collection_name
self.vector_search_index = vector_search_index
self._connection: Optional[MongoClient] = None
self._collection: Optional[Collection] = None

self.connection: MongoClient = MongoClient(
resolved_connection_string, driver=DriverInfo(name="MongoDBAtlasHaystackIntegration")
)
database = self.connection[self.database_name]
@property
def connection(self) -> MongoClient:
if self._connection is None:
self._connection = MongoClient(
self.resolved_connection_string, driver=DriverInfo(name="MongoDBAtlasHaystackIntegration")
)

if collection_name not in database.list_collection_names():
msg = f"Collection '{collection_name}' does not exist in database '{database_name}'."
raise ValueError(msg)
self.collection = database[self.collection_name]
return self._connection

@property
def collection(self) -> Collection:
if self._collection is None:
database = self.connection[self.database_name]

if self.collection_name not in database.list_collection_names():
msg = f"Collection '{self.collection_name}' does not exist in database '{self.database_name}'."
raise ValueError(msg)
self._collection = database[self.collection_name]
return self._collection

def to_dict(self) -> Dict[str, Any]:
"""
Expand Down
13 changes: 12 additions & 1 deletion integrations/mongodb_atlas/tests/test_document_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#
# SPDX-License-Identifier: Apache-2.0
import os
from unittest.mock import patch
from uuid import uuid4

import pytest
Expand All @@ -16,13 +17,23 @@
from pymongo.driver_info import DriverInfo


@patch("haystack_integrations.document_stores.mongodb_atlas.document_store.MongoClient")
def test_init_is_lazy(_mock_client):
MongoDBAtlasDocumentStore(
mongo_connection_string=Secret.from_token("test"),
database_name="database_name",
collection_name="collection_name",
vector_search_index="cosine_index",
)
_mock_client.assert_not_called()


@pytest.mark.skipif(
"MONGO_CONNECTION_STRING" not in os.environ,
reason="No MongoDB Atlas connection string provided",
)
@pytest.mark.integration
class TestDocumentStore(DocumentStoreBaseTests):

@pytest.fixture
def document_store(self):
database_name = "haystack_integration_test"
Expand Down

0 comments on commit 588d654

Please sign in to comment.