Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add distinct function to MongoHook in apache-airflow-providers-mongo #34466

Merged
merged 8 commits into from
Oct 25, 2023
19 changes: 19 additions & 0 deletions airflow/providers/mongo/hooks/mongo.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,3 +366,22 @@ def delete_many(
collection = self.get_collection(mongo_collection, mongo_db=mongo_db)

return collection.delete_many(filter_doc, **kwargs)

def distinct(
self, mongo_collection: str, distinct_key: str, filter_doc: dict | None, mongo_db: str | None = None, **kwargs
) -> list[Any]:
"""
Returns a list of distinct values for the given key across a collection.

https://pymongo.readthedocs.io/en/stable/api/pymongo/collection.html#pymongo.collection.Collection.distinct

:param mongo_collection: The name of the collection to perform distinct on.
:param distinct_key: The field to return distinct values from.
:param filter_doc: A query that matches the documents get distinct values from.
Can be omitted; then will cover the entire collection.
:param mongo_db: The name of the database to use.
Can be omitted; then the database from the connection string is used.
"""
collection = self.get_collection(mongo_collection, mongo_db=mongo_db)

return collection.distinct(distinct_key, filter=filter_doc, **kwargs)
26 changes: 26 additions & 0 deletions tests/providers/mongo/hooks/test_mongo.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,32 @@ def test_aggregate(self):
results = self.hook.aggregate(collection, aggregate_query)
assert len(list(results)) == 2

def test_distinct(self):
collection = mongomock.MongoClient().db.collection
objs = [
{"test_id": "1", "test_status": "success"},
{"test_id": "2", "test_status": "failure"},
{"test_id": "3", "test_status": "success"},
]

collection.insert_many(objs)

results = self.hook.distinct(collection, "test_status")
assert len(results) == 2

def test_distinct_with_filter(self):
collection = mongomock.MongoClient().db.collection
objs = [
{"test_id": "1", "test_status": "success"},
{"test_id": "2", "test_status": "failure"},
{"test_id": "3", "test_status": "success"},
]

collection.insert_many(objs)

results = self.hook.distinct(collection, "test_id", {"test_status": "failure"})
assert len(results) == 1


def test_context_manager():
with MongoHook(conn_id="mongo_default", mongo_db="default") as ctx_hook:
Expand Down