Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes #843: Bind to embedded struct method or field #919

Merged
merged 1 commit into from
Nov 11, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
148 changes: 100 additions & 48 deletions codegen/field.go
Original file line number Diff line number Diff line change
Expand Up @@ -184,75 +184,111 @@ func (b *builder) bindField(obj *Object, f *Field) (errret error) {
}
}

// findField attempts to match the name to a struct field with the following
// priorites:
// 1. Any method with a matching name
// 2. Any Fields with a struct tag (see config.StructTag)
// 3. Any fields with a matching name
// 4. Same logic again for embedded fields
// findBindTarget attempts to match the name to a struct field or method
// with the following priorites:
// 1. Any Fields with a struct tag (see config.StructTag). Errors if more than one match is found
// 2. Any method or field with a matching name. Errors if more than one match is found
// 3. Same logic again for embedded fields
func (b *builder) findBindTarget(named *types.Named, name string) (types.Object, error) {
strukt, isStruct := named.Underlying().(*types.Struct)
if isStruct {
// NOTE: a struct tag will override both methods and fields
// Bind to struct tag
found, err := b.findBindStructTagTarget(strukt, name)
if found != nil || err != nil {
return found, err
}
}

// Search for a method to bind to
var foundMethod types.Object
for i := 0; i < named.NumMethods(); i++ {
method := named.Method(i)
if !method.Exported() {
if !method.Exported() || !strings.EqualFold(method.Name(), name) {
continue
}

if !strings.EqualFold(method.Name(), name) {
continue
if foundMethod != nil {
return nil, errors.Errorf("found more than one matching method to bind for %s", name)
}

return method, nil
}

strukt, ok := named.Underlying().(*types.Struct)
if !ok {
return nil, fmt.Errorf("not a struct")
foundMethod = method
}
return b.findBindStructTarget(strukt, name)
}

func (b *builder) findBindStructTarget(strukt *types.Struct, name string) (types.Object, error) {
// struct tags have the highest priority
if b.Config.StructTag != "" {
var foundField *types.Var
for i := 0; i < strukt.NumFields(); i++ {
field := strukt.Field(i)
if !field.Exported() {
continue
}
tags := reflect.StructTag(strukt.Tag(i))
if val, ok := tags.Lookup(b.Config.StructTag); ok && equalFieldName(val, name) {
if foundField != nil {
return nil, errors.Errorf("tag %s is ambigious; multiple fields have the same tag value of %s", b.Config.StructTag, val)
}

foundField = field
}
// Search for a field to bind to
if isStruct {
foundField, err := b.findBindFieldTarget(strukt, name)
if err != nil {
return nil, err
}
if foundField != nil {

switch {
case foundField == nil && foundMethod == nil:
// Search embeds
return b.findBindEmbedsTarget(strukt, name)
case foundField == nil && foundMethod != nil:
// Bind to method
return foundMethod, nil
case foundField != nil && foundMethod == nil:
// Bind to field
return foundField, nil
case foundField != nil && foundMethod != nil:
// Error
return nil, errors.Errorf("found more than one way to bind for %s", name)
}
}

// Then matching field names
// Bind to method or don't bind at all
return foundMethod, nil
}

func (b *builder) findBindStructTagTarget(strukt *types.Struct, name string) (types.Object, error) {
if b.Config.StructTag == "" {
return nil, nil
}

var found types.Object
for i := 0; i < strukt.NumFields(); i++ {
field := strukt.Field(i)
if !field.Exported() {
if !field.Exported() || field.Embedded() {
continue
}
if equalFieldName(field.Name(), name) { // aqui!
return field, nil
tags := reflect.StructTag(strukt.Tag(i))
if val, ok := tags.Lookup(b.Config.StructTag); ok && equalFieldName(val, name) {
if found != nil {
return nil, errors.Errorf("tag %s is ambigious; multiple fields have the same tag value of %s", b.Config.StructTag, val)
}

found = field
}
}

// Then look in embedded structs
return found, nil
}

func (b *builder) findBindFieldTarget(strukt *types.Struct, name string) (types.Object, error) {
var found types.Object
for i := 0; i < strukt.NumFields(); i++ {
field := strukt.Field(i)
if !field.Exported() {
if !field.Exported() || !equalFieldName(field.Name(), name) {
continue
}

if !field.Anonymous() {
if found != nil {
return nil, errors.Errorf("found more than one matching field to bind for %s", name)
}

found = field
}

return found, nil
}

func (b *builder) findBindEmbedsTarget(strukt *types.Struct, name string) (types.Object, error) {
var found types.Object
for i := 0; i < strukt.NumFields(); i++ {
field := strukt.Field(i)
if !field.Embedded() {
continue
}

Expand All @@ -267,23 +303,39 @@ func (b *builder) findBindStructTarget(strukt *types.Struct, name string) (types
if err != nil {
return nil, err
}
if f != nil && found != nil {
return nil, errors.Errorf("found more than one way to bind for %s", name)
}
if f != nil {
return f, nil
found = f
}
case *types.Struct:
f, err := b.findBindStructTarget(fieldType, name)
f, err := b.findBindStructTagTarget(fieldType, name)
if err != nil {
return nil, err
}
if f != nil && found != nil {
return nil, errors.Errorf("found more than one way to bind for %s", name)
}
if f != nil {
found = f
continue
}

f, err = b.findBindFieldTarget(fieldType, name)
if err != nil {
return nil, err
}
if f != nil && found != nil {
return nil, errors.Errorf("found more than one way to bind for %s", name)
}
if f != nil {
return f, nil
found = f
}
default:
panic(fmt.Errorf("unknown embedded field type %T", field.Type()))
}
}

return nil, nil
return found, nil
}

func (f *Field) HasDirectives() bool {
Expand Down
18 changes: 9 additions & 9 deletions codegen/field_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,15 +40,15 @@ type Embed struct {
scope, err := parseScope(input, "test")
require.NoError(t, err)

std := scope.Lookup("Std").Type().Underlying().(*types.Struct)
anon := scope.Lookup("Anon").Type().Underlying().(*types.Struct)
tags := scope.Lookup("Tags").Type().Underlying().(*types.Struct)
amb := scope.Lookup("Amb").Type().Underlying().(*types.Struct)
embed := scope.Lookup("Embed").Type().Underlying().(*types.Struct)
std := scope.Lookup("Std").Type().(*types.Named)
anon := scope.Lookup("Anon").Type().(*types.Named)
tags := scope.Lookup("Tags").Type().(*types.Named)
amb := scope.Lookup("Amb").Type().(*types.Named)
embed := scope.Lookup("Embed").Type().(*types.Named)

tests := []struct {
Name string
Struct *types.Struct
Named *types.Named
Field string
Tag string
Expected string
Expand All @@ -65,13 +65,13 @@ type Embed struct {

for _, tt := range tests {
b := builder{Config: &config.Config{StructTag: tt.Tag}}
field, err := b.findBindStructTarget(tt.Struct, tt.Field)
target, err := b.findBindTarget(tt.Named, tt.Field)
if tt.ShouldError {
require.Nil(t, field, tt.Name)
require.Nil(t, target, tt.Name)
require.Error(t, err, tt.Name)
} else {
require.NoError(t, err, tt.Name)
require.Equal(t, tt.Expected, field.Name(), tt.Name)
require.Equal(t, tt.Expected, target.Name(), tt.Name)
}
}
}
Expand Down
30 changes: 30 additions & 0 deletions codegen/testserver/embedded.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
package testserver

// EmbeddedCase1 model
type EmbeddedCase1 struct {
Empty
*ExportedEmbeddedPointerAfterInterface
}

// Empty interface
type Empty interface{}

// ExportedEmbeddedPointerAfterInterface model
type ExportedEmbeddedPointerAfterInterface struct{}

// ExportedEmbeddedPointerExportedMethod method
func (*ExportedEmbeddedPointerAfterInterface) ExportedEmbeddedPointerExportedMethod() string {
return "ExportedEmbeddedPointerExportedMethodResponse"
}

// EmbeddedCase2 model
type EmbeddedCase2 struct {
*unexportedEmbeddedPointer
}

type unexportedEmbeddedPointer struct{}

// UnexportedEmbeddedPointerExportedMethod method
func (*unexportedEmbeddedPointer) UnexportedEmbeddedPointerExportedMethod() string {
return "UnexportedEmbeddedPointerExportedMethodResponse"
}
12 changes: 12 additions & 0 deletions codegen/testserver/embedded.graphql
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
extend type Query {
embeddedCase1: EmbeddedCase1
embeddedCase2: EmbeddedCase2
}

type EmbeddedCase1 @goModel(model:"testserver.EmbeddedCase1") {
exportedEmbeddedPointerExportedMethod: String!
}

type EmbeddedCase2 @goModel(model:"testserver.EmbeddedCase2") {
unexportedEmbeddedPointerExportedMethod: String!
}
46 changes: 46 additions & 0 deletions codegen/testserver/embedded_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
package testserver

import (
"context"
"testing"

"github.com/99designs/gqlgen/client"
"github.com/99designs/gqlgen/handler"
"github.com/stretchr/testify/require"
)

func TestEmbedded(t *testing.T) {
resolver := &Stub{}
resolver.QueryResolver.EmbeddedCase1 = func(ctx context.Context) (*EmbeddedCase1, error) {
return &EmbeddedCase1{}, nil
}
resolver.QueryResolver.EmbeddedCase2 = func(ctx context.Context) (*EmbeddedCase2, error) {
return &EmbeddedCase2{&unexportedEmbeddedPointer{}}, nil
}

c := client.New(handler.GraphQL(
NewExecutableSchema(Config{Resolvers: resolver}),
))

t.Run("embedded case 1", func(t *testing.T) {
var resp struct {
EmbeddedCase1 struct {
ExportedEmbeddedPointerExportedMethod string
}
}
err := c.Post(`query { embeddedCase1 { exportedEmbeddedPointerExportedMethod } }`, &resp)
require.NoError(t, err)
require.Equal(t, resp.EmbeddedCase1.ExportedEmbeddedPointerExportedMethod, "ExportedEmbeddedPointerExportedMethodResponse")
})

t.Run("embedded case 2", func(t *testing.T) {
var resp struct {
EmbeddedCase2 struct {
UnexportedEmbeddedPointerExportedMethod string
}
}
err := c.Post(`query { embeddedCase2 { unexportedEmbeddedPointerExportedMethod } }`, &resp)
require.NoError(t, err)
require.Equal(t, resp.EmbeddedCase2.UnexportedEmbeddedPointerExportedMethod, "UnexportedEmbeddedPointerExportedMethodResponse")
})
}
Loading