From 90dce1ae942dc960d50129ca1a32939c0ad6a96a Mon Sep 17 00:00:00 2001 From: ZanSara Date: Wed, 14 Feb 2024 14:37:33 +0100 Subject: [PATCH] lint --- .../mongodb_atlas/examples/example.py | 7 +- .../mongodb_atlas/document_store.py | 65 +++++-------------- .../document_stores/mongodb_atlas/errors.py | 6 +- .../document_stores/mongodb_atlas/filters.py | 10 +-- .../tests/test_document_store.py | 9 +-- 5 files changed, 29 insertions(+), 68 deletions(-) diff --git a/integrations/mongodb_atlas/examples/example.py b/integrations/mongodb_atlas/examples/example.py index 5f5077103..8370b25cb 100644 --- a/integrations/mongodb_atlas/examples/example.py +++ b/integrations/mongodb_atlas/examples/example.py @@ -1,11 +1,12 @@ from pymongo.mongo_client import MongoClient from pymongo.server_api import ServerApi + uri = "mongodb+srv://sarazanzottera:dN3hmY9RNxRDni13@clustertest.gwkckbk.mongodb.net/?retryWrites=true&w=majority" # Create a new client and connect to the server -client = MongoClient(uri, server_api=ServerApi('1')) +client = MongoClient(uri, server_api=ServerApi("1")) # Send a ping to confirm a successful connection try: - client.admin.command('ping') + client.admin.command("ping") print("Pinged your deployment. You successfully connected to MongoDB!") except Exception as e: - print(e) \ No newline at end of file + print(e) diff --git a/integrations/mongodb_atlas/src/haystack_integrations/document_stores/mongodb_atlas/document_store.py b/integrations/mongodb_atlas/src/haystack_integrations/document_stores/mongodb_atlas/document_store.py index c2d81b8ec..64d324fe0 100644 --- a/integrations/mongodb_atlas/src/haystack_integrations/document_stores/mongodb_atlas/document_store.py +++ b/integrations/mongodb_atlas/src/haystack_integrations/document_stores/mongodb_atlas/document_store.py @@ -1,39 +1,30 @@ # SPDX-FileCopyrightText: 2023-present deepset GmbH # # SPDX-License-Identifier: Apache-2.0 -from typing import Any, Dict, List, Optional, Union - -import re import logging +import re +from typing import Any, Dict, List, Optional, Union -from pymongo import InsertOne, ReplaceOne, UpdateOne, MongoClient -from pymongo.collection import Collection -from pymongo.driver_info import DriverInfo -from pymongo.errors import BulkWriteError from haystack import default_to_dict from haystack.dataclasses.document import Document -from haystack.document_stores.types import DuplicatePolicy from haystack.document_stores.errors import DuplicateDocumentError - +from haystack.document_stores.types import DuplicatePolicy from haystack_integrations.document_stores.mongodb_atlas.filters import haystack_filters_to_mongo - +from pymongo import InsertOne, MongoClient, ReplaceOne, UpdateOne +from pymongo.collection import Collection +from pymongo.driver_info import DriverInfo +from pymongo.errors import BulkWriteError logger = logging.getLogger(__name__) -METRIC_TYPES = ["euclidean", "cosine", "dotProduct"] - - class MongoDBAtlasDocumentStore: def __init__( self, + *, mongo_connection_string: str, database_name: str, collection_name: str, - vector_search_index: Optional[str] = None, - embedding_dim: int = 768, - similarity: str = "cosine", - embedding_field: str = "embedding", recreate_index: bool = False, ): """ @@ -41,35 +32,21 @@ def __init__( This Document Store uses MongoDB Atlas as a backend (https://www.mongodb.com/docs/atlas/getting-started/). - :param mongo_connection_string: MongoDB Atlas connection string in the format: + :param mongo_connection_string: MongoDB Atlas connection string in the format: "mongodb+srv://{mongo_atlas_username}:{mongo_atlas_password}@{mongo_atlas_host}/?{mongo_atlas_params_string}". This can be obtained on the MongoDB Atlas Dashboard by clicking on the `CONNECT` button. :param database_name: Name of the database to use. :param collection_name: Name of the collection to use. - :param vector_search_index: The name of the index to use for vector search. To use the search index it must have been created in the Atlas web UI before. None by default. - :param embedding_dim: Dimensionality of embeddings, 768 by default. - :param similarity: The similarity function to use for the embeddings. One of "euclidean", "cosine" or "dotProduct". "cosine" is the default. - :param embedding_field: The name of the field in the document that contains the embedding. :param recreate_index: Whether to recreate the index when initializing the document store. """ - if similarity not in METRIC_TYPES: - raise ValueError( - "MongoDB Atlas currently supports dotProduct, cosine and euclidean metrics. Please set similarity to one of the above." - ) if collection_name and not bool(re.match(r"^[a-zA-Z0-9\-_]+$", collection_name)): - raise ValueError( - f'Invalid collection name: "{collection_name}". Index name can only contain letters, numbers, hyphens, or underscores.' - ) - + msg = f'Invalid collection name: "{collection_name}". It can only contain letters, numbers, -, or _.' + raise ValueError(msg) + self.mongo_connection_string = mongo_connection_string self.database_name = database_name self.collection_name = collection_name - self.similarity = similarity - self.embedding_field = embedding_field - self.embedding_dim = embedding_dim - self.index = collection_name self.recreate_index = recreate_index - self.vector_search_index = vector_search_index self.connection: MongoClient = MongoClient( self.mongo_connection_string, driver=DriverInfo(name="MongoDBAtlasHaystackIntegration") @@ -84,9 +61,6 @@ def __init__( self.database.create_collection(self.collection_name) self._get_collection().create_index("id", unique=True) - def _create_document_field_map(self) -> Dict: - return {self.embedding_field: "embedding"} - def _get_collection(self) -> Collection: """ Returns the collection named by index or returns the collection specified when the @@ -103,10 +77,6 @@ def to_dict(self) -> Dict[str, Any]: mongo_connection_string=self.mongo_connection_string, database_name=self.database_name, collection_name=self.collection_name, - vector_search_index=self.vector_search_index, - embedding_dim=self.embedding_dim, - similarity=self.similarity, - embedding_field=self.embedding_field, recreate_index=self.recreate_index, ) @@ -115,7 +85,7 @@ def count_documents(self, filters: Optional[Dict[str, Any]] = None) -> int: Returns how many documents are present in the document store. """ return self._get_collection().count_documents({} if filters is None else filters) - + def filter_documents(self, filters: Optional[Dict[str, Any]] = None) -> List[Document]: """ Returns the documents that match the filters provided. @@ -166,12 +136,11 @@ def write_documents(self, documents: List[Document], policy: DuplicatePolicy = D else: operations = [ReplaceOne({"id": doc["id"]}, upsert=True, replacement=doc) for doc in mongo_documents] - print(operations) - try: collection.bulk_write(operations) except BulkWriteError as e: - raise DuplicateDocumentError(f"Duplicate documents found: {e.details['writeErrors']}") + msg = f"Duplicate documents found: {e.details['writeErrors']}" + raise DuplicateDocumentError(msg) from e return written_docs @@ -184,7 +153,3 @@ def delete_documents(self, document_ids: List[str]) -> None: if not document_ids: return self._get_collection().delete_many(filter={"id": {"$in": document_ids}}) - - - - diff --git a/integrations/mongodb_atlas/src/haystack_integrations/document_stores/mongodb_atlas/errors.py b/integrations/mongodb_atlas/src/haystack_integrations/document_stores/mongodb_atlas/errors.py index a15e69cd1..132156bd0 100644 --- a/integrations/mongodb_atlas/src/haystack_integrations/document_stores/mongodb_atlas/errors.py +++ b/integrations/mongodb_atlas/src/haystack_integrations/document_stores/mongodb_atlas/errors.py @@ -1,8 +1,4 @@ -from typing import Optional - - class MongoDBAtlasDocumentStoreError(Exception): """Exception for issues that occur in a MongoDBAtlas document store""" - def __init__(self, message: Optional[str] = None): - super().__init__(message=message) + pass diff --git a/integrations/mongodb_atlas/src/haystack_integrations/document_stores/mongodb_atlas/filters.py b/integrations/mongodb_atlas/src/haystack_integrations/document_stores/mongodb_atlas/filters.py index 9ad1ff42b..ce2eae518 100644 --- a/integrations/mongodb_atlas/src/haystack_integrations/document_stores/mongodb_atlas/filters.py +++ b/integrations/mongodb_atlas/src/haystack_integrations/document_stores/mongodb_atlas/filters.py @@ -1,7 +1,9 @@ -import warnings +import logging +logger = logging.getLogger(__name__) -def haystack_filters_to_mongo(filters): + +def haystack_filters_to_mongo(_): # TODO - warnings.warn("Filtering not yet implemented for MongoDBAtlasDocumentStore!") - return {} \ No newline at end of file + logger.warning("Filtering not yet implemented for MongoDBAtlasDocumentStore") + return {} diff --git a/integrations/mongodb_atlas/tests/test_document_store.py b/integrations/mongodb_atlas/tests/test_document_store.py index e95928bc1..daf60a855 100644 --- a/integrations/mongodb_atlas/tests/test_document_store.py +++ b/integrations/mongodb_atlas/tests/test_document_store.py @@ -12,16 +12,13 @@ from haystack_integrations.document_stores.mongodb_atlas import MongoDBAtlasDocumentStore from pandas import DataFrame -import pytest -from haystack_integrations.document_stores.mongodb_atlas import MongoDBAtlasDocumentStore - @pytest.fixture def document_store(): store = MongoDBAtlasDocumentStore( mongo_connection_string=os.environ["MONGO_CONNECTION_STRING"], database_name="ClusterTest", - collection_name="test" + collection_name="test", ) yield store store._get_collection().drop() @@ -54,11 +51,11 @@ def test_write_dataframe(self, document_store: MongoDBAtlasDocumentStore): assert retrieved_docs == docs @patch("haystack_integrations.document_stores.mongodb_atlas.document_store.MongoClient") - def test_to_dict(self, client_mock): + def test_to_dict(self, _): document_store = MongoDBAtlasDocumentStore( mongo_connection_string="mongo_connection_string", database_name="database_name", - collection_name="collection_name" + collection_name="collection_name", ) assert document_store.to_dict() == { "type": "haystack_integrations.document_stores.mongodb_atlas.document_store.MongoDBAtlasDocumentStore",