From e7539f110a3d90ff9a0bce861e7c024fd91e2a02 Mon Sep 17 00:00:00 2001 From: Adam Scarr Date: Thu, 26 Apr 2018 16:42:05 +1000 Subject: [PATCH] Add an error message when using types inside inputs --- codegen/codegen.go | 36 +++++++++++++++---------------- codegen/input_build.go | 25 +++++++++++++++------ codegen/models_build.go | 5 ++++- codegen/object_build.go | 5 ++++- codegen/tests/input_union_test.go | 20 ++++++++++++++++- 5 files changed, 64 insertions(+), 27 deletions(-) diff --git a/codegen/codegen.go b/codegen/codegen.go index 284f63e7b16..7fd9ecc6a52 100644 --- a/codegen/codegen.go +++ b/codegen/codegen.go @@ -1,6 +1,7 @@ package codegen import ( + "bytes" "fmt" "go/build" "io/ioutil" @@ -21,15 +22,15 @@ type Config struct { schema *schema.Schema - ExecFilename string - ExecPackageName string - execDir string - fullExecPackageName string + ExecFilename string + ExecPackageName string + execPackagePath string + execDir string - ModelFilename string - ModelPackageName string - modelDir string - fullModelPackageName string + ModelFilename string + ModelPackageName string + modelPackagePath string + modelDir string } func Generate(cfg Config) error { @@ -46,8 +47,8 @@ func Generate(cfg Config) error { } if len(modelsBuild.Models) > 0 { modelsBuild.PackageName = cfg.ModelPackageName - - buf, err := templates.Run("models.gotpl", modelsBuild) + var buf *bytes.Buffer + buf, err = templates.Run("models.gotpl", modelsBuild) if err != nil { return errors.Wrap(err, "model generation failed") } @@ -56,11 +57,11 @@ func Generate(cfg Config) error { return err } for _, model := range modelsBuild.Models { - cfg.Typemap[model.GQLType] = cfg.fullModelPackageName + "." + model.GoType + cfg.Typemap[model.GQLType] = cfg.modelPackagePath + "." + model.GoType } for _, enum := range modelsBuild.Enums { - cfg.Typemap[enum.GQLType] = cfg.fullModelPackageName + "." + enum.GoType + cfg.Typemap[enum.GQLType] = cfg.modelPackagePath + "." + enum.GoType } } @@ -71,7 +72,8 @@ func Generate(cfg Config) error { build.SchemaRaw = cfg.SchemaStr build.PackageName = cfg.ExecPackageName - buf, err := templates.Run("generated.gotpl", build) + var buf *bytes.Buffer + buf, err = templates.Run("generated.gotpl", build) if err != nil { return errors.Wrap(err, "exec codegen failed") } @@ -91,7 +93,7 @@ func (cfg *Config) normalize() error { if cfg.ModelPackageName == "" { cfg.ModelPackageName = filepath.Base(cfg.modelDir) } - cfg.fullModelPackageName = fullPackageName(cfg.modelDir, cfg.ModelPackageName) + cfg.modelPackagePath = fullPackageName(cfg.modelDir, cfg.ModelPackageName) if cfg.ExecFilename == "" { return errors.New("ModelFilename is required") @@ -101,7 +103,7 @@ func (cfg *Config) normalize() error { if cfg.ExecPackageName == "" { cfg.ExecPackageName = filepath.Base(cfg.execDir) } - cfg.fullExecPackageName = fullPackageName(cfg.execDir, cfg.ExecPackageName) + cfg.execPackagePath = fullPackageName(cfg.execDir, cfg.ExecPackageName) builtins := map[string]string{ "__Directive": "github.com/vektah/gqlgen/neelance/introspection.Directive", @@ -145,9 +147,7 @@ func fullPackageName(dir string, pkgName string) string { for _, gopath := range filepath.SplitList(build.Default.GOPATH) { gopath = filepath.Join(gopath, "src") + string(os.PathSeparator) - if strings.HasPrefix(fullPkgName, gopath) { - fullPkgName = fullPkgName[len(gopath):] - } + fullPkgName = strings.TrimPrefix(fullPkgName, gopath) } return filepath.ToSlash(fullPkgName) } diff --git a/codegen/input_build.go b/codegen/input_build.go index 13241495a82..72817d77c2e 100644 --- a/codegen/input_build.go +++ b/codegen/input_build.go @@ -16,7 +16,10 @@ func (cfg *Config) buildInputs(namedTypes NamedTypes, prog *loader.Program, impo for _, typ := range cfg.schema.Types { switch typ := typ.(type) { case *schema.InputObject: - input := buildInput(namedTypes, typ) + input, err := buildInput(namedTypes, typ) + if err != nil { + return nil, err + } def, err := findGoType(prog, input.Package, input.GoType) if err != nil { @@ -24,7 +27,10 @@ func (cfg *Config) buildInputs(namedTypes NamedTypes, prog *loader.Program, impo } if def != nil { input.Marshaler = buildInputMarshaler(typ, def) - bindObject(def.Type(), input, imports) + err = bindObject(def.Type(), input, imports) + if err != nil { + return nil, err + } } inputs = append(inputs, input) @@ -38,17 +44,24 @@ func (cfg *Config) buildInputs(namedTypes NamedTypes, prog *loader.Program, impo return inputs, nil } -func buildInput(types NamedTypes, typ *schema.InputObject) *Object { +func buildInput(types NamedTypes, typ *schema.InputObject) (*Object, error) { obj := &Object{NamedType: types[typ.TypeName()]} for _, field := range typ.Values { - obj.Fields = append(obj.Fields, Field{ + newField := Field{ GQLName: field.Name.Name, Type: types.getType(field.Type), Object: obj, - }) + } + + if !newField.Type.IsInput && !newField.Type.IsScalar { + return nil, errors.Errorf("%s cannot be used as a field of %s. only input and scalar types are allowed", newField.GQLType, obj.GQLType) + } + + obj.Fields = append(obj.Fields, newField) + } - return obj + return obj, nil } // if user has implemented an UnmarshalGQL method on the input type manually, use it diff --git a/codegen/models_build.go b/codegen/models_build.go index 0c03ae7fcbf..d694ce08e57 100644 --- a/codegen/models_build.go +++ b/codegen/models_build.go @@ -24,7 +24,10 @@ func (cfg *Config) buildModels(types NamedTypes, prog *loader.Program) ([]Model, } model = cfg.obj2Model(obj) case *schema.InputObject: - obj := buildInput(types, typ) + obj, err := buildInput(types, typ) + if err != nil { + return nil, err + } if obj.IsUserDefined { continue } diff --git a/codegen/object_build.go b/codegen/object_build.go index 50c2c65c94b..3b9e092dc71 100644 --- a/codegen/object_build.go +++ b/codegen/object_build.go @@ -25,7 +25,10 @@ func (cfg *Config) buildObjects(types NamedTypes, prog *loader.Program, imports return nil, err } if def != nil { - bindObject(def.Type(), obj, imports) + err = bindObject(def.Type(), obj, imports) + if err != nil { + return nil, err + } } objects = append(objects, obj) diff --git a/codegen/tests/input_union_test.go b/codegen/tests/input_union_test.go index a9ee295037e..f20d702b391 100644 --- a/codegen/tests/input_union_test.go +++ b/codegen/tests/input_union_test.go @@ -7,7 +7,7 @@ import ( "github.com/vektah/gqlgen/codegen" ) -func TestInputUnion(t *testing.T) { +func TestTypeUnionAsInput(t *testing.T) { err := codegen.Generate(codegen.Config{ SchemaStr: ` type Query { @@ -22,3 +22,21 @@ func TestInputUnion(t *testing.T) { require.EqualError(t, err, "model plan failed: Bookmarkable! cannot be used as argument of Query.addBookmark. only input and scalar types are allowed") } + +func TestTypeInInput(t *testing.T) { + err := codegen.Generate(codegen.Config{ + SchemaStr: ` + type Query { + addBookmark(b: BookmarkableInput!): Boolean! + } + type Item {} + input BookmarkableInput { + item: Item + } + `, + ExecFilename: "gen/typeinput/exec.go", + ModelFilename: "gen/typeinput/model.go", + }) + + require.EqualError(t, err, "model plan failed: Item cannot be used as a field of BookmarkableInput. only input and scalar types are allowed") +}