diff --git a/codegen/config.go b/codegen/config.go index 0ba1f65dabc..f5fc8abd348 100644 --- a/codegen/config.go +++ b/codegen/config.go @@ -64,6 +64,7 @@ type Config struct { Model PackageConfig `yaml:"model"` Resolver PackageConfig `yaml:"resolver,omitempty"` Models TypeMap `yaml:"models,omitempty"` + StructTag string `yaml:"struct_tag,omitempty"` FilePath string `yaml:"-"` diff --git a/codegen/input_build.go b/codegen/input_build.go index c333201522d..1059601a3c5 100644 --- a/codegen/input_build.go +++ b/codegen/input_build.go @@ -27,7 +27,7 @@ func (cfg *Config) buildInputs(namedTypes NamedTypes, prog *loader.Program, impo } if def != nil { input.Marshaler = buildInputMarshaler(typ, def) - bindErrs := bindObject(def.Type(), input, imports) + bindErrs := bindObject(def.Type(), input, imports, cfg.StructTag) if len(bindErrs) > 0 { return nil, bindErrs } diff --git a/codegen/object_build.go b/codegen/object_build.go index 95602a9138d..686b2bfe2a4 100644 --- a/codegen/object_build.go +++ b/codegen/object_build.go @@ -28,7 +28,7 @@ func (cfg *Config) buildObjects(types NamedTypes, prog *loader.Program, imports return nil, err } if def != nil { - for _, bindErr := range bindObject(def.Type(), obj, imports) { + for _, bindErr := range bindObject(def.Type(), obj, imports, cfg.StructTag) { log.Println(bindErr.Error()) log.Println(" Adding resolver method") } diff --git a/codegen/util.go b/codegen/util.go index fae94adeadd..0e880ec6601 100644 --- a/codegen/util.go +++ b/codegen/util.go @@ -3,6 +3,7 @@ package codegen import ( "fmt" "go/types" + "reflect" "regexp" "strings" @@ -104,19 +105,46 @@ func findMethod(typ *types.Named, name string) *types.Func { return nil } -func findField(typ *types.Struct, name string) *types.Var { +// findField attempts to match the name to a struct field with the following +// priorites: +// 1. If struct tag is passed then struct tag has highest priority +// 2. Field in an embedded struct +// 3. Actual Field name +func findField(typ *types.Struct, name, structTag string) (*types.Var, error) { + var foundField *types.Var + for i := 0; i < typ.NumFields(); i++ { field := typ.Field(i) + + if structTag != "" { + tags := reflect.StructTag(typ.Tag(i)) + if val, ok := tags.Lookup(structTag); ok { + if strings.EqualFold(val, name) { + if foundField != nil { + return nil, errors.Errorf("tag %s is ambigious; multiple fields have the same tag value of %s", structTag, val) + } + } + } + } + if field.Anonymous() { if named, ok := field.Type().(*types.Struct); ok { - if f := findField(named, name); f != nil { - return f + f, err := findField(named, name, structTag) + if err != nil { + return nil, err + } + if f != nil && foundField == nil { + foundField = f } } if named, ok := field.Type().Underlying().(*types.Struct); ok { - if f := findField(named, name); f != nil { - return f + f, err := findField(named, name, structTag) + if err != nil { + return nil, err + } + if f != nil && foundField == nil { + foundField = f } } } @@ -125,11 +153,16 @@ func findField(typ *types.Struct, name string) *types.Var { continue } - if strings.EqualFold(field.Name(), name) { - return field + if strings.EqualFold(field.Name(), name) && foundField == nil { + foundField = field } } - return nil + + if foundField == nil { + return nil, fmt.Errorf("no field named %s", name) + } + + return foundField, nil } type BindError struct { @@ -161,7 +194,7 @@ func (b BindErrors) Error() string { return strings.Join(errs, "\n\n") } -func bindObject(t types.Type, object *Object, imports *Imports) BindErrors { +func bindObject(t types.Type, object *Object, imports *Imports, structTag string) BindErrors { var errs BindErrors for i := range object.Fields { field := &object.Fields[i] @@ -177,7 +210,7 @@ func bindObject(t types.Type, object *Object, imports *Imports) BindErrors { } // otherwise try binding to a var - varErr := bindVar(imports, t, field) + varErr := bindVar(imports, t, field, structTag) if varErr != nil { errs = append(errs, BindError{ @@ -231,7 +264,7 @@ func bindMethod(imports *Imports, t types.Type, field *Field) error { return nil } -func bindVar(imports *Imports, t types.Type, field *Field) error { +func bindVar(imports *Imports, t types.Type, field *Field, structTag string) error { underlying, ok := t.Underlying().(*types.Struct) if !ok { return fmt.Errorf("not a struct") @@ -241,9 +274,9 @@ func bindVar(imports *Imports, t types.Type, field *Field) error { if field.GoFieldName != "" { goName = field.GoFieldName } - structField := findField(underlying, goName) - if structField == nil { - return fmt.Errorf("no field named %s", field.GQLName) + structField, err := findField(underlying, goName, structTag) + if err != nil { + return err } if err := validateTypeBinding(imports, field, structField.Type()); err != nil {