-
Notifications
You must be signed in to change notification settings - Fork 124
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
5 changed files
with
29 additions
and
68 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,11 +1,12 @@ | ||
from pymongo.mongo_client import MongoClient | ||
from pymongo.server_api import ServerApi | ||
|
||
uri = "mongodb+srv://sarazanzottera:[email protected]/?retryWrites=true&w=majority" | ||
# Create a new client and connect to the server | ||
client = MongoClient(uri, server_api=ServerApi('1')) | ||
client = MongoClient(uri, server_api=ServerApi("1")) | ||
# Send a ping to confirm a successful connection | ||
try: | ||
client.admin.command('ping') | ||
client.admin.command("ping") | ||
print("Pinged your deployment. You successfully connected to MongoDB!") | ||
except Exception as e: | ||
print(e) | ||
print(e) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,75 +1,52 @@ | ||
# SPDX-FileCopyrightText: 2023-present deepset GmbH <[email protected]> | ||
# | ||
# SPDX-License-Identifier: Apache-2.0 | ||
from typing import Any, Dict, List, Optional, Union | ||
|
||
import re | ||
import logging | ||
import re | ||
from typing import Any, Dict, List, Optional, Union | ||
|
||
from pymongo import InsertOne, ReplaceOne, UpdateOne, MongoClient | ||
from pymongo.collection import Collection | ||
from pymongo.driver_info import DriverInfo | ||
from pymongo.errors import BulkWriteError | ||
from haystack import default_to_dict | ||
from haystack.dataclasses.document import Document | ||
from haystack.document_stores.types import DuplicatePolicy | ||
from haystack.document_stores.errors import DuplicateDocumentError | ||
|
||
from haystack.document_stores.types import DuplicatePolicy | ||
from haystack_integrations.document_stores.mongodb_atlas.filters import haystack_filters_to_mongo | ||
|
||
from pymongo import InsertOne, MongoClient, ReplaceOne, UpdateOne | ||
from pymongo.collection import Collection | ||
from pymongo.driver_info import DriverInfo | ||
from pymongo.errors import BulkWriteError | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
METRIC_TYPES = ["euclidean", "cosine", "dotProduct"] | ||
|
||
|
||
class MongoDBAtlasDocumentStore: | ||
def __init__( | ||
self, | ||
*, | ||
mongo_connection_string: str, | ||
database_name: str, | ||
collection_name: str, | ||
vector_search_index: Optional[str] = None, | ||
embedding_dim: int = 768, | ||
similarity: str = "cosine", | ||
embedding_field: str = "embedding", | ||
recreate_index: bool = False, | ||
): | ||
""" | ||
Creates a new MongoDBAtlasDocumentStore instance. | ||
This Document Store uses MongoDB Atlas as a backend (https://www.mongodb.com/docs/atlas/getting-started/). | ||
:param mongo_connection_string: MongoDB Atlas connection string in the format: | ||
:param mongo_connection_string: MongoDB Atlas connection string in the format: | ||
"mongodb+srv://{mongo_atlas_username}:{mongo_atlas_password}@{mongo_atlas_host}/?{mongo_atlas_params_string}". | ||
This can be obtained on the MongoDB Atlas Dashboard by clicking on the `CONNECT` button. | ||
:param database_name: Name of the database to use. | ||
:param collection_name: Name of the collection to use. | ||
:param vector_search_index: The name of the index to use for vector search. To use the search index it must have been created in the Atlas web UI before. None by default. | ||
:param embedding_dim: Dimensionality of embeddings, 768 by default. | ||
:param similarity: The similarity function to use for the embeddings. One of "euclidean", "cosine" or "dotProduct". "cosine" is the default. | ||
:param embedding_field: The name of the field in the document that contains the embedding. | ||
:param recreate_index: Whether to recreate the index when initializing the document store. | ||
""" | ||
if similarity not in METRIC_TYPES: | ||
raise ValueError( | ||
"MongoDB Atlas currently supports dotProduct, cosine and euclidean metrics. Please set similarity to one of the above." | ||
) | ||
if collection_name and not bool(re.match(r"^[a-zA-Z0-9\-_]+$", collection_name)): | ||
raise ValueError( | ||
f'Invalid collection name: "{collection_name}". Index name can only contain letters, numbers, hyphens, or underscores.' | ||
) | ||
|
||
msg = f'Invalid collection name: "{collection_name}". It can only contain letters, numbers, -, or _.' | ||
raise ValueError(msg) | ||
|
||
self.mongo_connection_string = mongo_connection_string | ||
self.database_name = database_name | ||
self.collection_name = collection_name | ||
self.similarity = similarity | ||
self.embedding_field = embedding_field | ||
self.embedding_dim = embedding_dim | ||
self.index = collection_name | ||
self.recreate_index = recreate_index | ||
self.vector_search_index = vector_search_index | ||
|
||
self.connection: MongoClient = MongoClient( | ||
self.mongo_connection_string, driver=DriverInfo(name="MongoDBAtlasHaystackIntegration") | ||
|
@@ -84,9 +61,6 @@ def __init__( | |
self.database.create_collection(self.collection_name) | ||
self._get_collection().create_index("id", unique=True) | ||
|
||
def _create_document_field_map(self) -> Dict: | ||
return {self.embedding_field: "embedding"} | ||
|
||
def _get_collection(self) -> Collection: | ||
""" | ||
Returns the collection named by index or returns the collection specified when the | ||
|
@@ -103,10 +77,6 @@ def to_dict(self) -> Dict[str, Any]: | |
mongo_connection_string=self.mongo_connection_string, | ||
database_name=self.database_name, | ||
collection_name=self.collection_name, | ||
vector_search_index=self.vector_search_index, | ||
embedding_dim=self.embedding_dim, | ||
similarity=self.similarity, | ||
embedding_field=self.embedding_field, | ||
recreate_index=self.recreate_index, | ||
) | ||
|
||
|
@@ -115,7 +85,7 @@ def count_documents(self, filters: Optional[Dict[str, Any]] = None) -> int: | |
Returns how many documents are present in the document store. | ||
""" | ||
return self._get_collection().count_documents({} if filters is None else filters) | ||
|
||
def filter_documents(self, filters: Optional[Dict[str, Any]] = None) -> List[Document]: | ||
""" | ||
Returns the documents that match the filters provided. | ||
|
@@ -166,12 +136,11 @@ def write_documents(self, documents: List[Document], policy: DuplicatePolicy = D | |
else: | ||
operations = [ReplaceOne({"id": doc["id"]}, upsert=True, replacement=doc) for doc in mongo_documents] | ||
|
||
print(operations) | ||
|
||
try: | ||
collection.bulk_write(operations) | ||
except BulkWriteError as e: | ||
raise DuplicateDocumentError(f"Duplicate documents found: {e.details['writeErrors']}") | ||
msg = f"Duplicate documents found: {e.details['writeErrors']}" | ||
raise DuplicateDocumentError(msg) from e | ||
|
||
return written_docs | ||
|
||
|
@@ -184,7 +153,3 @@ def delete_documents(self, document_ids: List[str]) -> None: | |
if not document_ids: | ||
return | ||
self._get_collection().delete_many(filter={"id": {"$in": document_ids}}) | ||
|
||
|
||
|
||
|
6 changes: 1 addition & 5 deletions
6
integrations/mongodb_atlas/src/haystack_integrations/document_stores/mongodb_atlas/errors.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,8 +1,4 @@ | ||
from typing import Optional | ||
|
||
|
||
class MongoDBAtlasDocumentStoreError(Exception): | ||
"""Exception for issues that occur in a MongoDBAtlas document store""" | ||
|
||
def __init__(self, message: Optional[str] = None): | ||
super().__init__(message=message) | ||
pass |
10 changes: 6 additions & 4 deletions
10
...grations/mongodb_atlas/src/haystack_integrations/document_stores/mongodb_atlas/filters.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,9 @@ | ||
import warnings | ||
import logging | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
def haystack_filters_to_mongo(filters): | ||
|
||
def haystack_filters_to_mongo(_): | ||
# TODO | ||
warnings.warn("Filtering not yet implemented for MongoDBAtlasDocumentStore!") | ||
return {} | ||
logger.warning("Filtering not yet implemented for MongoDBAtlasDocumentStore") | ||
return {} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters