Skip to content

Commit

Permalink
[python-fastapi] support oneOf the pydantic v2 way
Browse files Browse the repository at this point in the history
* Support oneOf and anyOf schemas the pydantic v2 way by generating them as Unions.
* Generate model constructor that forcefully sets the discriminator field to ensure it is included in the marshalled representation.
  • Loading branch information
mgoltzsche committed Oct 16, 2024
1 parent 0ed2dbf commit 25c7fb1
Show file tree
Hide file tree
Showing 10 changed files with 122 additions and 316 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -863,6 +863,8 @@ public Map<String, ModelsMap> postProcessAllModels(Map<String, ModelsMap> objs)
codegenModelMap.put(cm.classname, ModelUtils.getModelByName(entry.getKey(), objs));
}

propagateDiscriminatorValuesToProperties(processed);

// create circular import
for (String m : codegenModelMap.keySet()) {
createImportMapOfSet(m, codegenModelMap);
Expand Down Expand Up @@ -1056,6 +1058,50 @@ private ModelsMap postProcessModelsMap(ModelsMap objs) {
return objs;
}

private void propagateDiscriminatorValuesToProperties(Map<String, ModelsMap> objMap) {
HashMap<String, CodegenModel> modelMap = new HashMap<>();
for (Map.Entry<String, ModelsMap> entry : objMap.entrySet()) {
for (ModelMap m : entry.getValue().getModels()) {
modelMap.put(String.format("#/components/schemas/%s", entry.getKey()), m.getModel());
}
}

for (Map.Entry<String, ModelsMap> entry : objMap.entrySet()) {
for (ModelMap m : entry.getValue().getModels()) {
CodegenModel model = m.getModel();
if (model.discriminator != null && !model.oneOf.isEmpty()) {
// Populate default, implicit discriminator values
for (String typeName : model.oneOf) {
ModelsMap obj = objMap.get(typeName);
if (obj == null) {
continue;
}
for (ModelMap m1 : obj.getModels()) {
for (CodegenProperty p : m1.getModel().vars) {
if (p.baseName.equals(model.discriminator.getPropertyBaseName())) {
p.isDiscriminator = true;
p.discriminatorValue = typeName;
}
}
}
}
// Populate explicit discriminator values from mapping, overwriting default values
for (Map.Entry<String, String> discrEntry : model.discriminator.getMapping().entrySet()) {
CodegenModel resolved = modelMap.get(discrEntry.getValue());
if (resolved != null) {
for (CodegenProperty p : resolved.vars) {
if (p.baseName.equals(model.discriminator.getPropertyBaseName())) {
p.isDiscriminator = true;
p.discriminatorValue = discrEntry.getKey();
}
}
}
}
}
}
}
}


/*
* Gets the pydantic type given a Codegen Property
Expand Down Expand Up @@ -2129,7 +2175,16 @@ private PythonType getType(CodegenProperty cp) {
}

private String finalizeType(CodegenProperty cp, PythonType pt) {
if (!cp.required || cp.isNullable) {
if (cp.isDiscriminator) {
moduleImports.add("typing", "Literal");
PythonType literal = new PythonType("Literal");
String literalValue = String.format("'%s'", cp.discriminatorValue);
PythonType valueType = new PythonType(literalValue);
literal.addTypeParam(valueType);
literal.setDefaultValue(literalValue);
cp.setDefaultValue(literalValue);
pt = literal;
} else if (!cp.required || cp.isNullable) {
moduleImports.add("typing", "Optional");
PythonType opt = new PythonType("Optional");
opt.addTypeParam(pt);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,173 +14,51 @@ import re # noqa: F401
{{/vendorExtensions.x-py-model-imports}}
from typing import Union, Any, List, TYPE_CHECKING, Optional, Dict
from typing_extensions import Literal
from pydantic import StrictStr, Field
from pydantic import StrictStr, Field, RootModel
try:
from typing import Self
except ImportError:
from typing_extensions import Self

{{#lambda.uppercase}}{{{classname}}}{{/lambda.uppercase}}_ANY_OF_SCHEMAS = [{{#anyOf}}"{{.}}"{{^-last}}, {{/-last}}{{/anyOf}}]

class {{classname}}({{#parent}}{{{.}}}{{/parent}}{{^parent}}BaseModel{{/parent}}):
class {{classname}}({{#parent}}{{{.}}}{{/parent}}{{^parent}}RootModel{{/parent}}):
"""
{{{description}}}{{^description}}{{{classname}}}{{/description}}
"""

{{#composedSchemas.anyOf}}
# data type: {{{dataType}}}
{{vendorExtensions.x-py-name}}: {{{vendorExtensions.x-py-typing}}}
{{/composedSchemas.anyOf}}
if TYPE_CHECKING:
actual_instance: Optional[Union[{{#anyOf}}{{{.}}}{{^-last}}, {{/-last}}{{/anyOf}}]] = None
else:
actual_instance: Any = None
any_of_schemas: List[str] = Literal[{{#lambda.uppercase}}{{{classname}}}{{/lambda.uppercase}}_ANY_OF_SCHEMAS]
root: Union[{{#anyOf}}{{{.}}}{{^-last}}, {{/-last}}{{/anyOf}}] = None

model_config = {
"validate_assignment": True,
"protected_namespaces": (),
}
{{#discriminator}}

discriminator_value_class_map: Dict[str, str] = {
{{#children}}
'{{^vendorExtensions.x-discriminator-value}}{{name}}{{/vendorExtensions.x-discriminator-value}}{{#vendorExtensions.x-discriminator-value}}{{{vendorExtensions.x-discriminator-value}}}{{/vendorExtensions.x-discriminator-value}}': '{{{classname}}}'{{^-last}},{{/-last}}
{{/children}}
}
{{/discriminator}}

def __init__(self, *args, **kwargs) -> None:
if args:
if len(args) > 1:
raise ValueError("If a position argument is used, only 1 is allowed to set `actual_instance`")
if kwargs:
raise ValueError("If a position argument is used, keyword arguments cannot be used.")
super().__init__(actual_instance=args[0])
else:
super().__init__(**kwargs)

@field_validator('actual_instance')
def actual_instance_must_validate_anyof(cls, v):
{{#isNullable}}
if v is None:
return v

{{/isNullable}}
instance = {{{classname}}}.model_construct()
error_messages = []
{{#composedSchemas.anyOf}}
# validate data type: {{{dataType}}}
{{#isContainer}}
try:
instance.{{vendorExtensions.x-py-name}} = v
return v
except (ValidationError, ValueError) as e:
error_messages.append(str(e))
{{/isContainer}}
{{^isContainer}}
{{#isPrimitiveType}}
try:
instance.{{vendorExtensions.x-py-name}} = v
return v
except (ValidationError, ValueError) as e:
error_messages.append(str(e))
{{/isPrimitiveType}}
{{^isPrimitiveType}}
if not isinstance(v, {{{dataType}}}):
error_messages.append(f"Error! Input type `{type(v)}` is not `{{{dataType}}}`")
else:
return v

{{/isPrimitiveType}}
{{/isContainer}}
{{/composedSchemas.anyOf}}
if error_messages:
# no match
raise ValueError("No match found when setting the actual_instance in {{{classname}}} with anyOf schemas: {{#anyOf}}{{{.}}}{{^-last}}, {{/-last}}{{/anyOf}}. Details: " + ", ".join(error_messages))
else:
return v
def to_str(self) -> str:
"""Returns the string representation of the model using alias"""
return pprint.pformat(self.model_dump(by_alias=True))

@classmethod
def from_dict(cls, obj: dict) -> Self:
return cls.from_json(json.dumps(obj))
def to_json(self) -> str:
"""Returns the JSON representation of the model using alias"""
return self.model_dump_json(by_alias=True, exclude_unset=True)

@classmethod
def from_json(cls, json_str: str) -> Self:
"""Returns the object represented by the json string"""
instance = cls.model_construct()
{{#isNullable}}
if json_str is None:
return instance

{{/isNullable}}
error_messages = []
{{#composedSchemas.anyOf}}
{{#isContainer}}
# deserialize data into {{{dataType}}}
try:
# validation
instance.{{vendorExtensions.x-py-name}} = json.loads(json_str)
# assign value to actual_instance
instance.actual_instance = instance.{{vendorExtensions.x-py-name}}
return instance
except (ValidationError, ValueError) as e:
error_messages.append(str(e))
{{/isContainer}}
{{^isContainer}}
{{#isPrimitiveType}}
# deserialize data into {{{dataType}}}
try:
# validation
instance.{{vendorExtensions.x-py-name}} = json.loads(json_str)
# assign value to actual_instance
instance.actual_instance = instance.{{vendorExtensions.x-py-name}}
return instance
except (ValidationError, ValueError) as e:
error_messages.append(str(e))
{{/isPrimitiveType}}
{{^isPrimitiveType}}
# {{vendorExtensions.x-py-name}}: {{{vendorExtensions.x-py-typing}}}
try:
instance.actual_instance = {{{dataType}}}.from_json(json_str)
return instance
except (ValidationError, ValueError) as e:
error_messages.append(str(e))
{{/isPrimitiveType}}
{{/isContainer}}
{{/composedSchemas.anyOf}}
def from_json(cls, json_str: str) -> {{^hasChildren}}Self{{/hasChildren}}{{#hasChildren}}{{#discriminator}}Union[{{#children}}Self{{^-last}}, {{/-last}}{{/children}}]{{/discriminator}}{{^discriminator}}Self{{/discriminator}}{{/hasChildren}}:
"""Create an instance of {{{classname}}} from a JSON string"""
return cls.from_dict(json.loads(json_str))

if error_messages:
# no match
raise ValueError("No match found when deserializing the JSON string into {{{classname}}} with anyOf schemas: {{#anyOf}}{{{.}}}{{^-last}}, {{/-last}}{{/anyOf}}. Details: " + ", ".join(error_messages))
else:
return instance

def to_json(self) -> str:
"""Returns the JSON representation of the actual instance"""
if self.actual_instance is None:
return "null"
def to_dict(self) -> Dict[str, Any]:
"""Return the dictionary representation of the model using alias"""
return self.model_dump(by_alias=True, exclude_unset=True)

to_json = getattr(self.actual_instance, "to_json", None)
if callable(to_json):
return self.actual_instance.to_json()
else:
return json.dumps(self.actual_instance)

def to_dict(self) -> Dict:
"""Returns the dict representation of the actual instance"""
if self.actual_instance is None:
return "null"
@classmethod
def from_dict(cls, obj: Dict) -> {{^hasChildren}}Self{{/hasChildren}}{{#hasChildren}}{{#discriminator}}Union[{{#children}}Self{{^-last}}, {{/-last}}{{/children}}]{{/discriminator}}{{^discriminator}}Self{{/discriminator}}{{/hasChildren}}:
"""Create an instance of {{{classname}}} from a dict"""
if obj is None:
return None

to_json = getattr(self.actual_instance, "to_json", None)
if callable(to_json):
return self.actual_instance.to_dict()
else:
return json.dumps(self.actual_instance)
if not isinstance(obj, dict):
return cls.model_validate(obj)

def to_str(self) -> str:
"""Returns the string representation of the actual instance"""
return pprint.pformat(self.model_dump())
return cls.parse_obj(obj)

{{#vendorExtensions.x-py-postponed-model-imports.size}}
{{#vendorExtensions.x-py-postponed-model-imports}}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,13 @@ class {{classname}}({{#parent}}{{{.}}}{{/parent}}{{^parent}}BaseModel{{/parent}}
{{/isAdditionalPropertiesTrue}}
}

def __init__(self, *a, **kw):
super().__init__(*a, **kw)
{{#vars}}
{{#isDiscriminator}}
self.{{name}} = self.{{name}}
{{/isDiscriminator}}
{{/vars}}

def to_str(self) -> str:
"""Returns the string representation of the model using alias"""
Expand Down
Loading

0 comments on commit 25c7fb1

Please sign in to comment.