Skip to content

Commit

Permalink
Fix type binding validation for slices of pointers like []*foo
Browse files Browse the repository at this point in the history
  • Loading branch information
Steven Normore committed Nov 20, 2018
1 parent f7932b4 commit 827dac5
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 2 deletions.
8 changes: 6 additions & 2 deletions codegen/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -337,14 +337,14 @@ func validateTypeBinding(imports *Imports, field *Field, goType types.Type) erro
gqlType := normalizeVendor(field.Type.FullSignature())
goTypeStr := normalizeVendor(goType.String())

if goTypeStr == gqlType || "*"+goTypeStr == gqlType || goTypeStr == "*"+gqlType {
if equalTypes(goTypeStr, gqlType) {
field.Type.Modifiers = modifiersFromGoType(goType)
return nil
}

// deal with type aliases
underlyingStr := normalizeVendor(goType.Underlying().String())
if underlyingStr == gqlType || "*"+underlyingStr == gqlType || underlyingStr == "*"+gqlType {
if equalTypes(underlyingStr, gqlType) {
field.Type.Modifiers = modifiersFromGoType(goType)
pkg, typ := pkgAndType(goType.String())
imp := imports.findByPath(pkg)
Expand Down Expand Up @@ -382,3 +382,7 @@ func normalizeVendor(pkg string) string {
parts := strings.Split(pkg, "/vendor/")
return modifiers + parts[len(parts)-1]
}

func equalTypes(goType string, gqlType string) bool {
return goType == gqlType || "*"+goType == gqlType || goType == "*"+gqlType || strings.Replace(goType, "[]*", "[]", -1) == gqlType
}
22 changes: 22 additions & 0 deletions codegen/util_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ func TestNormalizeVendor(t *testing.T) {
require.Equal(t, "[]bar/baz", normalizeVendor("[]foo/vendor/bar/baz"))
require.Equal(t, "*bar/baz", normalizeVendor("*foo/vendor/bar/baz"))
require.Equal(t, "*[]*bar/baz", normalizeVendor("*[]*foo/vendor/bar/baz"))
require.Equal(t, "[]*bar/baz", normalizeVendor("[]*foo/vendor/bar/baz"))
}

func TestFindField(t *testing.T) {
Expand Down Expand Up @@ -120,3 +121,24 @@ func TestEqualFieldName(t *testing.T) {
})
}
}

func TestEqualTypes(t *testing.T) {
tt := []struct {
Name string
Source string
Target string
Expected bool
}{
{Name: "basic", Source: "bar/baz", Target: "bar/baz", Expected: true},
{Name: "basic slice", Source: "[]bar/baz", Target: "[]bar/baz", Expected: true},
{Name: "pointer", Source: "*bar/baz", Target: "bar/baz", Expected: true},
{Name: "pointer slice", Source: "[]*bar/baz", Target: "[]bar/baz", Expected: true},
}

for _, tc := range tt {
t.Run(tc.Name, func(t *testing.T) {
result := equalTypes(tc.Source, tc.Target)
require.Equal(t, tc.Expected, result)
})
}
}

0 comments on commit 827dac5

Please sign in to comment.