diff --git a/src/maggma/api/API.py b/src/maggma/api/API.py index 3feee135d..9a7e0bdca 100644 --- a/src/maggma/api/API.py +++ b/src/maggma/api/API.py @@ -3,9 +3,9 @@ import uvicorn from fastapi import FastAPI +from fastapi.middleware.cors import CORSMiddleware from monty.json import MSONable from starlette.responses import RedirectResponse -from fastapi.middleware.cors import CORSMiddleware from maggma import __version__ from maggma.api.resource import Resource @@ -82,7 +82,7 @@ def app(self): @app.get("/heartbeat", include_in_schema=False) def heartbeat(): - """ API Heartbeat for Load Balancing """ + """API Heartbeat for Load Balancing""" return { "status": "OK", @@ -93,7 +93,7 @@ def heartbeat(): @app.get("/", include_in_schema=False) def redirect_docs(): - """ Redirects the root end point to the docs """ + """Redirects the root end point to the docs""" return RedirectResponse(url=app.docs_url, status_code=301) return app diff --git a/src/maggma/api/query_operator/__init__.py b/src/maggma/api/query_operator/__init__.py index b3b1e8a30..fb2d853b9 100644 --- a/src/maggma/api/query_operator/__init__.py +++ b/src/maggma/api/query_operator/__init__.py @@ -1,6 +1,6 @@ from maggma.api.query_operator.core import QueryOperator from maggma.api.query_operator.dynamic import NumericQuery, StringQueryOperator from maggma.api.query_operator.pagination import PaginationQuery -from maggma.api.query_operator.sparse_fields import SparseFieldsQuery from maggma.api.query_operator.sorting import SortQuery +from maggma.api.query_operator.sparse_fields import SparseFieldsQuery from maggma.api.query_operator.submission import SubmissionQuery diff --git a/src/maggma/api/query_operator/dynamic.py b/src/maggma/api/query_operator/dynamic.py index cf47b7c5b..aac6567b4 100644 --- a/src/maggma/api/query_operator/dynamic.py +++ b/src/maggma/api/query_operator/dynamic.py @@ -1,7 +1,6 @@ import inspect -from typing import Type from abc import abstractmethod -from typing import Any, Callable, Dict, List, Optional, Tuple +from typing import Any, Callable, Dict, List, Optional, Tuple, Type from fastapi.params import Query from monty.json import MontyDecoder @@ -81,7 +80,7 @@ def query(**kwargs) -> STORE_PARAMS: self.query = query # type: ignore def query(self): - " Stub query function for abstract class " + "Stub query function for abstract class" pass @abstractmethod @@ -115,7 +114,7 @@ def as_dict(self) -> Dict: class NumericQuery(DynamicQueryOperator): - " Query Operator to enable searching on numeric fields" + "Query Operator to enable searching on numeric fields" def field_to_operator( self, name: str, field: ModelField @@ -140,7 +139,8 @@ def field_to_operator( f"{field.name}_max", field_type, Query( - default=None, description=f"Query for maximum value of {title}", + default=None, + description=f"Query for maximum value of {title}", ), lambda val: {f"{field.name}": {"$lte": val}}, ), @@ -148,7 +148,8 @@ def field_to_operator( f"{field.name}_min", field_type, Query( - default=None, description=f"Query for minimum value of {title}", + default=None, + description=f"Query for minimum value of {title}", ), lambda val: {f"{field.name}": {"$gte": val}}, ), @@ -209,7 +210,7 @@ def field_to_operator( class StringQueryOperator(DynamicQueryOperator): - " Query Operator to enable searching on numeric fields" + "Query Operator to enable searching on numeric fields" def field_to_operator( self, name: str, field: ModelField diff --git a/src/maggma/api/query_operator/pagination.py b/src/maggma/api/query_operator/pagination.py index 0b700cb7f..0d5b5b850 100644 --- a/src/maggma/api/query_operator/pagination.py +++ b/src/maggma/api/query_operator/pagination.py @@ -9,7 +9,9 @@ class PaginationQuery(QueryOperator): """Query opertators to provides Pagination""" - def __init__(self, default_skip: int = 0, default_limit: int = 100, max_limit: int = 1000): + def __init__( + self, default_skip: int = 0, default_limit: int = 100, max_limit: int = 1000 + ): """ Args: default_skip: the default number of documents to skip @@ -21,10 +23,13 @@ def __init__(self, default_skip: int = 0, default_limit: int = 100, max_limit: i self.max_limit = max_limit def query( - skip: int = Query(default_skip, description="Number of entries to skip in the search"), + skip: int = Query( + default_skip, description="Number of entries to skip in the search" + ), limit: int = Query( default_limit, - description="Max number of entries to return in a single query." f" Limited to {max_limit}", + description="Max number of entries to return in a single query." + f" Limited to {max_limit}", ), ) -> STORE_PARAMS: """ @@ -41,7 +46,7 @@ def query( self.query = query # type: ignore def query(self): - " Stub query function for abstract class " + "Stub query function for abstract class" pass def meta(self) -> Dict: diff --git a/src/maggma/api/query_operator/sorting.py b/src/maggma/api/query_operator/sorting.py index 96fa44ce8..bd3a7e0e2 100644 --- a/src/maggma/api/query_operator/sorting.py +++ b/src/maggma/api/query_operator/sorting.py @@ -1,7 +1,9 @@ from typing import Optional + +from fastapi import HTTPException, Query + from maggma.api.query_operator import QueryOperator from maggma.api.utils import STORE_PARAMS -from fastapi import HTTPException, Query class SortQuery(QueryOperator): @@ -12,7 +14,10 @@ class SortQuery(QueryOperator): def query( self, field: Optional[str] = Query(None, description="Field to sort with"), - ascending: Optional[bool] = Query(None, description="Whether the sorting should be ascending",), + ascending: Optional[bool] = Query( + None, + description="Whether the sorting should be ascending", + ), ) -> STORE_PARAMS: sort = {} @@ -22,7 +27,8 @@ def query( elif field or ascending is not None: raise HTTPException( - status_code=400, detail="Must specify both a field and order for sorting.", + status_code=400, + detail="Must specify both a field and order for sorting.", ) return {"sort": sort} diff --git a/src/maggma/api/query_operator/sparse_fields.py b/src/maggma/api/query_operator/sparse_fields.py index 7c3462abe..7fb8e2324 100644 --- a/src/maggma/api/query_operator/sparse_fields.py +++ b/src/maggma/api/query_operator/sparse_fields.py @@ -49,7 +49,7 @@ def query( self.query = query # type: ignore def query(self): - " Stub query function for abstract class " + "Stub query function for abstract class" pass def meta(self) -> Dict: diff --git a/src/maggma/api/query_operator/submission.py b/src/maggma/api/query_operator/submission.py index 3e4a764a1..7b1c29469 100644 --- a/src/maggma/api/query_operator/submission.py +++ b/src/maggma/api/query_operator/submission.py @@ -1,8 +1,10 @@ +from datetime import datetime from typing import Optional + +from fastapi import Query + from maggma.api.query_operator import QueryOperator from maggma.api.utils import STORE_PARAMS -from fastapi import Query -from datetime import datetime class SubmissionQuery(QueryOperator): @@ -19,7 +21,8 @@ def query( None, description="Latest status of the submission" ), last_updated: Optional[datetime] = Query( - None, description="Minimum datetime of status update for submission", + None, + description="Minimum datetime of status update for submission", ), ) -> STORE_PARAMS: @@ -45,5 +48,5 @@ def query( self.query = query def query(self): - " Stub query function for abstract class " + "Stub query function for abstract class" pass diff --git a/src/maggma/api/resource/__init__.py b/src/maggma/api/resource/__init__.py index d2a5f1a2c..e4e2f2277 100644 --- a/src/maggma/api/resource/__init__.py +++ b/src/maggma/api/resource/__init__.py @@ -1,5 +1,9 @@ +# isort: off from maggma.api.resource.core import Resource -from maggma.api.resource.read_resource import ReadOnlyResource, attach_query_ops -from maggma.api.resource.submission import SubmissionResource + +# isort: on + from maggma.api.resource.aggregation import AggregationResource from maggma.api.resource.post_resource import PostOnlyResource +from maggma.api.resource.read_resource import ReadOnlyResource, attach_query_ops +from maggma.api.resource.submission import SubmissionResource diff --git a/src/maggma/api/resource/aggregation.py b/src/maggma/api/resource/aggregation.py index 42f7aff4c..e34e2e9c3 100644 --- a/src/maggma/api/resource/aggregation.py +++ b/src/maggma/api/resource/aggregation.py @@ -7,10 +7,7 @@ from maggma.api.query_operator import QueryOperator from maggma.api.resource import Resource from maggma.api.resource.utils import attach_query_ops -from maggma.api.utils import ( - STORE_PARAMS, - merge_queries, -) +from maggma.api.utils import STORE_PARAMS, merge_queries from maggma.core import Store diff --git a/src/maggma/api/resource/post_resource.py b/src/maggma/api/resource/post_resource.py index 0809da008..3e156e4e0 100644 --- a/src/maggma/api/resource/post_resource.py +++ b/src/maggma/api/resource/post_resource.py @@ -1,21 +1,14 @@ -from typing import Any, Dict, List, Optional, Type from inspect import signature +from typing import Any, Dict, List, Optional, Type from fastapi import HTTPException, Request from pydantic import BaseModel from maggma.api.models import Meta, Response -from maggma.api.query_operator import ( - PaginationQuery, - QueryOperator, - SparseFieldsQuery, -) +from maggma.api.query_operator import PaginationQuery, QueryOperator, SparseFieldsQuery from maggma.api.resource import Resource from maggma.api.resource.utils import attach_query_ops -from maggma.api.utils import ( - STORE_PARAMS, - merge_queries, -) +from maggma.api.utils import STORE_PARAMS, merge_queries from maggma.core import Store diff --git a/src/maggma/api/resource/read_resource.py b/src/maggma/api/resource/read_resource.py index f6f84d54a..e9b2c33f0 100644 --- a/src/maggma/api/resource/read_resource.py +++ b/src/maggma/api/resource/read_resource.py @@ -1,21 +1,14 @@ -from typing import Any, Dict, List, Optional, Type from inspect import signature +from typing import Any, Dict, List, Optional, Type from fastapi import Depends, HTTPException, Path, Request from pydantic import BaseModel from maggma.api.models import Meta, Response -from maggma.api.query_operator import ( - PaginationQuery, - QueryOperator, - SparseFieldsQuery, -) +from maggma.api.query_operator import PaginationQuery, QueryOperator, SparseFieldsQuery from maggma.api.resource import Resource from maggma.api.resource.utils import attach_query_ops -from maggma.api.utils import ( - STORE_PARAMS, - merge_queries, -) +from maggma.api.utils import STORE_PARAMS, merge_queries from maggma.core import Store @@ -104,7 +97,9 @@ def field_input(): async def get_by_key( key: str = Path( - ..., alias=key_name, title=f"The {key_name} of the {model_name} to get", + ..., + alias=key_name, + title=f"The {key_name} of the {model_name} to get", ), fields: STORE_PARAMS = Depends(field_input), ): diff --git a/src/maggma/api/resource/submission.py b/src/maggma/api/resource/submission.py index 4bd1f37a8..6058870f1 100644 --- a/src/maggma/api/resource/submission.py +++ b/src/maggma/api/resource/submission.py @@ -1,23 +1,18 @@ -from typing import Any, List, Optional, Type -from inspect import signature from datetime import datetime +from enum import Enum +from inspect import signature +from typing import Any, List, Optional, Type from uuid import uuid4 from fastapi import HTTPException, Path, Request +from pydantic import BaseModel, Field, create_model -from maggma.api.models import Response, Meta - +from maggma.api.models import Meta, Response from maggma.api.query_operator import QueryOperator, SubmissionQuery - from maggma.api.resource import Resource from maggma.api.resource.utils import attach_query_ops -from maggma.api.utils import ( - STORE_PARAMS, - merge_queries, -) +from maggma.api.utils import STORE_PARAMS, merge_queries from maggma.core import Store -from enum import Enum -from pydantic import create_model, Field, BaseModel class SubmissionResource(Resource): @@ -73,9 +68,7 @@ def __init__( self.tags = tags or [] self.post_query_operators = post_query_operators self.get_query_operators = ( - [ - op for op in get_query_operators if op is not None # type: ignore - ] + [op for op in get_query_operators if op is not None] # type: ignore + [SubmissionQuery(state_enum)] if state_enum is not None else get_query_operators @@ -281,7 +274,8 @@ async def post_data(**queries: STORE_PARAMS): self.store.update(docs=query["criteria"]) # type: ignore except Exception: raise HTTPException( - status_code=400, detail="Problem when trying to post data.", + status_code=400, + detail="Problem when trying to post data.", ) response = { diff --git a/src/maggma/api/resource/utils.py b/src/maggma/api/resource/utils.py index 0edb0f34c..8d279e652 100644 --- a/src/maggma/api/resource/utils.py +++ b/src/maggma/api/resource/utils.py @@ -1,5 +1,7 @@ -from typing import Callable, List, Dict -from fastapi import Request, Depends +from typing import Callable, Dict, List + +from fastapi import Depends, Request + from maggma.api.query_operator import QueryOperator from maggma.api.utils import STORE_PARAMS, attach_signature diff --git a/src/maggma/api/utils.py b/src/maggma/api/utils.py index 7ad505353..7b1e0893c 100644 --- a/src/maggma/api/utils.py +++ b/src/maggma/api/utils.py @@ -1,4 +1,5 @@ import inspect +import sys from typing import Any, Callable, Dict, List, Optional, Type from monty.json import MSONable @@ -7,6 +8,12 @@ from pydantic.utils import lenient_issubclass from typing_extensions import Literal +if sys.version_info >= (3, 8): + from typing import get_args +else: + from typing_extensions import get_args + + QUERY_PARAMS = ["criteria", "properties", "skip", "limit"] STORE_PARAMS = Dict[ Literal["criteria", "properties", "sort", "skip", "limit", "request", "pipeline"], @@ -110,12 +117,13 @@ def api_sanitize( field.required = False field.field_info.default = None - if ( - field_type is not None - and lenient_issubclass(field_type, MSONable) - and allow_dict_msonable - ): - field.type_ = allow_msonable_dict(field_type) + if field_type is not None and allow_dict_msonable: + if lenient_issubclass(field_type, MSONable): + field.type_ = allow_msonable_dict(field_type) + else: + for sub_type in get_args(field_type): + if lenient_issubclass(sub_type, MSONable): + allow_msonable_dict(sub_type) field.populate_validators() return pydantic_model diff --git a/src/maggma/core/store.py b/src/maggma/core/store.py index d95cc94ef..d9ca5bbe6 100644 --- a/src/maggma/core/store.py +++ b/src/maggma/core/store.py @@ -18,14 +18,14 @@ class Sort(Enum): - """ Enumeration for sorting order """ + """Enumeration for sorting order""" Ascending = 1 Descending = -1 class DateTimeFormat(Enum): - """ Datetime format in store document """ + """Datetime format in store document""" DateTime = "datetime" IsoFormat = "isoformat" @@ -354,7 +354,7 @@ def __exit__(self, exception_type, exception_value, traceback): class StoreError(Exception): - """ General Store-related error """ + """General Store-related error""" def __init__(self, *args, **kwargs): super().__init__(self, *args, **kwargs) diff --git a/src/maggma/stores/mongolike.py b/src/maggma/stores/mongolike.py index fdc4988bd..db92362d4 100644 --- a/src/maggma/stores/mongolike.py +++ b/src/maggma/stores/mongolike.py @@ -161,7 +161,7 @@ def connect(self, force_reset: bool = False): self._collection = db[self.collection_name] def __hash__(self) -> int: - """ Hash for MongoStore """ + """Hash for MongoStore""" return hash((self.database, self.collection_name, self.last_updated_field)) @classmethod @@ -177,7 +177,9 @@ def from_db_file(cls, filename: str): kwargs.pop("aliases", None) return cls(**kwargs) - def distinct(self, field: str, criteria: Optional[Dict] = None, all_exist: bool = False) -> List: + def distinct( + self, field: str, criteria: Optional[Dict] = None, all_exist: bool = False + ) -> List: """ Get all distinct values for a field @@ -190,7 +192,10 @@ def distinct(self, field: str, criteria: Optional[Dict] = None, all_exist: bool try: distinct_vals = self._collection.distinct(field, criteria) except (OperationFailure, DocumentTooLarge): - distinct_vals = [d["_id"] for d in self._collection.aggregate([{"$group": {"_id": f"${field}"}}])] + distinct_vals = [ + d["_id"] + for d in self._collection.aggregate([{"$group": {"_id": f"${field}"}}]) + ] if all(isinstance(d, list) for d in filter(None, distinct_vals)): # type: ignore distinct_vals = list(chain.from_iterable(filter(None, distinct_vals))) @@ -267,7 +272,7 @@ def from_collection(cls, collection): @property # type: ignore @deprecated(message="This will be removed in the future") def collection(self): - """ Property referring to underlying pymongo collection """ + """Property referring to underlying pymongo collection""" if self._collection is None: raise StoreError("Must connect Mongo-like store before attemping to use it") return self._collection @@ -306,10 +311,21 @@ def query( properties = {p: 1 for p in properties} sort_list = ( - [(k, Sort(v).value) if isinstance(v, int) else (k, v.value) for k, v in sort.items()] if sort else None + [ + (k, Sort(v).value) if isinstance(v, int) else (k, v.value) + for k, v in sort.items() + ] + if sort + else None ) - for d in self._collection.find(filter=criteria, projection=properties, skip=skip, limit=limit, sort=sort_list,): + for d in self._collection.find( + filter=criteria, + projection=properties, + skip=skip, + limit=limit, + sort=sort_list, + ): yield d def ensure_index(self, key: str, unique: Optional[bool] = False) -> bool: @@ -398,7 +414,7 @@ def remove_docs(self, criteria: Dict): self._collection.delete_many(filter=criteria) def close(self): - """ Close up all collections """ + """Close up all collections""" self._collection.database.client.close() if self.ssh_tunnel is not None: self.ssh_tunnel.stop() @@ -423,7 +439,12 @@ class MongoURIStore(MongoStore): """ def __init__( - self, uri: str, collection_name: str, database: str = None, ssh_tunnel: Optional[SSHTunnel] = None, **kwargs + self, + uri: str, + collection_name: str, + database: str = None, + ssh_tunnel: Optional[SSHTunnel] = None, + **kwargs, ): """ Args: @@ -438,7 +459,9 @@ def __init__( if database is None: d_uri = uri_parser.parse_uri(uri) if d_uri["database"] is None: - raise ConfigurationError("If database name is not supplied, a database must be set in the uri") + raise ConfigurationError( + "If database name is not supplied, a database must be set in the uri" + ) self.database = d_uri["database"] else: self.database = database @@ -494,11 +517,11 @@ def connect(self, force_reset: bool = False): @property def name(self): - """ Name for the store """ + """Name for the store""" return f"mem://{self.collection_name}" def __hash__(self): - """ Hash for the store """ + """Hash for the store""" return hash((self.name, self.last_updated_field)) def groupby( @@ -527,7 +550,11 @@ def groupby( generator returning tuples of (key, list of elemnts) """ keys = keys if isinstance(keys, list) else [keys] - data = [doc for doc in self.query(properties=keys, criteria=criteria) if all(has(doc, k) for k in keys)] + data = [ + doc + for doc in self.query(properties=keys, criteria=criteria) + if all(has(doc, k) for k in keys) + ] def grouping_keys(doc): return tuple(get(doc, k) for k in keys) diff --git a/tests/api/test_utils.py b/tests/api/test_utils.py index 8b05fa621..8fe48b3de 100644 --- a/tests/api/test_utils.py +++ b/tests/api/test_utils.py @@ -5,7 +5,8 @@ from monty.json import MSONable from pydantic import BaseModel, Field -from maggma.api.utils import api_sanitize, merge_queries +from maggma.api.utils import api_sanitize +from typing import Union class SomeEnum(Enum): @@ -20,6 +21,17 @@ def __init__(self, name, age): self.age = age +class AnotherPet(MSONable): + def __init__(self, name, age): + self.name = name + self.age = age + + +class AnotherOwner(BaseModel): + name: str = Field(..., description="Ower name") + weight_or_pet: Union[float, AnotherPet] = Field(..., title="Owners weight or Pet") + + class Owner(BaseModel): name: str = Field(..., title="Owner's name") age: int = Field(..., title="Owne'r Age") @@ -71,3 +83,9 @@ def test_api_sanitize(): # This should work assert isinstance(Owner(pet=temp_pet_dict).pet, dict) + + # This should work evne though AnotherPet is inside the Union type + api_sanitize(AnotherOwner, allow_dict_msonable=True) + temp_pet_dict = AnotherPet(name="fido", age=3).as_dict() + + assert isinstance(AnotherPet.validate_monty(temp_pet_dict), dict)