Skip to content

Commit

Permalink
feat: support renaming model to display in swagger UI (#631)
Browse files Browse the repository at this point in the history
Co-authored-by: Eason Lin <[email protected]>
  • Loading branch information
sdghchj and easonlin404 authored Feb 28, 2020
1 parent be7ef6b commit 7290e9b
Show file tree
Hide file tree
Showing 6 changed files with 109 additions and 29 deletions.
7 changes: 7 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -618,6 +618,13 @@ generate swagger doc as follows:
}
}
```
### Rename model to display
```golang
type Resp struct {
Code int
}//@name Response
```
### How to using security annotations
Expand Down
25 changes: 12 additions & 13 deletions operation.go
Original file line number Diff line number Diff line change
Expand Up @@ -257,22 +257,22 @@ func (operation *Operation) ParseParamComment(commentLine string, astFile *ast.F
if IsPrimitiveType(refType) {
param.Schema.Items.Schema.Type = spec.StringOrArray{refType}
} else {
var err error
refType, _, err = operation.registerSchemaType(refType, astFile)
refType, typeSpec, err := operation.registerSchemaType(refType, astFile)
if err != nil {
return err
}
param.Schema.Items.Schema.Ref = spec.Ref{Ref: jsonreference.MustCreateRef("#/definitions/" + refType)}
param.Schema.Items.Schema.Ref = spec.Ref{
Ref: jsonreference.MustCreateRef("#/definitions/" + TypeDocName(refType, typeSpec)),
}
}
case "object":
var err error
refType, _, err = operation.registerSchemaType(refType, astFile)
refType, typeSpec, err := operation.registerSchemaType(refType, astFile)
if err != nil {
return err
}
param.Schema.Type = []string{}
param.Schema.Ref = spec.Ref{
Ref: jsonreference.MustCreateRef("#/definitions/" + refType),
Ref: jsonreference.MustCreateRef("#/definitions/" + TypeDocName(refType, typeSpec)),
}
}
default:
Expand All @@ -291,7 +291,7 @@ func (operation *Operation) registerSchemaType(schemaType string, astFile *ast.F
if astFile == nil {
return schemaType, nil, fmt.Errorf("no package name for type %s", schemaType)
}
schemaType = astFile.Name.String() + "." + schemaType
schemaType = fullTypeName(astFile.Name.String(), schemaType)
}
refSplit := strings.Split(schemaType, ".")
pkgName := refSplit[0]
Expand Down Expand Up @@ -637,10 +637,11 @@ func (operation *Operation) ParseResponseComment(commentLine string, astFile *as
schemaType := strings.Trim(matches[2], "{}")
refType := matches[3]

var typeSpec *ast.TypeSpec
if !IsGolangPrimitiveType(refType) {
if operation.parser != nil { // checking refType has existing in 'TypeDefinitions'
var err error
if refType, _, err = operation.registerSchemaType(refType, astFile); err != nil {
if refType, typeSpec, err = operation.registerSchemaType(refType, astFile); err != nil {
return err
}
}
Expand All @@ -652,11 +653,9 @@ func (operation *Operation) ParseResponseComment(commentLine string, astFile *as
if schemaType == "object" {
response.Schema.SchemaProps = spec.SchemaProps{}
response.Schema.Ref = spec.Ref{
Ref: jsonreference.MustCreateRef("#/definitions/" + refType),
Ref: jsonreference.MustCreateRef("#/definitions/" + TypeDocName(refType, typeSpec)),
}
}

if schemaType == "array" {
} else if schemaType == "array" {
refType = TransToValidSchemeType(refType)
if IsPrimitiveType(refType) {
response.Schema.Items = &spec.SchemaOrArray{
Expand All @@ -670,7 +669,7 @@ func (operation *Operation) ParseResponseComment(commentLine string, astFile *as
response.Schema.Items = &spec.SchemaOrArray{
Schema: &spec.Schema{
SchemaProps: spec.SchemaProps{
Ref: spec.Ref{Ref: jsonreference.MustCreateRef("#/definitions/" + refType)},
Ref: spec.Ref{Ref: jsonreference.MustCreateRef("#/definitions/" + TypeDocName(refType, typeSpec))},
},
},
}
Expand Down
27 changes: 15 additions & 12 deletions parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -579,7 +579,7 @@ func (parser *Parser) parseDefinitions() error {
// given name and package, and populates swagger schema definitions registry
// with a schema for the given type
func (parser *Parser) ParseDefinition(pkgName, typeName string, typeSpec *ast.TypeSpec) error {
refTypeName := fullTypeName(pkgName, typeName)
refTypeName := TypeDocName(pkgName, typeSpec)

if typeSpec == nil {
Println("Skipping '" + refTypeName + "', pkg '" + pkgName + "' not found, try add flag --parseDependency or --parseVendor.")
Expand Down Expand Up @@ -654,9 +654,11 @@ func (parser *Parser) parseTypeExpr(pkgName, typeName string, typeExpr ast.Expr)
switch expr := typeExpr.(type) {
// type Foo struct {...}
case *ast.StructType:
refTypeName := fullTypeName(pkgName, typeName)
if schema, isParsed := parser.swagger.Definitions[refTypeName]; isParsed {
return &schema, nil
if typedef, ok := parser.TypeDefinitions[pkgName][typeName]; ok {
refTypeName := TypeDocName(pkgName, typedef)
if schema, isParsed := parser.swagger.Definitions[refTypeName]; isParsed {
return &schema, nil
}
}

return parser.parseStruct(pkgName, expr.Fields)
Expand All @@ -671,11 +673,13 @@ func (parser *Parser) parseTypeExpr(pkgName, typeName string, typeExpr ast.Expr)
}, nil
}
refTypeName := fullTypeName(pkgName, expr.Name)
if _, isParsed := parser.swagger.Definitions[refTypeName]; !isParsed {
if typedef, ok := parser.TypeDefinitions[pkgName][expr.Name]; ok {
if typedef, ok := parser.TypeDefinitions[pkgName][expr.Name]; ok {
refTypeName = TypeDocName(pkgName, typedef)
if _, isParsed := parser.swagger.Definitions[refTypeName]; !isParsed {
parser.ParseDefinition(pkgName, expr.Name, typedef)
}
}

return &spec.Schema{
SchemaProps: spec.SchemaProps{
Ref: spec.Ref{
Expand Down Expand Up @@ -895,10 +899,9 @@ func (parser *Parser) parseStructField(pkgName string, field *ast.Field) (map[st
return fillObject(&src.VendorExtensible, &dest.VendorExtensible)
}

if _, ok := parser.TypeDefinitions[pkgName][structField.schemaType]; ok { // user type field
if typeSpec, ok := parser.TypeDefinitions[pkgName][structField.schemaType]; ok { // user type field
// write definition if not yet present
parser.ParseDefinition(pkgName, structField.schemaType,
parser.TypeDefinitions[pkgName][structField.schemaType])
parser.ParseDefinition(pkgName, structField.schemaType, typeSpec)
required := make([]string, 0)
if structField.isRequired {
required = append(required, structField.name)
Expand All @@ -909,7 +912,7 @@ func (parser *Parser) parseStructField(pkgName string, field *ast.Field) (map[st
Description: structField.desc,
Required: required,
Ref: spec.Ref{
Ref: jsonreference.MustCreateRef("#/definitions/" + pkgName + "." + structField.schemaType),
Ref: jsonreference.MustCreateRef("#/definitions/" + TypeDocName(pkgName, typeSpec)),
},
},
SwaggerSchemaProps: spec.SwaggerSchemaProps{
Expand All @@ -918,7 +921,7 @@ func (parser *Parser) parseStructField(pkgName string, field *ast.Field) (map[st
}
} else if structField.schemaType == "array" { // array field type
// if defined -- ref it
if _, ok := parser.TypeDefinitions[pkgName][structField.arrayType]; ok { // user type in array
if typeSpec, ok := parser.TypeDefinitions[pkgName][structField.arrayType]; ok { // user type in array
parser.ParseDefinition(pkgName, structField.arrayType,
parser.TypeDefinitions[pkgName][structField.arrayType])
required := make([]string, 0)
Expand All @@ -934,7 +937,7 @@ func (parser *Parser) parseStructField(pkgName string, field *ast.Field) (map[st
Schema: &spec.Schema{
SchemaProps: spec.SchemaProps{
Ref: spec.Ref{
Ref: jsonreference.MustCreateRef("#/definitions/" + pkgName + "." + structField.arrayType),
Ref: jsonreference.MustCreateRef("#/definitions/" + TypeDocName(pkgName, typeSpec)),
},
},
},
Expand Down
45 changes: 45 additions & 0 deletions parser_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2991,3 +2991,48 @@ func Fun() {

assert.Equal(t, 3, len(p.swagger.Paths.Paths["/test"].Get.Parameters))
}

func TestParseRenamedStructDefinition(t *testing.T) {
src := `
package main
type Child struct {
Name string
}//@name Student
type Parent struct {
Name string
Child Child
}//@name Teacher
// @Param request body Parent true "query params"
// @Success 200 {object} Parent
// @Router /test [get]
func Fun() {
}
`
f, err := goparser.ParseFile(token.NewFileSet(), "", src, goparser.ParseComments)
assert.NoError(t, err)

p := New()
p.ParseType(f)
err = p.ParseRouterAPIInfo("", f)
assert.NoError(t, err)

err = p.ParseRouterAPIInfo("", f)
assert.NoError(t, err)

err = p.parseDefinitions()
assert.NoError(t, err)
teacher, ok := p.swagger.Definitions["Teacher"]
assert.True(t, ok)
ref := teacher.Properties["child"].SchemaProps.Ref
assert.Equal(t, "#/definitions/Student", ref.String())
_, ok = p.swagger.Definitions["Student"]
assert.True(t, ok)
path, ok := p.swagger.Paths.Paths["/test"]
assert.Equal(t, "#/definitions/Teacher", path.Get.Parameters[0].Schema.Ref.String())
ref = path.Get.Responses.ResponsesProps.StatusCodeResponses[200].ResponseProps.Schema.Ref
assert.Equal(t, "#/definitions/Teacher", ref.String())
}
6 changes: 3 additions & 3 deletions property.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ func getPropertyName(pkgName string, expr ast.Expr, parser *Parser) (propertyNam
case *ast.Ident:
name := tp.Name
// check if it is a custom type
if actualPrimitiveType, isCustomType := parser.CustomPrimitiveTypes[pkgName+"."+name]; isCustomType {
if actualPrimitiveType, isCustomType := parser.CustomPrimitiveTypes[fullTypeName(pkgName, name)]; isCustomType {
return propertyName{SchemaType: actualPrimitiveType, ArrayType: actualPrimitiveType}, nil
}

Expand All @@ -121,15 +121,15 @@ func getArrayPropertyName(pkgName string, astTypeArrayElt ast.Expr, parser *Pars
return parseFieldSelectorExpr(elt, parser, newArrayProperty)
case *ast.Ident:
name := elt.Name
if actualPrimitiveType, isCustomType := parser.CustomPrimitiveTypes[pkgName+"."+name]; isCustomType {
if actualPrimitiveType, isCustomType := parser.CustomPrimitiveTypes[fullTypeName(pkgName, name)]; isCustomType {
name = actualPrimitiveType
} else {
name = TransToValidSchemeType(elt.Name)
}
return propertyName{SchemaType: "array", ArrayType: name}
default:
name := fmt.Sprintf("%s", astTypeArrayElt)
if actualPrimitiveType, isCustomType := parser.CustomPrimitiveTypes[pkgName+"."+name]; isCustomType {
if actualPrimitiveType, isCustomType := parser.CustomPrimitiveTypes[fullTypeName(pkgName, name)]; isCustomType {
name = actualPrimitiveType
} else {
name = TransToValidSchemeType(name)
Expand Down
28 changes: 27 additions & 1 deletion schema.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
package swag

import "fmt"
import (
"fmt"
"go/ast"
"strings"
)

// CheckSchemaType checks if typeName is not a name of primitive type
func CheckSchemaType(typeName string) error {
Expand Down Expand Up @@ -79,3 +83,25 @@ func IsGolangPrimitiveType(typeName string) bool {
return false
}
}

// TypeDocName get alias from comment '// @name ', otherwise the original type name to display in doc
func TypeDocName(pkgName string, spec *ast.TypeSpec) string {
if spec != nil {
if spec.Comment != nil {
for _, comment := range spec.Comment.List {
text := strings.TrimSpace(comment.Text)
text = strings.TrimLeft(text, "//")
text = strings.TrimSpace(text)
texts := strings.Split(text, " ")
if len(texts) > 1 && strings.ToLower(texts[0]) == "@name" {
return texts[1]
}
}
}
if spec.Name != nil {
return fullTypeName(strings.Split(pkgName, ".")[0], spec.Name.Name)
}
}

return pkgName
}

0 comments on commit 7290e9b

Please sign in to comment.