Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add default sort parameter to MongoStore #758

Merged
merged 4 commits into from
Jan 9, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 31 additions & 41 deletions src/maggma/stores/mongolike.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ def __init__(
safe_update: bool = False,
auth_source: Optional[str] = None,
mongoclient_kwargs: Optional[Dict] = None,
default_sort: Optional[Dict[str, Union[Sort, int]]] = None,
**kwargs,
):
"""
Expand All @@ -135,6 +136,8 @@ def __init__(
password: Password to connect with
safe_update: fail gracefully on DocumentTooLarge errors on update
auth_source: The database to authenticate on. Defaults to the database name.
default_sort: Default sort field and direction to use when querying. Can be used to
ensure determinacy in query results.
"""
self.database = database
self.collection_name = collection_name
Expand All @@ -144,6 +147,7 @@ def __init__(
self.password = password
self.ssh_tunnel = ssh_tunnel
self.safe_update = safe_update
self.default_sort = default_sort
self._coll = None # type: Any
self.kwargs = kwargs

Expand Down Expand Up @@ -226,9 +230,7 @@ def from_launchpad_file(cls, lp_file, collection_name, **kwargs):

return cls(**db_creds, **kwargs)

def distinct(
self, field: str, criteria: Optional[Dict] = None, all_exist: bool = False
) -> List:
def distinct(self, field: str, criteria: Optional[Dict] = None, all_exist: bool = False) -> List:
"""
Get all distinct values for a field

Expand All @@ -242,10 +244,7 @@ def distinct(
distinct_vals = self._collection.distinct(field, criteria)
except (OperationFailure, DocumentTooLarge):
distinct_vals = [
d["_id"]
for d in self._collection.aggregate(
[{"$match": criteria}, {"$group": {"_id": f"${field}"}}]
)
d["_id"] for d in self._collection.aggregate([{"$match": criteria}, {"$group": {"_id": f"${field}"}}])
]
if all(isinstance(d, list) for d in filter(None, distinct_vals)): # type: ignore
distinct_vals = list(chain.from_iterable(filter(None, distinct_vals)))
Expand Down Expand Up @@ -344,12 +343,7 @@ def count(
criteria = criteria if criteria else {}

hint_list = (
[
(k, Sort(v).value) if isinstance(v, int) else (k, v.value)
for k, v in hint.items()
]
if hint
else None
[(k, Sort(v).value) if isinstance(v, int) else (k, v.value) for k, v in hint.items()] if hint else None
)

if hint_list is not None: # pragma: no cover
Expand All @@ -365,7 +359,7 @@ def query( # type: ignore
hint: Optional[Dict[str, Union[Sort, int]]] = None,
skip: int = 0,
limit: int = 0,
**kwargs
**kwargs,
) -> Iterator[Dict]:
"""
Queries the Store for a set of documents
Expand All @@ -384,32 +378,25 @@ def query( # type: ignore
if isinstance(properties, list):
properties = {p: 1 for p in properties}

sort_list = (
[
(k, Sort(v).value) if isinstance(v, int) else (k, v.value)
for k, v in sort.items()
default_sort_formatted = None

if self.default_sort is not None:
default_sort_formatted = [
(k, Sort(v).value) if isinstance(v, int) else (k, v.value) for k, v in self.default_sort.items()
]

sort_list = (
[(k, Sort(v).value) if isinstance(v, int) else (k, v.value) for k, v in sort.items()]
if sort
else [("_id", 1)]
else default_sort_formatted
)

hint_list = (
[
(k, Sort(v).value) if isinstance(v, int) else (k, v.value)
for k, v in hint.items()
]
if hint
else None
[(k, Sort(v).value) if isinstance(v, int) else (k, v.value) for k, v in hint.items()] if hint else None
)

for d in self._collection.find(
filter=criteria,
projection=properties,
skip=skip,
limit=limit,
sort=sort_list,
hint=hint_list,
**kwargs
filter=criteria, projection=properties, skip=skip, limit=limit, sort=sort_list, hint=hint_list, **kwargs
):
yield d

Expand Down Expand Up @@ -531,25 +518,27 @@ def __init__(
database: str = None,
ssh_tunnel: Optional[SSHTunnel] = None,
mongoclient_kwargs: Optional[Dict] = None,
default_sort: Optional[Dict[str, Union[Sort, int]]] = None,
**kwargs,
):
"""
Args:
uri: MongoDB+SRV URI
database: database to connect to
collection_name: The collection name
default_sort: Default sort field and direction to use when querying. Can be used to
ensure determinacy in query results.
"""
self.uri = uri
self.ssh_tunnel = ssh_tunnel
self.default_sort = default_sort
self.mongoclient_kwargs = mongoclient_kwargs or {}

# parse the dbname from the uri
if database is None:
d_uri = uri_parser.parse_uri(uri)
if d_uri["database"] is None:
raise ConfigurationError(
"If database name is not supplied, a database must be set in the uri"
)
raise ConfigurationError("If database name is not supplied, a database must be set in the uri")
self.database = d_uri["database"]
else:
self.database = database
Expand Down Expand Up @@ -590,6 +579,7 @@ def __init__(self, collection_name: str = "memory_db", **kwargs):
collection_name: name for the collection in memory
"""
self.collection_name = collection_name
self.default_sort = None
self._coll = None
self.kwargs = kwargs
super(MongoStore, self).__init__(**kwargs) # noqa
Expand Down Expand Up @@ -648,9 +638,7 @@ def groupby(
properties = list(properties.keys())

data = [
doc
for doc in self.query(properties=keys + properties, criteria=criteria)
if all(has(doc, k) for k in keys)
doc for doc in self.query(properties=keys + properties, criteria=criteria) if all(has(doc, k) for k in keys)
]

def grouping_keys(doc):
Expand Down Expand Up @@ -722,16 +710,17 @@ def __init__(
self.kwargs = kwargs

if not self.read_only and len(paths) > 1:
raise RuntimeError(
"Cannot instantiate file-writable JSONStore with multiple JSON files."
)
raise RuntimeError("Cannot instantiate file-writable JSONStore with multiple JSON files.")

# create the .json file if it does not exist
if not self.read_only and not Path(self.paths[0]).exists():
with zopen(self.paths[0], "w") as f:
data: List[dict] = []
bytesdata = orjson.dumps(data)
f.write(bytesdata.decode("utf-8"))

self.default_sort = None

super().__init__(**kwargs)

def connect(self, force_reset=False):
Expand Down Expand Up @@ -888,6 +877,7 @@ def __init__(
self.database_name = database_name
self.collection_name = collection_name
self._coll = None
self.default_sort = None
self.ssh_tunnel = None # This is to fix issues with the tunnel on close
self.kwargs = kwargs
self.storage = storage
Expand Down