Skip to content

Commit

Permalink
adds binding by passed tag
Browse files Browse the repository at this point in the history
  • Loading branch information
codyleyhan committed Aug 24, 2018
1 parent 77e6955 commit c584992
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 16 deletions.
1 change: 1 addition & 0 deletions codegen/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ type Config struct {
Model PackageConfig `yaml:"model"`
Resolver PackageConfig `yaml:"resolver,omitempty"`
Models TypeMap `yaml:"models,omitempty"`
StructTag string `yaml:"struct_tag,omitempty"`

FilePath string `yaml:"-"`

Expand Down
2 changes: 1 addition & 1 deletion codegen/input_build.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ func (cfg *Config) buildInputs(namedTypes NamedTypes, prog *loader.Program, impo
}
if def != nil {
input.Marshaler = buildInputMarshaler(typ, def)
bindErrs := bindObject(def.Type(), input, imports)
bindErrs := bindObject(def.Type(), input, imports, cfg.StructTag)
if len(bindErrs) > 0 {
return nil, bindErrs
}
Expand Down
2 changes: 1 addition & 1 deletion codegen/object_build.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ func (cfg *Config) buildObjects(types NamedTypes, prog *loader.Program, imports
return nil, err
}
if def != nil {
for _, bindErr := range bindObject(def.Type(), obj, imports) {
for _, bindErr := range bindObject(def.Type(), obj, imports, cfg.StructTag) {
log.Println(bindErr.Error())
log.Println(" Adding resolver method")
}
Expand Down
61 changes: 47 additions & 14 deletions codegen/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package codegen
import (
"fmt"
"go/types"
"reflect"
"regexp"
"strings"

Expand Down Expand Up @@ -104,19 +105,46 @@ func findMethod(typ *types.Named, name string) *types.Func {
return nil
}

func findField(typ *types.Struct, name string) *types.Var {
// findField attempts to match the name to a struct field with the following
// priorites:
// 1. If struct tag is passed then struct tag has highest priority
// 2. Field in an embedded struct
// 3. Actual Field name
func findField(typ *types.Struct, name, structTag string) (*types.Var, error) {
var foundField *types.Var

for i := 0; i < typ.NumFields(); i++ {
field := typ.Field(i)

if structTag != "" {
tags := reflect.StructTag(typ.Tag(i))
if val, ok := tags.Lookup(structTag); ok {
if strings.EqualFold(val, name) {
if foundField != nil {
return nil, errors.Errorf("tag %s is ambigious; multiple fields have the same tag value of %s", structTag, val)
}
}
}
}

if field.Anonymous() {
if named, ok := field.Type().(*types.Struct); ok {
if f := findField(named, name); f != nil {
return f
f, err := findField(named, name, structTag)
if err != nil {
return nil, err
}
if f != nil && foundField == nil {
foundField = f
}
}

if named, ok := field.Type().Underlying().(*types.Struct); ok {
if f := findField(named, name); f != nil {
return f
f, err := findField(named, name, structTag)
if err != nil {
return nil, err
}
if f != nil && foundField == nil {
foundField = f
}
}
}
Expand All @@ -125,11 +153,16 @@ func findField(typ *types.Struct, name string) *types.Var {
continue
}

if strings.EqualFold(field.Name(), name) {
return field
if strings.EqualFold(field.Name(), name) && foundField == nil {
foundField = field
}
}
return nil

if foundField == nil {
return nil, fmt.Errorf("no field named %s", name)
}

return foundField, nil
}

type BindError struct {
Expand Down Expand Up @@ -161,7 +194,7 @@ func (b BindErrors) Error() string {
return strings.Join(errs, "\n\n")
}

func bindObject(t types.Type, object *Object, imports *Imports) BindErrors {
func bindObject(t types.Type, object *Object, imports *Imports, structTag string) BindErrors {
var errs BindErrors
for i := range object.Fields {
field := &object.Fields[i]
Expand All @@ -177,7 +210,7 @@ func bindObject(t types.Type, object *Object, imports *Imports) BindErrors {
}

// otherwise try binding to a var
varErr := bindVar(imports, t, field)
varErr := bindVar(imports, t, field, structTag)

if varErr != nil {
errs = append(errs, BindError{
Expand Down Expand Up @@ -231,7 +264,7 @@ func bindMethod(imports *Imports, t types.Type, field *Field) error {
return nil
}

func bindVar(imports *Imports, t types.Type, field *Field) error {
func bindVar(imports *Imports, t types.Type, field *Field, structTag string) error {
underlying, ok := t.Underlying().(*types.Struct)
if !ok {
return fmt.Errorf("not a struct")
Expand All @@ -241,9 +274,9 @@ func bindVar(imports *Imports, t types.Type, field *Field) error {
if field.GoFieldName != "" {
goName = field.GoFieldName
}
structField := findField(underlying, goName)
if structField == nil {
return fmt.Errorf("no field named %s", field.GQLName)
structField, err := findField(underlying, goName, structTag)
if err != nil {
return err
}

if err := validateTypeBinding(imports, field, structField.Type()); err != nil {
Expand Down

0 comments on commit c584992

Please sign in to comment.