Skip to content

Commit

Permalink
generate input object unpackers
Browse files Browse the repository at this point in the history
  • Loading branch information
vektah committed Feb 16, 2018
1 parent d94cfb1 commit 146c653
Show file tree
Hide file tree
Showing 13 changed files with 202 additions and 146 deletions.
2 changes: 2 additions & 0 deletions codegen/build.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
type Build struct {
PackageName string
Objects Objects
Inputs Objects
Interfaces []*Interface
Imports Imports
QueryRoot *Object
Expand All @@ -35,6 +36,7 @@ func Bind(schema *schema.Schema, userTypes map[string]string, destDir string) (*
PackageName: filepath.Base(destDir),
Objects: buildObjects(namedTypes, schema, prog),
Interfaces: buildInterfaces(namedTypes, schema),
Inputs: buildInputs(namedTypes, schema, prog),
Imports: imports,
}

Expand Down
64 changes: 64 additions & 0 deletions codegen/input_build.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
package codegen

import (
"go/types"
"sort"
"strings"

"github.com/vektah/gqlgen/neelance/schema"
"golang.org/x/tools/go/loader"
)

func buildInputs(namedTypes NamedTypes, s *schema.Schema, prog *loader.Program) Objects {
var inputs Objects

for _, typ := range s.Types {
switch typ := typ.(type) {
case *schema.InputObject:
input := buildInput(namedTypes, typ)

if def := findGoType(prog, input.Package, input.GoType); def != nil {
input.Marshaler = buildInputMarshaler(typ, def)
bindObject(def.Type(), input)
}

inputs = append(inputs, input)
}
}

sort.Slice(inputs, func(i, j int) bool {
return strings.Compare(inputs[i].GQLType, inputs[j].GQLType) == -1
})

return inputs
}

func buildInput(types NamedTypes, typ *schema.InputObject) *Object {
obj := &Object{NamedType: types[typ.TypeName()]}

for _, field := range typ.Values {
obj.Fields = append(obj.Fields, Field{
GQLName: field.Name.Name,
Type: types.getType(field.Type),
Object: obj,
})
}
return obj
}

// if user has implemented an UnmarshalGQL method on the input type manually, use it
// otherwise we will generate one.
func buildInputMarshaler(typ *schema.InputObject, def types.Object) *Ref {
switch def := def.(type) {
case *types.TypeName:
namedType := def.Type().(*types.Named)
for i := 0; i < namedType.NumMethods(); i++ {
method := namedType.Method(i)
if method.Name() == "UnmarshalGQL" {
return nil
}
}
}

return &Ref{GoType: typ.Name}
}
20 changes: 11 additions & 9 deletions codegen/object.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ type FieldArgument struct {
GQLName string // The name of the argument in graphql
}

type Objects []*Object

func (o *Object) GetField(name string) *Field {
for i, field := range o.Fields {
if strings.EqualFold(field.GQLName, name) {
Expand Down Expand Up @@ -152,21 +154,21 @@ func (f *Field) doWriteJson(res string, val string, remainingMods []string, isPt
}
}

func (os Objects) ByName(name string) *Object {
for i, o := range os {
if strings.EqualFold(o.GQLType, name) {
return os[i]
}
}
return nil
}

func tpl(tpl string, vars map[string]interface{}) string {
b := &bytes.Buffer{}
template.Must(template.New("inline").Parse(tpl)).Execute(b, vars)
return b.String()
}

func ucFirst(s string) string {
if s == "" {
return ""
}
r := []rune(s)
r[0] = unicode.ToUpper(r[0])
return string(r)
}

func lcFirst(s string) string {
if s == "" {
return ""
Expand Down
100 changes: 1 addition & 99 deletions codegen/object_build.go
Original file line number Diff line number Diff line change
@@ -1,18 +1,13 @@
package codegen

import (
"fmt"
"go/types"
"os"
"sort"
"strings"

"github.com/vektah/gqlgen/neelance/schema"
"golang.org/x/tools/go/loader"
)

type Objects []*Object

func buildObjects(types NamedTypes, s *schema.Schema, prog *loader.Program) Objects {
var objects Objects

Expand All @@ -22,7 +17,7 @@ func buildObjects(types NamedTypes, s *schema.Schema, prog *loader.Program) Obje
obj := buildObject(types, typ)

if def := findGoType(prog, obj.Package, obj.GoType); def != nil {
findBindTargets(def.Type(), obj)
bindObject(def.Type(), obj)
}

objects = append(objects, obj)
Expand All @@ -44,15 +39,6 @@ func buildObjects(types NamedTypes, s *schema.Schema, prog *loader.Program) Obje
return objects
}

func (os Objects) ByName(name string) *Object {
for i, o := range os {
if strings.EqualFold(o.GQLType, name) {
return os[i]
}
}
return nil
}

func buildObject(types NamedTypes, typ *schema.Object) *Object {
obj := &Object{NamedType: types[typ.TypeName()]}

Expand All @@ -78,87 +64,3 @@ func buildObject(types NamedTypes, typ *schema.Object) *Object {
}
return obj
}

func findBindTargets(t types.Type, object *Object) bool {
switch t := t.(type) {
case *types.Named:
for i := 0; i < t.NumMethods(); i++ {
method := t.Method(i)
if !method.Exported() {
continue
}

if methodField := object.GetField(method.Name()); methodField != nil {
methodField.GoMethodName = "it." + method.Name()
sig := method.Type().(*types.Signature)

methodField.Type.Modifiers = modifiersFromGoType(sig.Results().At(0).Type())

// check arg order matches code, not gql

var newArgs []FieldArgument
l2:
for j := 0; j < sig.Params().Len(); j++ {
param := sig.Params().At(j)
for _, oldArg := range methodField.Args {
if strings.EqualFold(oldArg.GQLName, param.Name()) {
oldArg.Type.Modifiers = modifiersFromGoType(param.Type())
newArgs = append(newArgs, oldArg)
continue l2
}
}
fmt.Fprintln(os.Stderr, "cannot match argument "+param.Name()+" to any argument in "+t.String())
}
methodField.Args = newArgs

if sig.Results().Len() == 1 {
methodField.NoErr = true
} else if sig.Results().Len() != 2 {
fmt.Fprintf(os.Stderr, "weird number of results on %s. expected either (result), or (result, error)\n", method.Name())
}
}
}

findBindTargets(t.Underlying(), object)
return true

case *types.Struct:
for i := 0; i < t.NumFields(); i++ {
field := t.Field(i)
// Todo: struct tags, name and - at least

if !field.Exported() {
continue
}

// Todo: check for type matches before binding too?
if objectField := object.GetField(field.Name()); objectField != nil {
objectField.GoVarName = "it." + field.Name()
objectField.Type.Modifiers = modifiersFromGoType(field.Type())
}
}
t.Underlying()
return true
}

return false
}

func modifiersFromGoType(t types.Type) []string {
var modifiers []string
for {
switch val := t.(type) {
case *types.Pointer:
modifiers = append(modifiers, modPtr)
t = val.Elem()
case *types.Array:
modifiers = append(modifiers, modList)
t = val.Elem()
case *types.Slice:
modifiers = append(modifiers, modList)
t = val.Elem()
default:
return modifiers
}
}
}
4 changes: 4 additions & 0 deletions codegen/type.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,10 @@ func (t Type) IsSlice() bool {
return len(t.Modifiers) > 0 && t.Modifiers[0] == modList
}

func (t NamedType) IsMarshaled() bool {
return t.Marshaler != nil
}

func (t Type) Unmarshal(result, raw string) string {
if t.Marshaler != nil {
return result + ", err := " + t.Marshaler.pkgDot() + "Unmarshal" + t.Marshaler.GoType + "(" + raw + ")"
Expand Down
85 changes: 85 additions & 0 deletions codegen/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"fmt"
"go/types"
"os"
"strings"

"golang.org/x/tools/go/loader"
)
Expand Down Expand Up @@ -45,3 +46,87 @@ func isMethod(t types.Object) bool {

return f.Type().(*types.Signature).Recv() != nil
}

func bindObject(t types.Type, object *Object) bool {
switch t := t.(type) {
case *types.Named:
for i := 0; i < t.NumMethods(); i++ {
method := t.Method(i)
if !method.Exported() {
continue
}

if methodField := object.GetField(method.Name()); methodField != nil {
methodField.GoMethodName = "it." + method.Name()
sig := method.Type().(*types.Signature)

methodField.Type.Modifiers = modifiersFromGoType(sig.Results().At(0).Type())

// check arg order matches code, not gql

var newArgs []FieldArgument
l2:
for j := 0; j < sig.Params().Len(); j++ {
param := sig.Params().At(j)
for _, oldArg := range methodField.Args {
if strings.EqualFold(oldArg.GQLName, param.Name()) {
oldArg.Type.Modifiers = modifiersFromGoType(param.Type())
newArgs = append(newArgs, oldArg)
continue l2
}
}
fmt.Fprintln(os.Stderr, "cannot match argument "+param.Name()+" to any argument in "+t.String())
}
methodField.Args = newArgs

if sig.Results().Len() == 1 {
methodField.NoErr = true
} else if sig.Results().Len() != 2 {
fmt.Fprintf(os.Stderr, "weird number of results on %s. expected either (result), or (result, error)\n", method.Name())
}
}
}

bindObject(t.Underlying(), object)
return true

case *types.Struct:
for i := 0; i < t.NumFields(); i++ {
field := t.Field(i)
// Todo: struct tags, name and - at least

if !field.Exported() {
continue
}

// Todo: check for type matches before binding too?
if objectField := object.GetField(field.Name()); objectField != nil {
objectField.GoVarName = "it." + field.Name()
objectField.Type.Modifiers = modifiersFromGoType(field.Type())
}
}
t.Underlying()
return true
}

return false
}

func modifiersFromGoType(t types.Type) []string {
var modifiers []string
for {
switch val := t.(type) {
case *types.Pointer:
modifiers = append(modifiers, modPtr)
t = val.Elem()
case *types.Array:
modifiers = append(modifiers, modList)
t = val.Elem()
case *types.Slice:
modifiers = append(modifiers, modList)
t = val.Elem()
default:
return modifiers
}
}
}
25 changes: 0 additions & 25 deletions graphql/unmarshal.go

This file was deleted.

2 changes: 1 addition & 1 deletion templates.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ func runTemplate(e *codegen.Build) (*bytes.Buffer, error) {
"ucFirst": ucFirst,
"lcFirst": lcFirst,
"quote": strconv.Quote,
}).Parse(templates.String())
}).Parse(templates.All)
if err != nil {
return nil, err
}
Expand Down
Loading

0 comments on commit 146c653

Please sign in to comment.