diff --git a/Gopkg.lock b/Gopkg.lock index cc59153f348..a10062414a5 100644 --- a/Gopkg.lock +++ b/Gopkg.lock @@ -26,7 +26,7 @@ version = "v3.3.2" [[projects]] - digest = "1:78907d832e27dbfc6e3fdfc52bd2e5e2e05c1d0e3789d4825b824489fbeab233" + digest = "1:f3df613325a793ffb3d0ce7644a3bb6f62db45ac744dafe20172fe999c61cdbf" name = "github.com/gogo/protobuf" packages = [ "io", @@ -60,6 +60,17 @@ revision = "ea4d1f681babbce9545c9c5f3d5194a789c89f5b" version = "v1.2.0" +[[projects]] + branch = "master" + digest = "1:cf296baa185baae04a9a7004efee8511d08e2f5f51d4cbe5375da89722d681db" + name = "github.com/hashicorp/golang-lru" + packages = [ + ".", + "simplelru", + ] + pruneopts = "UT" + revision = "0fb14efe8c47ae851c0034ed7a448854d3d34cf3" + [[projects]] digest = "1:870d441fe217b8e689d7949fef6e43efbc787e50f200cb1e70dbca9204a1d6be" name = "github.com/inconshreveable/mousetrap" @@ -96,7 +107,7 @@ version = "v1.0.0" [[projects]] - digest = "1:27af6024faa3c28426a698b8c653be0fd908bc96e25b7d76f2192eb342427db6" + digest = "1:450b7623b185031f3a456801155c8320209f75d0d4c4e633c6b1e59d44d6e392" name = "github.com/opentracing/opentracing-go" packages = [ ".", @@ -140,7 +151,7 @@ revision = "ffb13db8def02f545acc58bd288ec6057c2bbfb9" [[projects]] - digest = "1:872fa275c31e1f9db31d66fa9b1d4a7bb9a080ff184e6977da01f36bfbe07f11" + digest = "1:645cabccbb4fa8aab25a956cbcbdf6a6845ca736b2c64e197ca7cbb9d210b939" name = "github.com/spf13/cobra" packages = ["."] pruneopts = "UT" @@ -156,7 +167,7 @@ version = "v1.0.1" [[projects]] - digest = "1:73697231b93fb74a73ebd8384b68b9a60c57ea6b13c56d2425414566a72c8e6d" + digest = "1:7e8d267900c7fa7f35129a2a37596e38ed0f11ca746d6d9ba727980ee138f9f6" name = "github.com/stretchr/testify" packages = [ "assert", @@ -200,7 +211,7 @@ [[projects]] branch = "master" - digest = "1:77fe642412bfed48743e2b75163e3ab5c430cfe22dd488788647b89b28794635" + digest = "1:3cbc05413b8aac22b1f6d4350ed696b5a83a8515a4136db8f1ec3a0aee3d76e1" name = "golang.org/x/tools" packages = [ "go/ast/astutil", @@ -221,7 +232,7 @@ [[projects]] branch = "master" - digest = "1:7ddb3a7b35cc853fe0db36a1b2473bdff03f28add7d28e4725e692603111266e" + digest = "1:741ebea9214cc226789d3003baeca9b169e04b5b336fb1a3b2c16e75bd296bb5" name = "sourcegraph.com/sourcegraph/appdash" packages = [ ".", @@ -237,7 +248,7 @@ [[projects]] branch = "master" - digest = "1:be108b48d79c3b3c345811a57a47ee87fdbe895beb4bb56239da71d4943e5be7" + digest = "1:8e0a2957fe342f22d70a543c3fcdf390f7627419c3d82d87ab4fd715a9ef5716" name = "sourcegraph.com/sourcegraph/appdash-data" packages = ["."] pruneopts = "UT" @@ -249,6 +260,7 @@ input-imports = [ "github.com/go-chi/chi", "github.com/gorilla/websocket", + "github.com/hashicorp/golang-lru", "github.com/mitchellh/mapstructure", "github.com/opentracing-contrib/go-stdlib/nethttp", "github.com/opentracing/opentracing-go", diff --git a/cmd/gen.go b/cmd/gen.go index 6a529ac5b5e..a941eda8693 100644 --- a/cmd/gen.go +++ b/cmd/gen.go @@ -8,7 +8,6 @@ import ( "github.com/99designs/gqlgen/codegen" "github.com/pkg/errors" "github.com/spf13/cobra" - "gopkg.in/yaml.v2" ) func init() { @@ -39,7 +38,6 @@ var genCmd = &cobra.Command{ } // overwrite by commandline options - var emitYamlGuidance bool if schemaFilename != "" { config.SchemaFilename = schemaFilename } @@ -55,10 +53,6 @@ var genCmd = &cobra.Command{ if modelPackageName != "" { config.Model.Package = modelPackageName } - if typemap != "" { - config.Models = loadModelMap() - emitYamlGuidance = true - } schemaRaw, err := ioutil.ReadFile(config.SchemaFilename) if err != nil { @@ -72,17 +66,6 @@ var genCmd = &cobra.Command{ os.Exit(1) } - if emitYamlGuidance { - var b []byte - b, err = yaml.Marshal(config) - if err != nil { - fmt.Fprintln(os.Stderr, "unable to marshal yaml: "+err.Error()) - os.Exit(1) - } - - fmt.Fprintf(os.Stderr, "DEPRECATION WARNING: we are moving away from the json typemap, instead create a gqlgen.yml with the following content:\n\n%s\n", string(b)) - } - err = codegen.Generate(*config) if err != nil { fmt.Fprintln(os.Stderr, err.Error()) @@ -90,26 +73,3 @@ var genCmd = &cobra.Command{ } }, } - -func loadModelMap() codegen.TypeMap { - var goTypes map[string]string - b, err := ioutil.ReadFile(typemap) - if err != nil { - fmt.Fprintln(os.Stderr, "unable to open typemap: "+err.Error()) - return nil - } - - if err = yaml.Unmarshal(b, &goTypes); err != nil { - fmt.Fprintln(os.Stderr, "unable to parse typemap: "+err.Error()) - os.Exit(1) - } - - typeMap := make(codegen.TypeMap) - for typeName, entityPath := range goTypes { - typeMap[typeName] = codegen.TypeMapEntry{ - Model: entityPath, - } - } - - return typeMap -} diff --git a/cmd/init.go b/cmd/init.go index 4b7792f6c2f..f8b31945cfd 100644 --- a/cmd/init.go +++ b/cmd/init.go @@ -141,9 +141,6 @@ func initConfig() *codegen.Config { if modelPackageName != "" { config.Model.Package = modelPackageName } - if typemap != "" { - config.Models = loadModelMap() - } var buf bytes.Buffer buf.WriteString(strings.TrimSpace(configComment)) diff --git a/cmd/root.go b/cmd/root.go index 8598acd315c..1e9894cedbf 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -16,7 +16,6 @@ var verbose bool var output string var models string var schemaFilename string -var typemap string var packageName string var modelPackageName string var serverFilename string @@ -29,7 +28,6 @@ func init() { rootCmd.PersistentFlags().StringVar(&output, "out", "", "the file to write to") rootCmd.PersistentFlags().StringVar(&models, "models", "", "the file to write the models to") rootCmd.PersistentFlags().StringVar(&schemaFilename, "schema", "", "the graphql schema to generate types from") - rootCmd.PersistentFlags().StringVar(&typemap, "typemap", "", "a json map going from graphql to golang types") rootCmd.PersistentFlags().StringVar(&packageName, "package", "", "the package name") rootCmd.PersistentFlags().StringVar(&modelPackageName, "modelpackage", "", "the package name to use for models") } diff --git a/codegen/config.go b/codegen/config.go index 8e2a5e74b31..1ded1c41758 100644 --- a/codegen/config.go +++ b/codegen/config.go @@ -65,6 +65,7 @@ type Config struct { Model PackageConfig `yaml:"model"` Resolver PackageConfig `yaml:"resolver,omitempty"` Models TypeMap `yaml:"models,omitempty"` + StructTag string `yaml:"struct_tag,omitempty"` FilePath string `yaml:"-"` diff --git a/codegen/input_build.go b/codegen/input_build.go index c333201522d..1059601a3c5 100644 --- a/codegen/input_build.go +++ b/codegen/input_build.go @@ -27,7 +27,7 @@ func (cfg *Config) buildInputs(namedTypes NamedTypes, prog *loader.Program, impo } if def != nil { input.Marshaler = buildInputMarshaler(typ, def) - bindErrs := bindObject(def.Type(), input, imports) + bindErrs := bindObject(def.Type(), input, imports, cfg.StructTag) if len(bindErrs) > 0 { return nil, bindErrs } diff --git a/codegen/object_build.go b/codegen/object_build.go index 95602a9138d..686b2bfe2a4 100644 --- a/codegen/object_build.go +++ b/codegen/object_build.go @@ -28,7 +28,7 @@ func (cfg *Config) buildObjects(types NamedTypes, prog *loader.Program, imports return nil, err } if def != nil { - for _, bindErr := range bindObject(def.Type(), obj, imports) { + for _, bindErr := range bindObject(def.Type(), obj, imports, cfg.StructTag) { log.Println(bindErr.Error()) log.Println(" Adding resolver method") } diff --git a/codegen/util.go b/codegen/util.go index fae94adeadd..1849f100bb1 100644 --- a/codegen/util.go +++ b/codegen/util.go @@ -3,6 +3,7 @@ package codegen import ( "fmt" "go/types" + "reflect" "regexp" "strings" @@ -104,19 +105,50 @@ func findMethod(typ *types.Named, name string) *types.Func { return nil } -func findField(typ *types.Struct, name string) *types.Var { +// findField attempts to match the name to a struct field with the following +// priorites: +// 1. If struct tag is passed then struct tag has highest priority +// 2. Field in an embedded struct +// 3. Actual Field name +func findField(typ *types.Struct, name, structTag string) (*types.Var, error) { + var foundField *types.Var + foundFieldWasTag := false + for i := 0; i < typ.NumFields(); i++ { field := typ.Field(i) + + if structTag != "" { + tags := reflect.StructTag(typ.Tag(i)) + if val, ok := tags.Lookup(structTag); ok { + if strings.EqualFold(val, name) { + if foundField != nil && foundFieldWasTag { + return nil, errors.Errorf("tag %s is ambigious; multiple fields have the same tag value of %s", structTag, val) + } + + foundField = field + foundFieldWasTag = true + } + } + } + if field.Anonymous() { if named, ok := field.Type().(*types.Struct); ok { - if f := findField(named, name); f != nil { - return f + f, err := findField(named, name, structTag) + if err != nil && !strings.HasPrefix(err.Error(), "no field named") { + return nil, err + } + if f != nil && foundField == nil { + foundField = f } } if named, ok := field.Type().Underlying().(*types.Struct); ok { - if f := findField(named, name); f != nil { - return f + f, err := findField(named, name, structTag) + if err != nil && !strings.HasPrefix(err.Error(), "no field named") { + return nil, err + } + if f != nil && foundField == nil { + foundField = f } } } @@ -125,11 +157,16 @@ func findField(typ *types.Struct, name string) *types.Var { continue } - if strings.EqualFold(field.Name(), name) { - return field + if strings.EqualFold(field.Name(), name) && foundField == nil { + foundField = field } } - return nil + + if foundField == nil { + return nil, fmt.Errorf("no field named %s", name) + } + + return foundField, nil } type BindError struct { @@ -161,7 +198,7 @@ func (b BindErrors) Error() string { return strings.Join(errs, "\n\n") } -func bindObject(t types.Type, object *Object, imports *Imports) BindErrors { +func bindObject(t types.Type, object *Object, imports *Imports, structTag string) BindErrors { var errs BindErrors for i := range object.Fields { field := &object.Fields[i] @@ -177,7 +214,7 @@ func bindObject(t types.Type, object *Object, imports *Imports) BindErrors { } // otherwise try binding to a var - varErr := bindVar(imports, t, field) + varErr := bindVar(imports, t, field, structTag) if varErr != nil { errs = append(errs, BindError{ @@ -231,7 +268,7 @@ func bindMethod(imports *Imports, t types.Type, field *Field) error { return nil } -func bindVar(imports *Imports, t types.Type, field *Field) error { +func bindVar(imports *Imports, t types.Type, field *Field, structTag string) error { underlying, ok := t.Underlying().(*types.Struct) if !ok { return fmt.Errorf("not a struct") @@ -241,9 +278,9 @@ func bindVar(imports *Imports, t types.Type, field *Field) error { if field.GoFieldName != "" { goName = field.GoFieldName } - structField := findField(underlying, goName) - if structField == nil { - return fmt.Errorf("no field named %s", field.GQLName) + structField, err := findField(underlying, goName, structTag) + if err != nil { + return err } if err := validateTypeBinding(imports, field, structField.Type()); err != nil { diff --git a/codegen/util_test.go b/codegen/util_test.go index cb41170ddb6..aedfda04694 100644 --- a/codegen/util_test.go +++ b/codegen/util_test.go @@ -1,6 +1,11 @@ package codegen import ( + "go/ast" + "go/importer" + "go/parser" + "go/token" + "go/types" "testing" "github.com/stretchr/testify/require" @@ -12,3 +17,84 @@ func TestNormalizeVendor(t *testing.T) { require.Equal(t, "*bar/baz", normalizeVendor("*foo/vendor/bar/baz")) require.Equal(t, "*[]*bar/baz", normalizeVendor("*[]*foo/vendor/bar/baz")) } + +func TestFindField(t *testing.T) { + input := ` +package test + +type Std struct { + Name string + Value int +} +type Anon struct { + Name string + Tags +} +type Tags struct { + Bar string ` + "`" + `gqlgen:"foo"` + "`" + ` + Foo int ` + "`" + `gqlgen:"bar"` + "`" + ` +} +type Amb struct { + Bar string ` + "`" + `gqlgen:"foo"` + "`" + ` + Foo int ` + "`" + `gqlgen:"foo"` + "`" + ` +} +type Embed struct { + Std + Test string +} +` + scope, err := parseScope(input, "test") + require.NoError(t, err) + + std := scope.Lookup("Std").Type().Underlying().(*types.Struct) + anon := scope.Lookup("Anon").Type().Underlying().(*types.Struct) + tags := scope.Lookup("Tags").Type().Underlying().(*types.Struct) + amb := scope.Lookup("Amb").Type().Underlying().(*types.Struct) + embed := scope.Lookup("Embed").Type().Underlying().(*types.Struct) + + tests := []struct { + Name string + Struct *types.Struct + Field string + Tag string + Expected string + ShouldError bool + }{ + {"Finds a field by name with no tag", std, "name", "", "Name", false}, + {"Finds a field by name when passed tag but tag not used", std, "name", "gqlgen", "Name", false}, + {"Ignores tags when not passed a tag", tags, "foo", "", "Foo", false}, + {"Picks field with tag over field name when passed a tag", tags, "foo", "gqlgen", "Bar", false}, + {"Errors when ambigious", amb, "foo", "gqlgen", "", true}, + {"Finds a field that is in embedded struct", anon, "bar", "", "Bar", false}, + {"Finds field that is not in embedded struct", embed, "test", "", "Test", false}, + } + + for _, tt := range tests { + tt := tt + field, err := findField(tt.Struct, tt.Field, tt.Tag) + if tt.ShouldError { + require.Nil(t, field, tt.Name) + require.Error(t, err, tt.Name) + } else { + require.NoError(t, err, tt.Name) + require.Equal(t, tt.Expected, field.Name(), tt.Name) + } + } +} + +func parseScope(input interface{}, packageName string) (*types.Scope, error) { + // test setup to parse the types + fset := token.NewFileSet() + f, err := parser.ParseFile(fset, "test.go", input, 0) + if err != nil { + return nil, err + } + + conf := types.Config{Importer: importer.Default()} + pkg, err := conf.Check(packageName, fset, []*ast.File{f}, nil) + if err != nil { + return nil, err + } + + return pkg.Scope(), nil +} diff --git a/docs/content/config.md b/docs/content/config.md index 8415fa37684..c6052ee9edc 100644 --- a/docs/content/config.md +++ b/docs/content/config.md @@ -27,6 +27,9 @@ resolver: filename: resolver.go # where to write them type: Resolver # whats the resolver root implementation type called? +# Optional, turns on binding to field names by tag provided +struct_tag: json + # Tell gqlgen about any existing models you want to reuse for # graphql. These normally come from the db or a remote api. models: diff --git a/handler/graphql.go b/handler/graphql.go index 0485af865b8..ccb3b38748f 100644 --- a/handler/graphql.go +++ b/handler/graphql.go @@ -10,6 +10,7 @@ import ( "github.com/99designs/gqlgen/graphql" "github.com/gorilla/websocket" + "github.com/hashicorp/golang-lru" "github.com/vektah/gqlparser" "github.com/vektah/gqlparser/ast" "github.com/vektah/gqlparser/gqlerror" @@ -23,6 +24,7 @@ type params struct { } type Config struct { + cacheSize int upgrader websocket.Upgrader recover graphql.RecoverFunc errorPresenter graphql.ErrorPresenterFunc @@ -110,8 +112,19 @@ func RequestMiddleware(middleware graphql.RequestMiddleware) Option { } } +// CacheSize sets the maximum size of the query cache. +// If size is less than or equal to 0, the cache is disabled. +func CacheSize(size int) Option { + return func(cfg *Config) { + cfg.cacheSize = size + } +} + +const DefaultCacheSize = 1000 + func GraphQL(exec graphql.ExecutableSchema, options ...Option) http.HandlerFunc { cfg := Config{ + cacheSize: DefaultCacheSize, upgrader: websocket.Upgrader{ ReadBufferSize: 1024, WriteBufferSize: 1024, @@ -122,6 +135,17 @@ func GraphQL(exec graphql.ExecutableSchema, options ...Option) http.HandlerFunc option(&cfg) } + var cache *lru.Cache + if cfg.cacheSize > 0 { + var err error + cache, err = lru.New(DefaultCacheSize) + if err != nil { + // An error is only returned for non-positive cache size + // and we already checked for that. + panic("unexpected error creating cache: " + err.Error()) + } + } + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.Method == http.MethodOptions { w.Header().Set("Allow", "OPTIONS, GET, POST") @@ -157,10 +181,23 @@ func GraphQL(exec graphql.ExecutableSchema, options ...Option) http.HandlerFunc } w.Header().Set("Content-Type", "application/json") - doc, qErr := gqlparser.LoadQuery(exec.Schema(), reqParams.Query) - if len(qErr) > 0 { - sendError(w, http.StatusUnprocessableEntity, qErr...) - return + var doc *ast.QueryDocument + if cache != nil { + val, ok := cache.Get(reqParams.Query) + if ok { + doc = val.(*ast.QueryDocument) + } + } + if doc == nil { + var qErr gqlerror.List + doc, qErr = gqlparser.LoadQuery(exec.Schema(), reqParams.Query) + if len(qErr) > 0 { + sendError(w, http.StatusUnprocessableEntity, qErr...) + return + } + if cache != nil { + cache.Add(reqParams.Query, doc) + } } op := doc.Operations.ForName(reqParams.OperationName)