Skip to content

Commit

Permalink
feat: defer the database connection to when it's needed (#769)
Browse files Browse the repository at this point in the history
* feat: defer the database connection to when it's needed

* fix typing

* test init is lazy
  • Loading branch information
masci authored May 29, 2024
1 parent 7d36d02 commit 5eebd84
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ def __init__(
"Set the ASTRA_DB_API_ENDPOINT environment variable (recommended) or pass it explicitly."
)
raise ValueError(msg)
self.resolved_api_endpoint = resolved_api_endpoint

resolved_token = token.resolve_value()
if resolved_token is None:
Expand All @@ -93,6 +94,7 @@ def __init__(
"Set the ASTRA_DB_APPLICATION_TOKEN environment variable (recommended) or pass it explicitly."
)
raise ValueError(msg)
self.resolved_token = resolved_token

self.api_endpoint = api_endpoint
self.token = token
Expand All @@ -101,15 +103,20 @@ def __init__(
self.duplicates_policy = duplicates_policy
self.similarity = similarity
self.namespace = namespace

self.index = AstraClient(
resolved_api_endpoint,
resolved_token,
self.collection_name,
self.embedding_dimension,
self.similarity,
namespace,
)
self._index: Optional[AstraClient] = None

@property
def index(self) -> AstraClient:
if self._index is None:
self._index = AstraClient(
self.resolved_api_endpoint,
self.resolved_token,
self.collection_name,
self.embedding_dimension,
self.similarity,
self.namespace,
)
return self._index

@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "AstraDocumentStore":
Expand Down
10 changes: 8 additions & 2 deletions integrations/astra/tests/test_document_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,19 @@ def mock_auth(monkeypatch):
monkeypatch.setenv("ASTRA_DB_APPLICATION_TOKEN", "test_token")


@mock.patch("haystack_integrations.document_stores.astra.astra_client.AstraDB")
def test_init_is_lazy(_mock_client):
_ = AstraDocumentStore()
_mock_client.assert_not_called()


def test_namespace_init(mock_auth): # noqa
with mock.patch("haystack_integrations.document_stores.astra.astra_client.AstraDB") as client:
AstraDocumentStore()
_ = AstraDocumentStore().index
assert "namespace" in client.call_args.kwargs
assert client.call_args.kwargs["namespace"] is None

AstraDocumentStore(namespace="foo")
_ = AstraDocumentStore(namespace="foo").index
assert "namespace" in client.call_args.kwargs
assert client.call_args.kwargs["namespace"] == "foo"

Expand Down

0 comments on commit 5eebd84

Please sign in to comment.