Skip to content

Commit

Permalink
fix: spectacular integration with pydantic field (#1783)
Browse files Browse the repository at this point in the history
  • Loading branch information
holtgrewe authored Jul 8, 2024
1 parent 101c15f commit b1fa2b5
Show file tree
Hide file tree
Showing 5 changed files with 511 additions and 345 deletions.
4 changes: 1 addition & 3 deletions backend/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,4 @@ celery:
.PHONY: gen-api-schema
gen-api-schema:
pipenv run $(MANAGE) spectacular \
| grep -v ^Loading \
| grep -v '^The ' \
> ./varfish/tests/drf_openapi_schema/varfish_api_schema.yaml
--file ./varfish/tests/drf_openapi_schema/varfish_api_schema.yaml
111 changes: 59 additions & 52 deletions backend/varfish/spectacular_utils.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,68 @@
from enum import Enum
from inspect import isclass
import typing

from drf_spectacular.drainage import set_override, warn
from drf_spectacular.extensions import OpenApiSerializerExtension
from drf_spectacular.plumbing import ResolvedComponent, build_basic_type
from drf_spectacular.types import OpenApiTypes
from drf_spectacular.drainage import set_override
from drf_spectacular.extensions import OpenApiSerializerFieldExtension
from drf_spectacular.plumbing import ResolvedComponent
import pydantic
from pydantic.json_schema import model_json_schema


class DjangoPydanticFieldFix(OpenApiSerializerExtension):
def pydantic_to_json_schema(schema_arg: typing.Any) -> typing.Dict[str, typing.Any]:
"""Convert a Python/pydantic schema to a JSON schema."""
if type(schema_arg) is type(int) or type(schema_arg) is type(float):
return {
"type": "number",
}
elif type(schema_arg) is type(str):
return {
"type": "string",
}
elif type(schema_arg) is type(None):
return {
"type": "null",
}
elif isclass(schema_arg) and issubclass(schema_arg, Enum):
return {
"type": "string",
"title": schema_arg.__name__,
"enum": [e.value for e in schema_arg],
}
elif typing.get_origin(schema_arg) is typing.Union: # is typing.Optional[X]
schema_arg = typing.get_args(schema_arg)[0]
one_ofs = [pydantic_to_json_schema(arg_inner) for arg_inner in typing.get_args(schema_arg)]
defs = {}
for one_of in one_ofs:
defs.update(one_of.pop("$defs", {}))
result = {"oneOf": one_ofs, "$defs": defs}
return result
elif typing.get_origin(schema_arg) is list:
inner_schema = pydantic_to_json_schema(typing.get_args(schema_arg)[0])
defs = inner_schema.pop("$defs", {})
return {
"type": "array",
"items": inner_schema,
"$defs": defs,
}
elif issubclass(schema_arg, Enum):
return {
"type": "string",
"title": schema_arg.__name__,
"enum": [e.value for e in schema_arg],
}
elif issubclass(schema_arg, pydantic.BaseModel):
return model_json_schema(schema_arg, ref_template="#/components/schemas/{model}")
else:
raise ValueError(f"Unsupported schema type: {schema_arg}")


class DjangoPydanticFieldFix(OpenApiSerializerFieldExtension):

target_class = "django_pydantic_field.v2.rest_framework.fields.SchemaField"
match_subclasses = True

def get_name(self, auto_schema, direction):
def get_name(self):
# due to the fact that it is complicated to pull out every field member BaseModel class
# of the entry model, we simply use the class name as string for object. This hack may
# create false positive warnings, so turn it off. However, this may suppress correct
Expand All @@ -23,53 +72,11 @@ def get_name(self, auto_schema, direction):
inner_type = typing.get_args(self.target.schema)[0]
return f"{inner_type.__name__}List"
else:
return super().get_name(auto_schema, direction)

def map_serializer(self, auto_schema, direction):
if typing.get_origin(self.target.schema) is list:
inner_type = typing.get_args(self.target.schema)[0]
if inner_type is str:
schema = {
"type": "array",
"items": {
"type": "string",
},
}
elif issubclass(inner_type, Enum):
inner_schema = {
"type": "string",
"title": inner_type.__name__,
"enum": [e.value for e in inner_type],
}
inner_schema_defs = inner_schema.pop("$defs", {})
schema = {
"type": "array",
"title": f"{inner_schema['title']}List",
"items": inner_schema,
}
schema.update({"$defs": inner_schema_defs})
else:
inner_schema = model_json_schema(
inner_type, ref_template="#/components/schemas/{model}"
)
inner_schema_defs = inner_schema.pop("$defs", {})
schema = {
"type": "array",
"title": f"{inner_schema['title']}List",
"items": inner_schema,
}
schema.update({"$defs": inner_schema_defs})
elif issubclass(self.target.schema, Enum):
return {
"type": "string",
"title": self.target.schema.__name__,
"enum": [e.value for e in self.target.schema],
}
else:
schema = model_json_schema(
self.target.schema, ref_template="#/components/schemas/{model}"
)
return super().get_name()

def map_serializer_field(self, auto_schema, direction):
_ = direction
schema = pydantic_to_json_schema(self.target.schema)
# pull out potential sub-schemas and put them into component section
for sub_name, sub_schema in schema.pop("$defs", {}).items():
component = ResolvedComponent(
Expand Down
Loading

0 comments on commit b1fa2b5

Please sign in to comment.