From c89a8774650d41a4a99bf5b90d0c69c4a7a166a3 Mon Sep 17 00:00:00 2001 From: Adam Scarr Date: Wed, 14 Feb 2018 00:03:40 +1100 Subject: [PATCH] Cleanup schema binding code --- codegen/build.go | 90 ++++++++ codegen/import.go | 16 ++ codegen/import_build.go | 77 +++++++ codegen/interface.go | 7 + codegen/interface_build.go | 40 ++++ types.go => codegen/object.go | 160 +++++--------- codegen/object_build.go | 202 +++++++++++++++++ codegen/type.go | 44 ++++ codegen/type_build.go | 92 ++++++++ extractor.go | 406 ---------------------------------- main.go | 90 +++----- templates.go | 3 +- templates/args.go | 24 +- templates/file.go | 8 +- templates/interface.go | 12 +- templates/object.go | 20 +- 16 files changed, 684 insertions(+), 607 deletions(-) create mode 100644 codegen/build.go create mode 100644 codegen/import.go create mode 100644 codegen/import_build.go create mode 100644 codegen/interface.go create mode 100644 codegen/interface_build.go rename types.go => codegen/object.go (52%) create mode 100644 codegen/object_build.go create mode 100644 codegen/type.go create mode 100644 codegen/type_build.go delete mode 100644 extractor.go diff --git a/codegen/build.go b/codegen/build.go new file mode 100644 index 00000000000..1e4256eaf46 --- /dev/null +++ b/codegen/build.go @@ -0,0 +1,90 @@ +package codegen + +import ( + "go/build" + "os" + "path/filepath" + + "github.com/vektah/gqlgen/neelance/schema" + "golang.org/x/tools/go/loader" +) + +type Build struct { + PackageName string + Objects Objects + Interfaces []*Interface + Imports Imports + QueryRoot *Object + MutationRoot *Object + SchemaRaw string +} + +// Bind a schema together with some code to generate a Build +func Bind(schema *schema.Schema, userTypes map[string]string, destDir string) (*Build, error) { + namedTypes := buildNamedTypes(schema, userTypes) + + imports := buildImports(namedTypes, destDir) + prog, err := loadProgram(imports) + if err != nil { + return nil, err + } + + b := &Build{ + PackageName: filepath.Base(destDir), + Objects: buildObjects(namedTypes, schema, prog), + Interfaces: buildInterfaces(namedTypes, schema), + Imports: imports, + } + + if qr, ok := schema.EntryPoints["query"]; ok { + b.QueryRoot = b.Objects.ByName(qr.TypeName()) + } + + if mr, ok := schema.EntryPoints["mutation"]; ok { + b.MutationRoot = b.Objects.ByName(mr.TypeName()) + } + + // Poke a few magic methods into query + q := b.Objects.ByName(b.QueryRoot.GQLType) + q.Fields = append(q.Fields, Field{ + Type: &Type{namedTypes["__Schema"], []string{modPtr}}, + GQLName: "__schema", + NoErr: true, + GoMethodName: "ec.introspectSchema", + Object: q, + }) + q.Fields = append(q.Fields, Field{ + Type: &Type{namedTypes["__Type"], []string{modPtr}}, + GQLName: "__type", + NoErr: true, + GoMethodName: "ec.introspectType", + Args: []FieldArgument{ + {GQLName: "name", Type: &Type{namedTypes["String"], []string{}}}, + }, + Object: q, + }) + + return b, nil +} + +func loadProgram(imports Imports) (*loader.Program, error) { + var conf loader.Config + for _, imp := range imports { + if imp.Package != "" { + conf.Import(imp.Package) + } + } + + return conf.Load() +} + +func resolvePkg(pkgName string) (string, error) { + cwd, _ := os.Getwd() + + pkg, err := build.Default.Import(pkgName, cwd, build.FindOnly) + if err != nil { + return "", err + } + + return pkg.ImportPath, nil +} diff --git a/codegen/import.go b/codegen/import.go new file mode 100644 index 00000000000..ed1ab0780c5 --- /dev/null +++ b/codegen/import.go @@ -0,0 +1,16 @@ +package codegen + +import ( + "strconv" +) + +type Import struct { + Name string + Package string +} + +type Imports []*Import + +func (i *Import) Write() string { + return i.Name + " " + strconv.Quote(i.Package) +} diff --git a/codegen/import_build.go b/codegen/import_build.go new file mode 100644 index 00000000000..df1ead2d089 --- /dev/null +++ b/codegen/import_build.go @@ -0,0 +1,77 @@ +package codegen + +import ( + "path/filepath" + "strconv" + "strings" +) + +func buildImports(types NamedTypes, destDir string) Imports { + imports := Imports{ + {"context", "context"}, + {"fmt", "fmt"}, + {"io", "io"}, + {"strconv", "strconv"}, + {"time", "time"}, + {"reflect", "reflect"}, + {"strings", "strings"}, + {"sync", "sync"}, + {"mapstructure", "github.com/mitchellh/mapstructure"}, + {"introspection", "github.com/vektah/gqlgen/neelance/introspection"}, + {"errors", "github.com/vektah/gqlgen/neelance/errors"}, + {"query", "github.com/vektah/gqlgen/neelance/query"}, + {"schema", "github.com/vektah/gqlgen/neelance/schema"}, + {"validation", "github.com/vektah/gqlgen/neelance/validation"}, + {"jsonw", "github.com/vektah/gqlgen/jsonw"}, + } + + for _, t := range types { + if t.Package == "" { + continue + } + + if existing := imports.findByPkg(t.Package); existing != nil { + t.Import = existing + continue + } + + localName := "" + if !strings.HasSuffix(destDir, t.Package) { + localName = filepath.Base(t.Package) + i := 0 + for imp := imports.findByName(localName); imp != nil && imp.Package != t.Package; localName = filepath.Base(t.Package) + strconv.Itoa(i) { + i++ + if i > 10 { + panic("too many collisions") + } + } + } + + imp := &Import{ + Name: localName, + Package: t.Package, + } + t.Import = imp + imports = append(imports, imp) + } + + return imports +} + +func (i Imports) findByPkg(pkg string) *Import { + for _, imp := range i { + if imp.Package == pkg { + return imp + } + } + return nil +} + +func (i Imports) findByName(name string) *Import { + for _, imp := range i { + if imp.Name == name { + return imp + } + } + return nil +} diff --git a/codegen/interface.go b/codegen/interface.go new file mode 100644 index 00000000000..98c9bc17751 --- /dev/null +++ b/codegen/interface.go @@ -0,0 +1,7 @@ +package codegen + +type Interface struct { + *NamedType + + Implementors []*NamedType +} diff --git a/codegen/interface_build.go b/codegen/interface_build.go new file mode 100644 index 00000000000..2b981e0bee5 --- /dev/null +++ b/codegen/interface_build.go @@ -0,0 +1,40 @@ +package codegen + +import ( + "sort" + "strings" + + "github.com/vektah/gqlgen/neelance/schema" +) + +func buildInterfaces(types NamedTypes, s *schema.Schema) []*Interface { + var interfaces []*Interface + for _, typ := range s.Types { + switch typ := typ.(type) { + + case *schema.Union: + i := &Interface{NamedType: types[typ.TypeName()]} + + for _, implementor := range typ.PossibleTypes { + i.Implementors = append(i.Implementors, types[implementor.TypeName()]) + } + + interfaces = append(interfaces, i) + + case *schema.Interface: + i := &Interface{NamedType: types[typ.TypeName()]} + + for _, implementor := range typ.PossibleTypes { + i.Implementors = append(i.Implementors, types[implementor.TypeName()]) + } + + interfaces = append(interfaces, i) + } + } + + sort.Slice(interfaces, func(i, j int) bool { + return strings.Compare(interfaces[i].GQLType, interfaces[j].GQLType) == -1 + }) + + return interfaces +} diff --git a/types.go b/codegen/object.go similarity index 52% rename from types.go rename to codegen/object.go index e2ed6a40376..84bedd62f17 100644 --- a/types.go +++ b/codegen/object.go @@ -1,4 +1,4 @@ -package main +package codegen import ( "bytes" @@ -6,90 +6,54 @@ import ( "strconv" "strings" "text/template" + "unicode" ) -type kind struct { - GraphQLName string - Name string - Package string - ImportedAs string - Modifiers []string - Implementors []kind - Scalar bool -} - -func (t kind) Local() string { - return strings.Join(t.Modifiers, "") + t.FullName() -} +type Object struct { + *NamedType -func (t kind) Ptr() kind { - t.Modifiers = append(t.Modifiers, modPtr) - return t + Fields []Field + Satisfies []string + Root bool + DisableConcurrency bool } -func (t kind) IsPtr() bool { - return len(t.Modifiers) > 0 && t.Modifiers[0] == modPtr -} +type Field struct { + *Type -func (t kind) IsSlice() bool { - return len(t.Modifiers) > 0 && t.Modifiers[0] == modList + GQLName string // The name of the field in graphql + GoMethodName string // The name of the method in go, if any + GoVarName string // The name of the var in go, if any + Args []FieldArgument // A list of arguments to be passed to this field + NoErr bool // If this is bound to a go method, does that method have an error as the second argument + Object *Object // A link back to the parent object } -func (t kind) Elem() kind { - if len(t.Modifiers) == 0 { - return t - } - - t.Modifiers = t.Modifiers[1:] - return t -} +type FieldArgument struct { + *Type -func (t kind) ByRef(name string) string { - needPtr := len(t.Implementors) == 0 - if needPtr && !t.IsPtr() { - return "&" + name - } - if !needPtr && t.IsPtr() { - return "*" + name - } - return name + GQLName string // The name of the argument in graphql } -func (t kind) ByVal(name string) string { - if t.IsPtr() { - return "*" + name +func (o *Object) GetField(name string) *Field { + for i, field := range o.Fields { + if strings.EqualFold(field.GQLName, name) { + return &o.Fields[i] + } } - return name + return nil } -func (t kind) FullName() string { - if t.ImportedAs == "" { - return t.Name +func (o *Object) Implementors() string { + satisfiedBy := strconv.Quote(o.GQLType) + for _, s := range o.Satisfies { + satisfiedBy += ", " + strconv.Quote(s) } - return t.ImportedAs + "." + t.Name -} - -type object struct { - Name string - Fields []Field - Type kind - satisfies []string - Root bool - DisableConcurrency bool -} - -type Field struct { - GraphQLName string - MethodName string - VarName string - Type kind - Args []FieldArgument - NoErr bool - Object *object + return "[]string{" + satisfiedBy + "}" } func (f *Field) IsResolver() bool { - return f.MethodName == "" && f.VarName == "" + return f.GoMethodName == "" && f.GoVarName == "" } func (f *Field) IsConcurrent() bool { @@ -100,23 +64,23 @@ func (f *Field) ResolverDeclaration() string { if !f.IsResolver() { return "" } - res := fmt.Sprintf("%s_%s(ctx context.Context", f.Object.Name, f.GraphQLName) + res := fmt.Sprintf("%s_%s(ctx context.Context", f.Object.GQLType, f.GQLName) if !f.Object.Root { - res += fmt.Sprintf(", it *%s", f.Object.Type.Local()) + res += fmt.Sprintf(", it *%s", f.Object.FullName()) } for _, arg := range f.Args { - res += fmt.Sprintf(", %s %s", arg.Name, arg.Type.Local()) + res += fmt.Sprintf(", %s %s", arg.GQLName, arg.Signature()) } - res += fmt.Sprintf(") (%s, error)", f.Type.Local()) + res += fmt.Sprintf(") (%s, error)", f.Signature()) return res } func (f *Field) CallArgs() string { var args []string - if f.MethodName == "" { + if f.GoMethodName == "" { args = append(args, "ec.ctx") if !f.Object.Root { @@ -141,7 +105,7 @@ func (f *Field) doWriteJson(res string, val string, remainingMods []string, isPt case len(remainingMods) > 0 && remainingMods[0] == modPtr: return tpl(` if {{.val}} == nil { - {{.res}} = jsonw.Null + {{.res}} = jsonw.Null } else { {{.next}} }`, map[string]interface{}{ @@ -174,53 +138,41 @@ func (f *Field) doWriteJson(res string, val string, remainingMods []string, isPt "next": f.doWriteJson(tmp, val+"["+index+"]", remainingMods[1:], false, depth+1), }) - case f.Type.Scalar: + case f.IsScalar: if isPtr { val = "*" + val } - return fmt.Sprintf("%s = jsonw.%s(%s)", res, ucFirst(f.Type.Name), val) + return fmt.Sprintf("%s = jsonw.%s(%s)", res, ucFirst(f.GoType), val) default: if !isPtr { val = "&" + val } - return fmt.Sprintf("%s = ec._%s(field.Selections, %s)", res, lcFirst(f.Type.GraphQLName), val) + return fmt.Sprintf("%s = ec._%s(field.Selections, %s)", res, lcFirst(f.GQLType), val) } } -func (o *object) GetField(name string) *Field { - for i, field := range o.Fields { - if strings.EqualFold(field.GraphQLName, name) { - return &o.Fields[i] - } - } - return nil +func tpl(tpl string, vars map[string]interface{}) string { + b := &bytes.Buffer{} + template.Must(template.New("inline").Parse(tpl)).Execute(b, vars) + return b.String() } -func (o *object) Implementors() string { - satisfiedBy := strconv.Quote(o.Type.GraphQLName) - for _, s := range o.satisfies { - satisfiedBy += ", " + strconv.Quote(s) +func ucFirst(s string) string { + if s == "" { + return "" } - return "[]string{" + satisfiedBy + "}" + r := []rune(s) + r[0] = unicode.ToUpper(r[0]) + return string(r) } -func (e *extractor) GetObject(name string) *object { - for i, o := range e.Objects { - if strings.EqualFold(o.Name, name) { - return e.Objects[i] - } +func lcFirst(s string) string { + if s == "" { + return "" } - return nil -} - -type FieldArgument struct { - Name string - Type kind -} -func tpl(tpl string, vars map[string]interface{}) string { - b := &bytes.Buffer{} - template.Must(template.New("inline").Parse(tpl)).Execute(b, vars) - return b.String() + r := []rune(s) + r[0] = unicode.ToLower(r[0]) + return string(r) } diff --git a/codegen/object_build.go b/codegen/object_build.go new file mode 100644 index 00000000000..8c619f778be --- /dev/null +++ b/codegen/object_build.go @@ -0,0 +1,202 @@ +package codegen + +import ( + "fmt" + "go/types" + "os" + "sort" + "strings" + + "github.com/vektah/gqlgen/neelance/schema" + "golang.org/x/tools/go/loader" +) + +type Objects []*Object + +func buildObjects(types NamedTypes, s *schema.Schema, prog *loader.Program) Objects { + var objects Objects + + for _, typ := range s.Types { + switch typ := typ.(type) { + case *schema.Object: + obj := buildObject(types, typ) + bindObject(prog, obj) + + objects = append(objects, obj) + } + } + + for name, typ := range s.EntryPoints { + obj := typ.(*schema.Object) + objects.ByName(obj.Name).Root = true + if name == "mutation" { + objects.ByName(obj.Name).DisableConcurrency = true + } + } + + sort.Slice(objects, func(i, j int) bool { + return strings.Compare(objects[i].GQLType, objects[j].GQLType) == -1 + }) + + return objects +} + +func (os Objects) ByName(name string) *Object { + for i, o := range os { + if strings.EqualFold(o.GQLType, name) { + return os[i] + } + } + return nil +} + +func buildObject(types NamedTypes, typ *schema.Object) *Object { + obj := &Object{NamedType: types[typ.TypeName()]} + + for _, i := range typ.Interfaces { + obj.Satisfies = append(obj.Satisfies, i.Name) + } + + for _, field := range typ.Fields { + var args []FieldArgument + for _, arg := range field.Args { + args = append(args, FieldArgument{ + GQLName: arg.Name.Name, + Type: types.getType(arg.Type), + }) + } + + obj.Fields = append(obj.Fields, Field{ + GQLName: field.Name, + Type: types.getType(field.Type), + Args: args, + Object: obj, + }) + } + return obj +} + +func bindObject(prog *loader.Program, obj *Object) { + if obj.Package == "" { + return + } + pkgName, err := resolvePkg(obj.Package) + if err != nil { + fmt.Fprintf(os.Stderr, "unable to resolve package for %s: %s\n", obj.GQLType, err.Error()) + return + } + + pkg := prog.Imported[pkgName] + if pkg == nil { + fmt.Fprintf(os.Stderr, "required package was not loaded: %s", pkgName) + return + } + + for astNode, object := range pkg.Defs { + if astNode.Name != obj.GoType { + continue + } + + if findBindTargets(object.Type(), obj) { + return + } + } +} + +func findBindTargets(t types.Type, object *Object) bool { + switch t := t.(type) { + case *types.Named: + for i := 0; i < t.NumMethods(); i++ { + method := t.Method(i) + if !method.Exported() { + continue + } + + if methodField := object.GetField(method.Name()); methodField != nil { + methodField.GoMethodName = "it." + method.Name() + sig := method.Type().(*types.Signature) + + methodField.Type.Modifiers = modifiersFromGoType(sig.Results().At(0).Type()) + + // check arg order matches code, not gql + + var newArgs []FieldArgument + l2: + for j := 0; j < sig.Params().Len(); j++ { + param := sig.Params().At(j) + for _, oldArg := range methodField.Args { + if strings.EqualFold(oldArg.GQLName, param.Name()) { + oldArg.Type.Modifiers = modifiersFromGoType(param.Type()) + newArgs = append(newArgs, oldArg) + continue l2 + } + } + fmt.Fprintln(os.Stderr, "cannot match argument "+param.Name()+" to any argument in "+t.String()) + } + methodField.Args = newArgs + + if sig.Results().Len() == 1 { + methodField.NoErr = true + } else if sig.Results().Len() != 2 { + fmt.Fprintf(os.Stderr, "weird number of results on %s. expected either (result), or (result, error)\n", method.Name()) + } + } + } + + findBindTargets(t.Underlying(), object) + return true + + case *types.Struct: + for i := 0; i < t.NumFields(); i++ { + field := t.Field(i) + // Todo: struct tags, name and - at least + + if !field.Exported() { + continue + } + + // Todo: check for type matches before binding too? + if objectField := object.GetField(field.Name()); objectField != nil { + objectField.GoVarName = "it." + field.Name() + objectField.Type.Modifiers = modifiersFromGoType(field.Type()) + } + } + t.Underlying() + return true + } + + return false +} + +func mutationRoot(schema *schema.Schema) string { + if mu, ok := schema.EntryPoints["mutation"]; ok { + return mu.TypeName() + } + return "" +} + +func queryRoot(schema *schema.Schema) string { + if mu, ok := schema.EntryPoints["mutation"]; ok { + return mu.TypeName() + } + return "" +} + +func modifiersFromGoType(t types.Type) []string { + var modifiers []string + for { + switch val := t.(type) { + case *types.Pointer: + modifiers = append(modifiers, modPtr) + t = val.Elem() + case *types.Array: + modifiers = append(modifiers, modList) + t = val.Elem() + case *types.Slice: + modifiers = append(modifiers, modList) + t = val.Elem() + default: + return modifiers + } + } +} diff --git a/codegen/type.go b/codegen/type.go new file mode 100644 index 00000000000..4605b4fe38a --- /dev/null +++ b/codegen/type.go @@ -0,0 +1,44 @@ +package codegen + +import "strings" + +type NamedTypes map[string]*NamedType + +type NamedType struct { + IsScalar bool + IsInterface bool + GQLType string // Name of the graphql type + GoType string // Name of the go type + Package string // the package the go type lives in + Import *Import +} + +type Type struct { + *NamedType + + Modifiers []string +} + +const ( + modList = "[]" + modPtr = "*" +) + +func (t NamedType) FullName() string { + if t.Import == nil || t.Import.Name == "" { + return t.GoType + } + return t.Import.Name + "." + t.GoType +} + +func (t Type) Signature() string { + return strings.Join(t.Modifiers, "") + t.FullName() +} + +func (t Type) IsPtr() bool { + return len(t.Modifiers) > 0 && t.Modifiers[0] == modPtr +} + +func (t Type) IsSlice() bool { + return len(t.Modifiers) > 0 && t.Modifiers[0] == modList +} diff --git a/codegen/type_build.go b/codegen/type_build.go new file mode 100644 index 00000000000..b49968dbd47 --- /dev/null +++ b/codegen/type_build.go @@ -0,0 +1,92 @@ +package codegen + +import ( + "fmt" + "strings" + + "github.com/vektah/gqlgen/neelance/common" + "github.com/vektah/gqlgen/neelance/schema" +) + +// namedTypeFromSchema objects for every graphql type, including scalars. There should only be one instance of Type for each thing +func buildNamedTypes(s *schema.Schema, userTypes map[string]string) NamedTypes { + types := map[string]*NamedType{} + for _, schemaType := range s.Types { + t := namedTypeFromSchema(schemaType) + + userType := userTypes[t.GQLType] + if userType == "" { + if t.IsScalar { + userType = "string" + } else { + userType = "interface{}" + userTypes[t.GQLType] = "interface{}" + } + } + t.Package, t.GoType = pkgAndType(userType) + + types[t.GQLType] = t + } + return types +} + +// namedTypeFromSchema objects for every graphql type, including primitives. +// don't recurse into object fields or interfaces yet, lets make sure we have collected everything first. +func namedTypeFromSchema(schemaType schema.NamedType) *NamedType { + switch val := schemaType.(type) { + case *schema.Scalar, *schema.Enum: + return &NamedType{GQLType: val.TypeName(), IsScalar: true} + case *schema.Interface, *schema.Union: + return &NamedType{GQLType: val.TypeName(), IsInterface: true} + default: + return &NamedType{GQLType: val.TypeName()} + } +} + +// take a string in the form github.com/package/blah.Type and split it into package and type +func pkgAndType(name string) (string, string) { + parts := strings.Split(name, ".") + if len(parts) == 1 { + return "", name + } + + return strings.Join(parts[:len(parts)-1], "."), parts[len(parts)-1] +} + +func (n NamedTypes) getType(t common.Type) *Type { + var modifiers []string + usePtr := true + for { + if _, nonNull := t.(*common.NonNull); nonNull { + usePtr = false + } else if _, nonNull := t.(*common.List); nonNull { + usePtr = false + } else { + if usePtr { + modifiers = append(modifiers, modPtr) + } + usePtr = true + } + + switch val := t.(type) { + case *common.NonNull: + t = val.OfType + case *common.List: + modifiers = append(modifiers, modList) + t = val.OfType + case schema.NamedType: + t := &Type{ + NamedType: n[val.TypeName()], + Modifiers: modifiers, + } + + if t.IsInterface && t.Modifiers[len(t.Modifiers)-1] == modPtr { + t.Modifiers = t.Modifiers[0 : len(t.Modifiers)-1] + } + + return t + default: + panic(fmt.Errorf("unknown type %T", t)) + } + } +} diff --git a/extractor.go b/extractor.go deleted file mode 100644 index b545043ec4d..00000000000 --- a/extractor.go +++ /dev/null @@ -1,406 +0,0 @@ -package main - -import ( - "fmt" - "go/types" - "os" - "path/filepath" - "sort" - "strconv" - "strings" - - "github.com/vektah/gqlgen/neelance/common" - "github.com/vektah/gqlgen/neelance/schema" - - "go/build" - - "golang.org/x/tools/go/loader" -) - -type extractor struct { - Errors []string - PackageName string - Objects []*object - Interfaces []*object - goTypeMap map[string]string - Imports map[string]string // local -> full path - schema *schema.Schema - SchemaRaw string - QueryRoot string - MutationRoot string -} - -func (e *extractor) extract() { - for _, typ := range e.schema.Types { - switch typ := typ.(type) { - case *schema.Object: - obj := &object{ - Name: typ.Name, - Type: e.getType(typ.Name), - } - - for _, i := range typ.Interfaces { - obj.satisfies = append(obj.satisfies, i.Name) - } - - for _, field := range typ.Fields { - var args []FieldArgument - for _, arg := range field.Args { - args = append(args, FieldArgument{ - Name: arg.Name.Name, - Type: e.buildType(arg.Type), - }) - } - - obj.Fields = append(obj.Fields, Field{ - GraphQLName: field.Name, - Type: e.buildType(field.Type), - Args: args, - Object: obj, - }) - } - e.Objects = append(e.Objects, obj) - case *schema.Union: - obj := &object{ - Name: typ.Name, - Type: e.buildType(typ), - } - e.Interfaces = append(e.Interfaces, obj) - - case *schema.Interface: - obj := &object{ - Name: typ.Name, - Type: e.buildType(typ), - } - e.Interfaces = append(e.Interfaces, obj) - } - - } - - for name, typ := range e.schema.EntryPoints { - obj := typ.(*schema.Object) - e.GetObject(obj.Name).Root = true - if name == "query" { - e.QueryRoot = obj.Name - } - if name == "mutation" { - e.MutationRoot = obj.Name - e.GetObject(obj.Name).DisableConcurrency = true - } - } - - sort.Slice(e.Objects, func(i, j int) bool { - return strings.Compare(e.Objects[i].Name, e.Objects[j].Name) == -1 - }) - - sort.Slice(e.Interfaces, func(i, j int) bool { - return strings.Compare(e.Interfaces[i].Name, e.Interfaces[j].Name) == -1 - }) -} - -func resolvePkg(pkgName string) (string, error) { - cwd, _ := os.Getwd() - - pkg, err := build.Default.Import(pkgName, cwd, build.FindOnly) - if err != nil { - return "", err - } - - return pkg.ImportPath, nil -} - -func (e *extractor) introspect() error { - var conf loader.Config - for _, name := range e.Imports { - conf.Import(name) - } - - prog, err := conf.Load() - if err != nil { - return err - } - - for _, o := range e.Objects { - if o.Type.Package == "" { - continue - } - - pkgName, err := resolvePkg(o.Type.Package) - if err != nil { - return fmt.Errorf("unable to resolve package: %s", o.Type.Package) - } - pkg := prog.Imported[pkgName] - if pkg == nil { - return fmt.Errorf("required package was not loaded: %s", pkgName) - } - - for astNode, object := range pkg.Defs { - if astNode.Name != o.Type.Name { - continue - } - - if e.findBindTargets(object.Type(), o) { - break - } - } - } - - return nil -} - -func (e *extractor) errorf(format string, args ...interface{}) { - e.Errors = append(e.Errors, fmt.Sprintf(format, args...)) -} - -func isOwnPkg(pkg string) bool { - absPath, err := filepath.Abs(*output) - if err != nil { - panic(err) - } - - return strings.HasSuffix(filepath.Dir(absPath), pkg) -} - -// getType to put in a file for a given fully resolved type, and add any Imports required -// eg name = github.com/my/pkg.myType will return `pkg.myType` and add an import for `github.com/my/pkg` -func (e *extractor) getType(name string) kind { - if fieldType, ok := e.goTypeMap[name]; ok { - parts := strings.Split(fieldType, ".") - if len(parts) == 1 { - return kind{ - GraphQLName: name, - Name: parts[0], - } - } - - packageName := strings.Join(parts[:len(parts)-1], ".") - typeName := parts[len(parts)-1] - - localName := "" - if !isOwnPkg(packageName) { - localName = filepath.Base(packageName) - i := 0 - for pkg, found := e.Imports[localName]; found && pkg != packageName; localName = filepath.Base(packageName) + strconv.Itoa(i) { - i++ - if i > 10 { - panic("too many collisions") - } - } - } - e.Imports[localName] = packageName - return kind{ - GraphQLName: name, - ImportedAs: localName, - Name: typeName, - Package: packageName, - } - } - - isRoot := false - for _, s := range e.schema.EntryPoints { - if s.(*schema.Object).Name == name { - isRoot = true - break - } - } - - if !isRoot { - fmt.Fprintf(os.Stderr, "unknown go type for %s, using interface{}. you should add it to types.json\n", name) - } - e.goTypeMap[name] = "interface{}" - return kind{ - GraphQLName: name, - Name: "interface{}", - } -} - -func (e *extractor) buildType(t common.Type) kind { - var modifiers []string - usePtr := true - for { - if _, nonNull := t.(*common.NonNull); nonNull { - usePtr = false - } else if _, nonNull := t.(*common.List); nonNull { - usePtr = false - } else { - if usePtr { - modifiers = append(modifiers, modPtr) - } - usePtr = true - } - - switch val := t.(type) { - case *common.NonNull: - t = val.OfType - case *common.List: - modifiers = append(modifiers, modList) - t = val.OfType - case *schema.Scalar: - var goType string - - switch val.Name { - case "String": - goType = "string" - case "ID": - goType = "string" - case "Boolean": - goType = "bool" - case "Int": - goType = "int" - case "Float": - goType = "float64" - case "Time": - return kind{ - Scalar: true, - Modifiers: modifiers, - GraphQLName: val.Name, - Name: "Time", - Package: "time", - ImportedAs: "time", - } - default: - panic(fmt.Errorf("unknown scalar %s", val.Name)) - } - return kind{ - Scalar: true, - Modifiers: modifiers, - GraphQLName: val.Name, - Name: goType, - } - case *schema.Object: - t := e.getType(val.Name) - t.Modifiers = modifiers - return t - case *common.TypeName: - t := e.getType(val.Name) - t.Modifiers = modifiers - return t - case *schema.Interface: - t := e.getType(val.Name) - t.Modifiers = modifiers - if t.Modifiers[len(t.Modifiers)-1] == modPtr { - t.Modifiers = t.Modifiers[0 : len(t.Modifiers)-1] - } - - for _, implementor := range val.PossibleTypes { - t.Implementors = append(t.Implementors, e.getType(implementor.Name)) - } - - return t - case *schema.Union: - t := e.getType(val.Name) - t.Modifiers = modifiers - if t.Modifiers[len(t.Modifiers)-1] == modPtr { - t.Modifiers = t.Modifiers[0 : len(t.Modifiers)-1] - } - - for _, implementor := range val.PossibleTypes { - t.Implementors = append(t.Implementors, e.getType(implementor.Name)) - } - - return t - case *schema.InputObject: - t := e.getType(val.Name) - t.Modifiers = modifiers - return t - case *schema.Enum: - return kind{ - Scalar: true, - Modifiers: modifiers, - GraphQLName: val.Name, - Name: "string", - } - default: - panic(fmt.Errorf("unknown type %T", t)) - } - } -} - -func (e *extractor) modifiersFromGoType(t types.Type) []string { - var modifiers []string - for { - switch val := t.(type) { - case *types.Pointer: - modifiers = append(modifiers, modPtr) - t = val.Elem() - case *types.Array: - modifiers = append(modifiers, modList) - t = val.Elem() - case *types.Slice: - modifiers = append(modifiers, modList) - t = val.Elem() - default: - return modifiers - } - } -} - -func (e *extractor) findBindTargets(t types.Type, object *object) bool { - switch t := t.(type) { - case *types.Named: - for i := 0; i < t.NumMethods(); i++ { - method := t.Method(i) - if !method.Exported() { - continue - } - - if methodField := object.GetField(method.Name()); methodField != nil { - methodField.MethodName = "it." + method.Name() - sig := method.Type().(*types.Signature) - - methodField.Type.Modifiers = e.modifiersFromGoType(sig.Results().At(0).Type()) - - // check arg order matches code, not gql - - var newArgs []FieldArgument - l2: - for j := 0; j < sig.Params().Len(); j++ { - param := sig.Params().At(j) - for _, oldArg := range methodField.Args { - if strings.EqualFold(oldArg.Name, param.Name()) { - oldArg.Type.Modifiers = e.modifiersFromGoType(param.Type()) - newArgs = append(newArgs, oldArg) - continue l2 - } - } - e.errorf("cannot match argument " + param.Name() + " to any argument in " + t.String()) - } - methodField.Args = newArgs - - if sig.Results().Len() == 1 { - methodField.NoErr = true - } else if sig.Results().Len() != 2 { - e.errorf("weird number of results on %s. expected either (result), or (result, error)", method.Name()) - } - } - } - - e.findBindTargets(t.Underlying(), object) - return true - - case *types.Struct: - for i := 0; i < t.NumFields(); i++ { - field := t.Field(i) - // Todo: struct tags, name and - at least - - if !field.Exported() { - continue - } - - // Todo: check for type matches before binding too? - if objectField := object.GetField(field.Name()); objectField != nil { - objectField.VarName = "it." + field.Name() - objectField.Type.Modifiers = e.modifiersFromGoType(field.Type()) - } - } - t.Underlying() - return true - } - - return false -} - -const ( - modList = "[]" - modPtr = "*" -) diff --git a/main.go b/main.go index ea29d71bb53..c97c7ac93f3 100644 --- a/main.go +++ b/main.go @@ -10,6 +10,7 @@ import ( "syscall" + "github.com/vektah/gqlgen/codegen" "github.com/vektah/gqlgen/neelance/schema" "golang.org/x/tools/imports" ) @@ -45,67 +46,22 @@ func main() { os.Exit(1) } - e := extractor{ - PackageName: getPkgName(), - goTypeMap: loadTypeMap(), - SchemaRaw: string(schemaRaw), - schema: schema, - Imports: map[string]string{ - "context": "context", - "fmt": "fmt", - "io": "io", - "strconv": "strconv", - "time": "time", - "reflect": "reflect", - "strings": "strings", - "sync": "sync", - - "mapstructure": "github.com/mitchellh/mapstructure", - "introspection": "github.com/vektah/gqlgen/neelance/introspection", - "errors": "github.com/vektah/gqlgen/neelance/errors", - "query": "github.com/vektah/gqlgen/neelance/query", - "schema": "github.com/vektah/gqlgen/neelance/schema", - "validation": "github.com/vektah/gqlgen/neelance/validation", - "jsonw": "github.com/vektah/gqlgen/jsonw", - }, - } - e.extract() - - // Poke a few magic methods into query - q := e.GetObject(e.QueryRoot) - q.Fields = append(q.Fields, Field{ - Type: e.getType("__Schema").Ptr(), - GraphQLName: "__schema", - NoErr: true, - MethodName: "ec.introspectSchema", - Object: q, - }) - q.Fields = append(q.Fields, Field{ - Type: e.getType("__Type").Ptr(), - GraphQLName: "__type", - NoErr: true, - MethodName: "ec.introspectType", - Args: []FieldArgument{{Name: "name", Type: kind{Scalar: true, Name: "string"}}}, - Object: q, - }) - - if len(e.Errors) != 0 { - for _, err := range e.Errors { - fmt.Fprintln(os.Stderr, "err: "+err) - } - os.Exit(1) - } - if *output != "-" { _ = syscall.Unlink(*output) } - if err = e.introspect(); err != nil { - fmt.Fprintln(os.Stderr, err.Error()) + build, err := codegen.Bind(schema, loadTypeMap(), dirName()) + if err != nil { + fmt.Fprintln(os.Stderr, "failed to generate code: "+err.Error()) os.Exit(1) } + build.SchemaRaw = string(schemaRaw) - buf, err := runTemplate(&e) + if *packageName != "" { + build.PackageName = *packageName + } + + buf, err := runTemplate(build) if err != nil { fmt.Fprintf(os.Stderr, "unable to generate code: "+err.Error()) os.Exit(1) @@ -131,22 +87,22 @@ func main() { func gofmt(filename string, b []byte) []byte { out, err := imports.Process(filename, b, nil) if err != nil { - fmt.Fprintln(os.Stderr, "unable to gofmt: "+*output+":"+err.Error()) + fmt.Fprintln(os.Stderr, "unable to gofmt: "+err.Error()) return b } return out } -func getPkgName() string { - pkgName := *packageName - if pkgName == "" { - absPath, err := filepath.Abs(*output) - if err != nil { - panic(err) - } - pkgName = filepath.Base(filepath.Dir(absPath)) +func absOutput() string { + absPath, err := filepath.Abs(*output) + if err != nil { + panic(err) } - return pkgName + return absPath +} + +func dirName() string { + return filepath.Dir(absOutput()) } func loadTypeMap() map[string]string { @@ -157,6 +113,12 @@ func loadTypeMap() map[string]string { "__EnumValue": "github.com/vektah/gqlgen/neelance/introspection.EnumValue", "__InputValue": "github.com/vektah/gqlgen/neelance/introspection.InputValue", "__Schema": "github.com/vektah/gqlgen/neelance/introspection.Schema", + "Int": "int", + "Float": "float64", + "String": "string", + "Boolean": "bool", + "ID": "string", + "Time": "time.Time", } b, err := ioutil.ReadFile(*typemap) if err != nil { diff --git a/templates.go b/templates.go index b6ac93d8e4a..a64cc060089 100644 --- a/templates.go +++ b/templates.go @@ -6,10 +6,11 @@ import ( "text/template" "unicode" + "github.com/vektah/gqlgen/codegen" "github.com/vektah/gqlgen/templates" ) -func runTemplate(e *extractor) (*bytes.Buffer, error) { +func runTemplate(e *codegen.Build) (*bytes.Buffer, error) { t, err := template.New("").Funcs(template.FuncMap{ "ucFirst": ucFirst, "lcFirst": lcFirst, diff --git a/templates/args.go b/templates/args.go index 69e5665fe29..ca3472f8420 100644 --- a/templates/args.go +++ b/templates/args.go @@ -3,9 +3,9 @@ package templates var argsTpl = ` {{- define "args" }} {{- range $i, $arg := . }} - var arg{{$i}} {{$arg.Type.Local }} - {{- if eq $arg.Type.FullName "time.Time" }} - if tmp, ok := field.Args[{{$arg.Name|quote}}]; ok { + var arg{{$i}} {{$arg.Signature }} + {{- if eq $arg.FullName "time.Time" }} + if tmp, ok := field.Args[{{$arg.GQLName|quote}}]; ok { if tmpStr, ok := tmp.(string); ok { tmpDate, err := time.Parse(time.RFC3339, tmpStr) if err != nil { @@ -14,22 +14,22 @@ var argsTpl = ` } arg{{$i}} = {{if $arg.Type.IsPtr}}&{{end}}tmpDate } else { - ec.Errorf("Time '{{$arg.Name}}' should be RFC3339 formatted string") + ec.Errorf("Time '{{$arg.GQLName}}' should be RFC3339 formatted string") continue } } - {{- else if eq $arg.Type.Name "map[string]interface{}" }} - if tmp, ok := field.Args[{{$arg.Name|quote}}]; ok { + {{- else if eq $arg.GoType "map[string]interface{}" }} + if tmp, ok := field.Args[{{$arg.GQLName|quote}}]; ok { {{- if $arg.Type.IsPtr }} - tmp2 := tmp.({{$arg.Type.Name}}) + tmp2 := tmp.({{$arg.GoType}}) arg{{$i}} = &tmp2 {{- else }} - arg{{$i}} = tmp.({{$arg.Type.Name}}) + arg{{$i}} = tmp.({{$arg.GoType}}) {{- end }} } - {{- else if $arg.Type.Scalar }} - if tmp, ok := field.Args[{{$arg.Name|quote}}]; ok { - tmp2, err := coerce{{$arg.Type.Name|ucFirst}}(tmp) + {{- else if $arg.IsScalar }} + if tmp, ok := field.Args[{{$arg.GQLName|quote}}]; ok { + tmp2, err := coerce{{$arg.GoType|ucFirst}}(tmp) if err != nil { ec.Error(err) continue @@ -37,7 +37,7 @@ var argsTpl = ` arg{{$i}} = {{if $arg.Type.IsPtr}}&{{end}}tmp2 } {{- else }} - err := unpackComplexArg(&arg{{$i}}, field.Args[{{$arg.Name|quote}}]) + err := unpackComplexArg(&arg{{$i}}, field.Args[{{$arg.GQLName|quote}}]) if err != nil { ec.Error(err) continue diff --git a/templates/file.go b/templates/file.go index 16dfd9f2869..d8d0aee53b2 100644 --- a/templates/file.go +++ b/templates/file.go @@ -7,8 +7,8 @@ const fileTpl = ` package {{ .PackageName }} import ( -{{- range $as, $import := .Imports }} - {{- $as }} "{{ $import }}" +{{- range $import := .Imports }} + {{- $import.Write }} {{ end }} ) @@ -46,10 +46,10 @@ func NewExecutor(resolvers Resolvers) func(context.Context, string, string, map[ var data jsonw.Writer if op.Type == query.Query { - data = c._{{.QueryRoot|lcFirst}}(op.Selections, nil) + data = c._{{.QueryRoot.GQLType|lcFirst}}(op.Selections, nil) {{- if .MutationRoot}} } else if op.Type == query.Mutation { - data = c._{{.MutationRoot|lcFirst}}(op.Selections, nil) + data = c._{{.MutationRoot.GQLType|lcFirst}}(op.Selections, nil) {{- end}} } else { return []*errors.QueryError{errors.Errorf("unsupported operation type")} diff --git a/templates/interface.go b/templates/interface.go index afc308022e9..3ff68849dd3 100644 --- a/templates/interface.go +++ b/templates/interface.go @@ -4,16 +4,16 @@ const interfaceTpl = ` {{- define "interface"}} {{- $interface := . }} -func (ec *executionContext) _{{$interface.Type.GraphQLName|lcFirst}}(sel []query.Selection, it *{{$interface.Type.Local}}) jsonw.Writer { +func (ec *executionContext) _{{$interface.GQLType|lcFirst}}(sel []query.Selection, it *{{$interface.FullName}}) jsonw.Writer { switch it := (*it).(type) { case nil: return jsonw.Null - {{- range $implementor := $interface.Type.Implementors }} - case {{$implementor.Local}}: - return ec._{{$implementor.GraphQLName|lcFirst}}(sel, &it) + {{- range $implementor := $interface.Implementors }} + case {{$implementor.FullName}}: + return ec._{{$implementor.GQLType|lcFirst}}(sel, &it) - case *{{$implementor.Local}}: - return ec._{{$implementor.GraphQLName|lcFirst}}(sel, it) + case *{{$implementor.FullName}}: + return ec._{{$implementor.GQLType|lcFirst}}(sel, it) {{- end }} default: diff --git a/templates/object.go b/templates/object.go index cf4fa498441..253fbe121d7 100644 --- a/templates/object.go +++ b/templates/object.go @@ -4,11 +4,11 @@ const objectTpl = ` {{- define "object" }} {{ $object := . }} -var {{ $object.Type.GraphQLName|lcFirst}}Implementors = {{$object.Implementors}} +var {{ $object.GQLType|lcFirst}}Implementors = {{$object.Implementors}} // nolint: gocyclo, errcheck, gas, goconst -func (ec *executionContext) _{{$object.Type.GraphQLName|lcFirst}}(sel []query.Selection, it *{{$object.Type.Local}}) jsonw.Writer { - fields := ec.collectFields(sel, {{$object.Type.GraphQLName|lcFirst}}Implementors, map[string]bool{}) +func (ec *executionContext) _{{$object.GQLType|lcFirst}}(sel []query.Selection, it *{{$object.FullName}}) jsonw.Writer { + fields := ec.collectFields(sel, {{$object.GQLType|lcFirst}}Implementors, map[string]bool{}) out := jsonw.NewOrderedMap(len(fields)) for i, field := range fields { out.Keys[i] = field.Alias @@ -16,7 +16,7 @@ func (ec *executionContext) _{{$object.Type.GraphQLName|lcFirst}}(sel []query.Se switch field.Name { {{- range $field := $object.Fields }} - case "{{$field.GraphQLName}}": + case "{{$field.GQLName}}": {{- template "args" $field.Args }} {{- if $field.IsConcurrent }} @@ -25,20 +25,20 @@ func (ec *executionContext) _{{$object.Type.GraphQLName|lcFirst}}(sel []query.Se defer ec.wg.Done() {{- end }} - {{- if $field.VarName }} - res := {{$field.VarName}} - {{- else if $field.MethodName }} + {{- if $field.GoVarName }} + res := {{$field.GoVarName}} + {{- else if $field.GoMethodName }} {{- if $field.NoErr }} - res := {{$field.MethodName}}({{ $field.CallArgs }}) + res := {{$field.GoMethodName}}({{ $field.CallArgs }}) {{- else }} - res, err := {{$field.MethodName}}({{ $field.CallArgs }}) + res, err := {{$field.GoMethodName}}({{ $field.CallArgs }}) if err != nil { ec.Error(err) {{ if $field.IsConcurrent }}return{{ else }}continue{{end}} } {{- end }} {{- else }} - res, err := ec.resolvers.{{ $object.Name }}_{{ $field.GraphQLName }}({{ $field.CallArgs }}) + res, err := ec.resolvers.{{ $object.GQLType }}_{{ $field.GQLName }}({{ $field.CallArgs }}) if err != nil { ec.Error(err) {{ if $field.IsConcurrent }}return{{ else }}continue{{end}}