Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
Andri Oskarsson committed Aug 27, 2018
2 parents 076f9ea + f6a733a commit 2488e1b
Show file tree
Hide file tree
Showing 11 changed files with 203 additions and 72 deletions.
26 changes: 19 additions & 7 deletions Gopkg.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

40 changes: 0 additions & 40 deletions cmd/gen.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import (
"github.com/99designs/gqlgen/codegen"
"github.com/pkg/errors"
"github.com/spf13/cobra"
"gopkg.in/yaml.v2"
)

func init() {
Expand Down Expand Up @@ -39,7 +38,6 @@ var genCmd = &cobra.Command{
}

// overwrite by commandline options
var emitYamlGuidance bool
if schemaFilename != "" {
config.SchemaFilename = schemaFilename
}
Expand All @@ -55,10 +53,6 @@ var genCmd = &cobra.Command{
if modelPackageName != "" {
config.Model.Package = modelPackageName
}
if typemap != "" {
config.Models = loadModelMap()
emitYamlGuidance = true
}

schemaRaw, err := ioutil.ReadFile(config.SchemaFilename)
if err != nil {
Expand All @@ -72,44 +66,10 @@ var genCmd = &cobra.Command{
os.Exit(1)
}

if emitYamlGuidance {
var b []byte
b, err = yaml.Marshal(config)
if err != nil {
fmt.Fprintln(os.Stderr, "unable to marshal yaml: "+err.Error())
os.Exit(1)
}

fmt.Fprintf(os.Stderr, "DEPRECATION WARNING: we are moving away from the json typemap, instead create a gqlgen.yml with the following content:\n\n%s\n", string(b))
}

err = codegen.Generate(*config)
if err != nil {
fmt.Fprintln(os.Stderr, err.Error())
os.Exit(2)
}
},
}

func loadModelMap() codegen.TypeMap {
var goTypes map[string]string
b, err := ioutil.ReadFile(typemap)
if err != nil {
fmt.Fprintln(os.Stderr, "unable to open typemap: "+err.Error())
return nil
}

if err = yaml.Unmarshal(b, &goTypes); err != nil {
fmt.Fprintln(os.Stderr, "unable to parse typemap: "+err.Error())
os.Exit(1)
}

typeMap := make(codegen.TypeMap)
for typeName, entityPath := range goTypes {
typeMap[typeName] = codegen.TypeMapEntry{
Model: entityPath,
}
}

return typeMap
}
3 changes: 0 additions & 3 deletions cmd/init.go
Original file line number Diff line number Diff line change
Expand Up @@ -141,9 +141,6 @@ func initConfig() *codegen.Config {
if modelPackageName != "" {
config.Model.Package = modelPackageName
}
if typemap != "" {
config.Models = loadModelMap()
}

var buf bytes.Buffer
buf.WriteString(strings.TrimSpace(configComment))
Expand Down
2 changes: 0 additions & 2 deletions cmd/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ var verbose bool
var output string
var models string
var schemaFilename string
var typemap string
var packageName string
var modelPackageName string
var serverFilename string
Expand All @@ -29,7 +28,6 @@ func init() {
rootCmd.PersistentFlags().StringVar(&output, "out", "", "the file to write to")
rootCmd.PersistentFlags().StringVar(&models, "models", "", "the file to write the models to")
rootCmd.PersistentFlags().StringVar(&schemaFilename, "schema", "", "the graphql schema to generate types from")
rootCmd.PersistentFlags().StringVar(&typemap, "typemap", "", "a json map going from graphql to golang types")
rootCmd.PersistentFlags().StringVar(&packageName, "package", "", "the package name")
rootCmd.PersistentFlags().StringVar(&modelPackageName, "modelpackage", "", "the package name to use for models")
}
Expand Down
1 change: 1 addition & 0 deletions codegen/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ type Config struct {
Model PackageConfig `yaml:"model"`
Resolver PackageConfig `yaml:"resolver,omitempty"`
Models TypeMap `yaml:"models,omitempty"`
StructTag string `yaml:"struct_tag,omitempty"`

FilePath string `yaml:"-"`

Expand Down
2 changes: 1 addition & 1 deletion codegen/input_build.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ func (cfg *Config) buildInputs(namedTypes NamedTypes, prog *loader.Program, impo
}
if def != nil {
input.Marshaler = buildInputMarshaler(typ, def)
bindErrs := bindObject(def.Type(), input, imports)
bindErrs := bindObject(def.Type(), input, imports, cfg.StructTag)
if len(bindErrs) > 0 {
return nil, bindErrs
}
Expand Down
2 changes: 1 addition & 1 deletion codegen/object_build.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ func (cfg *Config) buildObjects(types NamedTypes, prog *loader.Program, imports
return nil, err
}
if def != nil {
for _, bindErr := range bindObject(def.Type(), obj, imports) {
for _, bindErr := range bindObject(def.Type(), obj, imports, cfg.StructTag) {
log.Println(bindErr.Error())
log.Println(" Adding resolver method")
}
Expand Down
65 changes: 51 additions & 14 deletions codegen/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package codegen
import (
"fmt"
"go/types"
"reflect"
"regexp"
"strings"

Expand Down Expand Up @@ -104,19 +105,50 @@ func findMethod(typ *types.Named, name string) *types.Func {
return nil
}

func findField(typ *types.Struct, name string) *types.Var {
// findField attempts to match the name to a struct field with the following
// priorites:
// 1. If struct tag is passed then struct tag has highest priority
// 2. Field in an embedded struct
// 3. Actual Field name
func findField(typ *types.Struct, name, structTag string) (*types.Var, error) {
var foundField *types.Var
foundFieldWasTag := false

for i := 0; i < typ.NumFields(); i++ {
field := typ.Field(i)

if structTag != "" {
tags := reflect.StructTag(typ.Tag(i))
if val, ok := tags.Lookup(structTag); ok {
if strings.EqualFold(val, name) {
if foundField != nil && foundFieldWasTag {
return nil, errors.Errorf("tag %s is ambigious; multiple fields have the same tag value of %s", structTag, val)
}

foundField = field
foundFieldWasTag = true
}
}
}

if field.Anonymous() {
if named, ok := field.Type().(*types.Struct); ok {
if f := findField(named, name); f != nil {
return f
f, err := findField(named, name, structTag)
if err != nil && !strings.HasPrefix(err.Error(), "no field named") {
return nil, err
}
if f != nil && foundField == nil {
foundField = f
}
}

if named, ok := field.Type().Underlying().(*types.Struct); ok {
if f := findField(named, name); f != nil {
return f
f, err := findField(named, name, structTag)
if err != nil && !strings.HasPrefix(err.Error(), "no field named") {
return nil, err
}
if f != nil && foundField == nil {
foundField = f
}
}
}
Expand All @@ -125,11 +157,16 @@ func findField(typ *types.Struct, name string) *types.Var {
continue
}

if strings.EqualFold(field.Name(), name) {
return field
if strings.EqualFold(field.Name(), name) && foundField == nil {
foundField = field
}
}
return nil

if foundField == nil {
return nil, fmt.Errorf("no field named %s", name)
}

return foundField, nil
}

type BindError struct {
Expand Down Expand Up @@ -161,7 +198,7 @@ func (b BindErrors) Error() string {
return strings.Join(errs, "\n\n")
}

func bindObject(t types.Type, object *Object, imports *Imports) BindErrors {
func bindObject(t types.Type, object *Object, imports *Imports, structTag string) BindErrors {
var errs BindErrors
for i := range object.Fields {
field := &object.Fields[i]
Expand All @@ -177,7 +214,7 @@ func bindObject(t types.Type, object *Object, imports *Imports) BindErrors {
}

// otherwise try binding to a var
varErr := bindVar(imports, t, field)
varErr := bindVar(imports, t, field, structTag)

if varErr != nil {
errs = append(errs, BindError{
Expand Down Expand Up @@ -231,7 +268,7 @@ func bindMethod(imports *Imports, t types.Type, field *Field) error {
return nil
}

func bindVar(imports *Imports, t types.Type, field *Field) error {
func bindVar(imports *Imports, t types.Type, field *Field, structTag string) error {
underlying, ok := t.Underlying().(*types.Struct)
if !ok {
return fmt.Errorf("not a struct")
Expand All @@ -241,9 +278,9 @@ func bindVar(imports *Imports, t types.Type, field *Field) error {
if field.GoFieldName != "" {
goName = field.GoFieldName
}
structField := findField(underlying, goName)
if structField == nil {
return fmt.Errorf("no field named %s", field.GQLName)
structField, err := findField(underlying, goName, structTag)
if err != nil {
return err
}

if err := validateTypeBinding(imports, field, structField.Type()); err != nil {
Expand Down
Loading

0 comments on commit 2488e1b

Please sign in to comment.