Skip to content

Commit

Permalink
Merge pull request #901 from jmmshn/lint
Browse files Browse the repository at this point in the history
pre-commit run --all-files
  • Loading branch information
rkingsbury authored Dec 20, 2023
2 parents 5720fa3 + 2842448 commit 2f3d741
Show file tree
Hide file tree
Showing 36 changed files with 119 additions and 134 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/testing.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ jobs:
- name: Run pre-commit
run: |
pip install pre-commit
pre-commit run
pre-commit run --all-files
test:
needs: lint
Expand Down
16 changes: 8 additions & 8 deletions docs/getting_started/using_ssh_tunnel.md
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# Using `SSHTunnel` to conect to remote database
# Using `SSHTunnel` to connect to remote database

One of the typical scenarios to use `maggma` is to connect to a remote database that is behind a firewall and thus cannot be accessed directly from your local computer (as shown below, [image credits](https://github.com/pahaz/sshtunnel/)).
One of the typical scenarios to use `maggma` is to connect to a remote database that is behind a firewall and thus cannot be accessed directly from your local computer (as shown below, [image credits](https://github.com/pahaz/sshtunnel/)).

In this case, you can use `SSHTunnel` to first connect to the remote server, and then connect to the database from the server.
In this case, you can use `SSHTunnel` to first connect to the remote server, and then connect to the database from the server.

```
----------------------------------------------------------------------
Expand All @@ -17,10 +17,10 @@ In this case, you can use `SSHTunnel` to first connect to the remote server, and
----------------------------------------------------------------------
Note, the `local` indicates that the connction to the PRIVATE SERVER can only be made from the REMOTE SERVER.
Note, the `local` indicates that the connection to the PRIVATE SERVER can only be made from the REMOTE SERVER.
```

## Example usage with `S3Store`
## Example usage with `S3Store`

Below is an example of how to use `SSHTunnel` to connect to an AWS `S3Store` hosted on a private server.

Expand All @@ -45,17 +45,17 @@ tunnel = SSHTunnel(
tunnel_server_address = "<REMOTE_SERVER_ADDRESS>:22",
username = "<USERNAME>",
password= "<USER_CREDENTIAL>",
remote_server_address = "COMPUTE_NODE_1:9000",
remote_server_address = "COMPUTE_NODE_1:9000",
local_port = 9000,
)
```
and then pass it to the `S3Store` to connect to the database. The arguments of the `SSHTunnel` are self-explanatory, but `local_port` needs more explanation. We assume that on the local computer, we want to connect to the localhost address `http://127.0.0.1`, so we do not need to provide the address, but only the port number (`9000` in this case.)

In essence, `SSHTunnel` allows the connection to the database at `COMPUTE_NODE_1:9000` on the private server from the localhost address `http://127.0.0.1:9000` on the local computer as if the database is hosted on the local computer.
In essence, `SSHTunnel` allows the connection to the database at `COMPUTE_NODE_1:9000` on the private server from the localhost address `http://127.0.0.1:9000` on the local computer as if the database is hosted on the local computer.

## Other use cases

Alternative to using `username` and `password` for authentication with the remote server, `SSHTunnel` also supports authentication using SSH keys. In this case, you will need to provide your SSH credentials using the `private_key` argument. Read the docs of the `SSHTunnel` for more information.


`SSHTunnel` can also be used with other stores such as `MongoStore`, `MongoURIStore`, and `GridFSStore`. The usage is similar to the example above, but you might need to adjust the arguments to the `SSHTunnel` to match the use case.
`SSHTunnel` can also be used with other stores such as `MongoStore`, `MongoURIStore`, and `GridFSStore`. The usage is similar to the example above, but you might need to adjust the arguments to the `SSHTunnel` to match the use case.
6 changes: 3 additions & 3 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,6 @@ plugins:
- search
- minify
- mkdocstrings:
handlers:
python:
paths: [src]
handlers:
python:
paths: [src]
3 changes: 0 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
#!/usr/bin/env python


from pathlib import Path

from setuptools import find_packages, setup
Expand Down
44 changes: 9 additions & 35 deletions src/maggma/api/query_operator/dynamic.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,7 @@ def __init__(
self.excluded_fields = excluded_fields

all_fields: Dict[str, FieldInfo] = model.model_fields
param_fields = fields or list(
set(all_fields.keys()) - set(excluded_fields or [])
)
param_fields = fields or list(set(all_fields.keys()) - set(excluded_fields or []))

# Convert the fields into operator tuples
ops = [
Expand All @@ -49,9 +47,7 @@ def query(**kwargs) -> STORE_PARAMS:
try:
criteria.append(self.mapping[k](v))
except KeyError:
raise KeyError(
f"Cannot find key {k} in current query to database mapping"
)
raise KeyError(f"Cannot find key {k} in current query to database mapping")

final_crit = {}
for entry in criteria:
Expand Down Expand Up @@ -82,9 +78,7 @@ def query(self):
"Stub query function for abstract class"

@abstractmethod
def field_to_operator(
self, name: str, field: FieldInfo
) -> List[Tuple[str, Any, Query, Callable[..., Dict]]]:
def field_to_operator(self, name: str, field: FieldInfo) -> List[Tuple[str, Any, Query, Callable[..., Dict]]]:
"""
Converts a PyDantic FieldInfo into a Tuple with the
- query param name,
Expand Down Expand Up @@ -113,9 +107,7 @@ def as_dict(self) -> Dict:
class NumericQuery(DynamicQueryOperator):
"Query Operator to enable searching on numeric fields"

def field_to_operator(
self, name: str, field: FieldInfo
) -> List[Tuple[str, Any, Query, Callable[..., Dict]]]:
def field_to_operator(self, name: str, field: FieldInfo) -> List[Tuple[str, Any, Query, Callable[..., Dict]]]:
"""
Converts a PyDantic FieldInfo into a Tuple with the
query_param name,
Expand Down Expand Up @@ -179,11 +171,7 @@ def field_to_operator(
default=None,
description=f"Query for {title} being any of these values. Provide a comma separated list.",
),
lambda val: {
f"{title}": {
"$in": [int(entry.strip()) for entry in val.split(",")]
}
},
lambda val: {f"{title}": {"$in": [int(entry.strip()) for entry in val.split(",")]}},
),
(
f"{title}_neq_any",
Expand All @@ -193,11 +181,7 @@ def field_to_operator(
description=f"Query for {title} being not any of these values. \
Provide a comma separated list.",
),
lambda val: {
f"{title}": {
"$nin": [int(entry.strip()) for entry in val.split(",")]
}
},
lambda val: {f"{title}": {"$nin": [int(entry.strip()) for entry in val.split(",")]}},
),
]
)
Expand All @@ -208,9 +192,7 @@ def field_to_operator(
class StringQueryOperator(DynamicQueryOperator):
"Query Operator to enable searching on numeric fields"

def field_to_operator(
self, name: str, field: FieldInfo
) -> List[Tuple[str, Any, Query, Callable[..., Dict]]]:
def field_to_operator(self, name: str, field: FieldInfo) -> List[Tuple[str, Any, Query, Callable[..., Dict]]]:
"""
Converts a PyDantic FieldInfo into a Tuple with the
query_param name,
Expand Down Expand Up @@ -251,11 +233,7 @@ def field_to_operator(
default=None,
description=f"Query for {title} being any of these values. Provide a comma separated list.",
),
lambda val: {
f"{title}": {
"$in": [entry.strip() for entry in val.split(",")]
}
},
lambda val: {f"{title}": {"$in": [entry.strip() for entry in val.split(",")]}},
),
(
f"{title}_neq_any",
Expand All @@ -264,11 +242,7 @@ def field_to_operator(
default=None,
description=f"Query for {title} being not any of these values. Provide a comma separated list",
),
lambda val: {
f"{title}": {
"$nin": [entry.strip() for entry in val.split(",")]
}
},
lambda val: {f"{title}": {"$nin": [entry.strip() for entry in val.split(",")]}},
),
]

Expand Down
7 changes: 3 additions & 4 deletions src/maggma/api/resource/aggregation.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Any, Dict, List, Optional, Type
import orjson

import orjson
from fastapi import HTTPException, Request, Response
from pydantic import BaseModel
from pymongo import timeout as query_timeout
Expand All @@ -11,8 +11,7 @@
from maggma.api.query_operator import QueryOperator
from maggma.api.resource import HeaderProcessor, Resource
from maggma.api.resource.utils import attach_query_ops
from maggma.api.utils import STORE_PARAMS, merge_queries
from maggma.api.utils import serialization_helper
from maggma.api.utils import STORE_PARAMS, merge_queries, serialization_helper
from maggma.core import Store


Expand Down Expand Up @@ -69,7 +68,7 @@ def build_dynamic_model_search(self):

def search(**queries: Dict[str, STORE_PARAMS]) -> Dict:
request: Request = queries.pop("request") # type: ignore
temp_response: Response = queries.pop("temp_response") # type: ignore
queries.pop("temp_response") # type: ignore

query: Dict[Any, Any] = merge_queries(list(queries.values())) # type: ignore

Expand Down
14 changes: 5 additions & 9 deletions src/maggma/api/resource/read_resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,16 +226,17 @@ def search(**queries: Dict[str, STORE_PARAMS]) -> Union[Dict, Response]:

try:
with query_timeout(self.timeout):

if isinstance(self.store, S3Store):
count = self.store.count(criteria = query.get("criteria")) # type: ignore
count = self.store.count(criteria=query.get("criteria")) # type: ignore

if self.query_disk_use:
data = list(self.store.query(**query, allow_disk_use=True)) # type: ignore
else:
data = list(self.store.query(**query))
else:
count = self.store.count(criteria = query.get("criteria"), hint = query.get("count_hint")) # type: ignore
count = self.store.count(
criteria=query.get("criteria"), hint=query.get("count_hint")
) # type: ignore

pipeline = generate_query_pipeline(query, self.store)

Expand All @@ -244,12 +245,7 @@ def search(**queries: Dict[str, STORE_PARAMS]) -> Union[Dict, Response]:
if query.get("agg_hint"):
agg_kwargs["hint"] = query["agg_hint"]

data = list(
self.store._collection.aggregate(
pipeline,
**agg_kwargs
)
)
data = list(self.store._collection.aggregate(pipeline, **agg_kwargs))

except (NetworkTimeout, PyMongoError) as e:
if e.timeout:
Expand Down
40 changes: 16 additions & 24 deletions src/maggma/api/utils.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,23 @@
import base64
import inspect
import sys
from typing import Any, Callable, Dict, List, Optional, Type
from typing import (
Any,
Callable,
Dict,
List,
Optional,
Type,
get_args, # pragma: no cover
)

from bson.objectid import ObjectId
from monty.json import MSONable
from pydantic import BaseModel
from pydantic.fields import FieldInfo
from maggma.utils import get_flat_models_from_model
from pydantic._internal._utils import lenient_issubclass
from pydantic.fields import FieldInfo
from typing_extensions import Literal, Union

if sys.version_info >= (3, 8):
from typing import get_args
else:
from typing_extensions import get_args # pragma: no cover
from maggma.utils import get_flat_models_from_model

QUERY_PARAMS = ["criteria", "properties", "skip", "limit"]
STORE_PARAMS = Dict[
Expand Down Expand Up @@ -44,12 +47,7 @@ def merge_queries(queries: List[STORE_PARAMS]) -> STORE_PARAMS:
if "properties" in sub_query:
properties.extend(sub_query["properties"])

remainder = {
k: v
for query in queries
for k, v in query.items()
if k not in ["criteria", "properties"]
}
remainder = {k: v for query in queries for k, v in query.items() if k not in ["criteria", "properties"]}

return {
"criteria": criteria,
Expand Down Expand Up @@ -94,7 +92,7 @@ def attach_signature(function: Callable, defaults: Dict, annotations: Dict):

def api_sanitize(
pydantic_model: BaseModel,
fields_to_leave: Union[str, None] = None,
fields_to_leave: Optional[Union[str, None]] = None,
allow_dict_msonable=False,
):
"""Function to clean up pydantic models for the API by:
Expand All @@ -112,9 +110,7 @@ def api_sanitize(
"""

models = [
model
for model in get_flat_models_from_model(pydantic_model)
if issubclass(model, BaseModel)
model for model in get_flat_models_from_model(pydantic_model) if issubclass(model, BaseModel)
] # type: list[BaseModel]

fields_to_leave = fields_to_leave or []
Expand All @@ -136,9 +132,7 @@ def api_sanitize(
allow_msonable_dict(sub_type)

if name not in model_fields_to_leave:
new_field = FieldInfo.from_annotated_attribute(
Optional[field_type], None
)
new_field = FieldInfo.from_annotated_attribute(Optional[field_type], None)
model.model_fields[name] = new_field

model.model_rebuild(force=True)
Expand Down Expand Up @@ -167,9 +161,7 @@ def validate_monty(cls, v, _):
errors.append("@class")

if len(errors) > 0:
raise ValueError(
"Missing Monty serialization fields in dictionary: {errors}"
)
raise ValueError("Missing Monty serialization fields in dictionary: {errors}")

return v
else:
Expand Down
3 changes: 2 additions & 1 deletion src/maggma/stores/mongolike.py
Original file line number Diff line number Diff line change
Expand Up @@ -674,7 +674,8 @@ def connect(self, force_reset: bool = False):
Loads the files into the collection in memory
Args:
force_reset: whether to reset the connection or not. If False (default) and .connect() has been called previously, the .json file will not be read in again. This can improve performance
force_reset: whether to reset the connection or not. If False (default) and .connect()
has been called previously, the .json file will not be read in again. This can improve performance
on systems with slow storage when multiple connect / disconnects are performed.
"""
if self._coll is None or force_reset:
Expand Down
10 changes: 5 additions & 5 deletions src/maggma/stores/shared_stores.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ def groupby(
sort: Optional[Dict[str, Union[Sort, int]]] = None,
skip: int = 0,
limit: int = 0,
**kwargs
**kwargs,
) -> Iterator[Tuple[Dict, List[Dict]]]:
"""
Simple grouping function that will group documents
Expand Down Expand Up @@ -194,7 +194,7 @@ def query_one(
criteria: Optional[Dict] = None,
properties: Union[Dict, List, None] = None,
sort: Optional[Dict[str, Union[Sort, int]]] = None,
**kwargs
**kwargs,
):
"""
Queries the Store for a single document
Expand Down Expand Up @@ -409,7 +409,7 @@ def query(
sort: Optional[Dict[str, Union[Sort, int]]] = None,
skip: int = 0,
limit: int = 0,
**kwargs
**kwargs,
) -> List[Dict]:
"""
Queries the Store for a set of documents
Expand Down Expand Up @@ -467,7 +467,7 @@ def groupby(
sort: Optional[Dict[str, Union[Sort, int]]] = None,
skip: int = 0,
limit: int = 0,
**kwargs
**kwargs,
) -> Iterator[Tuple[Dict, List[Dict]]]:
"""
Simple grouping function that will group documents
Expand Down Expand Up @@ -506,7 +506,7 @@ def query_one(
criteria: Optional[Dict] = None,
properties: Union[Dict, List, None] = None,
sort: Optional[Dict[str, Union[Sort, int]]] = None,
**kwargs
**kwargs,
):
"""
Queries the Store for a single document
Expand Down
Loading

0 comments on commit 2f3d741

Please sign in to comment.