Skip to content

Commit

Permalink
Split model generation into its own stage
Browse files Browse the repository at this point in the history
  • Loading branch information
vektah committed Feb 25, 2018
1 parent 926384d commit cf580c2
Show file tree
Hide file tree
Showing 19 changed files with 252 additions and 140 deletions.
46 changes: 36 additions & 10 deletions codegen/build.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ import (
type Build struct {
PackageName string
Objects Objects
Models Objects
Inputs Objects
Interfaces []*Interface
Imports Imports
Expand All @@ -23,12 +22,38 @@ type Build struct {
SchemaRaw string
}

type ModelBuild struct {
PackageName string
Imports Imports
Models Objects
}

// Create a list of models that need to be generated
func Models(schema *schema.Schema, userTypes map[string]string, destDir string) *ModelBuild {
namedTypes := buildNamedTypes(schema, userTypes)

imports := buildImports(namedTypes, destDir)
prog, err := loadProgram(imports, true)
if err != nil {
panic(err)
}

bindTypes(imports, namedTypes, prog)

models := buildModels(namedTypes, schema)
return &ModelBuild{
PackageName: filepath.Base(destDir),
Models: models,
Imports: buildImports(namedTypes, destDir),
}
}

// Bind a schema together with some code to generate a Build
func Bind(schema *schema.Schema, userTypes map[string]string, destDir string) (*Build, error) {
namedTypes := buildNamedTypes(schema, userTypes)

imports := buildImports(namedTypes, destDir)
prog, err := loadProgram(imports)
prog, err := loadProgram(imports, false)
if err != nil {
return nil, err
}
Expand All @@ -37,12 +62,10 @@ func Bind(schema *schema.Schema, userTypes map[string]string, destDir string) (*

objects := buildObjects(namedTypes, schema, prog, imports)
inputs := buildInputs(namedTypes, schema, prog, imports)
models := append(findMissing(objects), findMissing(inputs)...)

b := &Build{
PackageName: filepath.Base(destDir),
Objects: objects,
Models: models,
Interfaces: buildInterfaces(namedTypes, schema),
Inputs: inputs,
Imports: imports,
Expand Down Expand Up @@ -83,12 +106,15 @@ func Bind(schema *schema.Schema, userTypes map[string]string, destDir string) (*
return b, nil
}

func loadProgram(imports Imports) (*loader.Program, error) {
conf := loader.Config{
AllowErrors: true,
TypeChecker: types.Config{
Error: func(e error) {},
},
func loadProgram(imports Imports, allowErrors bool) (*loader.Program, error) {
conf := loader.Config{}
if allowErrors {
conf = loader.Config{
AllowErrors: true,
TypeChecker: types.Config{
Error: func(e error) {},
},
}
}
for _, imp := range imports {
if imp.Package != "" {
Expand Down
2 changes: 1 addition & 1 deletion codegen/input_build.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ func buildInputs(namedTypes NamedTypes, s *schema.Schema, prog *loader.Program,
}

func buildInput(types NamedTypes, typ *schema.InputObject) *Object {
obj := &Object{NamedType: types[typ.TypeName()]}
obj := &Object{NamedType: types[typ.TypeName()], Input: true}

for _, field := range typ.Values {
obj.Fields = append(obj.Fields, Field{
Expand Down
59 changes: 59 additions & 0 deletions codegen/models_build.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
package codegen

import (
"sort"
"strings"

"github.com/vektah/gqlgen/neelance/schema"
)

func buildModels(types NamedTypes, s *schema.Schema) Objects {
var models Objects

for _, typ := range s.Types {
var model *Object
switch typ := typ.(type) {
case *schema.Object:
model = buildObject(types, typ, s)

case *schema.InputObject:
model = buildInput(types, typ)
}

if model == nil || model.Root || model.GoType != "" {
continue
}

bindGenerated(types, model)

models = append(models, model)
}

sort.Slice(models, func(i, j int) bool {
return strings.Compare(models[i].GQLType, models[j].GQLType) == -1
})

return models
}

func bindGenerated(types NamedTypes, object *Object) {
object.GoType = ucFirst(object.GQLType)
object.Marshaler = &Ref{GoType: object.GoType}

for i := range object.Fields {
field := &object.Fields[i]

if field.IsScalar {
field.GoVarName = ucFirst(field.GQLName)
if field.GoVarName == "Id" {
field.GoVarName = "ID"
}
} else if object.Input {
field.GoFKName = ucFirst(field.GQLName)
field.GoFKType = types[field.GQLType].GoType
} else {
field.GoFKName = ucFirst(field.GQLName) + "ID"
field.GoFKType = "int" // todo: use schema to determine type of id?
}
}
}
1 change: 1 addition & 0 deletions codegen/object.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ type Object struct {
Root bool
DisableConcurrency bool
Stream bool
Input bool
}

type Field struct {
Expand Down
70 changes: 17 additions & 53 deletions codegen/object_build.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ func buildObjects(types NamedTypes, s *schema.Schema, prog *loader.Program, impo
for _, typ := range s.Types {
switch typ := typ.(type) {
case *schema.Object:
obj := buildObject(types, typ)
obj := buildObject(types, typ, s)

def, err := findGoType(prog, obj.Package, obj.GoType)
if err != nil {
Expand All @@ -30,65 +30,14 @@ func buildObjects(types NamedTypes, s *schema.Schema, prog *loader.Program, impo
}
}

for name, typ := range s.EntryPoints {
obj := typ.(*schema.Object)
objects.ByName(obj.Name).Root = true
if name == "mutation" {
objects.ByName(obj.Name).DisableConcurrency = true
}
if name == "subscription" {
objects.ByName(obj.Name).Stream = true
}
}

sort.Slice(objects, func(i, j int) bool {
return strings.Compare(objects[i].GQLType, objects[j].GQLType) == -1
})

return objects
}

func findMissing(objects Objects) Objects {
var missingObjects Objects

for _, object := range objects {
if !object.Generated || object.Root {
continue
}
object.GoType = ucFirst(object.GQLType)
object.Marshaler = &Ref{GoType: object.GoType}

for i := range object.Fields {
field := &object.Fields[i]

if field.IsScalar {
field.GoVarName = ucFirst(field.GQLName)
if field.GoVarName == "Id" {
field.GoVarName = "ID"
}
} else {
field.GoFKName = ucFirst(field.GQLName) + "ID"
field.GoFKType = "int"

for _, f := range objects.ByName(field.Type.GQLType).Fields {
if strings.EqualFold(f.GQLName, "id") {
field.GoFKType = f.GoType
}
}
}
}

missingObjects = append(missingObjects, object)
}

sort.Slice(missingObjects, func(i, j int) bool {
return strings.Compare(missingObjects[i].GQLType, missingObjects[j].GQLType) == -1
})

return missingObjects
}

func buildObject(types NamedTypes, typ *schema.Object) *Object {
func buildObject(types NamedTypes, typ *schema.Object, s *schema.Schema) *Object {
obj := &Object{NamedType: types[typ.TypeName()]}

for _, i := range typ.Interfaces {
Expand Down Expand Up @@ -118,5 +67,20 @@ func buildObject(types NamedTypes, typ *schema.Object) *Object {
Object: obj,
})
}

for name, typ := range s.EntryPoints {
schemaObj := typ.(*schema.Object)
if schemaObj.TypeName() != obj.GQLType {
continue
}

obj.Root = true
if name == "mutation" {
obj.DisableConcurrency = true
}
if name == "subscription" {
obj.Stream = true
}
}
return obj
}
Loading

0 comments on commit cf580c2

Please sign in to comment.