Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: MongoDBAtlas Document Store #413

Merged
merged 25 commits into from
Feb 15, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 20 additions & 8 deletions integrations/mongodb_atlas/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,29 +3,41 @@
[![PyPI - Version](https://img.shields.io/pypi/v/mongodb-atlas-haystack.svg)](https://pypi.org/project/mongodb-atlas-haystack)
[![PyPI - Python Version](https://img.shields.io/pypi/pyversions/mongodb-atlas-haystack.svg)](https://pypi.org/project/mongodb-atlas-haystack)

---
-----

**Table of Contents**

- [mongodb-atlas-haystack](#mongodb-atlas-haystack)
- [Installation](#installation)
- [Testing](#testing)
- [License](#license)
- [Installation](#installation)
- [Contributing](#contributing)
- [License](#license)

## Installation

```console
pip install mongodb-atlas-haystack
```

## Testing
## Contributing

TODO
`hatch` is the best way to interact with this project, to install it:
```sh
pip install hatch
```

```console
To run the linters `ruff` and `mypy`:
```
hatch run lint:all
```

To run all the tests:
```
hatch run test
```

Note: you need your own MongoDB Atlas account to run the tests: you can make one here:
https://www.mongodb.com/cloud/atlas/register. Once you have it, export the connection string
to the env var `MONGO_CONNECTION_STRING`. If you forget to do so, all the tests will be skipped.

## License

`mongodb-atlas-haystack` is distributed under the terms of the [Apache-2.0](https://spdx.org/licenses/Apache-2.0.html) license.
2 changes: 1 addition & 1 deletion integrations/mongodb_atlas/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ classifiers = [
"Programming Language :: Python :: Implementation :: PyPy",
]
dependencies = [
"haystack-ai",
"haystack-ai>=2.0.0.b6",
ZanSara marked this conversation as resolved.
Show resolved Hide resolved
"pymongo[srv]",
]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@
import re
from typing import Any, Dict, List, Optional, Union

from haystack import default_to_dict
from haystack import default_to_dict, default_from_dict
from haystack.utils import Secret, deserialize_secrets_inplace
from haystack.dataclasses.document import Document
from haystack.document_stores.errors import DuplicateDocumentError
from haystack.document_stores.types import DuplicatePolicy
from haystack_integrations.document_stores.mongodb_atlas.filters import haystack_filters_to_mongo
from pymongo import InsertOne, MongoClient, ReplaceOne, UpdateOne # type: ignore
from pymongo.collection import Collection # type: ignore
from pymongo.driver_info import DriverInfo # type: ignore
from pymongo.errors import BulkWriteError # type: ignore

Expand All @@ -22,10 +22,10 @@ class MongoDBAtlasDocumentStore:
def __init__(
self,
*,
mongo_connection_string: str,
mongo_connection_string: Secret = Secret.from_env_var("MONGO_CONNECTION_STRING"), # noqa: B008
database_name: str,
collection_name: str,
recreate_index: bool = False,
recreate_collection: bool = False,
):
"""
Creates a new MongoDBAtlasDocumentStore instance.
Expand All @@ -35,72 +35,85 @@ def __init__(
:param mongo_connection_string: MongoDB Atlas connection string in the format:
"mongodb+srv://{mongo_atlas_username}:{mongo_atlas_password}@{mongo_atlas_host}/?{mongo_atlas_params_string}".
This can be obtained on the MongoDB Atlas Dashboard by clicking on the `CONNECT` button.
This value will be read automatically from the env var "MONGO_CONNECTION_STRING".
:param database_name: Name of the database to use.
:param collection_name: Name of the collection to use.
:param recreate_index: Whether to recreate the index when initializing the document store.
:param recreate_collection: Whether to recreate the collection when initializing the document store.
"""
if collection_name and not bool(re.match(r"^[a-zA-Z0-9\-_]+$", collection_name)):
msg = f'Invalid collection name: "{collection_name}". It can only contain letters, numbers, -, or _.'
raise ValueError(msg)


resolved_connection_string = mongo_connection_string.resolve_value()
if resolved_connection_string is None:
msg = (
"MongoDBAtlasDocumentStore expects an API key. "
"Set the MONGO_CONNECTION_STRING environment variable (recommended) or pass it explicitly."
)
raise ValueError(msg)
ZanSara marked this conversation as resolved.
Show resolved Hide resolved
self.mongo_connection_string = mongo_connection_string

self.database_name = database_name
self.collection_name = collection_name
self.recreate_index = recreate_index
self.recreate_collection = recreate_collection

self.connection: MongoClient = MongoClient(
self.mongo_connection_string, driver=DriverInfo(name="MongoDBAtlasHaystackIntegration")
resolved_connection_string, driver=DriverInfo(name="MongoDBAtlasHaystackIntegration")
)
self.database = self.connection[self.database_name]
database = self.connection[self.database_name]

if self.recreate_index:
self._get_collection().drop()
if self.recreate_collection and self.collection_name in database.list_collection_names():
database[self.collection_name].drop()

# Implicitly create the collection if it doesn't exist
if collection_name not in self.database.list_collection_names():
self.database.create_collection(self.collection_name)
self._get_collection().create_index("id", unique=True)
if collection_name not in database.list_collection_names():
database.create_collection(self.collection_name)
database[self.collection_name].create_index("id", unique=True)

def _get_collection(self) -> Collection:
"""
Returns the collection named by index or returns the collection specified when the
driver was initialized.
"""
return self.database[self.collection_name]
self.collection = database[self.collection_name]

def to_dict(self) -> Dict[str, Any]:
"""
Utility function that serializes this Document Store's configuration into a dictionary.
"""
return default_to_dict(
self,
mongo_connection_string=self.mongo_connection_string,
mongo_connection_string=self.mongo_connection_string.to_dict(),
database_name=self.database_name,
collection_name=self.collection_name,
recreate_index=self.recreate_index,
recreate_collection=self.recreate_collection,
)

@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "MongoDBAtlasDocumentStore":
"""
Utility function that deserializes this Document Store's configuration from a dictionary.
"""
deserialize_secrets_inplace(data["init_parameters"], keys=["mongo_connection_string"])
return default_from_dict(cls, data)

def count_documents(self, filters: Optional[Dict[str, Any]] = None) -> int:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"""
Returns how many documents are present in the document store.
ZanSara marked this conversation as resolved.
Show resolved Hide resolved

:param filters: The filters to apply. It counts only the documents that match the filters.
"""
return self._get_collection().count_documents({} if filters is None else filters)
return self.collection.count_documents(filters or {})

def filter_documents(self, filters: Optional[Dict[str, Any]] = None) -> List[Document]:
"""
Returns the documents that match the filters provided.

For a detailed specification of the filters,
refer to the [documentation](https://docs.haystack.deepset.ai/v2.0/docs/metadata-filtering)
refer to the [documentation](https://docs.haystack.deepset.ai/v2.0/docs/metadata-filtering).

:param filters: The filters to apply to the document list.
:param filters: The filters to apply. It returns only the documents that match the filters.
:return: A list of Documents that match the given filters.
"""
mongo_filters = haystack_filters_to_mongo(filters)
collection = self._get_collection()
documents = list(collection.find(mongo_filters))
documents = list(self.collection.find(mongo_filters))
for doc in documents:
doc.pop("_id", None)
doc.pop("_id", None) # MongoDB's internal id doesn't belong into a Haystack document, so we remove it.
return [Document.from_dict(doc) for doc in documents]

def write_documents(self, documents: List[Document], policy: DuplicatePolicy = DuplicatePolicy.NONE) -> int:
Expand All @@ -122,22 +135,21 @@ def write_documents(self, documents: List[Document], policy: DuplicatePolicy = D
if policy == DuplicatePolicy.NONE:
policy = DuplicatePolicy.FAIL

collection = self._get_collection()
mongo_documents = [doc.to_dict() for doc in documents]
operations: List[Union[UpdateOne, InsertOne, ReplaceOne]]
written_docs = len(documents)

if policy == DuplicatePolicy.SKIP:
operations = [UpdateOne({"id": doc["id"]}, {"$setOnInsert": doc}, upsert=True) for doc in mongo_documents]
existing_documents = collection.count_documents({"id": {"$in": [doc.id for doc in documents]}})
existing_documents = self.collection.count_documents({"id": {"$in": [doc.id for doc in documents]}})
written_docs -= existing_documents
elif policy == DuplicatePolicy.FAIL:
operations = [InsertOne(doc) for doc in mongo_documents]
else:
operations = [ReplaceOne({"id": doc["id"]}, upsert=True, replacement=doc) for doc in mongo_documents]

try:
collection.bulk_write(operations)
self.collection.bulk_write(operations)
except BulkWriteError as e:
msg = f"Duplicate documents found: {e.details['writeErrors']}"
raise DuplicateDocumentError(msg) from e
Expand All @@ -152,4 +164,4 @@ def delete_documents(self, document_ids: List[str]) -> None:
"""
if not document_ids:
return
self._get_collection().delete_many(filter={"id": {"$in": document_ids}})
self.collection.delete_many(filter={"id": {"$in": document_ids}})
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
import logging

logger = logging.getLogger(__name__)


def haystack_filters_to_mongo(_):
def haystack_filters_to_mongo(filters):
# TODO
ZanSara marked this conversation as resolved.
Show resolved Hide resolved
logger.warning("Filtering not yet implemented for MongoDBAtlasDocumentStore")
if filters:
raise "Filtering not yet implemented for MongoDBAtlasDocumentStore"
return {}
40 changes: 33 additions & 7 deletions integrations/mongodb_atlas/tests/test_document_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from uuid import uuid4

import pytest
from haystack.utils import Secret
from haystack.dataclasses.document import ByteStream, Document
from haystack.document_stores.errors import DuplicateDocumentError
from haystack.document_stores.types import DuplicatePolicy
Expand All @@ -17,12 +18,11 @@
@pytest.fixture
def document_store(request):
store = MongoDBAtlasDocumentStore(
mongo_connection_string=os.environ["MONGO_CONNECTION_STRING"],
database_name="ClusterTest",
collection_name="test_" + request.node.name + str(uuid4()),
database_name="haystack_integration_test",
collection_name=request.node.name + str(uuid4()),
)
yield store
store._get_collection().drop()
store.collection.drop()


@pytest.mark.skipif(
Expand Down Expand Up @@ -54,16 +54,42 @@ def test_write_dataframe(self, document_store: MongoDBAtlasDocumentStore):
@patch("haystack_integrations.document_stores.mongodb_atlas.document_store.MongoClient")
def test_to_dict(self, _):
ZanSara marked this conversation as resolved.
Show resolved Hide resolved
document_store = MongoDBAtlasDocumentStore(
mongo_connection_string="mongo_connection_string",
database_name="database_name",
collection_name="collection_name",
)
assert document_store.to_dict() == {
"type": "haystack_integrations.document_stores.mongodb_atlas.document_store.MongoDBAtlasDocumentStore",
"init_parameters": {
"mongo_connection_string": "mongo_connection_string",
"mongo_connection_string": {
"env_vars": [
"MONGO_CONNECTION_STRING",
],
"strict": True,
"type": "env_var",
},
"database_name": "database_name",
"collection_name": "collection_name",
"recreate_index": False,
"recreate_collection": False,
},
}

@patch("haystack_integrations.document_stores.mongodb_atlas.document_store.MongoClient")
def test_from_dict(self, _):
docstore = MongoDBAtlasDocumentStore.from_dict({
"type": "haystack_integrations.document_stores.mongodb_atlas.document_store.MongoDBAtlasDocumentStore",
"init_parameters": {
"mongo_connection_string": {
"env_vars": [
"MONGO_CONNECTION_STRING",
],
"strict": True,
"type": "env_var",
},
"database_name": "database_name",
"collection_name": "collection_name",
"recreate_collection": True,
}})
assert docstore.mongo_connection_string == Secret.from_env_var("MONGO_CONNECTION_STRING")
assert docstore.database_name == "database_name"
assert docstore.collection_name == "collection_name"
assert docstore.recreate_collection == True
Loading