Skip to content

Commit

Permalink
Generate resolver if configured
Browse files Browse the repository at this point in the history
  • Loading branch information
creativej committed Jul 31, 2018
1 parent 7031264 commit 58831ac
Show file tree
Hide file tree
Showing 10 changed files with 209 additions and 59 deletions.
61 changes: 52 additions & 9 deletions codegen/build.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,19 +31,28 @@ type ModelBuild struct {
Enums []Enum
}

type ResolverBuild struct {
PackageName string
Imports []*Import
ResolverType string
Objects Objects
ResolverFound bool
}

// Create a list of models that need to be generated
func (cfg *Config) models() (*ModelBuild, error) {
namedTypes := cfg.buildNamedTypes()

prog, err := cfg.loadProgram(namedTypes, true)
progLoader := newLoader(namedTypes, true)
prog, err := progLoader.Load()
if err != nil {
return nil, errors.Wrap(err, "loading failed")
}
imports := buildImports(namedTypes, cfg.Model.Dir())

cfg.bindTypes(imports, namedTypes, cfg.Model.Dir(), prog)

models, err := cfg.buildModels(namedTypes, prog)
models, err := cfg.buildModels(namedTypes, prog, imports)
if err != nil {
return nil, err
}
Expand All @@ -55,11 +64,47 @@ func (cfg *Config) models() (*ModelBuild, error) {
}, nil
}

// bind a schema together with some code to generate a Build
func (cfg *Config) resolver() (*ResolverBuild, error) {
progLoader := newLoader(cfg.buildNamedTypes(), true)
progLoader.Import(cfg.Resolver.ImportPath())

prog, err := progLoader.Load()
if err != nil {
return nil, err
}

destDir := cfg.Resolver.Dir()

namedTypes := cfg.buildNamedTypes()
imports := buildImports(namedTypes, destDir)
imports.add(cfg.Exec.ImportPath())

cfg.bindTypes(imports, namedTypes, destDir, prog)

objects, err := cfg.buildObjects(namedTypes, prog, imports)
if err != nil {
return nil, err
}

def, _ := findGoType(prog, cfg.Resolver.ImportPath(), cfg.Resolver.Type)
resolverFound := def != nil

return &ResolverBuild{
PackageName: cfg.Resolver.Package,
Imports: imports.finalize(),
Objects: objects,
ResolverType: cfg.Resolver.Type,
ResolverFound: resolverFound,
}, nil
}

// bind a schema together with some code to generate a Build
func (cfg *Config) bind() (*Build, error) {
namedTypes := cfg.buildNamedTypes()

prog, err := cfg.loadProgram(namedTypes, true)
progLoader := newLoader(namedTypes, true)
prog, err := progLoader.Load()
if err != nil {
return nil, errors.Wrap(err, "loading failed")
}
Expand Down Expand Up @@ -105,13 +150,12 @@ func (cfg *Config) bind() (*Build, error) {
}

func (cfg *Config) validate() error {
namedTypes := cfg.buildNamedTypes()

_, err := cfg.loadProgram(namedTypes, false)
progLoader := newLoader(cfg.buildNamedTypes(), false)
_, err := progLoader.Load()
return err
}

func (cfg *Config) loadProgram(namedTypes NamedTypes, allowErrors bool) (*loader.Program, error) {
func newLoader(namedTypes NamedTypes, allowErrors bool) loader.Config {
conf := loader.Config{}
if allowErrors {
conf = loader.Config{
Expand All @@ -130,8 +174,7 @@ func (cfg *Config) loadProgram(namedTypes NamedTypes, allowErrors bool) (*loader
conf.Import(imp.Package)
}
}

return conf.Load()
return conf
}

func resolvePkg(pkgName string) (string, error) {
Expand Down
85 changes: 39 additions & 46 deletions codegen/codegen.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
package codegen

import (
"bytes"
"fmt"
"io/ioutil"
"log"
"os"
"path/filepath"
"regexp"
Expand All @@ -14,7 +12,6 @@ import (
"github.com/vektah/gqlparser"
"github.com/vektah/gqlparser/ast"
"github.com/vektah/gqlparser/gqlerror"
"golang.org/x/tools/imports"
)

func Generate(cfg Config) error {
Expand All @@ -30,15 +27,10 @@ func Generate(cfg Config) error {
return errors.Wrap(err, "model plan failed")
}
if len(modelsBuild.Models) > 0 || len(modelsBuild.Enums) > 0 {
var buf *bytes.Buffer
buf, err = templates.Run("models.gotpl", modelsBuild)
if err != nil {
return errors.Wrap(err, "model generation failed")
}

if err = write(cfg.Model.Filename, buf.Bytes()); err != nil {
if err = templates.RenderToFile("models.gotpl", cfg.Model.Filename, modelsBuild); err != nil {
return err
}

for _, model := range modelsBuild.Models {
modelCfg := cfg.Models[model.GQLType]
modelCfg.Model = cfg.Model.ImportPath() + "." + model.GoType
Expand All @@ -57,23 +49,46 @@ func Generate(cfg Config) error {
return errors.Wrap(err, "exec plan failed")
}

var buf *bytes.Buffer
buf, err = templates.Run("generated.gotpl", build)
if err != nil {
return errors.Wrap(err, "exec codegen failed")
if err := templates.RenderToFile("generated.gotpl", cfg.Exec.Filename, build); err != nil {
return err
}

if err = write(cfg.Exec.Filename, buf.Bytes()); err != nil {
return err
if cfg.Resolver.IsDefined() {
if err := generateResolver(cfg); err != nil {
return errors.Wrap(err, "generating resolver failed")
}
}

if err = cfg.validate(); err != nil {
if err := cfg.validate(); err != nil {
return errors.Wrap(err, "validation failed")
}

return nil
}

func generateResolver(cfg Config) error {
resolverBuild, err := cfg.resolver()
if err != nil {
return errors.Wrap(err, "resolver build failed")
}
filename := cfg.Resolver.Filename

if resolverBuild.ResolverFound {
log.Printf("Skipped resolver: %s.%s already exists\n", cfg.Resolver.ImportPath(), cfg.Resolver.Type)
return nil
}

if _, err := os.Stat(filename); os.IsNotExist(errors.Cause(err)) {
if err := templates.RenderToFile("resolver.gotpl", filename, resolverBuild); err != nil {
return err
}
} else {
log.Printf("Skipped resolver: %s already exists\n", filename)
}

return nil
}

func (cfg *Config) normalize() error {
if err := cfg.Model.normalize(); err != nil {
return errors.Wrap(err, "model")
Expand All @@ -83,6 +98,12 @@ func (cfg *Config) normalize() error {
return errors.Wrap(err, "exec")
}

if cfg.Resolver.IsDefined() {
if err := cfg.Resolver.normalize(); err != nil {
return errors.Wrap(err, "resolver")
}
}

builtins := TypeMap{
"__Directive": {Model: "github.com/vektah/gqlgen/graphql/introspection.Directive"},
"__Type": {Model: "github.com/vektah/gqlgen/graphql/introspection.Type"},
Expand Down Expand Up @@ -129,31 +150,3 @@ func abs(path string) string {
}
return filepath.ToSlash(absPath)
}

func gofmt(filename string, b []byte) ([]byte, error) {
out, err := imports.Process(filename, b, nil)
if err != nil {
return b, errors.Wrap(err, "unable to gofmt")
}
return out, nil
}

func write(filename string, b []byte) error {
err := os.MkdirAll(filepath.Dir(filename), 0755)
if err != nil {
return errors.Wrap(err, "failed to create directory")
}

formatted, err := gofmt(filename, b)
if err != nil {
fmt.Fprintf(os.Stderr, "gofmt failed: %s\n", err.Error())
formatted = b
}

err = ioutil.WriteFile(filename, formatted, 0644)
if err != nil {
return errors.Wrapf(err, "failed to write %s", filename)
}

return nil
}
9 changes: 9 additions & 0 deletions codegen/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ type Config struct {
SchemaStr string `yaml:"-"`
Exec PackageConfig `yaml:"exec"`
Model PackageConfig `yaml:"model"`
Resolver PackageConfig `yaml:"resolver,omitempty"`
Models TypeMap `yaml:"models,omitempty"`

schema *ast.Schema `yaml:"-"`
Expand All @@ -68,6 +69,7 @@ type Config struct {
type PackageConfig struct {
Filename string `yaml:"filename,omitempty"`
Package string `yaml:"package,omitempty"`
Type string `yaml:"type,omitempty"`
}

type TypeMapEntry struct {
Expand Down Expand Up @@ -128,6 +130,10 @@ func (c *PackageConfig) Check() error {
return nil
}

func (c *PackageConfig) IsDefined() bool {
return c.Filename != ""
}

func (cfg *Config) Check() error {
if err := cfg.Models.Check(); err != nil {
return errors.Wrap(err, "config.models")
Expand All @@ -138,6 +144,9 @@ func (cfg *Config) Check() error {
if err := cfg.Model.Check(); err != nil {
return errors.Wrap(err, "config.model")
}
if err := cfg.Resolver.Check(); err != nil {
return errors.Wrap(err, "config.resolver")
}
return nil
}

Expand Down
6 changes: 6 additions & 0 deletions codegen/interface_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ package codegen
import (
"testing"

"syscall"

"github.com/stretchr/testify/require"
"golang.org/x/tools/go/loader"
)
Expand Down Expand Up @@ -41,7 +43,11 @@ func generate(name string, schema string, typemap ...TypeMap) error {
SchemaStr: schema,
Exec: PackageConfig{Filename: "tests/gen/" + name + "/exec.go"},
Model: PackageConfig{Filename: "tests/gen/" + name + "/model.go"},
Resolver: PackageConfig{Filename: "tests/gen/" + name + "/resolver.go", Type: "Resolver"},
}

_ = syscall.Unlink(cfg.Resolver.Filename)

if len(typemap) > 0 {
cfg.Models = typemap[0]
}
Expand Down
4 changes: 2 additions & 2 deletions codegen/models_build.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,14 @@ import (
"golang.org/x/tools/go/loader"
)

func (cfg *Config) buildModels(types NamedTypes, prog *loader.Program) ([]Model, error) {
func (cfg *Config) buildModels(types NamedTypes, prog *loader.Program, imports *Imports) ([]Model, error) {
var models []Model

for _, typ := range cfg.schema.Types {
var model Model
switch typ.Kind {
case ast.Object:
obj, err := cfg.buildObject(types, typ)
obj, err := cfg.buildObject(types, typ, imports)
if err != nil {
return nil, err
}
Expand Down
10 changes: 10 additions & 0 deletions codegen/object.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ type Object struct {

Fields []Field
Satisfies []string
ResolverInterface *Ref
Root bool
DisableConcurrency bool
Stream bool
Expand Down Expand Up @@ -77,6 +78,15 @@ func (f *Field) ShortInvocation() string {
return fmt.Sprintf("%s().%s(%s)", f.Object.GQLType, shortName, f.CallArgs())
}

func (f *Field) ResolverType() string {
if !f.IsResolver() {
return ""
}
shortName := strings.ToUpper(f.GQLName[:1]) + f.GQLName[1:]

return fmt.Sprintf("%s().%s(%s)", f.Object.GQLType, shortName, f.CallArgs())
}

func (f *Field) ShortResolverDeclaration() string {
if !f.IsResolver() {
return ""
Expand Down
7 changes: 5 additions & 2 deletions codegen/object_build.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ func (cfg *Config) buildObjects(types NamedTypes, prog *loader.Program, imports
continue
}

obj, err := cfg.buildObject(types, typ)
obj, err := cfg.buildObject(types, typ, imports)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -81,10 +81,13 @@ func sanitizeGoName(name string) string {
return name
}

func (cfg *Config) buildObject(types NamedTypes, typ *ast.Definition) (*Object, error) {
func (cfg *Config) buildObject(types NamedTypes, typ *ast.Definition, imports *Imports) (*Object, error) {
obj := &Object{NamedType: types[typ.Name]}
typeEntry, entryExists := cfg.Models[typ.Name]

imp := imports.findByPath(cfg.Exec.ImportPath())
obj.ResolverInterface = &Ref{GoType: obj.GQLType + "Resolver", Import: imp}

if typ == cfg.schema.Query {
obj.Root = true
}
Expand Down
1 change: 1 addition & 0 deletions codegen/templates/data.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit 58831ac

Please sign in to comment.