Skip to content

Commit

Permalink
Merge pull request #10 from vektah/newtypes
Browse files Browse the repository at this point in the history
User defined custom types
  • Loading branch information
vektah authored Feb 16, 2018
2 parents 3e7d80d + 5d86eeb commit d0244d2
Show file tree
Hide file tree
Showing 43 changed files with 2,250 additions and 1,614 deletions.
4 changes: 4 additions & 0 deletions codegen/build.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
type Build struct {
PackageName string
Objects Objects
Inputs Objects
Interfaces []*Interface
Imports Imports
QueryRoot *Object
Expand All @@ -29,10 +30,13 @@ func Bind(schema *schema.Schema, userTypes map[string]string, destDir string) (*
return nil, err
}

bindTypes(imports, namedTypes, prog)

b := &Build{
PackageName: filepath.Base(destDir),
Objects: buildObjects(namedTypes, schema, prog),
Interfaces: buildInterfaces(namedTypes, schema),
Inputs: buildInputs(namedTypes, schema, prog),
Imports: imports,
}

Expand Down
5 changes: 1 addition & 4 deletions codegen/import_build.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,13 @@ func buildImports(types NamedTypes, destDir string) Imports {
{"io", "io"},
{"strconv", "strconv"},
{"time", "time"},
{"reflect", "reflect"},
{"strings", "strings"},
{"sync", "sync"},
{"mapstructure", "github.com/mitchellh/mapstructure"},
{"introspection", "github.com/vektah/gqlgen/neelance/introspection"},
{"errors", "github.com/vektah/gqlgen/neelance/errors"},
{"query", "github.com/vektah/gqlgen/neelance/query"},
{"schema", "github.com/vektah/gqlgen/neelance/schema"},
{"validation", "github.com/vektah/gqlgen/neelance/validation"},
{"jsonw", "github.com/vektah/gqlgen/jsonw"},
{"graphql", "github.com/vektah/gqlgen/graphql"},
}

for _, t := range types {
Expand Down
70 changes: 70 additions & 0 deletions codegen/input_build.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
package codegen

import (
"fmt"
"go/types"
"os"
"sort"
"strings"

"github.com/vektah/gqlgen/neelance/schema"
"golang.org/x/tools/go/loader"
)

func buildInputs(namedTypes NamedTypes, s *schema.Schema, prog *loader.Program) Objects {
var inputs Objects

for _, typ := range s.Types {
switch typ := typ.(type) {
case *schema.InputObject:
input := buildInput(namedTypes, typ)

def, err := findGoType(prog, input.Package, input.GoType)
if err != nil {
fmt.Fprintf(os.Stderr, err.Error())
}
if def != nil {
input.Marshaler = buildInputMarshaler(typ, def)
bindObject(def.Type(), input)
}

inputs = append(inputs, input)
}
}

sort.Slice(inputs, func(i, j int) bool {
return strings.Compare(inputs[i].GQLType, inputs[j].GQLType) == -1
})

return inputs
}

func buildInput(types NamedTypes, typ *schema.InputObject) *Object {
obj := &Object{NamedType: types[typ.TypeName()]}

for _, field := range typ.Values {
obj.Fields = append(obj.Fields, Field{
GQLName: field.Name.Name,
Type: types.getType(field.Type),
Object: obj,
})
}
return obj
}

// if user has implemented an UnmarshalGQL method on the input type manually, use it
// otherwise we will generate one.
func buildInputMarshaler(typ *schema.InputObject, def types.Object) *Ref {
switch def := def.(type) {
case *types.TypeName:
namedType := def.Type().(*types.Named)
for i := 0; i < namedType.NumMethods(); i++ {
method := namedType.Method(i)
if method.Name() == "UnmarshalGQL" {
return nil
}
}
}

return &Ref{GoType: typ.Name}
}
28 changes: 15 additions & 13 deletions codegen/object.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ type FieldArgument struct {
GQLName string // The name of the argument in graphql
}

type Objects []*Object

func (o *Object) GetField(name string) *Field {
for i, field := range o.Fields {
if strings.EqualFold(field.GQLName, name) {
Expand Down Expand Up @@ -105,7 +107,7 @@ func (f *Field) doWriteJson(res string, val string, remainingMods []string, isPt
case len(remainingMods) > 0 && remainingMods[0] == modPtr:
return tpl(`
if {{.val}} == nil {
{{.res}} = jsonw.Null
{{.res}} = graphql.Null
} else {
{{.next}}
}`, map[string]interface{}{
Expand All @@ -123,9 +125,9 @@ func (f *Field) doWriteJson(res string, val string, remainingMods []string, isPt
var index = "idx" + strconv.Itoa(depth)

return tpl(`
{{.arr}} := jsonw.Array{}
{{.arr}} := graphql.Array{}
for {{.index}} := range {{.val}} {
var {{.tmp}} jsonw.Writer
var {{.tmp}} graphql.Marshaler
{{.next}}
{{.arr}} = append({{.arr}}, {{.tmp}})
}
Expand All @@ -142,7 +144,7 @@ func (f *Field) doWriteJson(res string, val string, remainingMods []string, isPt
if isPtr {
val = "*" + val
}
return fmt.Sprintf("%s = jsonw.%s(%s)", res, ucFirst(f.GoType), val)
return f.Marshal(res, val)

default:
if !isPtr {
Expand All @@ -152,21 +154,21 @@ func (f *Field) doWriteJson(res string, val string, remainingMods []string, isPt
}
}

func (os Objects) ByName(name string) *Object {
for i, o := range os {
if strings.EqualFold(o.GQLType, name) {
return os[i]
}
}
return nil
}

func tpl(tpl string, vars map[string]interface{}) string {
b := &bytes.Buffer{}
template.Must(template.New("inline").Parse(tpl)).Execute(b, vars)
return b.String()
}

func ucFirst(s string) string {
if s == "" {
return ""
}
r := []rune(s)
r[0] = unicode.ToUpper(r[0])
return string(r)
}

func lcFirst(s string) string {
if s == "" {
return ""
Expand Down
146 changes: 8 additions & 138 deletions codegen/object_build.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package codegen

import (
"fmt"
"go/types"
"os"
"sort"
"strings"
Expand All @@ -11,16 +10,21 @@ import (
"golang.org/x/tools/go/loader"
)

type Objects []*Object

func buildObjects(types NamedTypes, s *schema.Schema, prog *loader.Program) Objects {
var objects Objects

for _, typ := range s.Types {
switch typ := typ.(type) {
case *schema.Object:
obj := buildObject(types, typ)
bindObject(prog, obj)

def, err := findGoType(prog, obj.Package, obj.GoType)
if err != nil {
fmt.Fprintf(os.Stderr, err.Error())
}
if def != nil {
bindObject(def.Type(), obj)
}

objects = append(objects, obj)
}
Expand All @@ -41,15 +45,6 @@ func buildObjects(types NamedTypes, s *schema.Schema, prog *loader.Program) Obje
return objects
}

func (os Objects) ByName(name string) *Object {
for i, o := range os {
if strings.EqualFold(o.GQLType, name) {
return os[i]
}
}
return nil
}

func buildObject(types NamedTypes, typ *schema.Object) *Object {
obj := &Object{NamedType: types[typ.TypeName()]}

Expand All @@ -75,128 +70,3 @@ func buildObject(types NamedTypes, typ *schema.Object) *Object {
}
return obj
}

func bindObject(prog *loader.Program, obj *Object) {
if obj.Package == "" {
return
}
pkgName, err := resolvePkg(obj.Package)
if err != nil {
fmt.Fprintf(os.Stderr, "unable to resolve package for %s: %s\n", obj.GQLType, err.Error())
return
}

pkg := prog.Imported[pkgName]
if pkg == nil {
fmt.Fprintf(os.Stderr, "required package was not loaded: %s", pkgName)
return
}

for astNode, object := range pkg.Defs {
if astNode.Name != obj.GoType {
continue
}

if findBindTargets(object.Type(), obj) {
return
}
}
}

func findBindTargets(t types.Type, object *Object) bool {
switch t := t.(type) {
case *types.Named:
for i := 0; i < t.NumMethods(); i++ {
method := t.Method(i)
if !method.Exported() {
continue
}

if methodField := object.GetField(method.Name()); methodField != nil {
methodField.GoMethodName = "it." + method.Name()
sig := method.Type().(*types.Signature)

methodField.Type.Modifiers = modifiersFromGoType(sig.Results().At(0).Type())

// check arg order matches code, not gql

var newArgs []FieldArgument
l2:
for j := 0; j < sig.Params().Len(); j++ {
param := sig.Params().At(j)
for _, oldArg := range methodField.Args {
if strings.EqualFold(oldArg.GQLName, param.Name()) {
oldArg.Type.Modifiers = modifiersFromGoType(param.Type())
newArgs = append(newArgs, oldArg)
continue l2
}
}
fmt.Fprintln(os.Stderr, "cannot match argument "+param.Name()+" to any argument in "+t.String())
}
methodField.Args = newArgs

if sig.Results().Len() == 1 {
methodField.NoErr = true
} else if sig.Results().Len() != 2 {
fmt.Fprintf(os.Stderr, "weird number of results on %s. expected either (result), or (result, error)\n", method.Name())
}
}
}

findBindTargets(t.Underlying(), object)
return true

case *types.Struct:
for i := 0; i < t.NumFields(); i++ {
field := t.Field(i)
// Todo: struct tags, name and - at least

if !field.Exported() {
continue
}

// Todo: check for type matches before binding too?
if objectField := object.GetField(field.Name()); objectField != nil {
objectField.GoVarName = "it." + field.Name()
objectField.Type.Modifiers = modifiersFromGoType(field.Type())
}
}
t.Underlying()
return true
}

return false
}

func mutationRoot(schema *schema.Schema) string {
if mu, ok := schema.EntryPoints["mutation"]; ok {
return mu.TypeName()
}
return ""
}

func queryRoot(schema *schema.Schema) string {
if mu, ok := schema.EntryPoints["mutation"]; ok {
return mu.TypeName()
}
return ""
}

func modifiersFromGoType(t types.Type) []string {
var modifiers []string
for {
switch val := t.(type) {
case *types.Pointer:
modifiers = append(modifiers, modPtr)
t = val.Elem()
case *types.Array:
modifiers = append(modifiers, modList)
t = val.Elem()
case *types.Slice:
modifiers = append(modifiers, modList)
t = val.Elem()
default:
return modifiers
}
}
}
Loading

0 comments on commit d0244d2

Please sign in to comment.