diff --git a/operation.go b/operation.go index 490d855e4..d966b36ce 100644 --- a/operation.go +++ b/operation.go @@ -175,7 +175,66 @@ func (operation *Operation) ParseParamComment(commentLine string, astFile *ast.F }, } case "object": - return fmt.Errorf("%s is not supported type for %s", refType, paramType) + refType, typeSpec, err := operation.registerSchemaType(refType, astFile) + if err != nil { + return err + } + structType, ok := typeSpec.Type.(*ast.StructType) + if !ok { + return fmt.Errorf("%s is not supported type for %s", refType, paramType) + } + refSplit := strings.Split(refType, ".") + schema, err := operation.parser.parseStruct(refSplit[0], structType.Fields) + if err != nil { + return err + } + if len(schema.Properties) == 0 { + return nil + } + find := func(arr []string, target string) bool { + for _, str := range arr { + if str == target { + return true + } + } + return false + } + for name, prop := range schema.Properties { + if len(prop.Type) == 0 { + continue + } + if prop.Type[0] == "array" && + prop.Items.Schema != nil && + len(prop.Items.Schema.Type) > 0 && + IsSimplePrimitiveType(prop.Items.Schema.Type[0]) { + param = createParameter(paramType, prop.Description, name, prop.Type[0], find(schema.Required, name)) + param.SimpleSchema.Type = prop.Type[0] + param.SimpleSchema.Items = &spec.Items{ + SimpleSchema: spec.SimpleSchema{ + Type: prop.Items.Schema.Type[0], + }, + } + } else if IsSimplePrimitiveType(prop.Type[0]) { + param = createParameter(paramType, prop.Description, name, prop.Type[0], find(schema.Required, name)) + } else { + Println(fmt.Sprintf("skip field [%s] in %s is not supported type for %s", name, refType, paramType)) + continue + } + param.CommonValidations.Maximum = prop.Maximum + param.CommonValidations.Minimum = prop.Minimum + param.CommonValidations.ExclusiveMaximum = prop.ExclusiveMaximum + param.CommonValidations.ExclusiveMinimum = prop.ExclusiveMinimum + param.CommonValidations.MaxLength = prop.MaxLength + param.CommonValidations.MinLength = prop.MinLength + param.CommonValidations.Pattern = prop.Pattern + param.CommonValidations.MaxItems = prop.MaxItems + param.CommonValidations.MinItems = prop.MinItems + param.CommonValidations.UniqueItems = prop.UniqueItems + param.CommonValidations.MultipleOf = prop.MultipleOf + param.CommonValidations.Enum = prop.Enum + operation.Operation.Parameters = append(operation.Operation.Parameters, param) + } + return nil } case "body": switch objectType { @@ -192,19 +251,17 @@ func (operation *Operation) ParseParamComment(commentLine string, astFile *ast.F if IsPrimitiveType(refType) { param.Schema.Items.Schema.Type = spec.StringOrArray{refType} } else { - if !strings.Contains(refType, ".") { - refType = astFile.Name.String() + "." + refType - } - if err := operation.registerSchemaType(refType, astFile); err != nil { + var err error + refType, _, err = operation.registerSchemaType(refType, astFile) + if err != nil { return err } param.Schema.Items.Schema.Ref = spec.Ref{Ref: jsonreference.MustCreateRef("#/definitions/" + refType)} } case "object": - if !strings.Contains(refType, ".") { - refType = astFile.Name.String() + "." + refType - } - if err := operation.registerSchemaType(refType, astFile); err != nil { + var err error + refType, _, err = operation.registerSchemaType(refType, astFile) + if err != nil { return err } param.Schema.Type = []string{} @@ -223,20 +280,23 @@ func (operation *Operation) ParseParamComment(commentLine string, astFile *ast.F return nil } -func (operation *Operation) registerSchemaType(schemaType string, astFile *ast.File) error { - refSplit := strings.Split(schemaType, ".") - if len(refSplit) != 2 { - return nil +func (operation *Operation) registerSchemaType(schemaType string, astFile *ast.File) (string, *ast.TypeSpec, error) { + if !strings.ContainsRune(schemaType, '.') { + if astFile == nil { + return schemaType, nil, fmt.Errorf("no package name for type %s", schemaType) + } + schemaType = astFile.Name.String() + "." + schemaType } + refSplit := strings.Split(schemaType, ".") pkgName := refSplit[0] typeName := refSplit[1] if typeSpec, ok := operation.parser.TypeDefinitions[pkgName][typeName]; ok { operation.parser.registerTypes[schemaType] = typeSpec - return nil + return schemaType, typeSpec, nil } var typeSpec *ast.TypeSpec if astFile == nil { - return fmt.Errorf("can not register schema type: %q reason: astFile == nil", schemaType) + return schemaType, nil, fmt.Errorf("can not register schema type: %q reason: astFile == nil", schemaType) } for _, imp := range astFile.Imports { if imp.Name != nil && imp.Name.Name == pkgName { // the import had an alias that matched @@ -247,14 +307,14 @@ func (operation *Operation) registerSchemaType(schemaType string, astFile *ast.F var err error typeSpec, err = findTypeDef(impPath, typeName) if err != nil { - return fmt.Errorf("can not find type def: %q error: %s", schemaType, err) + return schemaType, nil, fmt.Errorf("can not find type def: %q error: %s", schemaType, err) } break } } if typeSpec == nil { - return fmt.Errorf("can not find schema type: %q", schemaType) + return schemaType, nil, fmt.Errorf("can not find schema type: %q", schemaType) } if _, ok := operation.parser.TypeDefinitions[pkgName]; !ok { @@ -263,7 +323,7 @@ func (operation *Operation) registerSchemaType(schemaType string, astFile *ast.F operation.parser.TypeDefinitions[pkgName][typeName] = typeSpec operation.parser.registerTypes[schemaType] = typeSpec - return nil + return schemaType, typeSpec, nil } var regexAttributes = map[string]*regexp.Regexp{ @@ -571,14 +631,12 @@ func (operation *Operation) ParseResponseComment(commentLine string, astFile *as schemaType := strings.Trim(matches[2], "{}") refType := matches[3] - if !IsGolangPrimitiveType(refType) && !strings.Contains(refType, ".") { - currentPkgName := astFile.Name.String() - refType = currentPkgName + "." + refType - } - - if operation.parser != nil { // checking refType has existing in 'TypeDefinitions' - if err := operation.registerSchemaType(refType, astFile); err != nil { - return err + if !IsGolangPrimitiveType(refType) { + if operation.parser != nil { // checking refType has existing in 'TypeDefinitions' + var err error + if refType, _, err = operation.registerSchemaType(refType, astFile); err != nil { + return err + } } } diff --git a/operation_test.go b/operation_test.go index 160bd89ef..64951a3a0 100644 --- a/operation_test.go +++ b/operation_test.go @@ -944,7 +944,6 @@ func TestParseDeprecationDescription(t *testing.T) { func TestRegisterSchemaType(t *testing.T) { operation := NewOperation() - assert.NoError(t, operation.registerSchemaType("string", nil)) fset := token.NewFileSet() astFile, err := goparser.ParseFile(fset, "main.go", `package main @@ -954,7 +953,8 @@ func TestRegisterSchemaType(t *testing.T) { assert.NoError(t, err) operation.parser = New() - assert.Error(t, operation.registerSchemaType("timer.Location", astFile)) + _, _, err = operation.registerSchemaType("timer.Location", astFile) + assert.Error(t, err) } func TestParseExtentions(t *testing.T) { diff --git a/parser_test.go b/parser_test.go index c1bca74ba..063f938af 100644 --- a/parser_test.go +++ b/parser_test.go @@ -2962,3 +2962,32 @@ func TestParseOutsideDependencies(t *testing.T) { t.Error("Failed to parse api: " + err.Error()) } } + +func TestParseStructParamCommentByQueryType(t *testing.T) { + src := ` +package main + +type Student struct { + Name string + Age int + Teachers []string + SkipField map[string]string +} + +// @Param request query Student true "query params" +// @Success 200 +// @Router /test [get] +func Fun() { + +} +` + f, err := goparser.ParseFile(token.NewFileSet(), "", src, goparser.ParseComments) + assert.NoError(t, err) + + p := New() + p.ParseType(f) + err = p.ParseRouterAPIInfo("", f) + assert.NoError(t, err) + + assert.Equal(t, 3, len(p.swagger.Paths.Paths["/test"].Get.Parameters)) +} diff --git a/schema.go b/schema.go index 2ebef1678..ba5fe19e1 100644 --- a/schema.go +++ b/schema.go @@ -10,6 +10,16 @@ func CheckSchemaType(typeName string) error { return nil } +// IsSimplePrimitiveType determine whether the type name is a simple primitive type +func IsSimplePrimitiveType(typeName string) bool { + switch typeName { + case "string", "number", "integer", "boolean": + return true + default: + return false + } +} + // IsPrimitiveType determine whether the type name is a primitive type func IsPrimitiveType(typeName string) bool { switch typeName {