Skip to content

Commit

Permalink
lint
Browse files Browse the repository at this point in the history
  • Loading branch information
ZanSara committed Feb 14, 2024
1 parent ebced28 commit 90dce1a
Show file tree
Hide file tree
Showing 5 changed files with 29 additions and 68 deletions.
7 changes: 4 additions & 3 deletions integrations/mongodb_atlas/examples/example.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from pymongo.mongo_client import MongoClient
from pymongo.server_api import ServerApi

uri = "mongodb+srv://sarazanzottera:[email protected]/?retryWrites=true&w=majority"
# Create a new client and connect to the server
client = MongoClient(uri, server_api=ServerApi('1'))
client = MongoClient(uri, server_api=ServerApi("1"))
# Send a ping to confirm a successful connection
try:
client.admin.command('ping')
client.admin.command("ping")
print("Pinged your deployment. You successfully connected to MongoDB!")
except Exception as e:
print(e)
print(e)
Original file line number Diff line number Diff line change
@@ -1,75 +1,52 @@
# SPDX-FileCopyrightText: 2023-present deepset GmbH <[email protected]>
#
# SPDX-License-Identifier: Apache-2.0
from typing import Any, Dict, List, Optional, Union

import re
import logging
import re
from typing import Any, Dict, List, Optional, Union

from pymongo import InsertOne, ReplaceOne, UpdateOne, MongoClient
from pymongo.collection import Collection
from pymongo.driver_info import DriverInfo
from pymongo.errors import BulkWriteError
from haystack import default_to_dict
from haystack.dataclasses.document import Document
from haystack.document_stores.types import DuplicatePolicy
from haystack.document_stores.errors import DuplicateDocumentError

from haystack.document_stores.types import DuplicatePolicy
from haystack_integrations.document_stores.mongodb_atlas.filters import haystack_filters_to_mongo

from pymongo import InsertOne, MongoClient, ReplaceOne, UpdateOne
from pymongo.collection import Collection
from pymongo.driver_info import DriverInfo
from pymongo.errors import BulkWriteError

logger = logging.getLogger(__name__)


METRIC_TYPES = ["euclidean", "cosine", "dotProduct"]


class MongoDBAtlasDocumentStore:
def __init__(
self,
*,
mongo_connection_string: str,
database_name: str,
collection_name: str,
vector_search_index: Optional[str] = None,
embedding_dim: int = 768,
similarity: str = "cosine",
embedding_field: str = "embedding",
recreate_index: bool = False,
):
"""
Creates a new MongoDBAtlasDocumentStore instance.
This Document Store uses MongoDB Atlas as a backend (https://www.mongodb.com/docs/atlas/getting-started/).
:param mongo_connection_string: MongoDB Atlas connection string in the format:
:param mongo_connection_string: MongoDB Atlas connection string in the format:
"mongodb+srv://{mongo_atlas_username}:{mongo_atlas_password}@{mongo_atlas_host}/?{mongo_atlas_params_string}".
This can be obtained on the MongoDB Atlas Dashboard by clicking on the `CONNECT` button.
:param database_name: Name of the database to use.
:param collection_name: Name of the collection to use.
:param vector_search_index: The name of the index to use for vector search. To use the search index it must have been created in the Atlas web UI before. None by default.
:param embedding_dim: Dimensionality of embeddings, 768 by default.
:param similarity: The similarity function to use for the embeddings. One of "euclidean", "cosine" or "dotProduct". "cosine" is the default.
:param embedding_field: The name of the field in the document that contains the embedding.
:param recreate_index: Whether to recreate the index when initializing the document store.
"""
if similarity not in METRIC_TYPES:
raise ValueError(
"MongoDB Atlas currently supports dotProduct, cosine and euclidean metrics. Please set similarity to one of the above."
)
if collection_name and not bool(re.match(r"^[a-zA-Z0-9\-_]+$", collection_name)):
raise ValueError(
f'Invalid collection name: "{collection_name}". Index name can only contain letters, numbers, hyphens, or underscores.'
)

msg = f'Invalid collection name: "{collection_name}". It can only contain letters, numbers, -, or _.'
raise ValueError(msg)

self.mongo_connection_string = mongo_connection_string
self.database_name = database_name
self.collection_name = collection_name
self.similarity = similarity
self.embedding_field = embedding_field
self.embedding_dim = embedding_dim
self.index = collection_name
self.recreate_index = recreate_index
self.vector_search_index = vector_search_index

self.connection: MongoClient = MongoClient(
self.mongo_connection_string, driver=DriverInfo(name="MongoDBAtlasHaystackIntegration")
Expand All @@ -84,9 +61,6 @@ def __init__(
self.database.create_collection(self.collection_name)
self._get_collection().create_index("id", unique=True)

def _create_document_field_map(self) -> Dict:
return {self.embedding_field: "embedding"}

def _get_collection(self) -> Collection:
"""
Returns the collection named by index or returns the collection specified when the
Expand All @@ -103,10 +77,6 @@ def to_dict(self) -> Dict[str, Any]:
mongo_connection_string=self.mongo_connection_string,
database_name=self.database_name,
collection_name=self.collection_name,
vector_search_index=self.vector_search_index,
embedding_dim=self.embedding_dim,
similarity=self.similarity,
embedding_field=self.embedding_field,
recreate_index=self.recreate_index,
)

Expand All @@ -115,7 +85,7 @@ def count_documents(self, filters: Optional[Dict[str, Any]] = None) -> int:
Returns how many documents are present in the document store.
"""
return self._get_collection().count_documents({} if filters is None else filters)

def filter_documents(self, filters: Optional[Dict[str, Any]] = None) -> List[Document]:
"""
Returns the documents that match the filters provided.
Expand Down Expand Up @@ -166,12 +136,11 @@ def write_documents(self, documents: List[Document], policy: DuplicatePolicy = D
else:
operations = [ReplaceOne({"id": doc["id"]}, upsert=True, replacement=doc) for doc in mongo_documents]

print(operations)

try:
collection.bulk_write(operations)
except BulkWriteError as e:
raise DuplicateDocumentError(f"Duplicate documents found: {e.details['writeErrors']}")
msg = f"Duplicate documents found: {e.details['writeErrors']}"
raise DuplicateDocumentError(msg) from e

return written_docs

Expand All @@ -184,7 +153,3 @@ def delete_documents(self, document_ids: List[str]) -> None:
if not document_ids:
return
self._get_collection().delete_many(filter={"id": {"$in": document_ids}})




Original file line number Diff line number Diff line change
@@ -1,8 +1,4 @@
from typing import Optional


class MongoDBAtlasDocumentStoreError(Exception):
"""Exception for issues that occur in a MongoDBAtlas document store"""

def __init__(self, message: Optional[str] = None):
super().__init__(message=message)
pass
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import warnings
import logging

logger = logging.getLogger(__name__)

def haystack_filters_to_mongo(filters):

def haystack_filters_to_mongo(_):
# TODO
warnings.warn("Filtering not yet implemented for MongoDBAtlasDocumentStore!")
return {}
logger.warning("Filtering not yet implemented for MongoDBAtlasDocumentStore")
return {}
9 changes: 3 additions & 6 deletions integrations/mongodb_atlas/tests/test_document_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,13 @@
from haystack_integrations.document_stores.mongodb_atlas import MongoDBAtlasDocumentStore
from pandas import DataFrame

import pytest
from haystack_integrations.document_stores.mongodb_atlas import MongoDBAtlasDocumentStore


@pytest.fixture
def document_store():
store = MongoDBAtlasDocumentStore(
mongo_connection_string=os.environ["MONGO_CONNECTION_STRING"],
database_name="ClusterTest",
collection_name="test"
collection_name="test",
)
yield store
store._get_collection().drop()
Expand Down Expand Up @@ -54,11 +51,11 @@ def test_write_dataframe(self, document_store: MongoDBAtlasDocumentStore):
assert retrieved_docs == docs

@patch("haystack_integrations.document_stores.mongodb_atlas.document_store.MongoClient")
def test_to_dict(self, client_mock):
def test_to_dict(self, _):
document_store = MongoDBAtlasDocumentStore(
mongo_connection_string="mongo_connection_string",
database_name="database_name",
collection_name="collection_name"
collection_name="collection_name",
)
assert document_store.to_dict() == {
"type": "haystack_integrations.document_stores.mongodb_atlas.document_store.MongoDBAtlasDocumentStore",
Expand Down

0 comments on commit 90dce1a

Please sign in to comment.