From 146c65380cec9e5f5bf59641d84faafecb31c07b Mon Sep 17 00:00:00 2001 From: Adam Scarr Date: Fri, 16 Feb 2018 19:35:28 +1100 Subject: [PATCH] generate input object unpackers --- codegen/build.go | 2 + codegen/input_build.go | 64 +++++++++++++++++++++++++ codegen/object.go | 20 ++++---- codegen/object_build.go | 100 +--------------------------------------- codegen/type.go | 4 ++ codegen/util.go | 85 ++++++++++++++++++++++++++++++++++ graphql/unmarshal.go | 25 ---------- templates.go | 2 +- templates/args.go | 10 +--- templates/file.go | 4 ++ templates/input.go | 26 +++++++++++ templates/object.go | 2 +- templates/templates.go | 4 +- 13 files changed, 202 insertions(+), 146 deletions(-) create mode 100644 codegen/input_build.go delete mode 100644 graphql/unmarshal.go create mode 100644 templates/input.go diff --git a/codegen/build.go b/codegen/build.go index 0ffd341adae..797c7a58596 100644 --- a/codegen/build.go +++ b/codegen/build.go @@ -12,6 +12,7 @@ import ( type Build struct { PackageName string Objects Objects + Inputs Objects Interfaces []*Interface Imports Imports QueryRoot *Object @@ -35,6 +36,7 @@ func Bind(schema *schema.Schema, userTypes map[string]string, destDir string) (* PackageName: filepath.Base(destDir), Objects: buildObjects(namedTypes, schema, prog), Interfaces: buildInterfaces(namedTypes, schema), + Inputs: buildInputs(namedTypes, schema, prog), Imports: imports, } diff --git a/codegen/input_build.go b/codegen/input_build.go new file mode 100644 index 00000000000..344c8a28210 --- /dev/null +++ b/codegen/input_build.go @@ -0,0 +1,64 @@ +package codegen + +import ( + "go/types" + "sort" + "strings" + + "github.com/vektah/gqlgen/neelance/schema" + "golang.org/x/tools/go/loader" +) + +func buildInputs(namedTypes NamedTypes, s *schema.Schema, prog *loader.Program) Objects { + var inputs Objects + + for _, typ := range s.Types { + switch typ := typ.(type) { + case *schema.InputObject: + input := buildInput(namedTypes, typ) + + if def := findGoType(prog, input.Package, input.GoType); def != nil { + input.Marshaler = buildInputMarshaler(typ, def) + bindObject(def.Type(), input) + } + + inputs = append(inputs, input) + } + } + + sort.Slice(inputs, func(i, j int) bool { + return strings.Compare(inputs[i].GQLType, inputs[j].GQLType) == -1 + }) + + return inputs +} + +func buildInput(types NamedTypes, typ *schema.InputObject) *Object { + obj := &Object{NamedType: types[typ.TypeName()]} + + for _, field := range typ.Values { + obj.Fields = append(obj.Fields, Field{ + GQLName: field.Name.Name, + Type: types.getType(field.Type), + Object: obj, + }) + } + return obj +} + +// if user has implemented an UnmarshalGQL method on the input type manually, use it +// otherwise we will generate one. +func buildInputMarshaler(typ *schema.InputObject, def types.Object) *Ref { + switch def := def.(type) { + case *types.TypeName: + namedType := def.Type().(*types.Named) + for i := 0; i < namedType.NumMethods(); i++ { + method := namedType.Method(i) + if method.Name() == "UnmarshalGQL" { + return nil + } + } + } + + return &Ref{GoType: typ.Name} +} diff --git a/codegen/object.go b/codegen/object.go index e1069f9ff6a..64a7d1858fc 100644 --- a/codegen/object.go +++ b/codegen/object.go @@ -35,6 +35,8 @@ type FieldArgument struct { GQLName string // The name of the argument in graphql } +type Objects []*Object + func (o *Object) GetField(name string) *Field { for i, field := range o.Fields { if strings.EqualFold(field.GQLName, name) { @@ -152,21 +154,21 @@ func (f *Field) doWriteJson(res string, val string, remainingMods []string, isPt } } +func (os Objects) ByName(name string) *Object { + for i, o := range os { + if strings.EqualFold(o.GQLType, name) { + return os[i] + } + } + return nil +} + func tpl(tpl string, vars map[string]interface{}) string { b := &bytes.Buffer{} template.Must(template.New("inline").Parse(tpl)).Execute(b, vars) return b.String() } -func ucFirst(s string) string { - if s == "" { - return "" - } - r := []rune(s) - r[0] = unicode.ToUpper(r[0]) - return string(r) -} - func lcFirst(s string) string { if s == "" { return "" diff --git a/codegen/object_build.go b/codegen/object_build.go index e96fa193432..a3812c51d99 100644 --- a/codegen/object_build.go +++ b/codegen/object_build.go @@ -1,9 +1,6 @@ package codegen import ( - "fmt" - "go/types" - "os" "sort" "strings" @@ -11,8 +8,6 @@ import ( "golang.org/x/tools/go/loader" ) -type Objects []*Object - func buildObjects(types NamedTypes, s *schema.Schema, prog *loader.Program) Objects { var objects Objects @@ -22,7 +17,7 @@ func buildObjects(types NamedTypes, s *schema.Schema, prog *loader.Program) Obje obj := buildObject(types, typ) if def := findGoType(prog, obj.Package, obj.GoType); def != nil { - findBindTargets(def.Type(), obj) + bindObject(def.Type(), obj) } objects = append(objects, obj) @@ -44,15 +39,6 @@ func buildObjects(types NamedTypes, s *schema.Schema, prog *loader.Program) Obje return objects } -func (os Objects) ByName(name string) *Object { - for i, o := range os { - if strings.EqualFold(o.GQLType, name) { - return os[i] - } - } - return nil -} - func buildObject(types NamedTypes, typ *schema.Object) *Object { obj := &Object{NamedType: types[typ.TypeName()]} @@ -78,87 +64,3 @@ func buildObject(types NamedTypes, typ *schema.Object) *Object { } return obj } - -func findBindTargets(t types.Type, object *Object) bool { - switch t := t.(type) { - case *types.Named: - for i := 0; i < t.NumMethods(); i++ { - method := t.Method(i) - if !method.Exported() { - continue - } - - if methodField := object.GetField(method.Name()); methodField != nil { - methodField.GoMethodName = "it." + method.Name() - sig := method.Type().(*types.Signature) - - methodField.Type.Modifiers = modifiersFromGoType(sig.Results().At(0).Type()) - - // check arg order matches code, not gql - - var newArgs []FieldArgument - l2: - for j := 0; j < sig.Params().Len(); j++ { - param := sig.Params().At(j) - for _, oldArg := range methodField.Args { - if strings.EqualFold(oldArg.GQLName, param.Name()) { - oldArg.Type.Modifiers = modifiersFromGoType(param.Type()) - newArgs = append(newArgs, oldArg) - continue l2 - } - } - fmt.Fprintln(os.Stderr, "cannot match argument "+param.Name()+" to any argument in "+t.String()) - } - methodField.Args = newArgs - - if sig.Results().Len() == 1 { - methodField.NoErr = true - } else if sig.Results().Len() != 2 { - fmt.Fprintf(os.Stderr, "weird number of results on %s. expected either (result), or (result, error)\n", method.Name()) - } - } - } - - findBindTargets(t.Underlying(), object) - return true - - case *types.Struct: - for i := 0; i < t.NumFields(); i++ { - field := t.Field(i) - // Todo: struct tags, name and - at least - - if !field.Exported() { - continue - } - - // Todo: check for type matches before binding too? - if objectField := object.GetField(field.Name()); objectField != nil { - objectField.GoVarName = "it." + field.Name() - objectField.Type.Modifiers = modifiersFromGoType(field.Type()) - } - } - t.Underlying() - return true - } - - return false -} - -func modifiersFromGoType(t types.Type) []string { - var modifiers []string - for { - switch val := t.(type) { - case *types.Pointer: - modifiers = append(modifiers, modPtr) - t = val.Elem() - case *types.Array: - modifiers = append(modifiers, modList) - t = val.Elem() - case *types.Slice: - modifiers = append(modifiers, modList) - t = val.Elem() - default: - return modifiers - } - } -} diff --git a/codegen/type.go b/codegen/type.go index a927f27e07e..507d2a4dca0 100644 --- a/codegen/type.go +++ b/codegen/type.go @@ -54,6 +54,10 @@ func (t Type) IsSlice() bool { return len(t.Modifiers) > 0 && t.Modifiers[0] == modList } +func (t NamedType) IsMarshaled() bool { + return t.Marshaler != nil +} + func (t Type) Unmarshal(result, raw string) string { if t.Marshaler != nil { return result + ", err := " + t.Marshaler.pkgDot() + "Unmarshal" + t.Marshaler.GoType + "(" + raw + ")" diff --git a/codegen/util.go b/codegen/util.go index 661d20fa513..36969b32083 100644 --- a/codegen/util.go +++ b/codegen/util.go @@ -4,6 +4,7 @@ import ( "fmt" "go/types" "os" + "strings" "golang.org/x/tools/go/loader" ) @@ -45,3 +46,87 @@ func isMethod(t types.Object) bool { return f.Type().(*types.Signature).Recv() != nil } + +func bindObject(t types.Type, object *Object) bool { + switch t := t.(type) { + case *types.Named: + for i := 0; i < t.NumMethods(); i++ { + method := t.Method(i) + if !method.Exported() { + continue + } + + if methodField := object.GetField(method.Name()); methodField != nil { + methodField.GoMethodName = "it." + method.Name() + sig := method.Type().(*types.Signature) + + methodField.Type.Modifiers = modifiersFromGoType(sig.Results().At(0).Type()) + + // check arg order matches code, not gql + + var newArgs []FieldArgument + l2: + for j := 0; j < sig.Params().Len(); j++ { + param := sig.Params().At(j) + for _, oldArg := range methodField.Args { + if strings.EqualFold(oldArg.GQLName, param.Name()) { + oldArg.Type.Modifiers = modifiersFromGoType(param.Type()) + newArgs = append(newArgs, oldArg) + continue l2 + } + } + fmt.Fprintln(os.Stderr, "cannot match argument "+param.Name()+" to any argument in "+t.String()) + } + methodField.Args = newArgs + + if sig.Results().Len() == 1 { + methodField.NoErr = true + } else if sig.Results().Len() != 2 { + fmt.Fprintf(os.Stderr, "weird number of results on %s. expected either (result), or (result, error)\n", method.Name()) + } + } + } + + bindObject(t.Underlying(), object) + return true + + case *types.Struct: + for i := 0; i < t.NumFields(); i++ { + field := t.Field(i) + // Todo: struct tags, name and - at least + + if !field.Exported() { + continue + } + + // Todo: check for type matches before binding too? + if objectField := object.GetField(field.Name()); objectField != nil { + objectField.GoVarName = "it." + field.Name() + objectField.Type.Modifiers = modifiersFromGoType(field.Type()) + } + } + t.Underlying() + return true + } + + return false +} + +func modifiersFromGoType(t types.Type) []string { + var modifiers []string + for { + switch val := t.(type) { + case *types.Pointer: + modifiers = append(modifiers, modPtr) + t = val.Elem() + case *types.Array: + modifiers = append(modifiers, modList) + t = val.Elem() + case *types.Slice: + modifiers = append(modifiers, modList) + t = val.Elem() + default: + return modifiers + } + } +} diff --git a/graphql/unmarshal.go b/graphql/unmarshal.go deleted file mode 100644 index 82a35b1ad8a..00000000000 --- a/graphql/unmarshal.go +++ /dev/null @@ -1,25 +0,0 @@ -package graphql - -import ( - "reflect" - - "github.com/mitchellh/mapstructure" -) - -func UnmarshalComplexArg(result interface{}, data interface{}) error { - decoder, err := mapstructure.NewDecoder(&mapstructure.DecoderConfig{ - TagName: "graphql", - ErrorUnused: true, - Result: result, - DecodeHook: decodeHook, - }) - if err != nil { - panic(err) - } - - return decoder.Decode(data) -} - -func decodeHook(sourceType reflect.Type, destType reflect.Type, value interface{}) (interface{}, error) { - return value, nil -} diff --git a/templates.go b/templates.go index a64cc060089..b16faf33728 100644 --- a/templates.go +++ b/templates.go @@ -15,7 +15,7 @@ func runTemplate(e *codegen.Build) (*bytes.Buffer, error) { "ucFirst": ucFirst, "lcFirst": lcFirst, "quote": strconv.Quote, - }).Parse(templates.String()) + }).Parse(templates.All) if err != nil { return nil, err } diff --git a/templates/args.go b/templates/args.go index c72694888ed..4e2b365425a 100644 --- a/templates/args.go +++ b/templates/args.go @@ -1,6 +1,6 @@ package templates -var argsTpl = ` +const argsTpl = ` {{- define "args" }} {{- range $i, $arg := . }} var arg{{$i}} {{$arg.Signature }} @@ -13,7 +13,7 @@ var argsTpl = ` arg{{$i}} = tmp.({{$arg.GoType}}) {{- end }} } - {{- else if $arg.IsScalar }} + {{- else}} if tmp, ok := field.Args[{{$arg.GQLName|quote}}]; ok { {{$arg.Unmarshal "tmp2" "tmp" }} if err != nil { @@ -22,12 +22,6 @@ var argsTpl = ` } arg{{$i}} = {{if $arg.Type.IsPtr}}&{{end}}tmp2 } - {{- else }} - err := unpackComplexArg(&arg{{$i}}, field.Args[{{$arg.GQLName|quote}}]) - if err != nil { - ec.Error(err) - continue - } {{- end}} {{- end }} {{- end }} diff --git a/templates/file.go b/templates/file.go index 6dcbf4d4380..d893453df84 100644 --- a/templates/file.go +++ b/templates/file.go @@ -86,6 +86,10 @@ type executionContext struct { {{ template "interface" $interface }} {{- end }} +{{- range $input := .Inputs }} + {{ template "input" $input }} +{{- end }} + var parsedSchema = schema.MustParse({{.SchemaRaw|quote}}) func (ec *executionContext) introspectSchema() *introspection.Schema { diff --git a/templates/input.go b/templates/input.go new file mode 100644 index 00000000000..6f6f81361d5 --- /dev/null +++ b/templates/input.go @@ -0,0 +1,26 @@ +package templates + +const inputTpl = ` +{{- define "input" }} + {{- if .IsMarshaled }} + func Unmarshal{{ .GQLType }}(v interface{}) ({{.FullName}}, error) { + var it {{.FullName}} + + for k, v := range v.(map[string]interface{}) { + switch k { + {{- range $field := .Fields }} + case {{$field.GQLName|quote}}: + {{$field.Unmarshal "val" "v" }} + if err != nil { + return it, err + } + {{$field.GoVarName}} = {{if $field.Type.IsPtr}}&{{end}}val + {{- end }} + } + } + + return it, nil + } + {{- end }} +{{- end }} +` diff --git a/templates/object.go b/templates/object.go index 24be6f000c8..a6c217f6f52 100644 --- a/templates/object.go +++ b/templates/object.go @@ -16,7 +16,7 @@ func (ec *executionContext) _{{$object.GQLType|lcFirst}}(sel []query.Selection, switch field.Name { case "__typename": - out.Values[i] = jsonw.String({{$object.GQLType|quote}}) + out.Values[i] = graphql.String({{$object.GQLType|quote}}) {{- range $field := $object.Fields }} case "{{$field.GQLName}}": {{- template "args" $field.Args }} diff --git a/templates/templates.go b/templates/templates.go index aa28ccff401..23ee30363e7 100644 --- a/templates/templates.go +++ b/templates/templates.go @@ -1,5 +1,3 @@ package templates -func String() string { - return argsTpl + fileTpl + interfaceTpl + objectTpl -} +const All = argsTpl + fileTpl + interfaceTpl + objectTpl + inputTpl