Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ensure all fields are properly sanitized #690

Merged
merged 1 commit into from
Jul 25, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 26 additions & 6 deletions src/maggma/api/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,10 @@

QUERY_PARAMS = ["criteria", "properties", "skip", "limit"]
STORE_PARAMS = Dict[
Literal["criteria", "properties", "sort", "skip", "limit", "request", "pipeline", "hint"], Any,
Literal[
"criteria", "properties", "sort", "skip", "limit", "request", "pipeline", "hint"
],
Any,
]


Expand All @@ -33,7 +36,12 @@ def merge_queries(queries: List[STORE_PARAMS]) -> STORE_PARAMS:
if "properties" in sub_query:
properties.extend(sub_query["properties"])

remainder = {k: v for query in queries for k, v in query.items() if k not in ["criteria", "properties"]}
remainder = {
k: v
for query in queries
for k, v in query.items()
if k not in ["criteria", "properties"]
}

return {
"criteria": criteria,
Expand Down Expand Up @@ -73,11 +81,15 @@ def attach_signature(function: Callable, defaults: Dict, annotations: Dict):
for param in defaults.keys()
]

setattr(function, "__signature__", inspect.Signature(required_params + optional_params))
setattr(
function, "__signature__", inspect.Signature(required_params + optional_params)
)


def api_sanitize(
pydantic_model: Type[BaseModel], fields_to_leave: Optional[List[str]] = None, allow_dict_msonable=False,
pydantic_model: Type[BaseModel],
fields_to_leave: Optional[List[str]] = None,
allow_dict_msonable=False,
):
"""
Function to clean up pydantic models for the API by:
Expand All @@ -91,7 +103,9 @@ def api_sanitize(
"""

models = [
model for model in get_flat_models_from_model(pydantic_model) if issubclass(model, BaseModel)
model
for model in get_flat_models_from_model(pydantic_model)
if issubclass(model, BaseModel)
] # type: List[Type[BaseModel]]

fields_to_leave = fields_to_leave or []
Expand All @@ -105,7 +119,11 @@ def api_sanitize(

if name not in model_fields_to_leave:
field.required = False
field.default = None
field.default_factory = None
field.allow_none = True
field.field_info.default = None
field.field_info.default_factory = None

if field_type is not None and allow_dict_msonable:
if lenient_issubclass(field_type, MSONable):
Expand Down Expand Up @@ -140,7 +158,9 @@ def validate_monty(cls, v):
errors.append("@class")

if len(errors) > 0:
raise ValueError("Missing Monty seriailzation fields in dictionary: {errors}")
raise ValueError(
"Missing Monty seriailzation fields in dictionary: {errors}"
)

return v
else:
Expand Down