Skip to content

Commit

Permalink
feat: defer the database connection to when it's needed (#766)
Browse files Browse the repository at this point in the history
* fix tests

* fix linting
masci authored May 29, 2024

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature.
1 parent 257f992 commit 7d36d02
Showing 2 changed files with 56 additions and 38 deletions.
Original file line number Diff line number Diff line change
@@ -93,48 +93,60 @@ def __init__(
:param **kwargs: Optional arguments that `Elasticsearch` takes.
"""
self._hosts = hosts
self._client = Elasticsearch(
hosts,
headers={"user-agent": f"haystack-py-ds/{haystack_version}"},
**kwargs,
)
self._client = None
self._index = index
self._embedding_similarity_function = embedding_similarity_function
self._custom_mapping = custom_mapping
self._kwargs = kwargs

# Check client connection, this will raise if not connected
self._client.info()

if self._custom_mapping and not isinstance(self._custom_mapping, Dict):
msg = "custom_mapping must be a dictionary"
raise ValueError(msg)

if self._custom_mapping:
mappings = self._custom_mapping
else:
# Configure mapping for the embedding field if none is provided
mappings = {
"properties": {
"embedding": {"type": "dense_vector", "index": True, "similarity": embedding_similarity_function},
"content": {"type": "text"},
},
"dynamic_templates": [
{
"strings": {
"path_match": "*",
"match_mapping_type": "string",
"mapping": {
"type": "keyword",
},
@property
def client(self) -> Elasticsearch:
if self._client is None:
client = Elasticsearch(
self._hosts,
headers={"user-agent": f"haystack-py-ds/{haystack_version}"},
**self._kwargs,
)
# Check client connection, this will raise if not connected
client.info()

if self._custom_mapping:
mappings = self._custom_mapping
else:
# Configure mapping for the embedding field if none is provided
mappings = {
"properties": {
"embedding": {
"type": "dense_vector",
"index": True,
"similarity": self._embedding_similarity_function,
},
"content": {"type": "text"},
},
"dynamic_templates": [
{
"strings": {
"path_match": "*",
"match_mapping_type": "string",
"mapping": {
"type": "keyword",
},
}
}
}
],
}
],
}

# Create the index if it doesn't exist
if not client.indices.exists(index=self._index):
client.indices.create(index=self._index, mappings=mappings)

self._client = client

# Create the index if it doesn't exist
if not self._client.indices.exists(index=index):
self._client.indices.create(index=index, mappings=mappings)
return self._client

def to_dict(self) -> Dict[str, Any]:
"""
@@ -172,7 +184,7 @@ def count_documents(self) -> int:
Returns how many documents are present in the document store.
:returns: Number of documents in the document store.
"""
return self._client.count(index=self._index)["count"]
return self.client.count(index=self._index)["count"]

def _search_documents(self, **kwargs) -> List[Document]:
"""
@@ -187,7 +199,7 @@ def _search_documents(self, **kwargs) -> List[Document]:
from_ = 0
# Handle pagination
while True:
res = self._client.search(
res = self.client.search(
index=self._index,
from_=from_,
**kwargs,
@@ -261,7 +273,7 @@ def write_documents(self, documents: List[Document], policy: DuplicatePolicy = D
)

documents_written, errors = helpers.bulk(
client=self._client,
client=self.client,
actions=elasticsearch_actions,
refresh="wait_for",
index=self._index,
@@ -317,7 +329,7 @@ def delete_documents(self, document_ids: List[str]) -> None:
"""

helpers.bulk(
client=self._client,
client=self.client,
actions=({"_op_type": "delete", "_id": id_} for id_ in document_ids),
refresh="wait_for",
index=self._index,
12 changes: 9 additions & 3 deletions integrations/elasticsearch/tests/test_document_store.py
Original file line number Diff line number Diff line change
@@ -15,6 +15,12 @@
from haystack_integrations.document_stores.elasticsearch import ElasticsearchDocumentStore


@patch("haystack_integrations.document_stores.elasticsearch.document_store.Elasticsearch")
def test_init_is_lazy(_mock_es_client):
ElasticsearchDocumentStore(hosts="testhost")
_mock_es_client.assert_not_called()


@patch("haystack_integrations.document_stores.elasticsearch.document_store.Elasticsearch")
def test_to_dict(_mock_elasticsearch_client):
document_store = ElasticsearchDocumentStore(hosts="some hosts")
@@ -73,7 +79,7 @@ def document_store(self, request):
hosts=hosts, index=index, embedding_similarity_function=embedding_similarity_function
)
yield store
store._client.options(ignore_status=[400, 404]).indices.delete(index=index)
store.client.options(ignore_status=[400, 404]).indices.delete(index=index)

def assert_documents_are_equal(self, received: List[Document], expected: List[Document]):
"""
@@ -101,7 +107,7 @@ def assert_documents_are_equal(self, received: List[Document], expected: List[Do
super().assert_documents_are_equal(received, expected)

def test_user_agent_header(self, document_store: ElasticsearchDocumentStore):
assert document_store._client._headers["user-agent"].startswith("haystack-py-ds/")
assert document_store.client._headers["user-agent"].startswith("haystack-py-ds/")

def test_write_documents(self, document_store: ElasticsearchDocumentStore):
docs = [Document(id="1")]
@@ -308,7 +314,7 @@ def test_init_with_custom_mapping(self, mock_elasticsearch):
)
mock_elasticsearch.return_value = mock_client

ElasticsearchDocumentStore(hosts="some hosts", custom_mapping=custom_mapping)
_ = ElasticsearchDocumentStore(hosts="some hosts", custom_mapping=custom_mapping).client
mock_client.indices.create.assert_called_once_with(
index="default",
mappings=custom_mapping,

0 comments on commit 7d36d02

Please sign in to comment.