Skip to content

Commit

Permalink
feature: extend struct's simple members in query comment (#594)
Browse files Browse the repository at this point in the history
  • Loading branch information
sdghchj authored and easonlin404 committed Jan 3, 2020
1 parent 8cedf42 commit 74d6d6c
Show file tree
Hide file tree
Showing 4 changed files with 125 additions and 28 deletions.
110 changes: 84 additions & 26 deletions operation.go
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,66 @@ func (operation *Operation) ParseParamComment(commentLine string, astFile *ast.F
},
}
case "object":
return fmt.Errorf("%s is not supported type for %s", refType, paramType)
refType, typeSpec, err := operation.registerSchemaType(refType, astFile)
if err != nil {
return err
}
structType, ok := typeSpec.Type.(*ast.StructType)
if !ok {
return fmt.Errorf("%s is not supported type for %s", refType, paramType)
}
refSplit := strings.Split(refType, ".")
schema, err := operation.parser.parseStruct(refSplit[0], structType.Fields)
if err != nil {
return err
}
if len(schema.Properties) == 0 {
return nil
}
find := func(arr []string, target string) bool {
for _, str := range arr {
if str == target {
return true
}
}
return false
}
for name, prop := range schema.Properties {
if len(prop.Type) == 0 {
continue
}
if prop.Type[0] == "array" &&
prop.Items.Schema != nil &&
len(prop.Items.Schema.Type) > 0 &&
IsSimplePrimitiveType(prop.Items.Schema.Type[0]) {
param = createParameter(paramType, prop.Description, name, prop.Type[0], find(schema.Required, name))
param.SimpleSchema.Type = prop.Type[0]
param.SimpleSchema.Items = &spec.Items{
SimpleSchema: spec.SimpleSchema{
Type: prop.Items.Schema.Type[0],
},
}
} else if IsSimplePrimitiveType(prop.Type[0]) {
param = createParameter(paramType, prop.Description, name, prop.Type[0], find(schema.Required, name))
} else {
Println(fmt.Sprintf("skip field [%s] in %s is not supported type for %s", name, refType, paramType))
continue
}
param.CommonValidations.Maximum = prop.Maximum
param.CommonValidations.Minimum = prop.Minimum
param.CommonValidations.ExclusiveMaximum = prop.ExclusiveMaximum
param.CommonValidations.ExclusiveMinimum = prop.ExclusiveMinimum
param.CommonValidations.MaxLength = prop.MaxLength
param.CommonValidations.MinLength = prop.MinLength
param.CommonValidations.Pattern = prop.Pattern
param.CommonValidations.MaxItems = prop.MaxItems
param.CommonValidations.MinItems = prop.MinItems
param.CommonValidations.UniqueItems = prop.UniqueItems
param.CommonValidations.MultipleOf = prop.MultipleOf
param.CommonValidations.Enum = prop.Enum
operation.Operation.Parameters = append(operation.Operation.Parameters, param)
}
return nil
}
case "body":
switch objectType {
Expand All @@ -192,19 +251,17 @@ func (operation *Operation) ParseParamComment(commentLine string, astFile *ast.F
if IsPrimitiveType(refType) {
param.Schema.Items.Schema.Type = spec.StringOrArray{refType}
} else {
if !strings.Contains(refType, ".") {
refType = astFile.Name.String() + "." + refType
}
if err := operation.registerSchemaType(refType, astFile); err != nil {
var err error
refType, _, err = operation.registerSchemaType(refType, astFile)
if err != nil {
return err
}
param.Schema.Items.Schema.Ref = spec.Ref{Ref: jsonreference.MustCreateRef("#/definitions/" + refType)}
}
case "object":
if !strings.Contains(refType, ".") {
refType = astFile.Name.String() + "." + refType
}
if err := operation.registerSchemaType(refType, astFile); err != nil {
var err error
refType, _, err = operation.registerSchemaType(refType, astFile)
if err != nil {
return err
}
param.Schema.Type = []string{}
Expand All @@ -223,20 +280,23 @@ func (operation *Operation) ParseParamComment(commentLine string, astFile *ast.F
return nil
}

func (operation *Operation) registerSchemaType(schemaType string, astFile *ast.File) error {
refSplit := strings.Split(schemaType, ".")
if len(refSplit) != 2 {
return nil
func (operation *Operation) registerSchemaType(schemaType string, astFile *ast.File) (string, *ast.TypeSpec, error) {
if !strings.ContainsRune(schemaType, '.') {
if astFile == nil {
return schemaType, nil, fmt.Errorf("no package name for type %s", schemaType)
}
schemaType = astFile.Name.String() + "." + schemaType
}
refSplit := strings.Split(schemaType, ".")
pkgName := refSplit[0]
typeName := refSplit[1]
if typeSpec, ok := operation.parser.TypeDefinitions[pkgName][typeName]; ok {
operation.parser.registerTypes[schemaType] = typeSpec
return nil
return schemaType, typeSpec, nil
}
var typeSpec *ast.TypeSpec
if astFile == nil {
return fmt.Errorf("can not register schema type: %q reason: astFile == nil", schemaType)
return schemaType, nil, fmt.Errorf("can not register schema type: %q reason: astFile == nil", schemaType)
}
for _, imp := range astFile.Imports {
if imp.Name != nil && imp.Name.Name == pkgName { // the import had an alias that matched
Expand All @@ -247,14 +307,14 @@ func (operation *Operation) registerSchemaType(schemaType string, astFile *ast.F
var err error
typeSpec, err = findTypeDef(impPath, typeName)
if err != nil {
return fmt.Errorf("can not find type def: %q error: %s", schemaType, err)
return schemaType, nil, fmt.Errorf("can not find type def: %q error: %s", schemaType, err)
}
break
}
}

if typeSpec == nil {
return fmt.Errorf("can not find schema type: %q", schemaType)
return schemaType, nil, fmt.Errorf("can not find schema type: %q", schemaType)
}

if _, ok := operation.parser.TypeDefinitions[pkgName]; !ok {
Expand All @@ -263,7 +323,7 @@ func (operation *Operation) registerSchemaType(schemaType string, astFile *ast.F

operation.parser.TypeDefinitions[pkgName][typeName] = typeSpec
operation.parser.registerTypes[schemaType] = typeSpec
return nil
return schemaType, typeSpec, nil
}

var regexAttributes = map[string]*regexp.Regexp{
Expand Down Expand Up @@ -571,14 +631,12 @@ func (operation *Operation) ParseResponseComment(commentLine string, astFile *as
schemaType := strings.Trim(matches[2], "{}")
refType := matches[3]

if !IsGolangPrimitiveType(refType) && !strings.Contains(refType, ".") {
currentPkgName := astFile.Name.String()
refType = currentPkgName + "." + refType
}

if operation.parser != nil { // checking refType has existing in 'TypeDefinitions'
if err := operation.registerSchemaType(refType, astFile); err != nil {
return err
if !IsGolangPrimitiveType(refType) {
if operation.parser != nil { // checking refType has existing in 'TypeDefinitions'
var err error
if refType, _, err = operation.registerSchemaType(refType, astFile); err != nil {
return err
}
}
}

Expand Down
4 changes: 2 additions & 2 deletions operation_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -944,7 +944,6 @@ func TestParseDeprecationDescription(t *testing.T) {

func TestRegisterSchemaType(t *testing.T) {
operation := NewOperation()
assert.NoError(t, operation.registerSchemaType("string", nil))

fset := token.NewFileSet()
astFile, err := goparser.ParseFile(fset, "main.go", `package main
Expand All @@ -954,7 +953,8 @@ func TestRegisterSchemaType(t *testing.T) {
assert.NoError(t, err)

operation.parser = New()
assert.Error(t, operation.registerSchemaType("timer.Location", astFile))
_, _, err = operation.registerSchemaType("timer.Location", astFile)
assert.Error(t, err)
}

func TestParseExtentions(t *testing.T) {
Expand Down
29 changes: 29 additions & 0 deletions parser_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2962,3 +2962,32 @@ func TestParseOutsideDependencies(t *testing.T) {
t.Error("Failed to parse api: " + err.Error())
}
}

func TestParseStructParamCommentByQueryType(t *testing.T) {
src := `
package main
type Student struct {
Name string
Age int
Teachers []string
SkipField map[string]string
}
// @Param request query Student true "query params"
// @Success 200
// @Router /test [get]
func Fun() {
}
`
f, err := goparser.ParseFile(token.NewFileSet(), "", src, goparser.ParseComments)
assert.NoError(t, err)

p := New()
p.ParseType(f)
err = p.ParseRouterAPIInfo("", f)
assert.NoError(t, err)

assert.Equal(t, 3, len(p.swagger.Paths.Paths["/test"].Get.Parameters))
}
10 changes: 10 additions & 0 deletions schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,16 @@ func CheckSchemaType(typeName string) error {
return nil
}

// IsSimplePrimitiveType determine whether the type name is a simple primitive type
func IsSimplePrimitiveType(typeName string) bool {
switch typeName {
case "string", "number", "integer", "boolean":
return true
default:
return false
}
}

// IsPrimitiveType determine whether the type name is a primitive type
func IsPrimitiveType(typeName string) bool {
switch typeName {
Expand Down

0 comments on commit 74d6d6c

Please sign in to comment.