From fc95c1d7aa62e4966a2fa5878e914ff7fb8247d5 Mon Sep 17 00:00:00 2001 From: Jason Munro Date: Mon, 9 Jan 2023 14:24:12 -0800 Subject: [PATCH 1/4] Add default sort parameter to MongoStore --- src/maggma/stores/mongolike.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/src/maggma/stores/mongolike.py b/src/maggma/stores/mongolike.py index e66beb55a..c36a4e9a8 100644 --- a/src/maggma/stores/mongolike.py +++ b/src/maggma/stores/mongolike.py @@ -123,6 +123,7 @@ def __init__( safe_update: bool = False, auth_source: Optional[str] = None, mongoclient_kwargs: Optional[Dict] = None, + default_sort: Optional[Dict[str, Union[Sort, int]]] = None, **kwargs, ): """ @@ -135,6 +136,8 @@ def __init__( password: Password to connect with safe_update: fail gracefully on DocumentTooLarge errors on update auth_source: The database to authenticate on. Defaults to the database name. + default_sort: Default sort field and direction to use when querying. Can be used to + ensure determinacy in query results. """ self.database = database self.collection_name = collection_name @@ -144,6 +147,7 @@ def __init__( self.password = password self.ssh_tunnel = ssh_tunnel self.safe_update = safe_update + self.default_sort = default_sort self._coll = None # type: Any self.kwargs = kwargs @@ -384,13 +388,18 @@ def query( # type: ignore if isinstance(properties, list): properties = {p: 1 for p in properties} + default_sort_formatted = None + + if self.default_sort is not None: + default_sort_formatted = [(k, Sort(v).value) if isinstance(v, int) else (k, v.value) for k, v in self.default_sort.items()] + sort_list = ( [ (k, Sort(v).value) if isinstance(v, int) else (k, v.value) for k, v in sort.items() ] if sort - else [("_id", 1)] + else default_sort_formatted ) hint_list = ( @@ -531,6 +540,7 @@ def __init__( database: str = None, ssh_tunnel: Optional[SSHTunnel] = None, mongoclient_kwargs: Optional[Dict] = None, + default_sort: Optional[Dict[str, Union[Sort, int]]] = None, **kwargs, ): """ @@ -538,9 +548,12 @@ def __init__( uri: MongoDB+SRV URI database: database to connect to collection_name: The collection name + default_sort: Default sort field and direction to use when querying. Can be used to + ensure determinacy in query results. """ self.uri = uri self.ssh_tunnel = ssh_tunnel + self.default_sort = default_sort self.mongoclient_kwargs = mongoclient_kwargs or {} # parse the dbname from the uri From 9954f389d66237e9dddd75a282ef3f6819b90ef0 Mon Sep 17 00:00:00 2001 From: Jason Munro Date: Mon, 9 Jan 2023 14:29:30 -0800 Subject: [PATCH 2/4] Linting --- src/maggma/stores/mongolike.py | 54 ++++++++-------------------------- 1 file changed, 13 insertions(+), 41 deletions(-) diff --git a/src/maggma/stores/mongolike.py b/src/maggma/stores/mongolike.py index c36a4e9a8..63ae846ee 100644 --- a/src/maggma/stores/mongolike.py +++ b/src/maggma/stores/mongolike.py @@ -230,9 +230,7 @@ def from_launchpad_file(cls, lp_file, collection_name, **kwargs): return cls(**db_creds, **kwargs) - def distinct( - self, field: str, criteria: Optional[Dict] = None, all_exist: bool = False - ) -> List: + def distinct(self, field: str, criteria: Optional[Dict] = None, all_exist: bool = False) -> List: """ Get all distinct values for a field @@ -246,10 +244,7 @@ def distinct( distinct_vals = self._collection.distinct(field, criteria) except (OperationFailure, DocumentTooLarge): distinct_vals = [ - d["_id"] - for d in self._collection.aggregate( - [{"$match": criteria}, {"$group": {"_id": f"${field}"}}] - ) + d["_id"] for d in self._collection.aggregate([{"$match": criteria}, {"$group": {"_id": f"${field}"}}]) ] if all(isinstance(d, list) for d in filter(None, distinct_vals)): # type: ignore distinct_vals = list(chain.from_iterable(filter(None, distinct_vals))) @@ -348,12 +343,7 @@ def count( criteria = criteria if criteria else {} hint_list = ( - [ - (k, Sort(v).value) if isinstance(v, int) else (k, v.value) - for k, v in hint.items() - ] - if hint - else None + [(k, Sort(v).value) if isinstance(v, int) else (k, v.value) for k, v in hint.items()] if hint else None ) if hint_list is not None: # pragma: no cover @@ -369,7 +359,7 @@ def query( # type: ignore hint: Optional[Dict[str, Union[Sort, int]]] = None, skip: int = 0, limit: int = 0, - **kwargs + **kwargs, ) -> Iterator[Dict]: """ Queries the Store for a set of documents @@ -391,34 +381,22 @@ def query( # type: ignore default_sort_formatted = None if self.default_sort is not None: - default_sort_formatted = [(k, Sort(v).value) if isinstance(v, int) else (k, v.value) for k, v in self.default_sort.items()] + default_sort_formatted = [ + (k, Sort(v).value) if isinstance(v, int) else (k, v.value) for k, v in self.default_sort.items() + ] sort_list = ( - [ - (k, Sort(v).value) if isinstance(v, int) else (k, v.value) - for k, v in sort.items() - ] + [(k, Sort(v).value) if isinstance(v, int) else (k, v.value) for k, v in sort.items()] if sort else default_sort_formatted ) hint_list = ( - [ - (k, Sort(v).value) if isinstance(v, int) else (k, v.value) - for k, v in hint.items() - ] - if hint - else None + [(k, Sort(v).value) if isinstance(v, int) else (k, v.value) for k, v in hint.items()] if hint else None ) for d in self._collection.find( - filter=criteria, - projection=properties, - skip=skip, - limit=limit, - sort=sort_list, - hint=hint_list, - **kwargs + filter=criteria, projection=properties, skip=skip, limit=limit, sort=sort_list, hint=hint_list, **kwargs ): yield d @@ -560,9 +538,7 @@ def __init__( if database is None: d_uri = uri_parser.parse_uri(uri) if d_uri["database"] is None: - raise ConfigurationError( - "If database name is not supplied, a database must be set in the uri" - ) + raise ConfigurationError("If database name is not supplied, a database must be set in the uri") self.database = d_uri["database"] else: self.database = database @@ -661,9 +637,7 @@ def groupby( properties = list(properties.keys()) data = [ - doc - for doc in self.query(properties=keys + properties, criteria=criteria) - if all(has(doc, k) for k in keys) + doc for doc in self.query(properties=keys + properties, criteria=criteria) if all(has(doc, k) for k in keys) ] def grouping_keys(doc): @@ -735,9 +709,7 @@ def __init__( self.kwargs = kwargs if not self.read_only and len(paths) > 1: - raise RuntimeError( - "Cannot instantiate file-writable JSONStore with multiple JSON files." - ) + raise RuntimeError("Cannot instantiate file-writable JSONStore with multiple JSON files.") # create the .json file if it does not exist if not self.read_only and not Path(self.paths[0]).exists(): From f255b7035236fd7b92b2bd04fea988e6a0957e5e Mon Sep 17 00:00:00 2001 From: Jason Munro Date: Mon, 9 Jan 2023 14:44:17 -0800 Subject: [PATCH 3/4] Fix JSON and Memory store attributes --- src/maggma/stores/mongolike.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/maggma/stores/mongolike.py b/src/maggma/stores/mongolike.py index 63ae846ee..2a596ba03 100644 --- a/src/maggma/stores/mongolike.py +++ b/src/maggma/stores/mongolike.py @@ -579,6 +579,7 @@ def __init__(self, collection_name: str = "memory_db", **kwargs): collection_name: name for the collection in memory """ self.collection_name = collection_name + self.default_sort = None self._coll = None self.kwargs = kwargs super(MongoStore, self).__init__(**kwargs) # noqa @@ -717,6 +718,9 @@ def __init__( data: List[dict] = [] bytesdata = orjson.dumps(data) f.write(bytesdata.decode("utf-8")) + + self.default_sort = None + super().__init__(**kwargs) def connect(self, force_reset=False): From d446906d968cc0dd25664e2b188132a9b3408eb3 Mon Sep 17 00:00:00 2001 From: Jason Munro Date: Mon, 9 Jan 2023 14:53:12 -0800 Subject: [PATCH 4/4] FIx MontyStore attributes --- src/maggma/stores/mongolike.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/maggma/stores/mongolike.py b/src/maggma/stores/mongolike.py index 2a596ba03..922d40dc5 100644 --- a/src/maggma/stores/mongolike.py +++ b/src/maggma/stores/mongolike.py @@ -877,6 +877,7 @@ def __init__( self.database_name = database_name self.collection_name = collection_name self._coll = None + self.default_sort = None self.ssh_tunnel = None # This is to fix issues with the tunnel on close self.kwargs = kwargs self.storage = storage