Skip to content

Commit

Permalink
feat: defer the database connection to when it's needed (#804)
Browse files Browse the repository at this point in the history
* feat: defer the database connection to when it's needed

* linting

* fix tests
  • Loading branch information
masci authored Jun 11, 2024
1 parent 29b363c commit 7524022
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 32 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -66,35 +66,45 @@ def __init__(
"""
self.api_key = api_key
spec = spec or DEFAULT_STARTER_PLAN_SPEC
self.namespace = namespace
self.batch_size = batch_size
self.metric = metric
self.spec = spec
self.dimension = dimension
self.index_name = index

self._index = None
self._dummy_vector = [-10.0] * self.dimension

@property
def index(self):
if self._index is not None:
return self._index

client = Pinecone(api_key=api_key.resolve_value(), source_tag="haystack")
client = Pinecone(api_key=self.api_key.resolve_value(), source_tag="haystack")

if index not in client.list_indexes().names():
logger.info(f"Index {index} does not exist. Creating a new index.")
pinecone_spec = self._convert_dict_spec_to_pinecone_object(spec)
client.create_index(name=index, dimension=dimension, spec=pinecone_spec, metric=metric)
if self.index_name not in client.list_indexes().names():
logger.info(f"Index {self.index_name} does not exist. Creating a new index.")
pinecone_spec = self._convert_dict_spec_to_pinecone_object(self.spec)
client.create_index(name=self.index_name, dimension=self.dimension, spec=pinecone_spec, metric=self.metric)
else:
logger.info(
f"Index {index} already exists. Connecting to it. `dimension`, `spec`, and `metric` will be ignored."
f"Connecting to existing index {self.index_name}. `dimension`, `spec`, and `metric` will be ignored."
)

self._index = client.Index(name=index)
self._index = client.Index(name=self.index_name)

actual_dimension = self._index.describe_index_stats().get("dimension")
if actual_dimension and actual_dimension != dimension:
if actual_dimension and actual_dimension != self.dimension:
logger.warning(
f"Dimension of index {index} is {actual_dimension}, but {dimension} was specified. "
f"Dimension of index {self.index_name} is {actual_dimension}, but {self.dimension} was specified. "
"The specified dimension will be ignored."
"If you need an index with a different dimension, please create a new one."
)
self.dimension = actual_dimension or dimension

self.dimension = actual_dimension or self.dimension
self._dummy_vector = [-10.0] * self.dimension
self.index = index
self.namespace = namespace
self.batch_size = batch_size
self.metric = metric
self.spec = spec

return self._index

@staticmethod
def _convert_dict_spec_to_pinecone_object(spec: Dict[str, Any]):
Expand Down Expand Up @@ -135,7 +145,7 @@ def to_dict(self) -> Dict[str, Any]:
self,
api_key=self.api_key.to_dict(),
spec=self.spec,
index=self.index,
index=self.index_name,
dimension=self.dimension,
namespace=self.namespace,
batch_size=self.batch_size,
Expand All @@ -147,7 +157,7 @@ def count_documents(self) -> int:
Returns how many documents are present in the document store.
"""
try:
count = self._index.describe_index_stats()["namespaces"][self.namespace]["vector_count"]
count = self.index.describe_index_stats()["namespaces"][self.namespace]["vector_count"]
except KeyError:
count = 0
return count
Expand All @@ -174,9 +184,7 @@ def write_documents(self, documents: List[Document], policy: DuplicatePolicy = D

documents_for_pinecone = self._convert_documents_to_pinecone_format(documents)

result = self._index.upsert(
vectors=documents_for_pinecone, namespace=self.namespace, batch_size=self.batch_size
)
result = self.index.upsert(vectors=documents_for_pinecone, namespace=self.namespace, batch_size=self.batch_size)

written_docs = result["upserted_count"]
return written_docs
Expand Down Expand Up @@ -214,7 +222,7 @@ def delete_documents(self, document_ids: List[str]) -> None:
:param document_ids: the document ids to delete
"""
self._index.delete(ids=document_ids, namespace=self.namespace)
self.index.delete(ids=document_ids, namespace=self.namespace)

def _embedding_retrieval(
self,
Expand Down Expand Up @@ -247,7 +255,7 @@ def _embedding_retrieval(
filters = convert(filters)
filters = _normalize_filters(filters) if filters else None

result = self._index.query(
result = self.index.query(
vector=query_embedding,
top_k=top_k,
namespace=namespace or self.namespace,
Expand Down
2 changes: 1 addition & 1 deletion integrations/pinecone/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,6 @@ def delete_documents_and_wait(filters):

yield store
try:
store._index.delete(delete_all=True, namespace=namespace)
store.index.delete(delete_all=True, namespace=namespace)
except NotFoundException:
pass
30 changes: 23 additions & 7 deletions integrations/pinecone/tests/test_document_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,12 @@
from haystack_integrations.document_stores.pinecone import PineconeDocumentStore


@patch("haystack_integrations.document_stores.pinecone.document_store.Pinecone")
def test_init_is_lazy(_mock_client):
_ = PineconeDocumentStore(api_key=Secret.from_token("fake-api-key"))
_mock_client.assert_not_called()


@patch("haystack_integrations.document_stores.pinecone.document_store.Pinecone")
def test_init(mock_pinecone):
mock_pinecone.return_value.Index.return_value.describe_index_stats.return_value = {"dimension": 60}
Expand All @@ -25,9 +31,12 @@ def test_init(mock_pinecone):
metric="euclidean",
)

# Trigger an actual connection
_ = document_store.index

mock_pinecone.assert_called_with(api_key="fake-api-key", source_tag="haystack")

assert document_store.index == "my_index"
assert document_store.index_name == "my_index"
assert document_store.namespace == "test"
assert document_store.batch_size == 50
assert document_store.dimension == 60
Expand All @@ -38,14 +47,17 @@ def test_init(mock_pinecone):
def test_init_api_key_in_environment_variable(mock_pinecone, monkeypatch):
monkeypatch.setenv("PINECONE_API_KEY", "env-api-key")

PineconeDocumentStore(
ds = PineconeDocumentStore(
index="my_index",
namespace="test",
batch_size=50,
dimension=30,
metric="euclidean",
)

# Trigger an actual connection
_ = ds.index

mock_pinecone.assert_called_with(api_key="env-api-key", source_tag="haystack")


Expand All @@ -61,6 +73,9 @@ def test_to_from_dict(mock_pinecone, monkeypatch):
metric="euclidean",
)

# Trigger an actual connection
_ = document_store.index

dict_output = {
"type": "haystack_integrations.document_stores.pinecone.document_store.PineconeDocumentStore",
"init_parameters": {
Expand All @@ -83,7 +98,7 @@ def test_to_from_dict(mock_pinecone, monkeypatch):

document_store = PineconeDocumentStore.from_dict(dict_output)
assert document_store.api_key == Secret.from_env_var("PINECONE_API_KEY", strict=True)
assert document_store.index == "my_index"
assert document_store.index_name == "my_index"
assert document_store.namespace == "test"
assert document_store.batch_size == 50
assert document_store.dimension == 60
Expand All @@ -94,9 +109,9 @@ def test_to_from_dict(mock_pinecone, monkeypatch):
def test_init_fails_wo_api_key(monkeypatch):
monkeypatch.delenv("PINECONE_API_KEY", raising=False)
with pytest.raises(ValueError):
PineconeDocumentStore(
_ = PineconeDocumentStore(
index="my_index",
)
).index


def test_convert_dict_spec_to_pinecone_object_serverless():
Expand All @@ -108,7 +123,6 @@ def test_convert_dict_spec_to_pinecone_object_serverless():


def test_convert_dict_spec_to_pinecone_object_pod():

dict_spec = {"pod": {"replicas": 1, "shards": 1, "pods": 1, "pod_type": "p1.x1", "environment": "us-west1-gcp"}}
pinecone_object = PineconeDocumentStore._convert_dict_spec_to_pinecone_object(dict_spec)

Expand Down Expand Up @@ -141,14 +155,16 @@ def test_serverless_index_creation_from_scratch(sleep_time):

time.sleep(sleep_time)

PineconeDocumentStore(
ds = PineconeDocumentStore(
index=index_name,
namespace="test",
batch_size=50,
dimension=30,
metric="euclidean",
spec={"serverless": {"region": "us-east-1", "cloud": "aws"}},
)
# Trigger the connection
_ = ds.index

index_description = client.describe_index(name=index_name)
assert index_description["name"] == index_name
Expand Down
2 changes: 1 addition & 1 deletion integrations/pinecone/tests/test_embedding_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def test_from_dict(mock_pinecone, monkeypatch):

document_store = retriever.document_store
assert document_store.api_key == Secret.from_env_var("PINECONE_API_KEY", strict=True)
assert document_store.index == "default"
assert document_store.index_name == "default"
assert document_store.namespace == "test-namespace"
assert document_store.batch_size == 50
assert document_store.dimension == 512
Expand Down

0 comments on commit 7524022

Please sign in to comment.