Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Generic Fields does not handle Arrays #1311

Merged
merged 2 commits into from
Aug 28, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 84 additions & 33 deletions generics.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,13 @@ package swag
import (
"errors"
"fmt"
"github.com/go-openapi/spec"
"go/ast"
"strings"
"sync"
)

var genericDefinitionsMutex = &sync.RWMutex{}
var genericsDefinitions = map[*TypeSpecDef]map[string]*TypeSpecDef{}

type genericTypeSpec struct {
Expand Down Expand Up @@ -55,9 +58,12 @@ func typeSpecFullName(typeSpecDef *TypeSpecDef) string {
return fullName
}

func (pkgDefs *PackagesDefinitions) parametrizeStruct(original *TypeSpecDef, fullGenericForm string, parseDependency bool) *TypeSpecDef {
if spec, ok := genericsDefinitions[original][fullGenericForm]; ok {
return spec
func (pkgDefs *PackagesDefinitions) parametrizeStruct(file *ast.File, original *TypeSpecDef, fullGenericForm string, parseDependency bool) *TypeSpecDef {
genericDefinitionsMutex.RLock()
tSpec, ok := genericsDefinitions[original][fullGenericForm]
genericDefinitionsMutex.RUnlock()
if ok {
return tSpec
}

pkgName := strings.Split(fullGenericForm, ".")[0]
Expand All @@ -81,7 +87,10 @@ func (pkgDefs *PackagesDefinitions) parametrizeStruct(original *TypeSpecDef, ful
arrayDepth++
}

tdef := pkgDefs.FindTypeSpec(genericParam, original.File, parseDependency)
tdef := pkgDefs.FindTypeSpec(genericParam, file, parseDependency)
if tdef != nil && !strings.Contains(genericParam, ".") {
genericParam = fullTypeName(file.Name.Name, genericParam)
}
genericParamTypeDefs[original.TypeSpec.TypeParams.List[i].Names[0].Name] = &genericTypeSpec{
ArrayDepth: arrayDepth,
TypeSpec: tdef,
Expand Down Expand Up @@ -156,6 +165,8 @@ func (pkgDefs *PackagesDefinitions) parametrizeStruct(original *TypeSpecDef, ful
newStructTypeDef.Fields.List = append(newStructTypeDef.Fields.List, newField)
}

genericDefinitionsMutex.Lock()
defer genericDefinitionsMutex.Unlock()
parametrizedTypeSpec.TypeSpec.Type = newStructTypeDef
if genericsDefinitions[original] == nil {
genericsDefinitions[original] = map[string]*TypeSpecDef{}
Expand Down Expand Up @@ -225,78 +236,118 @@ func resolveType(expr ast.Expr, field *ast.Field, genericParamTypeDefs map[strin
return field.Type
}

func getExtendedGenericFieldType(file *ast.File, field ast.Expr) (string, error) {
switch fieldType := field.(type) {
case *ast.ArrayType:
fieldName, err := getExtendedGenericFieldType(file, fieldType.Elt)
return "[]" + fieldName, err
case *ast.StarExpr:
return getExtendedGenericFieldType(file, fieldType.X)
default:
return getFieldType(file, field)
}
}

func getGenericFieldType(file *ast.File, field ast.Expr) (string, error) {
var fullName string
var baseName string
var err error
switch fieldType := field.(type) {
case *ast.IndexListExpr:
fullName, err := getGenericTypeName(file, fieldType.X)
baseName, err = getGenericTypeName(file, fieldType.X)
if err != nil {
return "", err
}
fullName += "["
fullName = baseName + "["

for _, index := range fieldType.Indices {
var fieldName string
var err error

switch item := index.(type) {
case *ast.ArrayType:
fieldName, err = getFieldType(file, item.Elt)
fieldName = "[]" + fieldName
default:
fieldName, err = getFieldType(file, index)
}

fieldName, err := getExtendedGenericFieldType(file, index)
if err != nil {
return "", err
}

fullName += fieldName + ","
}

return strings.TrimRight(fullName, ",") + "]", nil
fullName = strings.TrimRight(fullName, ",") + "]"
case *ast.IndexExpr:
x, err := getFieldType(file, fieldType.X)
baseName, err = getGenericTypeName(file, fieldType.X)
if err != nil {
return "", err
}

i, err := getFieldType(file, fieldType.Index)
indexName, err := getExtendedGenericFieldType(file, fieldType.Index)
if err != nil {
return "", err
}

packageName := ""
if !strings.Contains(x, ".") {
if file.Name == nil {
return "", errors.New("file name is nil")
}
packageName, _ = getFieldType(file, file.Name)
}
fullName = fmt.Sprintf("%s[%s]", baseName, indexName)
}

return strings.TrimLeft(fmt.Sprintf("%s.%s[%s]", packageName, x, i), "."), nil
if fullName == "" {
return "", fmt.Errorf("unknown field type %#v", field)
}

var packageName string
if !strings.Contains(baseName, ".") {
if file.Name == nil {
return "", errors.New("file name is nil")
}
packageName, _ = getFieldType(file, file.Name)
}

return "", fmt.Errorf("unknown field type %#v", field)
return strings.TrimLeft(fmt.Sprintf("%s.%s", packageName, fullName), "."), nil
}

func getGenericTypeName(file *ast.File, field ast.Expr) (string, error) {
switch indexType := field.(type) {
case *ast.Ident:
spec := &TypeSpecDef{
if indexType.Obj == nil {
return getFieldType(file, field)
}

tSpec := &TypeSpecDef{
File: file,
TypeSpec: indexType.Obj.Decl.(*ast.TypeSpec),
PkgPath: file.Name.Name,
}
return spec.FullName(), nil
return tSpec.FullName(), nil
case *ast.ArrayType:
spec := &TypeSpecDef{
tSpec := &TypeSpecDef{
File: file,
TypeSpec: indexType.Elt.(*ast.Ident).Obj.Decl.(*ast.TypeSpec),
PkgPath: file.Name.Name,
}
return spec.FullName(), nil
return tSpec.FullName(), nil
case *ast.SelectorExpr:
return fmt.Sprintf("%s.%s", indexType.X.(*ast.Ident).Name, indexType.Sel.Name), nil
}
return "", fmt.Errorf("unknown type %#v", field)
}

func (parser *Parser) parseGenericTypeExpr(file *ast.File, typeExpr ast.Expr) (*spec.Schema, error) {
switch expr := typeExpr.(type) {
// suppress debug messages for these types
case *ast.InterfaceType:
case *ast.StructType:
case *ast.Ident:
case *ast.StarExpr:
case *ast.SelectorExpr:
case *ast.ArrayType:
case *ast.MapType:
case *ast.FuncType:
case *ast.IndexExpr:
name, err := getExtendedGenericFieldType(file, expr)
if err == nil {
if schema, err := parser.getTypeSchema(name, file, false); err == nil {
return spec.MapProperty(schema), nil
}
}

parser.debug.Printf("Type definition of type '%T' is not supported yet. Using 'object' instead. (%s)\n", typeExpr, err)
default:
parser.debug.Printf("Type definition of type '%T' is not supported yet. Using 'object' instead.\n", typeExpr)
}

return PrimitiveSchema(OBJECT), nil
}
21 changes: 20 additions & 1 deletion generics_other.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,36 @@ package swag

import (
"fmt"
"github.com/go-openapi/spec"
"go/ast"
)

func typeSpecFullName(typeSpecDef *TypeSpecDef) string {
return typeSpecDef.FullName()
}

func (pkgDefs *PackagesDefinitions) parametrizeStruct(original *TypeSpecDef, fullGenericForm string, parseDependency bool) *TypeSpecDef {
func (pkgDefs *PackagesDefinitions) parametrizeStruct(file *ast.File, original *TypeSpecDef, fullGenericForm string, parseDependency bool) *TypeSpecDef {
return original
}

func getGenericFieldType(file *ast.File, field ast.Expr) (string, error) {
return "", fmt.Errorf("unknown field type %#v", field)
}

func (parser *Parser) parseGenericTypeExpr(file *ast.File, typeExpr ast.Expr) (*spec.Schema, error) {
switch typeExpr.(type) {
// suppress debug messages for these types
case *ast.InterfaceType:
case *ast.StructType:
case *ast.Ident:
case *ast.StarExpr:
case *ast.SelectorExpr:
case *ast.ArrayType:
case *ast.MapType:
case *ast.FuncType:
default:
parser.debug.Printf("Type definition of type '%T' is not supported yet. Using 'object' instead.\n", typeExpr)
}

return PrimitiveSchema(OBJECT), nil
}
67 changes: 67 additions & 0 deletions generics_other_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
//go:build !go1.18
// +build !go1.18

package swag

import (
"fmt"
"github.com/stretchr/testify/assert"
"go/ast"
"testing"
)

type testLogger struct {
Messages []string
}

func (t *testLogger) Printf(format string, v ...interface{}) {
t.Messages = append(t.Messages, fmt.Sprintf(format, v...))
}

func TestParametrizeStruct(t *testing.T) {
t.Parallel()

pd := PackagesDefinitions{
packages: make(map[string]*PackageDefinitions),
}

tSpec := &TypeSpecDef{
TypeSpec: &ast.TypeSpec{
Name: &ast.Ident{Name: "Field"},
Type: &ast.StructType{Struct: 100, Fields: &ast.FieldList{Opening: 101, Closing: 102}},
},
}

tr := pd.parametrizeStruct(&ast.File{}, tSpec, "", false)
assert.Equal(t, tr, tSpec)

tr = pd.parametrizeStruct(&ast.File{}, tSpec, "", true)
assert.Equal(t, tr, tSpec)
}

func TestParseGenericTypeExpr(t *testing.T) {
t.Parallel()

parser := New()
logger := &testLogger{}
SetDebugger(logger)(parser)

_, _ = parser.parseGenericTypeExpr(&ast.File{}, &ast.InterfaceType{})
assert.Empty(t, logger.Messages)
_, _ = parser.parseGenericTypeExpr(&ast.File{}, &ast.StructType{})
assert.Empty(t, logger.Messages)
_, _ = parser.parseGenericTypeExpr(&ast.File{}, &ast.Ident{})
assert.Empty(t, logger.Messages)
_, _ = parser.parseGenericTypeExpr(&ast.File{}, &ast.StarExpr{})
assert.Empty(t, logger.Messages)
_, _ = parser.parseGenericTypeExpr(&ast.File{}, &ast.SelectorExpr{})
assert.Empty(t, logger.Messages)
_, _ = parser.parseGenericTypeExpr(&ast.File{}, &ast.ArrayType{})
assert.Empty(t, logger.Messages)
_, _ = parser.parseGenericTypeExpr(&ast.File{}, &ast.MapType{})
assert.Empty(t, logger.Messages)
_, _ = parser.parseGenericTypeExpr(&ast.File{}, &ast.FuncType{})
assert.Empty(t, logger.Messages)
_, _ = parser.parseGenericTypeExpr(&ast.File{}, &ast.BadExpr{})
assert.NotEmpty(t, logger.Messages)
}
Loading