diff --git a/modules/openapi-generator/src/main/java/org/openapitools/codegen/languages/AbstractPythonPydanticV1Codegen.java b/modules/openapi-generator/src/main/java/org/openapitools/codegen/languages/AbstractPythonPydanticV1Codegen.java new file mode 100644 index 000000000000..5fc6023360f9 --- /dev/null +++ b/modules/openapi-generator/src/main/java/org/openapitools/codegen/languages/AbstractPythonPydanticV1Codegen.java @@ -0,0 +1,1997 @@ +/* + * Copyright 2018 OpenAPI-Generator Contributors (https://openapi-generator.tech) + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.openapitools.codegen.languages; + +import com.github.curiousoddman.rgxgen.RgxGen; +import io.swagger.v3.oas.models.examples.Example; +import io.swagger.v3.oas.models.media.ArraySchema; +import io.swagger.v3.oas.models.media.Schema; +import io.swagger.v3.oas.models.parameters.Parameter; +import org.apache.commons.io.FilenameUtils; +import org.apache.commons.lang3.StringUtils; +import org.openapitools.codegen.*; +import org.openapitools.codegen.meta.features.SecurityFeature; +import org.openapitools.codegen.model.ModelMap; +import org.openapitools.codegen.model.ModelsMap; +import org.openapitools.codegen.model.OperationMap; +import org.openapitools.codegen.model.OperationsMap; +import org.openapitools.codegen.utils.ModelUtils; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.File; +import java.io.IOException; +import java.util.*; +import java.util.regex.Matcher; +import java.util.regex.Pattern; +import java.util.stream.Collectors; + +import static org.openapitools.codegen.utils.StringUtils.*; + +public abstract class AbstractPythonPydanticV1Codegen extends DefaultCodegen implements CodegenConfig { + private final Logger LOGGER = LoggerFactory.getLogger(AbstractPythonPydanticV1Codegen.class); + + public static final String MAP_NUMBER_TO = "mapNumberTo"; + + protected String packageName = "openapi_client"; + protected String packageVersion = "1.0.0"; + protected String projectName; // for setup.py, e.g. petstore-api + protected boolean hasModelsToImport = Boolean.FALSE; + protected String mapNumberTo = "Union[StrictFloat, StrictInt]"; + protected Map regexModifiers; + + private Map schemaKeyToModelNameCache = new HashMap<>(); + // map of set (model imports) + private HashMap> circularImports = new HashMap<>(); + // map of codegen models + private HashMap codegenModelMap = new HashMap<>(); + + public AbstractPythonPydanticV1Codegen() { + super(); + + modifyFeatureSet(features -> features.securityFeatures(EnumSet.of( + SecurityFeature.BasicAuth, + SecurityFeature.BearerToken, + SecurityFeature.ApiKey, + SecurityFeature.OAuth2_Implicit + ))); + + // from https://docs.python.org/3/reference/lexical_analysis.html#keywords + setReservedWordsLowerCase( + Arrays.asList( + // local variable name used in API methods (endpoints) + "all_params", "resource_path", "path_params", "query_params", + "header_params", "form_params", "local_var_files", "body_params", "auth_settings", + // @property + "property", + // python reserved words + "and", "del", "from", "not", "while", "as", "elif", "global", "or", "with", + "assert", "else", "if", "pass", "yield", "break", "except", "import", + "print", "class", "exec", "in", "raise", "continue", "finally", "is", + "return", "def", "for", "lambda", "try", "self", "nonlocal", "None", "True", + "False", "async", "await")); + + languageSpecificPrimitives.clear(); + languageSpecificPrimitives.add("int"); + languageSpecificPrimitives.add("float"); + languageSpecificPrimitives.add("list"); + languageSpecificPrimitives.add("dict"); + languageSpecificPrimitives.add("List"); + languageSpecificPrimitives.add("Dict"); + languageSpecificPrimitives.add("bool"); + languageSpecificPrimitives.add("str"); + languageSpecificPrimitives.add("datetime"); + languageSpecificPrimitives.add("date"); + languageSpecificPrimitives.add("object"); + // TODO file and binary is mapped as `file` + languageSpecificPrimitives.add("file"); + languageSpecificPrimitives.add("bytes"); + + typeMapping.clear(); + typeMapping.put("integer", "int"); + typeMapping.put("float", "float"); + typeMapping.put("number", "float"); + typeMapping.put("long", "int"); + typeMapping.put("double", "float"); + typeMapping.put("array", "list"); + typeMapping.put("set", "list"); + typeMapping.put("map", "dict"); + typeMapping.put("boolean", "bool"); + typeMapping.put("string", "str"); + typeMapping.put("date", "date"); + typeMapping.put("DateTime", "datetime"); + typeMapping.put("object", "object"); + typeMapping.put("AnyType", "object"); + typeMapping.put("file", "file"); + // TODO binary should be mapped to byte array + // mapped to String as a workaround + typeMapping.put("binary", "str"); + typeMapping.put("ByteArray", "str"); + // map uuid to string for the time being + typeMapping.put("UUID", "str"); + typeMapping.put("URI", "str"); + typeMapping.put("null", "none_type"); + + regexModifiers = new HashMap(); + regexModifiers.put('i', "IGNORECASE"); + regexModifiers.put('l', "LOCALE"); + regexModifiers.put('m', "MULTILINE"); + regexModifiers.put('s', "DOTALL"); + regexModifiers.put('u', "UNICODE"); + regexModifiers.put('x', "VERBOSE"); + } + + @Override + public void processOpts() { + super.processOpts(); + + if (StringUtils.isEmpty(System.getenv("PYTHON_POST_PROCESS_FILE"))) { + LOGGER.info("Environment variable PYTHON_POST_PROCESS_FILE not defined so the Python code may not be properly formatted. To define it, try 'export PYTHON_POST_PROCESS_FILE=\"/usr/local/bin/yapf -i\"' (Linux/Mac)"); + LOGGER.info("NOTE: To enable file post-processing, 'enablePostProcessFile' must be set to `true` (--enable-post-process-file for CLI)."); + } + } + + @Override + public String escapeReservedWord(String name) { + if (this.reservedWordsMappings().containsKey(name)) { + return this.reservedWordsMappings().get(name); + } + return "_" + name; + } + + + /** + * Return the default value of the property + * + * @param p OpenAPI property object + * @return string presentation of the default value of the property + */ + @Override + public String toDefaultValue(Schema p) { + if (ModelUtils.isBooleanSchema(p)) { + if (p.getDefault() != null) { + if (!Boolean.valueOf(p.getDefault().toString())) + return "False"; + else + return "True"; + } + } else if (ModelUtils.isDateSchema(p)) { + // TODO + } else if (ModelUtils.isDateTimeSchema(p)) { + // TODO + } else if (ModelUtils.isNumberSchema(p)) { + if (p.getDefault() != null) { + return p.getDefault().toString(); + } + } else if (ModelUtils.isIntegerSchema(p)) { + if (p.getDefault() != null) { + return p.getDefault().toString(); + } + } else if (ModelUtils.isStringSchema(p)) { + if (p.getDefault() != null) { + String defaultValue = String.valueOf(p.getDefault()); + if (defaultValue != null) { + defaultValue = defaultValue.replace("\\", "\\\\") + .replace("'", "\'"); + if (Pattern.compile("\r\n|\r|\n").matcher(defaultValue).find()) { + return "'''" + defaultValue + "'''"; + } else { + return "'" + defaultValue + "'"; + } + } + } + } else if (ModelUtils.isArraySchema(p)) { + if (p.getDefault() != null) { + return p.getDefault().toString(); + } else { + return null; + } + } + + return null; + } + + + @Override + public String toVarName(String name) { + // obtain the name from nameMapping directly if provided + if (nameMapping.containsKey(name)) { + return nameMapping.get(name); + } + + // sanitize name + name = sanitizeName(name); // FIXME: a parameter should not be assigned. Also declare the methods parameters as 'final'. + + // remove dollar sign + name = name.replace("$", ""); + + // if it's all upper case, convert to lower case + if (name.matches("^[A-Z_]*$")) { + name = name.toLowerCase(Locale.ROOT); + } + + // underscore the variable name + // petId => pet_id + name = underscore(name); + + // remove leading underscore + name = name.replaceAll("^_*", ""); + + // for reserved word or word starting with number, append _ + if (isReservedWord(name) || name.matches("^\\d.*")) { + name = escapeReservedWord(name); + } + + return name; + } + + @Override + public String toRegularExpression(String pattern) { + return addRegularExpressionDelimiter(pattern); + } + + @Override + public String toParamName(String name) { + // obtain the name from parameterNameMapping directly if provided + if (parameterNameMapping.containsKey(name)) { + return parameterNameMapping.get(name); + } + + // to avoid conflicts with 'callback' parameter for async call + if ("callback".equals(name)) { + return "param_callback"; + } + + // should be the same as variable name + return toVarName(name); + } + + @Override + public String toOperationId(String operationId) { + // throw exception if method name is empty (should not occur as an auto-generated method name will be used) + if (StringUtils.isEmpty(operationId)) { + throw new RuntimeException("Empty method name (operationId) not allowed"); + } + + // method name cannot use reserved keyword, e.g. return + if (isReservedWord(operationId)) { + LOGGER.warn("{} (reserved word) cannot be used as method name. Renamed to {}", operationId, underscore(sanitizeName("call_" + operationId))); + operationId = "call_" + operationId; + } + + // operationId starts with a number + if (operationId.matches("^\\d.*")) { + LOGGER.warn("{} (starting with a number) cannot be used as method name. Renamed to {}", operationId, underscore(sanitizeName("call_" + operationId))); + operationId = "call_" + operationId; + } + + return underscore(sanitizeName(operationId)); + } + + @Override + public String escapeQuotationMark(String input) { + // remove ' to avoid code injection + return input.replace("'", ""); + } + + @Override + public String escapeUnsafeCharacters(String input) { + // remove multiline comment + return input.replace("'''", "'_'_'"); + } + + @Override + public void postProcessFile(File file, String fileType) { + if (file == null) { + return; + } + String pythonPostProcessFile = System.getenv("PYTHON_POST_PROCESS_FILE"); + if (StringUtils.isEmpty(pythonPostProcessFile)) { + return; // skip if PYTHON_POST_PROCESS_FILE env variable is not defined + } + + // only process files with py extension + if ("py".equals(FilenameUtils.getExtension(file.toString()))) { + String command = pythonPostProcessFile + " " + file; + try { + Process p = Runtime.getRuntime().exec(command); + int exitValue = p.waitFor(); + if (exitValue != 0) { + LOGGER.error("Error running the command ({}). Exit value: {}", command, exitValue); + } else { + LOGGER.info("Successfully executed: {}", command); + } + } catch (InterruptedException | IOException e) { + LOGGER.error("Error running the command ({}). Exception: {}", command, e.getMessage()); + // Restore interrupted state + Thread.currentThread().interrupt(); + } + } + } + + @Override + public String toExampleValue(Schema schema) { + return toExampleValueRecursive(schema, new ArrayList<>(), 5); + } + + private String toExampleValueRecursive(Schema schema, List includedSchemas, int indentation) { + boolean cycleFound = includedSchemas.stream().filter(s -> schema.equals(s)).count() > 1; + if (cycleFound) { + return ""; + } + String indentationString = ""; + for (int i = 0; i < indentation; i++) indentationString += " "; + String example = null; + if (schema.getExample() != null) { + example = schema.getExample().toString(); + } + + if (ModelUtils.isNullType(schema) && null != example) { + // The 'null' type is allowed in OAS 3.1 and above. It is not supported by OAS 3.0.x, + // though this tooling supports it. + return "None"; + } + // correct "true"s into "True"s, since super.toExampleValue uses "toString()" on Java booleans + if (ModelUtils.isBooleanSchema(schema) && null != example) { + if ("false".equalsIgnoreCase(example)) example = "False"; + else example = "True"; + } + + // correct "'"s into "'"s after toString() + if (ModelUtils.isStringSchema(schema) && schema.getDefault() != null && !ModelUtils.isDateSchema(schema) && !ModelUtils.isDateTimeSchema(schema)) { + example = String.valueOf(schema.getDefault()); + } + + if (StringUtils.isNotBlank(example) && !"null".equals(example)) { + if (ModelUtils.isStringSchema(schema)) { + example = "'" + example + "'"; + } + return example; + } + + if (schema.getEnum() != null && !schema.getEnum().isEmpty()) { + // Enum case: + example = schema.getEnum().get(0).toString(); + if (ModelUtils.isStringSchema(schema)) { + example = "'" + escapeText(example) + "'"; + } + if (null == example) + LOGGER.warn("Empty enum. Cannot built an example!"); + + return example; + } else if (null != schema.get$ref()) { + // $ref case: + Map allDefinitions = ModelUtils.getSchemas(this.openAPI); + String ref = ModelUtils.getSimpleRef(schema.get$ref()); + if (allDefinitions != null) { + Schema refSchema = allDefinitions.get(ref); + if (null == refSchema) { + return "None"; + } else { + String refTitle = refSchema.getTitle(); + if (StringUtils.isBlank(refTitle) || "null".equals(refTitle)) { + refSchema.setTitle(ref); + } + if (StringUtils.isNotBlank(schema.getTitle()) && !"null".equals(schema.getTitle())) { + includedSchemas.add(schema); + } + return toExampleValueRecursive(refSchema, includedSchemas, indentation); + } + } else { + LOGGER.warn("allDefinitions not defined in toExampleValue!\n"); + } + } + if (ModelUtils.isDateSchema(schema)) { + example = "datetime.datetime.strptime('1975-12-30', '%Y-%m-%d').date()"; + return example; + } else if (ModelUtils.isDateTimeSchema(schema)) { + example = "datetime.datetime.strptime('2013-10-20 19:20:30.00', '%Y-%m-%d %H:%M:%S.%f')"; + return example; + } else if (ModelUtils.isBinarySchema(schema)) { + example = "bytes(b'blah')"; + return example; + } else if (ModelUtils.isByteArraySchema(schema)) { + example = "YQ=="; + } else if (ModelUtils.isStringSchema(schema)) { + // a BigDecimal: + if ("Number".equalsIgnoreCase(schema.getFormat())) { + return "1"; + } + if (StringUtils.isNotBlank(schema.getPattern())) { + String pattern = schema.getPattern(); + RgxGen rgxGen = new RgxGen(patternCorrection(pattern)); + // this seed makes it so if we have [a-z] we pick a + Random random = new Random(18); + String sample = rgxGen.generate(random); + // omit leading / and trailing /, omit trailing /i + Pattern valueExtractor = Pattern.compile("^/\\^?(.+?)\\$?/.?$"); + Matcher m = valueExtractor.matcher(sample); + if (m.find()) { + example = m.group(m.groupCount()); + } else { + example = sample; + } + } + if (example == null) { + example = ""; + } + int len = 0; + if (null != schema.getMinLength()) { + len = schema.getMinLength().intValue(); + if (len < 1) { + example = ""; + } else { + for (int i = 0; i < len; i++) example += i; + } + } + } else if (ModelUtils.isIntegerSchema(schema)) { + if (schema.getMinimum() != null) + example = schema.getMinimum().toString(); + else + example = "56"; + } else if (ModelUtils.isNumberSchema(schema)) { + if (schema.getMinimum() != null) + example = schema.getMinimum().toString(); + else + example = "1.337"; + } else if (ModelUtils.isBooleanSchema(schema)) { + example = "True"; + } else if (ModelUtils.isArraySchema(schema)) { + if (StringUtils.isNotBlank(schema.getTitle()) && !"null".equals(schema.getTitle())) { + includedSchemas.add(schema); + } + ArraySchema arrayschema = (ArraySchema) schema; + example = "[\n" + indentationString + toExampleValueRecursive(arrayschema.getItems(), includedSchemas, indentation + 1) + "\n" + indentationString + "]"; + } else if (ModelUtils.isMapSchema(schema)) { + if (StringUtils.isNotBlank(schema.getTitle()) && !"null".equals(schema.getTitle())) { + includedSchemas.add(schema); + } + Object additionalObject = schema.getAdditionalProperties(); + if (additionalObject instanceof Schema) { + Schema additional = (Schema) additionalObject; + String theKey = "'key'"; + if (additional.getEnum() != null && !additional.getEnum().isEmpty()) { + theKey = additional.getEnum().get(0).toString(); + if (ModelUtils.isStringSchema(additional)) { + theKey = "'" + escapeText(theKey) + "'"; + } + } + example = "{\n" + indentationString + theKey + " : " + toExampleValueRecursive(additional, includedSchemas, indentation + 1) + "\n" + indentationString + "}"; + } else { + example = "{ }"; + } + } else if (ModelUtils.isObjectSchema(schema)) { + if (StringUtils.isBlank(schema.getTitle())) { + example = "None"; + return example; + } + + // I remove any property that is a discriminator, since it is not well supported by the python generator + String toExclude = null; + if (schema.getDiscriminator() != null) { + toExclude = schema.getDiscriminator().getPropertyName(); + } + + example = packageName + ".models." + underscore(schema.getTitle()) + "." + schema.getTitle() + "("; + + // if required only: + // List reqs = schema.getRequired(); + + // if required and optionals + List reqs = new ArrayList<>(); + if (schema.getProperties() != null && !schema.getProperties().isEmpty()) { + for (Object toAdd : schema.getProperties().keySet()) { + reqs.add((String) toAdd); + } + + Map properties = schema.getProperties(); + Set propkeys = null; + if (properties != null) propkeys = properties.keySet(); + if (toExclude != null && reqs.contains(toExclude)) { + reqs.remove(toExclude); + } + for (String toRemove : includedSchemas.stream().map(Schema::getTitle).collect(Collectors.toList())) { + if (reqs.contains(toRemove)) { + reqs.remove(toRemove); + } + } + if (StringUtils.isNotBlank(schema.getTitle()) && !"null".equals(schema.getTitle())) { + includedSchemas.add(schema); + } + if (null != schema.getRequired()) for (Object toAdd : schema.getRequired()) { + reqs.add((String) toAdd); + } + if (null != propkeys) for (String propname : propkeys) { + Schema schema2 = properties.get(propname); + if (reqs.contains(propname)) { + String refTitle = schema2.getTitle(); + if (StringUtils.isBlank(refTitle) || "null".equals(refTitle)) { + schema2.setTitle(propname); + } + example += "\n" + indentationString + underscore(propname) + " = " + + toExampleValueRecursive(schema2, includedSchemas, indentation + 1) + ", "; + } + } + } + example += ")"; + } else { + LOGGER.debug("Type {} not handled properly in toExampleValue", schema.getType()); + } + + if (ModelUtils.isStringSchema(schema)) { + example = "'" + escapeText(example) + "'"; + } + + return example; + } + + @Override + public void setParameterExampleValue(CodegenParameter p) { + String example; + + if (p.defaultValue == null) { + example = p.example; + } else { + p.example = p.defaultValue; + return; + } + + String type = p.baseType; + if (type == null) { + type = p.dataType; + } + + if ("String".equalsIgnoreCase(type) || "str".equalsIgnoreCase(type)) { + if (example == null) { + example = p.paramName + "_example"; + } + example = "'" + escapeText(example) + "'"; + } else if ("Integer".equals(type) || "int".equals(type)) { + if (example == null) { + example = "56"; + } + } else if ("Float".equalsIgnoreCase(type) || "Double".equalsIgnoreCase(type)) { + if (example == null) { + example = "3.4"; + } + } else if ("BOOLEAN".equalsIgnoreCase(type) || "bool".equalsIgnoreCase(type)) { + if (example == null) { + example = "True"; + } + } else if ("file".equalsIgnoreCase(type)) { + if (example == null) { + example = "/path/to/file"; + } + example = "'" + escapeText(example) + "'"; + } else if ("Date".equalsIgnoreCase(type)) { + if (example == null) { + example = "2013-10-20"; + } + example = "'" + escapeText(example) + "'"; + } else if ("DateTime".equalsIgnoreCase(type)) { + if (example == null) { + example = "2013-10-20T19:20:30+01:00"; + } + example = "'" + escapeText(example) + "'"; + } else if (!languageSpecificPrimitives.contains(type)) { + // type is a model class, e.g. User + example = this.packageName + "." + type + "()"; + } else { + LOGGER.debug("Type {} not handled properly in setParameterExampleValue", type); + } + + if (example == null) { + example = "None"; + } else if (Boolean.TRUE.equals(p.isArray)) { + example = "[" + example + "]"; + } else if (Boolean.TRUE.equals(p.isMap)) { + example = "{'key': " + example + "}"; + } + + p.example = example; + } + + @Override + public void setParameterExampleValue(CodegenParameter codegenParameter, Parameter parameter) { + Schema schema = parameter.getSchema(); + + if (parameter.getExample() != null) { + codegenParameter.example = parameter.getExample().toString(); + } else if (parameter.getExamples() != null && !parameter.getExamples().isEmpty()) { + Example example = parameter.getExamples().values().iterator().next(); + if (example.getValue() != null) { + codegenParameter.example = example.getValue().toString(); + } + } else if (schema != null && schema.getExample() != null) { + codegenParameter.example = schema.getExample().toString(); + } + + setParameterExampleValue(codegenParameter); + } + + @Override + public String sanitizeTag(String tag) { + return sanitizeName(tag); + } + + public String patternCorrection(String pattern) { + // Java does not recognize starting and ending forward slashes and mode modifiers + // It considers them as characters with no special meaning and tries to find them in the match string + boolean checkEnding = pattern.endsWith("/i") || pattern.endsWith("/g") || pattern.endsWith("/m"); + if (checkEnding) pattern = pattern.substring(0, pattern.length() - 2); + if (pattern.endsWith("/")) pattern = pattern.substring(0, pattern.length() - 1); + if (pattern.startsWith("/")) pattern = pattern.substring(1); + return pattern; + } + + public void setPackageName(String packageName) { + this.packageName = packageName; + additionalProperties.put(CodegenConstants.PACKAGE_NAME, this.packageName); + } + + public void setProjectName(String projectName) { + this.projectName = projectName; + } + + public void setPackageVersion(String packageVersion) { + this.packageVersion = packageVersion; + } + + @Override + public String getTypeDeclaration(Schema p) { + p = ModelUtils.unaliasSchema(openAPI, p); + + if (ModelUtils.isArraySchema(p)) { + ArraySchema ap = (ArraySchema) p; + Schema inner = ap.getItems(); + return getSchemaType(p) + "[" + getTypeDeclaration(inner) + "]"; + } else if (ModelUtils.isMapSchema(p)) { + Schema inner = ModelUtils.getAdditionalProperties(p); + return getSchemaType(p) + "[str, " + getTypeDeclaration(inner) + "]"; + } + + String openAPIType = getSchemaType(p); + if (typeMapping.containsKey(openAPIType)) { + return typeMapping.get(openAPIType); + } + + if (languageSpecificPrimitives.contains(openAPIType)) { + return openAPIType; + } + + return toModelName(openAPIType); + } + + @Override + public String getSchemaType(Schema p) { + String openAPIType = super.getSchemaType(p); + String type; + + if (openAPIType == null) { + LOGGER.error("OpenAPI Type for {} is null. Default to UNKNOWN_OPENAPI_TYPE instead.", p.getName()); + openAPIType = "UNKNOWN_OPENAPI_TYPE"; + } + + if (typeMapping.containsKey(openAPIType)) { + type = typeMapping.get(openAPIType); + if (type != null) { + return type; + } + } else { + type = openAPIType; + } + + return toModelName(type); + } + + @Override + public String toModelName(String name) { + // obtain the name from modelNameMapping directly if provided + if (modelNameMapping.containsKey(name)) { + return modelNameMapping.get(name); + } + + // check if schema-mapping has a different model for this class, so we can use it + // instead of the auto-generated one. + if (schemaMapping.containsKey(name)) { + return schemaMapping.get(name); + } + + // memoization + String origName = name; + if (schemaKeyToModelNameCache.containsKey(origName)) { + return schemaKeyToModelNameCache.get(origName); + } + + String sanitizedName = sanitizeName(name); // FIXME: a parameter should not be assigned. Also declare the methods parameters as 'final'. + // remove dollar sign + sanitizedName = sanitizedName.replace("$", ""); + // remove whitespace + sanitizedName = sanitizedName.replaceAll("\\s+", ""); + + String nameWithPrefixSuffix = sanitizedName; + if (!StringUtils.isEmpty(modelNamePrefix)) { + // add '_' so that model name can be camelized correctly + nameWithPrefixSuffix = modelNamePrefix + "_" + nameWithPrefixSuffix; + } + + if (!StringUtils.isEmpty(modelNameSuffix)) { + // add '_' so that model name can be camelized correctly + nameWithPrefixSuffix = nameWithPrefixSuffix + "_" + modelNameSuffix; + } + + // camelize the model name + // phone_number => PhoneNumber + String camelizedName = camelize(nameWithPrefixSuffix); + + // model name cannot use reserved keyword, e.g. return + if (isReservedWord(camelizedName)) { + String modelName = "Model" + camelizedName; // e.g. return => ModelReturn (after camelize) + LOGGER.warn("{} (reserved word) cannot be used as model name. Renamed to {}", camelizedName, modelName); + schemaKeyToModelNameCache.put(origName, modelName); + return modelName; + } + + // model name starts with number + if (camelizedName.matches("^\\d.*")) { + String modelName = "Model" + camelizedName; // e.g. return => ModelReturn (after camelize) + LOGGER.warn("{} (model name starts with number) cannot be used as model name. Renamed to {}", camelizedName, modelName); + schemaKeyToModelNameCache.put(origName, modelName); + return modelName; + } + + schemaKeyToModelNameCache.put(origName, camelizedName); + return camelizedName; + } + + @Override + public String toModelFilename(String name) { + // underscore the model file name + // PhoneNumber => phone_number + return underscore(dropDots(toModelName(name))); + } + + @Override + public String toModelTestFilename(String name) { + return "test_" + toModelFilename(name); + } + + @Override + public String toApiFilename(String name) { + // e.g. PhoneNumberApi.py => phone_number_api.py + return underscore(toApiName(name)); + } + + @Override + public String toApiTestFilename(String name) { + return "test_" + toApiFilename(name); + } + + @Override + public String toApiName(String name) { + return super.toApiName(name); + } + + @Override + public String toApiVarName(String name) { + return underscore(toApiName(name)); + } + + protected static String dropDots(String str) { + return str.replaceAll("\\.", "_"); + } + + @Override + public GeneratorLanguage generatorLanguage() { + return GeneratorLanguage.PYTHON; + } + + @Override + public Map postProcessAllModels(Map objs) { + final Map processed = super.postProcessAllModels(objs); + + for (Map.Entry entry : objs.entrySet()) { + // create hash map of codegen model + CodegenModel cm = ModelUtils.getModelByName(entry.getKey(), objs); + codegenModelMap.put(cm.classname, ModelUtils.getModelByName(entry.getKey(), objs)); + } + + // create circular import + for (String m : codegenModelMap.keySet()) { + createImportMapOfSet(m, codegenModelMap); + } + + for (Map.Entry entry : processed.entrySet()) { + entry.setValue(postProcessModelsMap(entry.getValue())); + } + + return processed; + } + + private ModelsMap postProcessModelsMap(ModelsMap objs) { + // process enum in models + objs = postProcessModelsEnum(objs); + + TreeSet typingImports = new TreeSet<>(); + TreeSet pydanticImports = new TreeSet<>(); + TreeSet datetimeImports = new TreeSet<>(); + TreeSet modelImports = new TreeSet<>(); + TreeSet postponedModelImports = new TreeSet<>(); + + for (ModelMap m : objs.getModels()) { + TreeSet exampleImports = new TreeSet<>(); + TreeSet postponedExampleImports = new TreeSet<>(); + List readOnlyFields = new ArrayList<>(); + hasModelsToImport = false; + int property_count = 1; + typingImports.clear(); + pydanticImports.clear(); + datetimeImports.clear(); + + CodegenModel model = m.getModel(); + + // handle null type in oneOf + if (model.getComposedSchemas() != null && model.getComposedSchemas().getOneOf() != null + && !model.getComposedSchemas().getOneOf().isEmpty()) { + int index = 0; + List oneOfs = model.getComposedSchemas().getOneOf(); + for (CodegenProperty oneOf : oneOfs) { + if ("none_type".equals(oneOf.dataType)) { + oneOfs.remove(index); + break; // return earlier assuming there's only 1 null type defined + } + index++; + } + } + + List codegenProperties = null; + if (!model.oneOf.isEmpty()) { // oneOfValidationError + codegenProperties = model.getComposedSchemas().getOneOf(); + typingImports.add("Any"); + typingImports.add("List"); + pydanticImports.add("Field"); + pydanticImports.add("StrictStr"); + pydanticImports.add("ValidationError"); + pydanticImports.add("validator"); + } else if (!model.anyOf.isEmpty()) { // anyOF + codegenProperties = model.getComposedSchemas().getAnyOf(); + pydanticImports.add("Field"); + pydanticImports.add("StrictStr"); + pydanticImports.add("ValidationError"); + pydanticImports.add("validator"); + } else { // typical model + codegenProperties = model.vars; + + // if super class + if (model.getDiscriminator() != null && model.getDiscriminator().getMappedModels() != null) { + typingImports.add("Union"); + Set discriminator = model.getDiscriminator().getMappedModels(); + for (CodegenDiscriminator.MappedModel mappedModel : discriminator) { + postponedModelImports.add(mappedModel.getMappingName()); + } + } + } + + if (!model.allOf.isEmpty()) { // allOf + for (CodegenProperty cp : model.allVars) { + if (!cp.isPrimitiveType || cp.isModel) { + if (cp.isArray){ // if array + modelImports.add(cp.items.dataType); + }else{ // if model + modelImports.add(cp.dataType); + } + } + } + } + + // if model_generic.mustache is used and support additionalProperties + if (model.oneOf.isEmpty() && model.anyOf.isEmpty() + && !model.isEnum + && !this.disallowAdditionalPropertiesIfNotPresent) { + typingImports.add("Dict"); + typingImports.add("Any"); + } + + //loop through properties/schemas to set up typing, pydantic + for (CodegenProperty cp : codegenProperties) { + String typing = getPydanticType(cp, typingImports, pydanticImports, datetimeImports, modelImports, exampleImports, postponedModelImports, postponedExampleImports, model.classname); + List fields = new ArrayList<>(); + String firstField = ""; + + // is readOnly? + if (cp.isReadOnly) { + readOnlyFields.add(cp.name); + } + + if (!cp.required) { //optional + firstField = "None"; + typing = "Optional[" + typing + "]"; + typingImports.add("Optional"); + } else { // required + firstField = "..."; + if (cp.isNullable) { + typing = "Optional[" + typing + "]"; + typingImports.add("Optional"); + } + } + + // field + if (cp.baseName != null && !cp.baseName.equals(cp.name)) { // base name not the same as name + fields.add(String.format(Locale.ROOT, "alias=\"%s\"", cp.baseName)); + } + + if (!StringUtils.isEmpty(cp.description)) { // has description + fields.add(String.format(Locale.ROOT, "description=\"%s\"", cp.description)); + } + + /* TODO review as example may break the build + if (!StringUtils.isEmpty(cp.getExample())) { // has example + fields.add(String.format(Locale.ROOT, "example=%s", cp.getExample())); + }*/ + + String fieldCustomization; + if ("None".equals(firstField)) { + if (cp.defaultValue == null) { + fieldCustomization = "None"; + } else { + if (cp.isArray || cp.isMap) { + // TODO handle default value for array/map + fieldCustomization = "None"; + } else { + fieldCustomization = cp.defaultValue; + } + } + } else { // required field + fieldCustomization = firstField; + } + + if (!fields.isEmpty()) { + fields.add(0, fieldCustomization); + pydanticImports.add("Field"); + fieldCustomization = String.format(Locale.ROOT, "Field(%s)", StringUtils.join(fields, ", ")); + } + + if ("...".equals(fieldCustomization)) { + // use Field() to avoid pylint warnings + pydanticImports.add("Field"); + fieldCustomization = "Field(...)"; + } + + cp.vendorExtensions.put("x-py-typing", typing + " = " + fieldCustomization); + + // setup x-py-name for each oneOf/anyOf schema + if (!model.oneOf.isEmpty()) { // oneOf + cp.vendorExtensions.put("x-py-name", String.format(Locale.ROOT, "oneof_schema_%d_validator", property_count++)); + } else if (!model.anyOf.isEmpty()) { // anyOf + cp.vendorExtensions.put("x-py-name", String.format(Locale.ROOT, "anyof_schema_%d_validator", property_count++)); + } + } + + // add parent model to import + if (!StringUtils.isEmpty(model.parent)) { + modelImports.add(model.parent); + } else if (!model.isEnum) { + pydanticImports.add("BaseModel"); + } + + // set enum type in extensions and update `name` in enumVars + if (model.isEnum) { + for (Map enumVars : (List>) model.getAllowableValues().get("enumVars")) { + if ((Boolean) enumVars.get("isString")) { + model.vendorExtensions.putIfAbsent("x-py-enum-type", "str"); + // update `name`, e.g. + enumVars.put("name", toEnumVariableName((String) enumVars.get("value"), "str")); + } else { + model.vendorExtensions.putIfAbsent("x-py-enum-type", "int"); + enumVars.put("name", toEnumVariableName((String) enumVars.get("value"), "int")); + } + } + } + + // set the extensions if the key is absent + model.getVendorExtensions().putIfAbsent("x-py-typing-imports", typingImports); + model.getVendorExtensions().putIfAbsent("x-py-pydantic-imports", pydanticImports); + model.getVendorExtensions().putIfAbsent("x-py-datetime-imports", datetimeImports); + model.getVendorExtensions().putIfAbsent("x-py-readonly", readOnlyFields); + + // import models one by one + if (!modelImports.isEmpty()) { + Set modelsToImport = new TreeSet<>(); + for (String modelImport : modelImports) { + if (modelImport.equals(model.classname)) { + // skip self import + continue; + } + modelsToImport.add("from " + packageName + ".models." + underscore(modelImport) + " import " + modelImport); + } + + model.getVendorExtensions().putIfAbsent("x-py-model-imports", modelsToImport); + } + + if (!postponedModelImports.isEmpty()) { + Set modelsToImport = new TreeSet<>(); + for (String modelImport : postponedModelImports) { + if (modelImport.equals(model.classname)) { + // skip self import + continue; + } + modelsToImport.add("from " + packageName + ".models." + underscore(modelImport) + " import " + modelImport); + } + + model.getVendorExtensions().putIfAbsent("x-py-postponed-model-imports", modelsToImport); + } + + } + + return objs; + } + + + /* + * Gets the pydantic type given a Codegen Parameter + * + * @param cp codegen parameter + * @param typingImports typing imports + * @param pydantic pydantic imports + * @param datetimeImports datetime imports + * @param modelImports model imports + * @param exampleImports example imports + * @param postponedModelImports postponed model imports + * @param postponedExampleImports postponed example imports + * @param classname class name + * @return pydantic type + * + */ + private String getPydanticType(CodegenParameter cp, + Set typingImports, + Set pydanticImports, + Set datetimeImports, + Set modelImports, + Set exampleImports, + Set postponedModelImports, + Set postponedExampleImports, + String classname) { + if (cp == null) { + // if codegen parameter (e.g. map/dict of undefined type) is null, default to string + LOGGER.warn("Codegen property is null (e.g. map/dict of undefined type). Default to typing.Any."); + typingImports.add("Any"); + return "Any"; + } + + if (cp.isArray) { + String constraints = ""; + if (cp.maxItems != null) { + constraints += String.format(Locale.ROOT, ", max_items=%d", cp.maxItems); + } + if (cp.minItems != null) { + constraints += String.format(Locale.ROOT, ", min_items=%d", cp.minItems); + } + if (cp.getUniqueItems()) { + constraints += ", unique_items=True"; + } + pydanticImports.add("conlist"); + return String.format(Locale.ROOT, "conlist(%s%s)", + getPydanticType(cp.items, typingImports, pydanticImports, datetimeImports, modelImports, exampleImports, postponedModelImports, postponedExampleImports, classname), + constraints); + } else if (cp.isMap) { + typingImports.add("Dict"); + return String.format(Locale.ROOT, "Dict[str, %s]", + getPydanticType(cp.items, typingImports, pydanticImports, datetimeImports, modelImports, exampleImports, postponedModelImports, postponedExampleImports, classname)); + } else if (cp.isString) { + if (cp.hasValidation) { + List fieldCustomization = new ArrayList<>(); + // e.g. constr(regex=r'/[a-z]/i', strict=True) + fieldCustomization.add("strict=True"); + if (cp.getMaxLength() != null) { + fieldCustomization.add("max_length=" + cp.getMaxLength()); + } + if (cp.getMinLength() != null) { + fieldCustomization.add("min_length=" + cp.getMinLength()); + } + if (cp.getPattern() != null) { + pydanticImports.add("validator"); + // use validator instead as regex doesn't support flags, e.g. IGNORECASE + //fieldCustomization.add(String.format(Locale.ROOT, "regex=r'%s'", cp.getPattern())); + } + pydanticImports.add("constr"); + return String.format(Locale.ROOT, "constr(%s)", StringUtils.join(fieldCustomization, ", ")); + } else { + if ("password".equals(cp.getFormat())) { // TDOO avoid using format, use `is` boolean flag instead + pydanticImports.add("SecretStr"); + return "SecretStr"; + } else { + pydanticImports.add("StrictStr"); + return "StrictStr"; + } + } + } else if (cp.isNumber || cp.isFloat || cp.isDouble) { + if (cp.hasValidation) { + List fieldCustomization = new ArrayList<>(); + List intFieldCustomization = new ArrayList<>(); + + // e.g. confloat(ge=10, le=100, strict=True) + if (cp.getMaximum() != null) { + if (cp.getExclusiveMaximum()) { + fieldCustomization.add("lt=" + cp.getMaximum()); + intFieldCustomization.add("lt=" + Math.ceil(Double.valueOf(cp.getMaximum()))); // e.g. < 7.59 becomes < 8 + } else { + fieldCustomization.add("le=" + cp.getMaximum()); + intFieldCustomization.add("le=" + Math.floor(Double.valueOf(cp.getMaximum()))); // e.g. <= 7.59 becomes <= 7 + } + } + if (cp.getMinimum() != null) { + if (cp.getExclusiveMinimum()) { + fieldCustomization.add("gt=" + cp.getMinimum()); + intFieldCustomization.add("gt=" + Math.floor(Double.valueOf(cp.getMinimum()))); // e.g. > 7.59 becomes > 7 + } else { + fieldCustomization.add("ge=" + cp.getMinimum()); + intFieldCustomization.add("ge=" + Math.ceil(Double.valueOf(cp.getMinimum()))); // e.g. >= 7.59 becomes >= 8 + } + } + if (cp.getMultipleOf() != null) { + fieldCustomization.add("multiple_of=" + cp.getMultipleOf()); + } + + if ("Union[StrictFloat, StrictInt]".equals(mapNumberTo)) { + fieldCustomization.add("strict=True"); + intFieldCustomization.add("strict=True"); + pydanticImports.add("confloat"); + pydanticImports.add("conint"); + typingImports.add("Union"); + return String.format(Locale.ROOT, "Union[%s(%s), %s(%s)]", "confloat", + StringUtils.join(fieldCustomization, ", "), + "conint", + StringUtils.join(intFieldCustomization, ", ") + ); + } else if ("StrictFloat".equals(mapNumberTo)) { + fieldCustomization.add("strict=True"); + pydanticImports.add("confloat"); + return String.format(Locale.ROOT, "%s(%s)", "confloat", + StringUtils.join(fieldCustomization, ", ")); + } else { // float + pydanticImports.add("confloat"); + return String.format(Locale.ROOT, "%s(%s)", "confloat", + StringUtils.join(fieldCustomization, ", ")); + } + } else { + if ("Union[StrictFloat, StrictInt]".equals(mapNumberTo)) { + typingImports.add("Union"); + pydanticImports.add("StrictFloat"); + pydanticImports.add("StrictInt"); + return "Union[StrictFloat, StrictInt]"; + } else if ("StrictFloat".equals(mapNumberTo)) { + pydanticImports.add("StrictFloat"); + return "StrictFloat"; + } else { + return "float"; + } + } + } else if (cp.isInteger || cp.isLong || cp.isShort || cp.isUnboundedInteger) { + if (cp.hasValidation) { + List fieldCustomization = new ArrayList<>(); + // e.g. conint(ge=10, le=100, strict=True) + fieldCustomization.add("strict=True"); + if (cp.getMaximum() != null) { + if (cp.getExclusiveMaximum()) { + fieldCustomization.add("lt=" + cp.getMaximum()); + } else { + fieldCustomization.add("le=" + cp.getMaximum()); + } + } + if (cp.getMinimum() != null) { + if (cp.getExclusiveMinimum()) { + fieldCustomization.add("gt=" + cp.getMinimum()); + } else { + fieldCustomization.add("ge=" + cp.getMinimum()); + } + } + if (cp.getMultipleOf() != null) { + fieldCustomization.add("multiple_of=" + cp.getMultipleOf()); + } + + pydanticImports.add("conint"); + return String.format(Locale.ROOT, "%s(%s)", "conint", + StringUtils.join(fieldCustomization, ", ")); + } else { + pydanticImports.add("StrictInt"); + return "StrictInt"; + } + } else if (cp.isBinary || cp.isByteArray) { + if (cp.hasValidation) { + List fieldCustomization = new ArrayList<>(); + // e.g. conbytes(min_length=2, max_length=10) + fieldCustomization.add("strict=True"); + if (cp.getMinLength() != null) { + fieldCustomization.add("min_length=" + cp.getMinLength()); + } + if (cp.getMaxLength() != null) { + fieldCustomization.add("max_length=" + cp.getMaxLength()); + } + if (cp.getPattern() != null) { + pydanticImports.add("validator"); + // use validator instead as regex doesn't support flags, e.g. IGNORECASE + //fieldCustomization.add(Locale.ROOT, String.format(Locale.ROOT, "regex=r'%s'", cp.getPattern())); + } + + pydanticImports.add("conbytes"); + pydanticImports.add("constr"); + typingImports.add("Union"); + return String.format(Locale.ROOT, "Union[conbytes(%s), constr(% fieldCustomization = new ArrayList<>(); + // e.g. condecimal(ge=10, le=100, strict=True) + fieldCustomization.add("strict=True"); + if (cp.getMaximum() != null) { + if (cp.getExclusiveMaximum()) { + fieldCustomization.add("gt=" + cp.getMaximum()); + } else { + fieldCustomization.add("ge=" + cp.getMaximum()); + } + } + if (cp.getMinimum() != null) { + if (cp.getExclusiveMinimum()) { + fieldCustomization.add("lt=" + cp.getMinimum()); + } else { + fieldCustomization.add("le=" + cp.getMinimum()); + } + } + if (cp.getMultipleOf() != null) { + fieldCustomization.add("multiple_of=" + cp.getMultipleOf()); + } + pydanticImports.add("condecimal"); + return String.format(Locale.ROOT, "%s(%s)", "condecimal", StringUtils.join(fieldCustomization, ", ")); + } else { + pydanticImports.add("condecimal"); + return "condecimal()"; + } + } else if (cp.getIsAnyType()) { + typingImports.add("Any"); + return "Any"; + } else if (cp.isDate || cp.isDateTime) { + if (cp.isDate) { + datetimeImports.add("date"); + } + if (cp.isDateTime) { + datetimeImports.add("datetime"); + } + + return cp.dataType; + } else if (cp.isUuid) { + return cp.dataType; + } else if (cp.isFreeFormObject) { // type: object + typingImports.add("Dict"); + typingImports.add("Any"); + return "Dict[str, Any]"; + } else if (!cp.isPrimitiveType) { + // add model prefix + hasModelsToImport = true; + modelImports.add(cp.dataType); + exampleImports.add(cp.dataType); + return cp.dataType; + } else if (cp.getContent() != null) { + LinkedHashMap contents = cp.getContent(); + for (String key : contents.keySet()) { + CodegenMediaType cmt = contents.get(key); + // TODO process the first one only at the moment + if (cmt != null) + return getPydanticType(cmt.getSchema(), typingImports, pydanticImports, datetimeImports, modelImports, exampleImports, postponedModelImports, postponedExampleImports, classname); + } + throw new RuntimeException("Error! Failed to process getPydanticType when getting the content: " + cp); + } else { + throw new RuntimeException("Error! Codegen Parameter not yet supported in getPydanticType: " + cp); + } + } + + + /* + * Gets the pydantic type given a Codegen Property + * + * @param cp codegen property + * @param typingImports typing imports + * @param pydantic pydantic imports + * @param datetimeImports datetime imports + * @param modelImports model imports + * @param exampleImports example imports + * @param postponedModelImports postponed model imports + * @param postponedExampleImports postponed example imports + * @param classname class name + * @return pydantic type + * + */ + private String getPydanticType(CodegenProperty cp, + Set typingImports, + Set pydanticImports, + Set datetimeImports, + Set modelImports, + Set exampleImports, + Set postponedModelImports, + Set postponedExampleImports, + String classname) { + if (cp == null) { + // if codegen property (e.g. map/dict of undefined type) is null, default to string + LOGGER.warn("Codegen property is null (e.g. map/dict of undefined type). Default to typing.Any."); + typingImports.add("Any"); + return "Any"; + } + + if (cp.isEnum) { + pydanticImports.add("validator"); + } + + /* comment out the following since Literal requires python 3.8 + also need to put cp.isEnum check after isArray, isMap check + if (cp.isEnum) { + // use Literal for inline enum + typingImports.add("Literal"); + List values = new ArrayList<>(); + List> enumVars = (List>) cp.allowableValues.get("enumVars"); + if (enumVars != null) { + for (Map enumVar : enumVars) { + values.add((String) enumVar.get("value")); + } + } + return String.format(Locale.ROOT, "%sEnum", cp.nameInCamelCase); + } else*/ + if (cp.isArray) { + String constraints = ""; + if (cp.maxItems != null) { + constraints += String.format(Locale.ROOT, ", max_items=%d", cp.maxItems); + } + if (cp.minItems != null) { + constraints += String.format(Locale.ROOT, ", min_items=%d", cp.minItems); + } + if (cp.getUniqueItems()) { + constraints += ", unique_items=True"; + } + pydanticImports.add("conlist"); + typingImports.add("List"); // for return type + return String.format(Locale.ROOT, "conlist(%s%s)", + getPydanticType(cp.items, typingImports, pydanticImports, datetimeImports, modelImports, exampleImports, postponedModelImports, postponedExampleImports, classname), + constraints); + } else if (cp.isMap) { + typingImports.add("Dict"); + return String.format(Locale.ROOT, "Dict[str, %s]", getPydanticType(cp.items, typingImports, pydanticImports, datetimeImports, modelImports, exampleImports, postponedModelImports, postponedExampleImports, classname)); + } else if (cp.isString) { + if (cp.hasValidation) { + List fieldCustomization = new ArrayList<>(); + // e.g. constr(regex=r'/[a-z]/i', strict=True) + fieldCustomization.add("strict=True"); + if (cp.getMaxLength() != null) { + fieldCustomization.add("max_length=" + cp.getMaxLength()); + } + if (cp.getMinLength() != null) { + fieldCustomization.add("min_length=" + cp.getMinLength()); + } + if (cp.getPattern() != null) { + pydanticImports.add("validator"); + // use validator instead as regex doesn't support flags, e.g. IGNORECASE + //fieldCustomization.add(Locale.ROOT, String.format(Locale.ROOT, "regex=r'%s'", cp.getPattern())); + } + pydanticImports.add("constr"); + return String.format(Locale.ROOT, "constr(%s)", StringUtils.join(fieldCustomization, ", ")); + } else { + if ("password".equals(cp.getFormat())) { // TDOO avoid using format, use `is` boolean flag instead + pydanticImports.add("SecretStr"); + return "SecretStr"; + } else { + pydanticImports.add("StrictStr"); + return "StrictStr"; + } + } + } else if (cp.isNumber || cp.isFloat || cp.isDouble) { + if (cp.hasValidation) { + List fieldCustomization = new ArrayList<>(); + List intFieldCustomization = new ArrayList<>(); + + // e.g. confloat(ge=10, le=100, strict=True) + if (cp.getMaximum() != null) { + if (cp.getExclusiveMaximum()) { + fieldCustomization.add("lt=" + cp.getMaximum()); + intFieldCustomization.add("lt=" + (int) Math.ceil(Double.valueOf(cp.getMaximum()))); // e.g. < 7.59 => < 8 + } else { + fieldCustomization.add("le=" + cp.getMaximum()); + intFieldCustomization.add("le=" + (int) Math.floor(Double.valueOf(cp.getMaximum()))); // e.g. <= 7.59 => <= 7 + } + } + if (cp.getMinimum() != null) { + if (cp.getExclusiveMinimum()) { + fieldCustomization.add("gt=" + cp.getMinimum()); + intFieldCustomization.add("gt=" + (int) Math.floor(Double.valueOf(cp.getMinimum()))); // e.g. > 7.59 => > 7 + } else { + fieldCustomization.add("ge=" + cp.getMinimum()); + intFieldCustomization.add("ge=" + (int) Math.ceil(Double.valueOf(cp.getMinimum()))); // e.g. >= 7.59 => >= 8 + } + } + if (cp.getMultipleOf() != null) { + fieldCustomization.add("multiple_of=" + cp.getMultipleOf()); + } + + if ("Union[StrictFloat, StrictInt]".equals(mapNumberTo)) { + fieldCustomization.add("strict=True"); + intFieldCustomization.add("strict=True"); + pydanticImports.add("confloat"); + pydanticImports.add("conint"); + typingImports.add("Union"); + return String.format(Locale.ROOT, "Union[%s(%s), %s(%s)]", "confloat", + StringUtils.join(fieldCustomization, ", "), + "conint", + StringUtils.join(intFieldCustomization, ", ") + ); + } else if ("StrictFloat".equals(mapNumberTo)) { + fieldCustomization.add("strict=True"); + pydanticImports.add("confloat"); + return String.format(Locale.ROOT, "%s(%s)", "confloat", + StringUtils.join(fieldCustomization, ", ")); + } else { // float + pydanticImports.add("confloat"); + return String.format(Locale.ROOT, "%s(%s)", "confloat", + StringUtils.join(fieldCustomization, ", ")); + } + } else { + if ("Union[StrictFloat, StrictInt]".equals(mapNumberTo)) { + typingImports.add("Union"); + pydanticImports.add("StrictFloat"); + pydanticImports.add("StrictInt"); + return "Union[StrictFloat, StrictInt]"; + } else if ("StrictFloat".equals(mapNumberTo)) { + pydanticImports.add("StrictFloat"); + return "StrictFloat"; + } else { + return "float"; + } + } + } else if (cp.isInteger || cp.isLong || cp.isShort || cp.isUnboundedInteger) { + if (cp.hasValidation) { + List fieldCustomization = new ArrayList<>(); + // e.g. conint(ge=10, le=100, strict=True) + fieldCustomization.add("strict=True"); + if (cp.getMaximum() != null) { + if (cp.getExclusiveMaximum()) { + fieldCustomization.add("lt=" + cp.getMaximum()); + } else { + fieldCustomization.add("le=" + cp.getMaximum()); + } + } + if (cp.getMinimum() != null) { + if (cp.getExclusiveMinimum()) { + fieldCustomization.add("gt=" + cp.getMinimum()); + } else { + fieldCustomization.add("ge=" + cp.getMinimum()); + } + } + if (cp.getMultipleOf() != null) { + fieldCustomization.add("multiple_of=" + cp.getMultipleOf()); + } + + pydanticImports.add("conint"); + return String.format(Locale.ROOT, "%s(%s)", "conint", + StringUtils.join(fieldCustomization, ", ")); + } else { + pydanticImports.add("StrictInt"); + return "StrictInt"; + } + } else if (cp.isBinary || cp.isByteArray) { + if (cp.hasValidation) { + List fieldCustomization = new ArrayList<>(); + // e.g. conbytes(min_length=2, max_length=10) + fieldCustomization.add("strict=True"); + if (cp.getMinLength() != null) { + fieldCustomization.add("min_length=" + cp.getMinLength()); + } + if (cp.getMaxLength() != null) { + fieldCustomization.add("max_length=" + cp.getMaxLength()); + } + if (cp.getPattern() != null) { + pydanticImports.add("validator"); + // use validator instead as regex doesn't support flags, e.g. IGNORECASE + //fieldCustomization.add(Locale.ROOT, String.format(Locale.ROOT, "regex=r'%s'", cp.getPattern())); + } + + pydanticImports.add("conbytes"); + pydanticImports.add("constr"); + typingImports.add("Union"); + return String.format(Locale.ROOT, "Union[conbytes(%s), constr(% fieldCustomization = new ArrayList<>(); + // e.g. condecimal(ge=10, le=100, strict=True) + fieldCustomization.add("strict=True"); + if (cp.getMaximum() != null) { + if (cp.getExclusiveMaximum()) { + fieldCustomization.add("gt=" + cp.getMaximum()); + } else { + fieldCustomization.add("ge=" + cp.getMaximum()); + } + } + if (cp.getMinimum() != null) { + if (cp.getExclusiveMinimum()) { + fieldCustomization.add("lt=" + cp.getMinimum()); + } else { + fieldCustomization.add("le=" + cp.getMinimum()); + } + } + if (cp.getMultipleOf() != null) { + fieldCustomization.add("multiple_of=" + cp.getMultipleOf()); + } + pydanticImports.add("condecimal"); + return String.format(Locale.ROOT, "%s(%s)", "condecimal", StringUtils.join(fieldCustomization, ", ")); + } else { + pydanticImports.add("condecimal"); + return "condecimal()"; + } + } else if (cp.getIsAnyType()) { + typingImports.add("Any"); + return "Any"; + } else if (cp.isDate || cp.isDateTime) { + if (cp.isDate) { + datetimeImports.add("date"); + } + if (cp.isDateTime) { + datetimeImports.add("datetime"); + } + return cp.dataType; + } else if (cp.isUuid) { + return cp.dataType; + } else if (cp.isFreeFormObject) { // type: object + typingImports.add("Dict"); + typingImports.add("Any"); + return "Dict[str, Any]"; + } else if (!cp.isPrimitiveType || cp.isModel) { // model + // skip import if it's a circular reference + if (classname == null) { + // for parameter model, import directly + hasModelsToImport = true; + modelImports.add(cp.dataType); + exampleImports.add(cp.dataType); + } else { + if (circularImports.containsKey(cp.dataType)) { + if (circularImports.get(cp.dataType).contains(classname)) { + hasModelsToImport = true; + postponedModelImports.add(cp.dataType); + postponedExampleImports.add(cp.dataType); + // cp.dataType import map of set contains this model (classname), don't import + LOGGER.debug("Skipped importing {} in {} due to circular import.", cp.dataType, classname); + } else { + // not circular import, so ok to import it + hasModelsToImport = true; + modelImports.add(cp.dataType); + exampleImports.add(cp.dataType); + } + } else { + LOGGER.error("Failed to look up {} from the imports (map of set) of models.", cp.dataType); + } + } + return cp.dataType; + } else { + throw new RuntimeException("Error! Codegen Property not yet supported in getPydanticType: " + cp); + } + } + + public void setMapNumberTo(String mapNumberTo) { + if ("Union[StrictFloat, StrictInt]".equals(mapNumberTo) + || "StrictFloat".equals(mapNumberTo) + || "float".equals(mapNumberTo)) { + this.mapNumberTo = mapNumberTo; + } else { + throw new IllegalArgumentException("mapNumberTo value must be Union[StrictFloat, StrictInt], StrictStr or float"); + } + } + + public String toEnumVariableName(String name, String datatype) { + if ("int".equals(datatype)) { + return "NUMBER_" + name.replace("-", "MINUS_"); + } + + // remove quote e.g. 'abc' => abc + name = name.substring(1, name.length() - 1); + + if (name.length() == 0) { + return "EMPTY"; + } + + if (" ".equals(name)) { + return "SPACE"; + } + + if ("_".equals(name)) { + return "UNDERSCORE"; + } + + if (reservedWords.contains(name)) { + name = name.toUpperCase(Locale.ROOT); + } else if (((CharSequence) name).chars().anyMatch(character -> specialCharReplacements.keySet().contains(String.valueOf((char) character)))) { + name = underscore(escape(name, specialCharReplacements, Collections.singletonList("_"), "_")).toUpperCase(Locale.ROOT); + } else { + name = name.toUpperCase(Locale.ROOT); + } + + name = name.replace(" ", "_"); + name = name.replaceFirst("^_", ""); + name = name.replaceFirst("_$", ""); + + if (name.matches("\\d.*")) { + name = "ENUM_" + name.toUpperCase(Locale.ROOT); + } + + return name; + } + + /** + * Update circularImports with the model name (key) and its imports gathered recursively + * + * @param modelName model name + * @param codegenModelMap a map of CodegenModel + */ + void createImportMapOfSet(String modelName, Map codegenModelMap) { + HashSet imports = new HashSet<>(); + circularImports.put(modelName, imports); + + CodegenModel cm = codegenModelMap.get(modelName); + + if (cm == null) { + LOGGER.warn("Failed to lookup model in createImportMapOfSet: " + modelName); + return; + } + + List codegenProperties = null; + if (cm.oneOf != null && !cm.oneOf.isEmpty()) { // oneOf + codegenProperties = cm.getComposedSchemas().getOneOf(); + } else if (cm.anyOf != null && !cm.anyOf.isEmpty()) { // anyOF + codegenProperties = cm.getComposedSchemas().getAnyOf(); + } else { // typical model + codegenProperties = cm.vars; + } + + for (CodegenProperty cp : codegenProperties) { + String modelNameFromDataType = getModelNameFromDataType(cp); + if (modelNameFromDataType != null) { // model + imports.add(modelNameFromDataType); // update import + // go through properties or sub-schemas of the model recursively to identify more (model) import if any + updateImportsFromCodegenModel(modelNameFromDataType, codegenModelMap.get(modelNameFromDataType), imports); + } + } + } + + /** + * Returns the model name (if any) from data type of codegen property. + * Returns null if it's not a model. + * + * @param cp Codegen property + * @return model name + */ + private String getModelNameFromDataType(CodegenProperty cp) { + if (cp.isArray) { + return getModelNameFromDataType(cp.items); + } else if (cp.isMap) { + return getModelNameFromDataType(cp.items); + } else if (!cp.isPrimitiveType || cp.isModel) { + return cp.dataType; + } else { + return null; + } + } + + /** + * Update set of imports from codegen model recursivly + * + * @param modelName model name + * @param cm codegen model + * @param imports set of imports + */ + public void updateImportsFromCodegenModel(String modelName, CodegenModel cm, Set imports) { + if (cm == null) { + LOGGER.warn("Failed to lookup model in createImportMapOfSet " + modelName); + return; + } + + List codegenProperties = null; + if (cm.oneOf != null && !cm.oneOf.isEmpty()) { // oneOfValidationError + codegenProperties = cm.getComposedSchemas().getOneOf(); + } else if (cm.anyOf != null && !cm.anyOf.isEmpty()) { // anyOF + codegenProperties = cm.getComposedSchemas().getAnyOf(); + } else { // typical model + codegenProperties = cm.vars; + } + + for (CodegenProperty cp : codegenProperties) { + String modelNameFromDataType = getModelNameFromDataType(cp); + if (modelNameFromDataType != null) { // model + if (modelName.equals(modelNameFromDataType)) { // self referencing + continue; + } else if (imports.contains(modelNameFromDataType)) { // circular import + continue; + } else { + imports.add(modelNameFromDataType); // update import + // go through properties of the model recursively to identify more (model) import if any + updateImportsFromCodegenModel(modelNameFromDataType, codegenModelMap.get(modelNameFromDataType), imports); + } + } + } + } + + @Override + public OperationsMap postProcessOperationsWithModels(OperationsMap objs, List allModels) { + hasModelsToImport = false; + boolean importAnnotated = false; + TreeSet typingImports = new TreeSet<>(); + TreeSet pydanticImports = new TreeSet<>(); + TreeSet datetimeImports = new TreeSet<>(); + TreeSet modelImports = new TreeSet<>(); + TreeSet postponedModelImports = new TreeSet<>(); + + OperationMap objectMap = objs.getOperations(); + List operations = objectMap.getOperation(); + for (CodegenOperation operation : operations) { + TreeSet exampleImports = new TreeSet<>(); // import for each operation to be show in sample code + TreeSet postponedExampleImports = new TreeSet<>(); // import for each operation to be show in sample code + List params = operation.allParams; + + for (CodegenParameter param : params) { + String typing = getPydanticType(param, typingImports, pydanticImports, datetimeImports, modelImports, exampleImports, postponedModelImports, postponedExampleImports, null); + List fields = new ArrayList<>(); + String firstField = ""; + + if (!param.required) { //optional + firstField = "None"; + typing = "Optional[" + typing + "]"; + typingImports.add("Optional"); + } else { // required + firstField = "..."; + if (param.isNullable) { + typing = "Optional[" + typing + "]"; + typingImports.add("Optional"); + } + } + + if (!StringUtils.isEmpty(param.description)) { // has description + fields.add(String.format(Locale.ROOT, "description=\"%s\"", param.description)); + } + + /* TODO support example + if (!StringUtils.isEmpty(cp.getExample())) { // has example + fields.add(String.format(Locale.ROOT, "example=%s", cp.getExample())); + }*/ + + String fieldCustomization; + if ("None".equals(firstField)) { + fieldCustomization = null; + } else { // required field + fieldCustomization = firstField; + } + + if (!fields.isEmpty()) { + if (fieldCustomization != null) { + fields.add(0, fieldCustomization); + } + pydanticImports.add("Field"); + fieldCustomization = String.format(Locale.ROOT, "Field(%s)", StringUtils.join(fields, ", ")); + } else { + fieldCustomization = "Field()"; + } + + if ("Field()".equals(fieldCustomization)) { + param.vendorExtensions.put("x-py-typing", typing); + } else { + param.vendorExtensions.put("x-py-typing", String.format(Locale.ROOT, "Annotated[%s, %s]", typing, fieldCustomization)); + importAnnotated = true; + } + } + + // update typing import for operation return type + if (!StringUtils.isEmpty(operation.returnType)) { + String typing = getPydanticType(operation.returnProperty, typingImports, + new TreeSet<>() /* skip pydantic import for return type */, datetimeImports, modelImports, exampleImports, postponedModelImports, postponedExampleImports, null); + } + + // add import for code samples + // import models one by one + if (!exampleImports.isEmpty()) { + List imports = new ArrayList<>(); + for (String exampleImport : exampleImports) { + imports.add("from " + packageName + ".models." + underscore(exampleImport) + " import " + exampleImport); + } + operation.vendorExtensions.put("x-py-example-import", imports); + } + + if (!postponedExampleImports.isEmpty()) { + List imports = new ArrayList<>(); + for (String exampleImport : postponedExampleImports) { + imports.add("from " + packageName + ".models." + underscore(exampleImport) + " import " + + exampleImport); + } + operation.vendorExtensions.put("x-py-example-import", imports); + } + } + + List> newImports = new ArrayList<>(); + + if (importAnnotated) { + Map item = new HashMap<>(); + item.put("import", String.format(Locale.ROOT, String.format(Locale.ROOT, "from typing_extensions import Annotated"))); + newImports.add(item); + } + + // need datetime import + if (!datetimeImports.isEmpty()) { + Map item = new HashMap<>(); + item.put("import", String.format(Locale.ROOT, "from datetime import %s\n", StringUtils.join(datetimeImports, ", "))); + newImports.add(item); + } + + // need pydantic imports + if (!pydanticImports.isEmpty()) { + Map item = new HashMap<>(); + item.put("import", String.format(Locale.ROOT, "from pydantic import %s\n", StringUtils.join(pydanticImports, ", "))); + newImports.add(item); + } + + // need typing imports + if (!typingImports.isEmpty()) { + Map item = new HashMap<>(); + item.put("import", String.format(Locale.ROOT, "from typing import %s\n", StringUtils.join(typingImports, ", "))); + newImports.add(item); + } + + // import models one by one + if (!modelImports.isEmpty()) { + for (String modelImport : modelImports) { + Map item = new HashMap<>(); + item.put("import", "from " + packageName + ".models." + underscore(modelImport) + " import " + modelImport); + newImports.add(item); + } + } + + if (!postponedModelImports.isEmpty()) { + for (String modelImport : postponedModelImports) { + Map item = new HashMap<>(); + item.put("import", "from " + packageName + ".models." + underscore(modelImport) + " import " + modelImport); + newImports.add(item); + } + } + + // reset imports with newImports + objs.setImports(newImports); + return objs; + } + + + @Override + public void postProcessParameter(CodegenParameter parameter) { + postProcessPattern(parameter.pattern, parameter.vendorExtensions); + } + + @Override + public void postProcessModelProperty(CodegenModel model, CodegenProperty property) { + postProcessPattern(property.pattern, property.vendorExtensions); + } + + /* + * The OpenAPI pattern spec follows the Perl convention and style of modifiers. Python + * does not support this in as natural a way so it needs to convert it. See + * https://docs.python.org/2/howto/regex.html#compilation-flags for details. + * + * @param pattern (the String pattern to convert from python to Perl convention) + * @param vendorExtensions (list of custom x-* properties for extra functionality-see https://swagger.io/docs/specification/openapi-extensions/) + * @return void + * @throws IllegalArgumentException if pattern does not follow the Perl /pattern/modifiers convention + * + * Includes fix for issue #6675 + */ + public void postProcessPattern(String pattern, Map vendorExtensions) { + if (pattern != null) { + int i = pattern.lastIndexOf('/'); + + // TOOD update the check below follow python convention + //Must follow Perl /pattern/modifiers convention + if (pattern.charAt(0) != '/' || i < 2) { + throw new IllegalArgumentException("Pattern must follow the Perl " + + "/pattern/modifiers convention. " + pattern + " is not valid."); + } + + String regex = pattern.substring(1, i).replace("'", "\\'"); + List modifiers = new ArrayList(); + + for (char c : pattern.substring(i).toCharArray()) { + if (regexModifiers.containsKey(c)) { + String modifier = regexModifiers.get(c); + modifiers.add(modifier); + } + } + + vendorExtensions.put("x-regex", regex.replace("\"", "\\\"")); + vendorExtensions.put("x-pattern", pattern.replace("\"", "\\\"")); + vendorExtensions.put("x-modifiers", modifiers); + } + } + + @Override + public String addRegularExpressionDelimiter(String pattern) { + if (StringUtils.isEmpty(pattern)) { + return pattern; + } + + if (!pattern.matches("^/.*")) { + // Perform a negative lookbehind on each `/` to ensure that it is escaped. + return "/" + pattern.replaceAll("(?partial_header}} - -import io -import json -import logging -import re -import ssl - -import aiohttp -from urllib.parse import urlencode, quote_plus - -from {{packageName}}.exceptions import ApiException, ApiValueError - -logger = logging.getLogger(__name__) - - -class RESTResponse(io.IOBase): - - def __init__(self, resp, data) -> None: - self.aiohttp_response = resp - self.status = resp.status - self.reason = resp.reason - self.data = data - - def getheaders(self): - """Returns a CIMultiDictProxy of the response headers.""" - return self.aiohttp_response.headers - - def getheader(self, name, default=None): - """Returns a given response header.""" - return self.aiohttp_response.headers.get(name, default) - - -class RESTClientObject: - - def __init__(self, configuration, pools_size=4, maxsize=None) -> None: - - # maxsize is number of requests to host that are allowed in parallel - if maxsize is None: - maxsize = configuration.connection_pool_maxsize - - ssl_context = ssl.create_default_context(cafile=configuration.ssl_ca_cert) - if configuration.cert_file: - ssl_context.load_cert_chain( - configuration.cert_file, keyfile=configuration.key_file - ) - - if not configuration.verify_ssl: - ssl_context.check_hostname = False - ssl_context.verify_mode = ssl.CERT_NONE - - connector = aiohttp.TCPConnector( - limit=maxsize, - ssl=ssl_context - ) - - self.proxy = configuration.proxy - self.proxy_headers = configuration.proxy_headers - - # https pool manager - self.pool_manager = aiohttp.ClientSession( - connector=connector, - trust_env=True - ) - - async def close(self): - await self.pool_manager.close() - - async def request(self, method, url, query_params=None, headers=None, - body=None, post_params=None, _preload_content=True, - _request_timeout=None): - """Execute request - - :param method: http request method - :param url: http request url - :param query_params: query parameters in the url - :param headers: http request headers - :param body: request json body, for `application/json` - :param post_params: request post parameters, - `application/x-www-form-urlencoded` - and `multipart/form-data` - :param _preload_content: this is a non-applicable field for - the AiohttpClient. - :param _request_timeout: timeout setting for this request. If one - number provided, it will be total request - timeout. It can also be a pair (tuple) of - (connection, read) timeouts. - """ - method = method.upper() - assert method in ['GET', 'HEAD', 'DELETE', 'POST', 'PUT', - 'PATCH', 'OPTIONS'] - - if post_params and body: - raise ApiValueError( - "body parameter cannot be used with post_params parameter." - ) - - post_params = post_params or {} - headers = headers or {} - # url already contains the URL query string - # so reset query_params to empty dict - query_params = {} - timeout = _request_timeout or 5 * 60 - - if 'Content-Type' not in headers: - headers['Content-Type'] = 'application/json' - - args = { - "method": method, - "url": url, - "timeout": timeout, - "headers": headers - } - - if self.proxy: - args["proxy"] = self.proxy - if self.proxy_headers: - args["proxy_headers"] = self.proxy_headers - - if query_params: - args["url"] += '?' + urlencode(query_params) - - # For `POST`, `PUT`, `PATCH`, `OPTIONS`, `DELETE` - if method in ['POST', 'PUT', 'PATCH', 'OPTIONS', 'DELETE']: - if re.search('json', headers['Content-Type'], re.IGNORECASE): - if body is not None: - body = json.dumps(body) - args["data"] = body - elif headers['Content-Type'] == 'application/x-www-form-urlencoded': # noqa: E501 - args["data"] = aiohttp.FormData(post_params) - elif headers['Content-Type'] == 'multipart/form-data': - # must del headers['Content-Type'], or the correct - # Content-Type which generated by aiohttp - del headers['Content-Type'] - data = aiohttp.FormData() - for param in post_params: - k, v = param - if isinstance(v, tuple) and len(v) == 3: - data.add_field(k, - value=v[1], - filename=v[0], - content_type=v[2]) - else: - data.add_field(k, v) - args["data"] = data - - # Pass a `bytes` parameter directly in the body to support - # other content types than Json when `body` argument is provided - # in serialized form - elif isinstance(body, bytes): - args["data"] = body - else: - # Cannot generate the request from given parameters - msg = """Cannot prepare a request message for provided - arguments. Please check that your arguments match - declared content type.""" - raise ApiException(status=0, reason=msg) - - r = await self.pool_manager.request(**args) - if _preload_content: - - data = await r.read() - r = RESTResponse(r, data) - - # log response body - logger.debug("response body: %s", r.data) - - if not 200 <= r.status <= 299: - raise ApiException(http_resp=r) - - return r - - async def get_request(self, url, headers=None, query_params=None, - _preload_content=True, _request_timeout=None): - return (await self.request("GET", url, - headers=headers, - _preload_content=_preload_content, - _request_timeout=_request_timeout, - query_params=query_params)) - - async def head_request(self, url, headers=None, query_params=None, - _preload_content=True, _request_timeout=None): - return (await self.request("HEAD", url, - headers=headers, - _preload_content=_preload_content, - _request_timeout=_request_timeout, - query_params=query_params)) - - async def options_request(self, url, headers=None, query_params=None, - post_params=None, body=None, _preload_content=True, - _request_timeout=None): - return (await self.request("OPTIONS", url, - headers=headers, - query_params=query_params, - post_params=post_params, - _preload_content=_preload_content, - _request_timeout=_request_timeout, - body=body)) - - async def delete_request(self, url, headers=None, query_params=None, body=None, - _preload_content=True, _request_timeout=None): - return (await self.request("DELETE", url, - headers=headers, - query_params=query_params, - _preload_content=_preload_content, - _request_timeout=_request_timeout, - body=body)) - - async def post_request(self, url, headers=None, query_params=None, - post_params=None, body=None, _preload_content=True, - _request_timeout=None): - return (await self.request("POST", url, - headers=headers, - query_params=query_params, - post_params=post_params, - _preload_content=_preload_content, - _request_timeout=_request_timeout, - body=body)) - - async def put_request(self, url, headers=None, query_params=None, post_params=None, - body=None, _preload_content=True, _request_timeout=None): - return (await self.request("PUT", url, - headers=headers, - query_params=query_params, - post_params=post_params, - _preload_content=_preload_content, - _request_timeout=_request_timeout, - body=body)) - - async def patch_request(self, url, headers=None, query_params=None, - post_params=None, body=None, _preload_content=True, - _request_timeout=None): - return (await self.request("PATCH", url, - headers=headers, - query_params=query_params, - post_params=post_params, - _preload_content=_preload_content, - _request_timeout=_request_timeout, - body=body)) diff --git a/modules/openapi-generator/src/main/resources/python-pydantic-v1/tornado/rest.mustache b/modules/openapi-generator/src/main/resources/python-pydantic-v1/tornado/rest.mustache deleted file mode 100644 index 6b91cdef986e..000000000000 --- a/modules/openapi-generator/src/main/resources/python-pydantic-v1/tornado/rest.mustache +++ /dev/null @@ -1,223 +0,0 @@ -# coding: utf-8 - -{{>partial_header}} - -import io -import json -import logging -import re - -from urllib.parse import urlencode, quote_plus -import tornado -import tornado.gen -from tornado import httpclient -from urllib3.filepost import encode_multipart_formdata - -from {{packageName}}.exceptions import ApiException, ApiValueError - -logger = logging.getLogger(__name__) - - -class RESTResponse(io.IOBase): - - def __init__(self, resp) -> None: - self.tornado_response = resp - self.status = resp.code - self.reason = resp.reason - - if resp.body: - self.data = resp.body - else: - self.data = None - - def getheaders(self): - """Returns a CIMultiDictProxy of the response headers.""" - return self.tornado_response.headers - - def getheader(self, name, default=None): - """Returns a given response header.""" - return self.tornado_response.headers.get(name, default) - - -class RESTClientObject: - - def __init__(self, configuration, pools_size=4, maxsize=4) -> None: - # maxsize is number of requests to host that are allowed in parallel - - self.ca_certs = configuration.ssl_ca_cert - self.client_key = configuration.key_file - self.client_cert = configuration.cert_file - - self.proxy_port = self.proxy_host = None - - # https pool manager - if configuration.proxy: - self.proxy_port = 80 - self.proxy_host = configuration.proxy - - self.pool_manager = httpclient.AsyncHTTPClient() - - @tornado.gen.coroutine - def request(self, method, url, query_params=None, headers=None, body=None, - post_params=None, _preload_content=True, - _request_timeout=None): - """Execute Request - - :param method: http request method - :param url: http request url - :param query_params: query parameters in the url - :param headers: http request headers - :param body: request json body, for `application/json` - :param post_params: request post parameters, - `application/x-www-form-urlencoded` - and `multipart/form-data` - :param _preload_content: this is a non-applicable field for - the AiohttpClient. - :param _request_timeout: timeout setting for this request. If one - number provided, it will be total request - timeout. It can also be a pair (tuple) of - (connection, read) timeouts. - """ - method = method.upper() - assert method in ['GET', 'HEAD', 'DELETE', 'POST', 'PUT', - 'PATCH', 'OPTIONS'] - - if post_params and body: - raise ApiValueError( - "body parameter cannot be used with post_params parameter." - ) - - request = httpclient.HTTPRequest(url) - request.allow_nonstandard_methods = True - request.ca_certs = self.ca_certs - request.client_key = self.client_key - request.client_cert = self.client_cert - request.proxy_host = self.proxy_host - request.proxy_port = self.proxy_port - request.method = method - if headers: - request.headers = headers - if 'Content-Type' not in headers: - request.headers['Content-Type'] = 'application/json' - request.request_timeout = _request_timeout or 5 * 60 - - post_params = post_params or {} - - if query_params: - request.url += '?' + urlencode(query_params) - - # For `POST`, `PUT`, `PATCH`, `OPTIONS`, `DELETE` - if method in ['POST', 'PUT', 'PATCH', 'OPTIONS', 'DELETE']: - if re.search('json', headers['Content-Type'], re.IGNORECASE): - if body: - body = json.dumps(body) - request.body = body - elif headers['Content-Type'] == 'application/x-www-form-urlencoded': # noqa: E501 - request.body = urlencode(post_params) - elif headers['Content-Type'] == 'multipart/form-data': - multipart = encode_multipart_formdata(post_params) - request.body, headers['Content-Type'] = multipart - # Pass a `bytes` parameter directly in the body to support - # other content types than Json when `body` argument is provided - # in serialized form - elif isinstance(body, bytes): - request.body = body - else: - # Cannot generate the request from given parameters - msg = """Cannot prepare a request message for provided - arguments. Please check that your arguments match - declared content type.""" - raise ApiException(status=0, reason=msg) - - r = yield self.pool_manager.fetch(request, raise_error=False) - - if _preload_content: - - r = RESTResponse(r) - - # log response body - logger.debug("response body: %s", r.data) - - if not 200 <= r.status <= 299: - raise ApiException(http_resp=r) - - raise tornado.gen.Return(r) - - @tornado.gen.coroutine - def GET(self, url, headers=None, query_params=None, _preload_content=True, - _request_timeout=None): - result = yield self.request("GET", url, - headers=headers, - _preload_content=_preload_content, - _request_timeout=_request_timeout, - query_params=query_params) - raise tornado.gen.Return(result) - - @tornado.gen.coroutine - def HEAD(self, url, headers=None, query_params=None, _preload_content=True, - _request_timeout=None): - result = yield self.request("HEAD", url, - headers=headers, - _preload_content=_preload_content, - _request_timeout=_request_timeout, - query_params=query_params) - raise tornado.gen.Return(result) - - @tornado.gen.coroutine - def OPTIONS(self, url, headers=None, query_params=None, post_params=None, - body=None, _preload_content=True, _request_timeout=None): - result = yield self.request("OPTIONS", url, - headers=headers, - query_params=query_params, - post_params=post_params, - _preload_content=_preload_content, - _request_timeout=_request_timeout, - body=body) - raise tornado.gen.Return(result) - - @tornado.gen.coroutine - def DELETE(self, url, headers=None, query_params=None, body=None, - _preload_content=True, _request_timeout=None): - result = yield self.request("DELETE", url, - headers=headers, - query_params=query_params, - _preload_content=_preload_content, - _request_timeout=_request_timeout, - body=body) - raise tornado.gen.Return(result) - - @tornado.gen.coroutine - def POST(self, url, headers=None, query_params=None, post_params=None, - body=None, _preload_content=True, _request_timeout=None): - result = yield self.request("POST", url, - headers=headers, - query_params=query_params, - post_params=post_params, - _preload_content=_preload_content, - _request_timeout=_request_timeout, - body=body) - raise tornado.gen.Return(result) - - @tornado.gen.coroutine - def PUT(self, url, headers=None, query_params=None, post_params=None, - body=None, _preload_content=True, _request_timeout=None): - result = yield self.request("PUT", url, - headers=headers, - query_params=query_params, - post_params=post_params, - _preload_content=_preload_content, - _request_timeout=_request_timeout, - body=body) - raise tornado.gen.Return(result) - - @tornado.gen.coroutine - def PATCH(self, url, headers=None, query_params=None, post_params=None, - body=None, _preload_content=True, _request_timeout=None): - result = yield self.request("PATCH", url, - headers=headers, - query_params=query_params, - post_params=post_params, - _preload_content=_preload_content, - _request_timeout=_request_timeout, - body=body) - raise tornado.gen.Return(result) diff --git a/modules/openapi-generator/src/test/java/org/openapitools/codegen/python/PythonPydanticV1ClientCodegenTest.java b/modules/openapi-generator/src/test/java/org/openapitools/codegen/python/PythonPydanticV1ClientCodegenTest.java new file mode 100644 index 000000000000..837c7ad3f7da --- /dev/null +++ b/modules/openapi-generator/src/test/java/org/openapitools/codegen/python/PythonPydanticV1ClientCodegenTest.java @@ -0,0 +1,504 @@ +/* + * Copyright 2018 OpenAPI-Generator Contributors (https://openapi-generator.tech) + * Copyright 2018 SmartBear Software + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.openapitools.codegen.python; + +import com.google.common.collect.Sets; +import io.swagger.parser.OpenAPIParser; +import io.swagger.v3.oas.models.OpenAPI; +import io.swagger.v3.oas.models.Operation; +import io.swagger.v3.oas.models.media.*; +import io.swagger.v3.parser.core.models.ParseOptions; +import io.swagger.v3.parser.util.SchemaTypeUtil; +import org.openapitools.codegen.*; +import org.openapitools.codegen.languages.PythonPydanticV1ClientCodegen; +import org.openapitools.codegen.languages.features.CXFServerFeatures; +import static org.openapitools.codegen.TestUtils.assertFileContains; +import static org.openapitools.codegen.TestUtils.assertFileExists; +import org.openapitools.codegen.TestUtils; +import org.testng.Assert; +import org.testng.annotations.Test; +import java.io.File; +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Paths; +import java.nio.file.Path; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +public class PythonPydanticV1ClientCodegenTest { + + @Test + public void testInitialConfigValues() throws Exception { + final PythonPydanticV1ClientCodegen codegen = new PythonPydanticV1ClientCodegen(); + codegen.processOpts(); + + Assert.assertEquals(codegen.additionalProperties().get(CodegenConstants.HIDE_GENERATION_TIMESTAMP), Boolean.TRUE); + Assert.assertEquals(codegen.isHideGenerationTimestamp(), true); + } + + @Test + public void testSettersForConfigValues() throws Exception { + final PythonPydanticV1ClientCodegen codegen = new PythonPydanticV1ClientCodegen(); + codegen.setHideGenerationTimestamp(false); + codegen.processOpts(); + + Assert.assertEquals(codegen.additionalProperties().get(CodegenConstants.HIDE_GENERATION_TIMESTAMP), Boolean.FALSE); + Assert.assertEquals(codegen.isHideGenerationTimestamp(), false); + } + + @Test + public void testAdditionalPropertiesPutForConfigValues() throws Exception { + final PythonPydanticV1ClientCodegen codegen = new PythonPydanticV1ClientCodegen(); + codegen.additionalProperties().put(CodegenConstants.HIDE_GENERATION_TIMESTAMP, false); + codegen.processOpts(); + + Assert.assertEquals(codegen.additionalProperties().get(CodegenConstants.HIDE_GENERATION_TIMESTAMP), Boolean.FALSE); + Assert.assertEquals(codegen.isHideGenerationTimestamp(), false); + } + + @Test(description = "test enum null/nullable patterns") + public void testEnumNull() { + final OpenAPI openAPI = TestUtils.parseFlattenSpec("src/test/resources/3_0/issue_1997.yaml"); + + StringSchema prop = (StringSchema) openAPI.getComponents().getSchemas().get("Type").getProperties().get("prop"); + ArrayList expected = new ArrayList<>(Arrays.asList("A", "B", "C")); + assert prop.getNullable(); + assert prop.getEnum().equals(expected); + } + + @Test(description = "test regex patterns") + public void testRegularExpressionOpenAPISchemaVersion3() { + final OpenAPI openAPI = TestUtils.parseFlattenSpec("src/test/resources/3_0/issue_1517.yaml"); + final PythonPydanticV1ClientCodegen codegen = new PythonPydanticV1ClientCodegen(); + codegen.setOpenAPI(openAPI); + final String path = "/ping"; + final Operation p = openAPI.getPaths().get(path).getGet(); + final CodegenOperation op = codegen.fromOperation(path, "get", p, null); + // pattern_no_forward_slashes '^pattern$' + Assert.assertEquals(op.allParams.get(0).pattern, "/^pattern$/"); + // pattern_two_slashes '/^pattern$/' + Assert.assertEquals(op.allParams.get(1).pattern, "/^pattern$/"); + // pattern_dont_escape_backslash '/^pattern\d{3}$/' + Assert.assertEquals(op.allParams.get(2).pattern, "/^pattern\\d{3}$/"); + // pattern_dont_escape_escaped_forward_slash '/^pattern\/\d{3}$/' + Assert.assertEquals(op.allParams.get(3).pattern, "/^pattern\\/\\d{3}$/"); + // pattern_escape_unescaped_forward_slash '^pattern/\d{3}$' + Assert.assertEquals(op.allParams.get(4).pattern, "/^pattern\\/\\d{3}$/"); + // pattern_with_modifiers '/^pattern\d{3}$/i + Assert.assertEquals(op.allParams.get(5).pattern, "/^pattern\\d{3}$/i"); + // pattern_with_backslash_after_bracket '/^[\pattern\d{3}$/i' + // added to test fix for issue #6675 + // removed because "/^[\\pattern\\d{3}$/i" is invalid regex because [ is not escaped and there is no closing ] + // Assert.assertEquals(op.allParams.get(6).pattern, "/^[\\pattern\\d{3}$/i"); + + } + + + @Test(description = "test generated example values for string properties") + public void testGeneratedExampleValues() { + final OpenAPI openAPI = TestUtils.parseFlattenSpec("src/test/resources/3_0/examples.yaml"); + final PythonPydanticV1ClientCodegen codegen = new PythonPydanticV1ClientCodegen(); + codegen.setOpenAPI(openAPI); + final Schema dummyUserSchema = openAPI.getComponents().getSchemas().get("DummyUser"); + final Schema nameSchema = (Schema) dummyUserSchema.getProperties().get("name"); + final Schema numberSchema = (Schema) dummyUserSchema.getProperties().get("number"); + final Schema addressSchema = (Schema) dummyUserSchema.getProperties().get("address"); + final String namePattern = codegen.patternCorrection(nameSchema.getPattern()); + final String numberPattern = codegen.patternCorrection(numberSchema.getPattern()); + final String addressPattern = codegen.patternCorrection(addressSchema.getPattern()); + Assert.assertTrue(codegen.escapeQuotationMark(codegen.toExampleValue(nameSchema)).matches(namePattern)); + Assert.assertTrue(codegen.escapeQuotationMark(codegen.toExampleValue(numberSchema)).matches(numberPattern)); + Assert.assertTrue(codegen.escapeQuotationMark(codegen.toExampleValue(addressSchema)).matches(addressPattern)); + } + + @Test(description = "test single quotes escape") + public void testSingleQuotes() { + final PythonPydanticV1ClientCodegen codegen = new PythonPydanticV1ClientCodegen(); + StringSchema schema = new StringSchema(); + schema.setDefault("Text containing 'single' quote"); + String defaultValue = codegen.toDefaultValue(schema); + Assert.assertEquals("'Text containing \'single\' quote'", defaultValue); + } + + @Test(description = "test backslash default") + public void testBackslashDefault() { + final PythonPydanticV1ClientCodegen codegen = new PythonPydanticV1ClientCodegen(); + StringSchema schema = new StringSchema(); + schema.setDefault("\\"); + String defaultValue = codegen.toDefaultValue(schema); + Assert.assertEquals("'\\\\'", defaultValue); + } + + @Test(description = "convert a python model with dots") + public void modelTest() { + final OpenAPI openAPI = TestUtils.parseFlattenSpec("src/test/resources/3_0/v1beta3.yaml"); + final DefaultCodegen codegen = new PythonPydanticV1ClientCodegen(); + codegen.setOpenAPI(openAPI); + + codegen.setOpenAPI(openAPI); + final CodegenModel simpleName = codegen.fromModel("v1beta3.Binding", openAPI.getComponents().getSchemas().get("v1beta3.Binding")); + Assert.assertEquals(simpleName.name, "v1beta3.Binding"); + Assert.assertEquals(simpleName.classname, "V1beta3Binding"); + Assert.assertEquals(simpleName.classVarName, "v1beta3_binding"); + + codegen.setOpenAPI(openAPI); + final CodegenModel compoundName = codegen.fromModel("v1beta3.ComponentStatus", openAPI.getComponents().getSchemas().get("v1beta3.ComponentStatus")); + Assert.assertEquals(compoundName.name, "v1beta3.ComponentStatus"); + Assert.assertEquals(compoundName.classname, "V1beta3ComponentStatus"); + Assert.assertEquals(compoundName.classVarName, "v1beta3_component_status"); + + final String path = "/api/v1beta3/namespaces/{namespaces}/bindings"; + final Operation operation = openAPI.getPaths().get(path).getPost(); + final CodegenOperation codegenOperation = codegen.fromOperation(path, "get", operation, null); + Assert.assertEquals(codegenOperation.returnType, "V1beta3Binding"); + Assert.assertEquals(codegenOperation.returnBaseType, "V1beta3Binding"); + } + + @Test(description = "convert a simple java model") + public void simpleModelTest() { + final Schema schema = new Schema() + .description("a sample model") + .addProperties("id", new IntegerSchema().format(SchemaTypeUtil.INTEGER64_FORMAT)) + .addProperties("name", new StringSchema()) + .addProperties("createdAt", new DateTimeSchema()) + .addRequiredItem("id") + .addRequiredItem("name"); + final DefaultCodegen codegen = new PythonPydanticV1ClientCodegen(); + OpenAPI openAPI = TestUtils.createOpenAPIWithOneSchema("sample", schema); + codegen.setOpenAPI(openAPI); + final CodegenModel cm = codegen.fromModel("sample", schema); + + Assert.assertEquals(cm.name, "sample"); + Assert.assertEquals(cm.classname, "Sample"); + Assert.assertEquals(cm.description, "a sample model"); + Assert.assertEquals(cm.vars.size(), 3); + + final CodegenProperty property1 = cm.vars.get(0); + Assert.assertEquals(property1.baseName, "id"); + Assert.assertEquals(property1.dataType, "int"); + Assert.assertEquals(property1.name, "id"); + Assert.assertNull(property1.defaultValue); + Assert.assertEquals(property1.baseType, "int"); + Assert.assertTrue(property1.required); + Assert.assertTrue(property1.isPrimitiveType); + + final CodegenProperty property2 = cm.vars.get(1); + Assert.assertEquals(property2.baseName, "name"); + Assert.assertEquals(property2.dataType, "str"); + Assert.assertEquals(property2.name, "name"); + Assert.assertNull(property2.defaultValue); + Assert.assertEquals(property2.baseType, "str"); + Assert.assertTrue(property2.required); + Assert.assertTrue(property2.isPrimitiveType); + + final CodegenProperty property3 = cm.vars.get(2); + Assert.assertEquals(property3.baseName, "createdAt"); + Assert.assertEquals(property3.dataType, "datetime"); + Assert.assertEquals(property3.name, "created_at"); + Assert.assertNull(property3.defaultValue); + Assert.assertEquals(property3.baseType, "datetime"); + Assert.assertFalse(property3.required); + } + + @Test(description = "convert a model with list property") + public void listPropertyTest() { + final Schema model = new Schema() + .description("a sample model") + .addProperties("id", new IntegerSchema().format(SchemaTypeUtil.INTEGER64_FORMAT)) + .addProperties("urls", new ArraySchema() + .items(new StringSchema())) + .addRequiredItem("id"); + final DefaultCodegen codegen = new PythonPydanticV1ClientCodegen(); + OpenAPI openAPI = TestUtils.createOpenAPIWithOneSchema("sample", model); + codegen.setOpenAPI(openAPI); + final CodegenModel cm = codegen.fromModel("sample", model); + + Assert.assertEquals(cm.name, "sample"); + Assert.assertEquals(cm.classname, "Sample"); + Assert.assertEquals(cm.description, "a sample model"); + Assert.assertEquals(cm.vars.size(), 2); + + final CodegenProperty property1 = cm.vars.get(0); + Assert.assertEquals(property1.baseName, "id"); + Assert.assertEquals(property1.dataType, "int"); + Assert.assertEquals(property1.name, "id"); + Assert.assertNull(property1.defaultValue); + Assert.assertEquals(property1.baseType, "int"); + Assert.assertTrue(property1.required); + Assert.assertTrue(property1.isPrimitiveType); + + final CodegenProperty property2 = cm.vars.get(1); + Assert.assertEquals(property2.baseName, "urls"); + Assert.assertEquals(property2.dataType, "List[str]"); + Assert.assertEquals(property2.name, "urls"); + Assert.assertNull(property2.defaultValue); + Assert.assertEquals(property2.baseType, "List"); + Assert.assertEquals(property2.containerType, "array"); + Assert.assertFalse(property2.required); + Assert.assertTrue(property2.isPrimitiveType); + Assert.assertTrue(property2.isContainer); + } + + @Test(description = "convert a model with a map property") + public void mapPropertyTest() { + final Schema model = new Schema() + .description("a sample model") + .addProperties("translations", new MapSchema() + .additionalProperties(new StringSchema())) + .addRequiredItem("id"); + final DefaultCodegen codegen = new PythonPydanticV1ClientCodegen(); + OpenAPI openAPI = TestUtils.createOpenAPIWithOneSchema("sample", model); + codegen.setOpenAPI(openAPI); + final CodegenModel cm = codegen.fromModel("sample", model); + + Assert.assertEquals(cm.name, "sample"); + Assert.assertEquals(cm.classname, "Sample"); + Assert.assertEquals(cm.description, "a sample model"); + Assert.assertEquals(cm.vars.size(), 1); + + final CodegenProperty property1 = cm.vars.get(0); + Assert.assertEquals(property1.baseName, "translations"); + Assert.assertEquals(property1.dataType, "Dict[str, str]"); + Assert.assertEquals(property1.name, "translations"); + Assert.assertEquals(property1.baseType, "Dict"); + Assert.assertEquals(property1.containerType, "map"); + Assert.assertFalse(property1.required); + Assert.assertTrue(property1.isContainer); + Assert.assertTrue(property1.isPrimitiveType); + } + + @Test(description = "convert a model with complex property") + public void complexPropertyTest() { + final Schema model = new Schema() + .description("a sample model") + .addProperties("children", new Schema().$ref("#/definitions/Children")); + final DefaultCodegen codegen = new PythonPydanticV1ClientCodegen(); + OpenAPI openAPI = TestUtils.createOpenAPIWithOneSchema("sample", model); + codegen.setOpenAPI(openAPI); + final CodegenModel cm = codegen.fromModel("sample", model); + + Assert.assertEquals(cm.name, "sample"); + Assert.assertEquals(cm.classname, "Sample"); + Assert.assertEquals(cm.description, "a sample model"); + Assert.assertEquals(cm.vars.size(), 1); + + final CodegenProperty property1 = cm.vars.get(0); + Assert.assertEquals(property1.baseName, "children"); + Assert.assertEquals(property1.dataType, "Children"); + Assert.assertEquals(property1.name, "children"); + Assert.assertEquals(property1.baseType, "Children"); + Assert.assertFalse(property1.required); + Assert.assertFalse(property1.isContainer); + } + + @Test(description = "convert a model with complex list property") + public void complexListPropertyTest() { + final Schema model = new Schema() + .description("a sample model") + .addProperties("children", new ArraySchema() + .items(new Schema().$ref("#/definitions/Children"))); + final DefaultCodegen codegen = new PythonPydanticV1ClientCodegen(); + OpenAPI openAPI = TestUtils.createOpenAPIWithOneSchema("sample", model); + codegen.setOpenAPI(openAPI); + final CodegenModel cm = codegen.fromModel("sample", model); + + Assert.assertEquals(cm.name, "sample"); + Assert.assertEquals(cm.classname, "Sample"); + Assert.assertEquals(cm.description, "a sample model"); + Assert.assertEquals(cm.vars.size(), 1); + + final CodegenProperty property1 = cm.vars.get(0); + Assert.assertEquals(property1.baseName, "children"); + Assert.assertEquals(property1.complexType, "Children"); + Assert.assertEquals(property1.dataType, "List[Children]"); + Assert.assertEquals(property1.name, "children"); + Assert.assertEquals(property1.baseType, "List"); + Assert.assertEquals(property1.containerType, "array"); + Assert.assertFalse(property1.required); + Assert.assertTrue(property1.isContainer); + } + + @Test(description = "convert a model with complex map property") + public void complexMapPropertyTest() { + final Schema model = new Schema() + .description("a sample model") + .addProperties("children", new MapSchema() + .additionalProperties(new Schema().$ref("#/definitions/Children"))); + final DefaultCodegen codegen = new PythonPydanticV1ClientCodegen(); + OpenAPI openAPI = TestUtils.createOpenAPIWithOneSchema("sample", model); + codegen.setOpenAPI(openAPI); + final CodegenModel cm = codegen.fromModel("sample", model); + + Assert.assertEquals(cm.name, "sample"); + Assert.assertEquals(cm.classname, "Sample"); + Assert.assertEquals(cm.description, "a sample model"); + Assert.assertEquals(cm.vars.size(), 1); + Assert.assertEquals(Sets.intersection(cm.imports, Sets.newHashSet("Children")).size(), 1); + + final CodegenProperty property1 = cm.vars.get(0); + Assert.assertEquals(property1.baseName, "children"); + Assert.assertEquals(property1.complexType, "Children"); + Assert.assertEquals(property1.dataType, "Dict[str, Children]"); + Assert.assertEquals(property1.name, "children"); + Assert.assertEquals(property1.baseType, "Dict"); + Assert.assertEquals(property1.containerType, "map"); + Assert.assertFalse(property1.required); + Assert.assertTrue(property1.isContainer); + } + + + // should not start with 'null'. need help from the community to investigate further + @Test(description = "convert an array model") + public void arrayModelTest() { + final Schema model = new ArraySchema() + //.description() + .items(new Schema().$ref("#/definitions/Children")) + .description("an array model"); + final DefaultCodegen codegen = new PythonPydanticV1ClientCodegen(); + OpenAPI openAPI = TestUtils.createOpenAPIWithOneSchema("sample", model); + codegen.setOpenAPI(openAPI); + final CodegenModel cm = codegen.fromModel("sample", model); + + Assert.assertEquals(cm.name, "sample"); + Assert.assertEquals(cm.classname, "Sample"); + Assert.assertEquals(cm.description, "an array model"); + Assert.assertEquals(cm.vars.size(), 0); + Assert.assertEquals(cm.parent, "null"); + Assert.assertEquals(cm.imports.size(), 1); + Assert.assertEquals(Sets.intersection(cm.imports, Sets.newHashSet("Children")).size(), 1); + } + + // should not start with 'null'. need help from the community to investigate further + @Test(description = "convert a map model") + public void mapModelTest() { + final Schema model = new Schema() + .description("a map model") + .additionalProperties(new Schema().$ref("#/definitions/Children")); + final DefaultCodegen codegen = new PythonPydanticV1ClientCodegen(); + OpenAPI openAPI = TestUtils.createOpenAPIWithOneSchema("sample", model); + codegen.setOpenAPI(openAPI); + final CodegenModel cm = codegen.fromModel("sample", model); + + Assert.assertEquals(cm.name, "sample"); + Assert.assertEquals(cm.classname, "Sample"); + Assert.assertEquals(cm.description, "a map model"); + Assert.assertEquals(cm.vars.size(), 0); + Assert.assertEquals(cm.parent, null); + Assert.assertEquals(cm.imports.size(), 0); + } + @Test(description ="check API example has input param(configuration) when it creates api_client") + public void apiExampleDocTest() throws Exception { + final DefaultCodegen codegen = new PythonPydanticV1ClientCodegen(); + final String outputPath = generateFiles(codegen, "src/test/resources/3_0/generic.yaml"); + final Path p = Paths.get(outputPath + "docs/DefaultApi.md"); + + assertFileExists(p); + assertFileContains(p, "openapi_client.ApiClient(configuration) as api_client"); + } + + // Helper function, intended to reduce boilerplate + static private String generateFiles(DefaultCodegen codegen, String filePath) throws IOException { + final File output = Files.createTempDirectory("test").toFile().getCanonicalFile(); + output.deleteOnExit(); + final String outputPath = output.getAbsolutePath().replace('\\', '/'); + + codegen.setOutputDir(output.getAbsolutePath()); + codegen.additionalProperties().put(CXFServerFeatures.LOAD_TEST_DATA_FROM_FILE, "true"); + + final ClientOptInput input = new ClientOptInput(); + final OpenAPI openAPI = new OpenAPIParser().readLocation(filePath, null, new ParseOptions()).getOpenAPI(); + input.openAPI(openAPI); + input.config(codegen); + + final DefaultGenerator generator = new DefaultGenerator(); + final List files = generator.opts(input).generate(); + + Assert.assertTrue(files.size() > 0); + return outputPath + "/"; + } + + @Test(description = "test containerType in parameters") + public void testContainerType() { + final OpenAPI openAPI = TestUtils.parseFlattenSpec("src/test/resources/3_0/petstore.yaml"); + final PythonPydanticV1ClientCodegen codegen = new PythonPydanticV1ClientCodegen(); + codegen.setOpenAPI(openAPI); + // path parameter + String path = "/store/order/{orderId}"; + Operation p = openAPI.getPaths().get(path).getGet(); + CodegenOperation op = codegen.fromOperation(path, "get", p, null); + Assert.assertEquals(op.allParams.get(0).containerType, null); + Assert.assertEquals(op.allParams.get(0).baseName, "orderId"); + + // query parameter + path = "/user/login"; + p = openAPI.getPaths().get(path).getGet(); + op = codegen.fromOperation(path, "get", p, null); + Assert.assertEquals(op.allParams.get(0).containerType, null); + Assert.assertEquals(op.allParams.get(0).baseName, "username"); + Assert.assertEquals(op.allParams.get(1).containerType, null); + Assert.assertEquals(op.allParams.get(1).baseName, "password"); + + // body parameter + path = "/user/createWithList"; + p = openAPI.getPaths().get(path).getPost(); + op = codegen.fromOperation(path, "post", p, null); + Assert.assertEquals(op.allParams.get(0).baseName, "User"); + Assert.assertEquals(op.allParams.get(0).containerType, "array"); + Assert.assertEquals(op.allParams.get(0).containerTypeMapped, "List"); + + path = "/pet"; + p = openAPI.getPaths().get(path).getPost(); + op = codegen.fromOperation(path, "post", p, null); + Assert.assertEquals(op.allParams.get(0).baseName, "Pet"); + Assert.assertEquals(op.allParams.get(0).containerType, null); + Assert.assertEquals(op.allParams.get(0).containerTypeMapped, null); + + } + + @Test(description = "test containerType (dict) in parameters") + public void testContainerTypeForDict() { + final OpenAPI openAPI = TestUtils.parseFlattenSpec("src/test/resources/3_0/dict_query_parameter.yaml"); + final PythonPydanticV1ClientCodegen codegen = new PythonPydanticV1ClientCodegen(); + codegen.setOpenAPI(openAPI); + // query parameter + String path = "/query_parameter_dict"; + Operation p = openAPI.getPaths().get(path).getGet(); + CodegenOperation op = codegen.fromOperation(path, "get", p, null); + Assert.assertEquals(op.allParams.get(0).containerType, "map"); + Assert.assertEquals(op.allParams.get(0).containerTypeMapped, "Dict"); + Assert.assertEquals(op.allParams.get(0).baseName, "dict_string_integer"); + } + + @Test(description = "convert a model with dollar signs") + public void modelTestDollarSign() { + final OpenAPI openAPI = TestUtils.parseFlattenSpec("src/test/resources/3_0/dollar-in-names-pull14359.yaml"); + final DefaultCodegen codegen = new PythonPydanticV1ClientCodegen(); + + codegen.setOpenAPI(openAPI); + final CodegenModel simpleName = codegen.fromModel("$DollarModel$", openAPI.getComponents().getSchemas().get("$DollarModel$")); + Assert.assertEquals(simpleName.name, "$DollarModel$"); + Assert.assertEquals(simpleName.classname, "DollarModel"); + Assert.assertEquals(simpleName.classVarName, "dollar_model"); + + List vars = simpleName.getVars(); + Assert.assertEquals(vars.size(), 1); + CodegenProperty property = vars.get(0); + Assert.assertEquals(property.name, "dollar_value"); + } +} diff --git a/pom.xml b/pom.xml index 174fdcc71489..2ac0d20cd1d4 100644 --- a/pom.xml +++ b/pom.xml @@ -1217,7 +1217,9 @@ samples/openapi3/client/petstore/python + samples/openapi3/client/petstore/python-pydantic-v1 samples/openapi3/client/petstore/python-aiohttp + samples/openapi3/client/petstore/python-pydantic-v1-aiohttp