Skip to content

Commit

Permalink
Merge object build and bind
Browse files Browse the repository at this point in the history
  • Loading branch information
Adam Scarr committed Feb 4, 2019
1 parent 97764ae commit 4e49d48
Show file tree
Hide file tree
Showing 27 changed files with 308 additions and 264 deletions.
61 changes: 0 additions & 61 deletions codegen/build_bind.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,67 +10,6 @@ import (
"github.com/pkg/errors"
)

type BindError struct {
object *Object
field *Field
typ types.Type
methodErr error
varErr error
}

func (b BindError) Error() string {
return fmt.Sprintf(
"\nunable to bind %s.%s to %s\n %s\n %s",
b.object.Definition.GQLDefinition.Name,
b.field.GQLName,
b.typ.String(),
b.methodErr.Error(),
b.varErr.Error(),
)
}

type BindErrors []BindError

func (b BindErrors) Error() string {
var errs []string
for _, err := range b {
errs = append(errs, err.Error())
}
return strings.Join(errs, "\n\n")
}

func (b *builder) bindObject(object *Object) BindErrors {
var errs BindErrors
for _, field := range object.Fields {
if field.IsResolver {
continue
}

// first try binding to a method
methodErr := b.bindMethod(object.Definition.GoType, field)
if methodErr == nil {
continue
}

// otherwise try binding to a var
varErr := b.bindVar(object.Definition.GoType, field)

// if both failed, add a resolver
if varErr != nil {
field.IsResolver = true

errs = append(errs, BindError{
object: object,
typ: object.Definition.GoType,
field: field,
varErr: varErr,
methodErr: methodErr,
})
}
}
return errs
}

func (b *builder) bindMethod(t types.Type, field *Field) error {
namedType, err := findGoNamedType(t)
if err != nil {
Expand Down
66 changes: 1 addition & 65 deletions codegen/build_object.go
Original file line number Diff line number Diff line change
@@ -1,75 +1,11 @@
package codegen

import (
"go/types"
"log"
"strings"

"github.com/99designs/gqlgen/codegen/templates"

"github.com/pkg/errors"
"github.com/vektah/gqlparser/ast"
)

func (b *builder) buildObject(typ *ast.Definition) (*Object, error) {
dirs, err := b.getDirectives(typ.Directives)
if err != nil {
return nil, errors.Wrap(err, typ.Name)
}

isRoot := typ == b.Schema.Query || typ == b.Schema.Mutation || typ == b.Schema.Subscription

obj := &Object{
Definition: b.NamedTypes[typ.Name],
InTypemap: b.Config.Models.UserDefined(typ.Name) || isRoot,
Root: isRoot,
DisableConcurrency: typ == b.Schema.Mutation,
Stream: typ == b.Schema.Subscription,
Directives: dirs,
ResolverInterface: types.NewNamed(
types.NewTypeName(0, b.Config.Exec.Pkg(), typ.Name+"Resolver", nil),
nil,
nil,
),
}

for _, intf := range b.Schema.GetImplements(typ) {
obj.Implements = append(obj.Implements, b.NamedTypes[intf.Name])
}

for _, field := range typ.Fields {
if strings.HasPrefix(field.Name, "__") {
continue
}

f, err := b.buildField(obj, field)
if err != nil {
return nil, errors.Wrap(err, typ.Name+"."+field.Name)
}

if typ.Kind == ast.InputObject && !f.TypeReference.Definition.GQLDefinition.IsInputType() {
return nil, errors.Errorf(
"%s.%s: cannot use %s because %s is not a valid input type",
typ.Name,
field.Name,
f.Definition.GQLDefinition.Name,
f.TypeReference.Definition.GQLDefinition.Kind,
)
}

obj.Fields = append(obj.Fields, f)
}

if obj.InTypemap && !isMap(obj.Definition.GoType) {
for _, bindErr := range b.bindObject(obj) {
log.Println(bindErr.Error())
log.Println(" Adding resolver method")
}
}

return obj, nil
}

func (b *builder) buildField(obj *Object, field *ast.FieldDefinition) (*Field, error) {
dirs, err := b.getDirectives(field.Directives)
if err != nil {
Expand All @@ -94,7 +30,7 @@ func (b *builder) buildField(obj *Object, field *ast.FieldDefinition) (*Field, e
}
}

typeEntry, entryExists := b.Config.Models[obj.Definition.GQLDefinition.Name]
typeEntry, entryExists := b.Config.Models[obj.Definition.Name]
if entryExists {
if typeField, ok := typeEntry.Fields[field.Name]; ok {
if typeField.Resolver {
Expand Down
12 changes: 9 additions & 3 deletions codegen/build_typedef.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@ import (
"fmt"
"go/types"

"github.com/99designs/gqlgen/internal/code"

"github.com/99designs/gqlgen/codegen/config"
"github.com/99designs/gqlgen/codegen/templates"
"github.com/99designs/gqlgen/internal/code"
"github.com/pkg/errors"
"github.com/vektah/gqlparser/ast"
)
Expand All @@ -20,7 +20,13 @@ func (b *builder) buildTypeDef(schemaType *ast.Definition) (*TypeDefinition, err
if userEntry, ok := b.Config.Models[t.GQLDefinition.Name]; ok && userEntry.Model != "" {
// special case for maps
if userEntry.Model == "map[string]interface{}" {
t.GoType = types.NewMap(types.Typ[types.String], types.NewInterfaceType(nil, nil).Complete())
t.GoType = config.MapType

return t, nil
}

if userEntry.Model == "interface{}" {
t.GoType = config.InterfaceType

return t, nil
}
Expand Down
41 changes: 37 additions & 4 deletions codegen/config/binder.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,8 @@ import (
"strings"

"github.com/99designs/gqlgen/internal/code"
"github.com/vektah/gqlparser/ast"

"github.com/pkg/errors"
"github.com/vektah/gqlparser/ast"
"golang.org/x/tools/go/packages"
)

Expand Down Expand Up @@ -52,6 +51,36 @@ func (b *Binder) getPkg(find string) *packages.Package {
return nil
}

var MapType = types.NewMap(types.Typ[types.String], types.NewInterfaceType(nil, nil).Complete())
var InterfaceType = types.NewInterfaceType(nil, nil)

func (b *Binder) FindUserObject(name string) (types.Type, error) {
userEntry, ok := b.types[name]
if !ok {
return nil, fmt.Errorf(name + " not found")
}

if userEntry.Model == "map[string]interface{}" {
return MapType, nil
}

if userEntry.Model == "interface{}" {
return InterfaceType, nil
}

pkgName, typeName := code.PkgAndType(userEntry.Model)
if pkgName == "" {
return nil, fmt.Errorf("missing package name for %s", name)
}

obj, err := b.FindObject(pkgName, typeName)
if err != nil {
return nil, err
}

return obj.Type(), nil
}

func (b *Binder) FindObject(pkgName string, typeName string) (types.Object, error) {
if pkgName == "" {
return nil, fmt.Errorf("package cannot be nil")
Expand Down Expand Up @@ -95,12 +124,16 @@ func (b *Binder) FindBackingType(schemaType *ast.Type) (types.Type, error) {
if userEntry, ok := b.types[schemaType.Name()]; ok && userEntry.Model != "" {
// special case for maps
if userEntry.Model == "map[string]interface{}" {
return types.NewMap(types.Typ[types.String], types.NewInterfaceType(nil, nil).Complete()), nil
return MapType, nil
}

if userEntry.Model == "interface{}" {
return InterfaceType, nil
}

pkgName, typeName = code.PkgAndType(userEntry.Model)
if pkgName == "" {
return nil, fmt.Errorf("missing package name for %s", schemaType.Name)
return nil, fmt.Errorf("missing package name for %s", schemaType.Name())
}

} else {
Expand Down
13 changes: 8 additions & 5 deletions codegen/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,14 @@ func (c *Config) normalize() error {
}
}

if c.Models == nil {
c.Models = TypeMap{}
}

return nil
}

func (c *Config) InjectBuiltins(s *ast.Schema) {
builtins := TypeMap{
"__Directive": {Model: "github.com/99designs/gqlgen/graphql/introspection.Directive"},
"__DirectiveLocation": {Model: "github.com/99designs/gqlgen/graphql.String"},
Expand All @@ -311,16 +319,11 @@ func (c *Config) normalize() error {
"Map": {Model: "github.com/99designs/gqlgen/graphql.Map"},
}

if c.Models == nil {
c.Models = TypeMap{}
}
for typeName, entry := range builtins {
if !c.Models.Exists(typeName) {
c.Models[typeName] = entry
}
}

return nil
}

func (c *TypeMapEntry) PkgAndType() (string, string) {
Expand Down
6 changes: 4 additions & 2 deletions codegen/data.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ func BuildData(cfg *config.Config) (*Data, error) {
return nil, err
}

cfg.InjectBuiltins(b.Schema)

b.Binder, err = b.Config.NewBinder()
if err != nil {
return nil, err
Expand Down Expand Up @@ -124,11 +126,11 @@ func BuildData(cfg *config.Config) (*Data, error) {
}

sort.Slice(s.Objects, func(i, j int) bool {
return s.Objects[i].Definition.GQLDefinition.Name < s.Objects[j].Definition.GQLDefinition.Name
return s.Objects[i].Definition.Name < s.Objects[j].Definition.Name
})

sort.Slice(s.Inputs, func(i, j int) bool {
return s.Inputs[i].Definition.GQLDefinition.Name < s.Inputs[j].Definition.GQLDefinition.Name
return s.Inputs[i].Definition.Name < s.Inputs[j].Definition.Name
})

sort.Slice(s.Interfaces, func(i, j int) bool {
Expand Down
8 changes: 4 additions & 4 deletions codegen/field.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,23 +53,23 @@ func (f *Field) GoNameUnexported() string {
}

func (f *Field) ShortInvocation() string {
return fmt.Sprintf("%s().%s(%s)", f.Object.Definition.GQLDefinition.Name, f.GoFieldName, f.CallArgs())
return fmt.Sprintf("%s().%s(%s)", f.Object.Definition.Name, f.GoFieldName, f.CallArgs())
}

func (f *Field) ArgsFunc() string {
if len(f.Args) == 0 {
return ""
}

return "field_" + f.Object.Definition.GQLDefinition.Name + "_" + f.GQLName + "_args"
return "field_" + f.Object.Definition.Name + "_" + f.GQLName + "_args"
}

func (f *Field) ResolverType() string {
if !f.IsResolver {
return ""
}

return fmt.Sprintf("%s().%s(%s)", f.Object.Definition.GQLDefinition.Name, f.GoFieldName, f.CallArgs())
return fmt.Sprintf("%s().%s(%s)", f.Object.Definition.Name, f.GoFieldName, f.CallArgs())
}

func (f *Field) ShortResolverDeclaration() string {
Expand All @@ -79,7 +79,7 @@ func (f *Field) ShortResolverDeclaration() string {
res := fmt.Sprintf("%s(ctx context.Context", f.GoFieldName)

if !f.Object.Root {
res += fmt.Sprintf(", obj *%s", templates.CurrentImports.LookupType(f.Object.Definition.GoType))
res += fmt.Sprintf(", obj *%s", templates.CurrentImports.LookupType(f.Object.Type))
}
for _, arg := range f.Args {
res += fmt.Sprintf(", %s %s", arg.GoVarName, templates.CurrentImports.LookupType(arg.GoType))
Expand Down
6 changes: 3 additions & 3 deletions codegen/field.gotpl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
{{- range $object := .Objects }}{{- range $field := $object.Fields }}

{{- if $object.Stream }}
func (ec *executionContext) _{{$object.Definition.GQLDefinition.Name}}_{{$field.GQLName}}(ctx context.Context, field graphql.CollectedField) func() graphql.Marshaler {
func (ec *executionContext) _{{$object.Name}}_{{$field.GQLName}}(ctx context.Context, field graphql.CollectedField) func() graphql.Marshaler {
ctx = graphql.WithResolverContext(ctx, &graphql.ResolverContext{
Field: field,
Args: nil,
Expand Down Expand Up @@ -40,11 +40,11 @@
}
{{ else }}
// nolint: vetshadow
func (ec *executionContext) _{{$object.Definition.GQLDefinition.Name}}_{{$field.GQLName}}(ctx context.Context, field graphql.CollectedField, {{if not $object.Root}}obj *{{$object.Definition.GoType | ref}}{{end}}) graphql.Marshaler {
func (ec *executionContext) _{{$object.Name}}_{{$field.GQLName}}(ctx context.Context, field graphql.CollectedField{{ if not $object.Root }}, obj *{{$object.Type | ref}}{{end}}) graphql.Marshaler {
ctx = ec.Tracer.StartFieldExecution(ctx, field)
defer func () { ec.Tracer.EndFieldExecution(ctx) }()
rctx := &graphql.ResolverContext{
Object: {{$object.Definition.GQLDefinition.Name|quote}},
Object: {{$object.Name|quote}},
Field: field,
Args: nil,
}
Expand Down
Loading

0 comments on commit 4e49d48

Please sign in to comment.