Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: defer the database connection to when it's needed #770

Merged
merged 4 commits into from
May 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from haystack.utils import Secret, deserialize_secrets_inplace
from haystack_integrations.document_stores.mongodb_atlas.filters import _normalize_filters
from pymongo import InsertOne, MongoClient, ReplaceOne, UpdateOne
from pymongo.collection import Collection
from pymongo.driver_info import DriverInfo
from pymongo.errors import BulkWriteError

Expand Down Expand Up @@ -81,22 +82,34 @@ def __init__(
msg = f'Invalid collection name: "{collection_name}". It can only contain letters, numbers, -, or _.'
raise ValueError(msg)

resolved_connection_string = mongo_connection_string.resolve_value()
self.resolved_connection_string = mongo_connection_string.resolve_value()
self.mongo_connection_string = mongo_connection_string

self.database_name = database_name
self.collection_name = collection_name
self.vector_search_index = vector_search_index
self._connection: Optional[MongoClient] = None
self._collection: Optional[Collection] = None

self.connection: MongoClient = MongoClient(
resolved_connection_string, driver=DriverInfo(name="MongoDBAtlasHaystackIntegration")
)
database = self.connection[self.database_name]
@property
def connection(self) -> MongoClient:
if self._connection is None:
self._connection = MongoClient(
self.resolved_connection_string, driver=DriverInfo(name="MongoDBAtlasHaystackIntegration")
)

if collection_name not in database.list_collection_names():
msg = f"Collection '{collection_name}' does not exist in database '{database_name}'."
raise ValueError(msg)
self.collection = database[self.collection_name]
return self._connection

@property
def collection(self) -> Collection:
if self._collection is None:
database = self.connection[self.database_name]

if self.collection_name not in database.list_collection_names():
msg = f"Collection '{self.collection_name}' does not exist in database '{self.database_name}'."
raise ValueError(msg)
self._collection = database[self.collection_name]
return self._collection

def to_dict(self) -> Dict[str, Any]:
"""
Expand Down
13 changes: 12 additions & 1 deletion integrations/mongodb_atlas/tests/test_document_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#
# SPDX-License-Identifier: Apache-2.0
import os
from unittest.mock import patch
from uuid import uuid4

import pytest
Expand All @@ -16,13 +17,23 @@
from pymongo.driver_info import DriverInfo


@patch("haystack_integrations.document_stores.mongodb_atlas.document_store.MongoClient")
def test_init_is_lazy(_mock_client):
MongoDBAtlasDocumentStore(
mongo_connection_string=Secret.from_token("test"),
database_name="database_name",
collection_name="collection_name",
vector_search_index="cosine_index",
)
_mock_client.assert_not_called()


@pytest.mark.skipif(
"MONGO_CONNECTION_STRING" not in os.environ,
reason="No MongoDB Atlas connection string provided",
)
@pytest.mark.integration
class TestDocumentStore(DocumentStoreBaseTests):

@pytest.fixture
def document_store(self):
database_name = "haystack_integration_test"
Expand Down