Skip to content

Commit

Permalink
Fixes #843: Bind to embedded struct method or field
Browse files Browse the repository at this point in the history
  • Loading branch information
matiasanaya committed Nov 8, 2019
1 parent 1172128 commit a745dc7
Show file tree
Hide file tree
Showing 8 changed files with 492 additions and 57 deletions.
148 changes: 100 additions & 48 deletions codegen/field.go
Original file line number Diff line number Diff line change
Expand Up @@ -184,75 +184,111 @@ func (b *builder) bindField(obj *Object, f *Field) (errret error) {
}
}

// findField attempts to match the name to a struct field with the following
// priorites:
// 1. Any method with a matching name
// 2. Any Fields with a struct tag (see config.StructTag)
// 3. Any fields with a matching name
// 4. Same logic again for embedded fields
// findBindTarget attempts to match the name to a struct field or method
// with the following priorites:
// 1. Any Fields with a struct tag (see config.StructTag). Errors if more than one match is found
// 2. Any method or field with a matching name. Errors if more than one match is found
// 3. Same logic again for embedded fields
func (b *builder) findBindTarget(named *types.Named, name string) (types.Object, error) {
strukt, isStruct := named.Underlying().(*types.Struct)
if isStruct {
// NOTE: a struct tag will override both methods and fields
// Bind to struct tag
found, err := b.findBindStructTagTarget(strukt, name)
if found != nil || err != nil {
return found, err
}
}

// Search for a method to bind to
var foundMethod types.Object
for i := 0; i < named.NumMethods(); i++ {
method := named.Method(i)
if !method.Exported() {
if !method.Exported() || !strings.EqualFold(method.Name(), name) {
continue
}

if !strings.EqualFold(method.Name(), name) {
continue
if foundMethod != nil {
return nil, errors.Errorf("found more than one matching method to bind for %s", name)
}

return method, nil
}

strukt, ok := named.Underlying().(*types.Struct)
if !ok {
return nil, fmt.Errorf("not a struct")
foundMethod = method
}
return b.findBindStructTarget(strukt, name)
}

func (b *builder) findBindStructTarget(strukt *types.Struct, name string) (types.Object, error) {
// struct tags have the highest priority
if b.Config.StructTag != "" {
var foundField *types.Var
for i := 0; i < strukt.NumFields(); i++ {
field := strukt.Field(i)
if !field.Exported() {
continue
}
tags := reflect.StructTag(strukt.Tag(i))
if val, ok := tags.Lookup(b.Config.StructTag); ok && equalFieldName(val, name) {
if foundField != nil {
return nil, errors.Errorf("tag %s is ambigious; multiple fields have the same tag value of %s", b.Config.StructTag, val)
}

foundField = field
}
// Search for a field to bind to
if isStruct {
foundField, err := b.findBindFieldTarget(strukt, name)
if err != nil {
return nil, err
}
if foundField != nil {

switch {
case foundField == nil && foundMethod == nil:
// Search embeds
return b.findBindEmbedsTarget(strukt, name)
case foundField == nil && foundMethod != nil:
// Bind to method
return foundMethod, nil
case foundField != nil && foundMethod == nil:
// Bind to field
return foundField, nil
case foundField != nil && foundMethod != nil:
// Error
return nil, errors.Errorf("found more than one way to bind for %s", name)
}
}

// Then matching field names
// Bind to method or don't bind at all
return foundMethod, nil
}

func (b *builder) findBindStructTagTarget(strukt *types.Struct, name string) (types.Object, error) {
if b.Config.StructTag == "" {
return nil, nil
}

var found types.Object
for i := 0; i < strukt.NumFields(); i++ {
field := strukt.Field(i)
if !field.Exported() {
if !field.Exported() || field.Embedded() {
continue
}
if equalFieldName(field.Name(), name) { // aqui!
return field, nil
tags := reflect.StructTag(strukt.Tag(i))
if val, ok := tags.Lookup(b.Config.StructTag); ok && equalFieldName(val, name) {
if found != nil {
return nil, errors.Errorf("tag %s is ambigious; multiple fields have the same tag value of %s", b.Config.StructTag, val)
}

found = field
}
}

// Then look in embedded structs
return found, nil
}

func (b *builder) findBindFieldTarget(strukt *types.Struct, name string) (types.Object, error) {
var found types.Object
for i := 0; i < strukt.NumFields(); i++ {
field := strukt.Field(i)
if !field.Exported() {
if !field.Exported() || !equalFieldName(field.Name(), name) {
continue
}

if !field.Anonymous() {
if found != nil {
return nil, errors.Errorf("found more than one matching field to bind for %s", name)
}

found = field
}

return found, nil
}

func (b *builder) findBindEmbedsTarget(strukt *types.Struct, name string) (types.Object, error) {
var found types.Object
for i := 0; i < strukt.NumFields(); i++ {
field := strukt.Field(i)
if !field.Embedded() {
continue
}

Expand All @@ -267,23 +303,39 @@ func (b *builder) findBindStructTarget(strukt *types.Struct, name string) (types
if err != nil {
return nil, err
}
if f != nil && found != nil {
return nil, errors.Errorf("found more than one way to bind for %s", name)
}
if f != nil {
return f, nil
found = f
}
case *types.Struct:
f, err := b.findBindStructTarget(fieldType, name)
f, err := b.findBindStructTagTarget(fieldType, name)
if err != nil {
return nil, err
}
if f != nil && found != nil {
return nil, errors.Errorf("found more than one way to bind for %s", name)
}
if f != nil {
found = f
continue
}

f, err = b.findBindFieldTarget(fieldType, name)
if err != nil {
return nil, err
}
if f != nil && found != nil {
return nil, errors.Errorf("found more than one way to bind for %s", name)
}
if f != nil {
return f, nil
found = f
}
default:
panic(fmt.Errorf("unknown embedded field type %T", field.Type()))
}
}

return nil, nil
return found, nil
}

func (f *Field) HasDirectives() bool {
Expand Down
18 changes: 9 additions & 9 deletions codegen/field_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,15 +40,15 @@ type Embed struct {
scope, err := parseScope(input, "test")
require.NoError(t, err)

std := scope.Lookup("Std").Type().Underlying().(*types.Struct)
anon := scope.Lookup("Anon").Type().Underlying().(*types.Struct)
tags := scope.Lookup("Tags").Type().Underlying().(*types.Struct)
amb := scope.Lookup("Amb").Type().Underlying().(*types.Struct)
embed := scope.Lookup("Embed").Type().Underlying().(*types.Struct)
std := scope.Lookup("Std").Type().(*types.Named)
anon := scope.Lookup("Anon").Type().(*types.Named)
tags := scope.Lookup("Tags").Type().(*types.Named)
amb := scope.Lookup("Amb").Type().(*types.Named)
embed := scope.Lookup("Embed").Type().(*types.Named)

tests := []struct {
Name string
Struct *types.Struct
Named *types.Named
Field string
Tag string
Expected string
Expand All @@ -65,13 +65,13 @@ type Embed struct {

for _, tt := range tests {
b := builder{Config: &config.Config{StructTag: tt.Tag}}
field, err := b.findBindStructTarget(tt.Struct, tt.Field)
target, err := b.findBindTarget(tt.Named, tt.Field)
if tt.ShouldError {
require.Nil(t, field, tt.Name)
require.Nil(t, target, tt.Name)
require.Error(t, err, tt.Name)
} else {
require.NoError(t, err, tt.Name)
require.Equal(t, tt.Expected, field.Name(), tt.Name)
require.Equal(t, tt.Expected, target.Name(), tt.Name)
}
}
}
Expand Down
30 changes: 30 additions & 0 deletions codegen/testserver/embedded.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
package testserver

// EmbeddedCase1 model
type EmbeddedCase1 struct {
Empty
*ExportedEmbeddedPointerAfterInterface
}

// Empty interface
type Empty interface{}

// ExportedEmbeddedPointerAfterInterface model
type ExportedEmbeddedPointerAfterInterface struct{}

// ExportedEmbeddedPointerExportedMethod method
func (*ExportedEmbeddedPointerAfterInterface) ExportedEmbeddedPointerExportedMethod() string {
return "ExportedEmbeddedPointerExportedMethodResponse"
}

// EmbeddedCase2 model
type EmbeddedCase2 struct {
*unexportedEmbeddedPointer
}

type unexportedEmbeddedPointer struct{}

// UnexportedEmbeddedPointerExportedMethod method
func (*unexportedEmbeddedPointer) UnexportedEmbeddedPointerExportedMethod() string {
return "UnexportedEmbeddedPointerExportedMethodResponse"
}
12 changes: 12 additions & 0 deletions codegen/testserver/embedded.graphql
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
extend type Query {
embeddedCase1: EmbeddedCase1
embeddedCase2: EmbeddedCase2
}

type EmbeddedCase1 @goModel(model:"testserver.EmbeddedCase1") {
exportedEmbeddedPointerExportedMethod: String!
}

type EmbeddedCase2 @goModel(model:"testserver.EmbeddedCase2") {
unexportedEmbeddedPointerExportedMethod: String!
}
46 changes: 46 additions & 0 deletions codegen/testserver/embedded_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
package testserver

import (
"context"
"testing"

"github.com/99designs/gqlgen/client"
"github.com/99designs/gqlgen/handler"
"github.com/stretchr/testify/require"
)

func TestEmbedded(t *testing.T) {
resolver := &Stub{}
resolver.QueryResolver.EmbeddedCase1 = func(ctx context.Context) (*EmbeddedCase1, error) {
return &EmbeddedCase1{}, nil
}
resolver.QueryResolver.EmbeddedCase2 = func(ctx context.Context) (*EmbeddedCase2, error) {
return &EmbeddedCase2{&unexportedEmbeddedPointer{}}, nil
}

c := client.New(handler.GraphQL(
NewExecutableSchema(Config{Resolvers: resolver}),
))

t.Run("embedded case 1", func(t *testing.T) {
var resp struct {
EmbeddedCase1 struct {
ExportedEmbeddedPointerExportedMethod string
}
}
err := c.Post(`query { embeddedCase1 { exportedEmbeddedPointerExportedMethod } }`, &resp)
require.NoError(t, err)
require.Equal(t, resp.EmbeddedCase1.ExportedEmbeddedPointerExportedMethod, "ExportedEmbeddedPointerExportedMethodResponse")
})

t.Run("embedded case 2", func(t *testing.T) {
var resp struct {
EmbeddedCase2 struct {
UnexportedEmbeddedPointerExportedMethod string
}
}
err := c.Post(`query { embeddedCase2 { unexportedEmbeddedPointerExportedMethod } }`, &resp)
require.NoError(t, err)
require.Equal(t, resp.EmbeddedCase2.UnexportedEmbeddedPointerExportedMethod, "UnexportedEmbeddedPointerExportedMethodResponse")
})
}
Loading

0 comments on commit a745dc7

Please sign in to comment.