From 2ce7d4f9b7241433cc45f700339ee9ed3b8f92d9 Mon Sep 17 00:00:00 2001 From: Jason Munro Date: Mon, 25 Jul 2022 12:02:56 -0700 Subject: [PATCH] Ensure all fields are properly sanitized --- src/maggma/api/utils.py | 32 ++++++++++++++++++++++++++------ 1 file changed, 26 insertions(+), 6 deletions(-) diff --git a/src/maggma/api/utils.py b/src/maggma/api/utils.py index c86d94f22..95c3b2bda 100644 --- a/src/maggma/api/utils.py +++ b/src/maggma/api/utils.py @@ -18,7 +18,10 @@ QUERY_PARAMS = ["criteria", "properties", "skip", "limit"] STORE_PARAMS = Dict[ - Literal["criteria", "properties", "sort", "skip", "limit", "request", "pipeline", "hint"], Any, + Literal[ + "criteria", "properties", "sort", "skip", "limit", "request", "pipeline", "hint" + ], + Any, ] @@ -33,7 +36,12 @@ def merge_queries(queries: List[STORE_PARAMS]) -> STORE_PARAMS: if "properties" in sub_query: properties.extend(sub_query["properties"]) - remainder = {k: v for query in queries for k, v in query.items() if k not in ["criteria", "properties"]} + remainder = { + k: v + for query in queries + for k, v in query.items() + if k not in ["criteria", "properties"] + } return { "criteria": criteria, @@ -73,11 +81,15 @@ def attach_signature(function: Callable, defaults: Dict, annotations: Dict): for param in defaults.keys() ] - setattr(function, "__signature__", inspect.Signature(required_params + optional_params)) + setattr( + function, "__signature__", inspect.Signature(required_params + optional_params) + ) def api_sanitize( - pydantic_model: Type[BaseModel], fields_to_leave: Optional[List[str]] = None, allow_dict_msonable=False, + pydantic_model: Type[BaseModel], + fields_to_leave: Optional[List[str]] = None, + allow_dict_msonable=False, ): """ Function to clean up pydantic models for the API by: @@ -91,7 +103,9 @@ def api_sanitize( """ models = [ - model for model in get_flat_models_from_model(pydantic_model) if issubclass(model, BaseModel) + model + for model in get_flat_models_from_model(pydantic_model) + if issubclass(model, BaseModel) ] # type: List[Type[BaseModel]] fields_to_leave = fields_to_leave or [] @@ -105,7 +119,11 @@ def api_sanitize( if name not in model_fields_to_leave: field.required = False + field.default = None + field.default_factory = None + field.allow_none = True field.field_info.default = None + field.field_info.default_factory = None if field_type is not None and allow_dict_msonable: if lenient_issubclass(field_type, MSONable): @@ -140,7 +158,9 @@ def validate_monty(cls, v): errors.append("@class") if len(errors) > 0: - raise ValueError("Missing Monty seriailzation fields in dictionary: {errors}") + raise ValueError( + "Missing Monty seriailzation fields in dictionary: {errors}" + ) return v else: