From 57adb244df6b509c2ecf4bfb15ea7d4ecd00e17c Mon Sep 17 00:00:00 2001 From: Adam Scarr Date: Thu, 5 Apr 2018 13:38:05 +1000 Subject: [PATCH] Add resolver middleware --- codegen/templates/field.gotpl | 14 +++- example/starwars/server/server.go | 13 +++- graphql/context.go | 15 ++++- handler/graphql.go | 40 +++++++++-- handler/websocket.go | 40 +++++------ test/generated.go | 106 ++++++++++++++++++++++++------ 6 files changed, 176 insertions(+), 52 deletions(-) diff --git a/codegen/templates/field.gotpl b/codegen/templates/field.gotpl index 2f6bb87438..9ab270b572 100644 --- a/codegen/templates/field.gotpl +++ b/codegen/templates/field.gotpl @@ -48,12 +48,22 @@ } {{- end }} {{- else }} - rctx := graphql.WithResolverContext(ctx, &graphql.ResolverContext{Field: field}) - res, err := ec.resolvers.{{ $object.GQLType }}_{{ $field.GQLName }}({{ $field.CallArgs }}) + rctx := graphql.WithResolverContext(ctx, &graphql.ResolverContext{ + Object: {{$object.GQLType|quote}}, + Args: {{if $field.Args }}args{{else}}nil{{end}}, + Field: field, + }) + resTmp, err := ec.Middleware(rctx, func(rctx context.Context) (interface{}, error) { + return ec.resolvers.{{ $object.GQLType }}_{{ $field.GQLName }}({{ $field.CallArgs }}) + }) if err != nil { ec.Error(err) return graphql.Null } + if resTmp == nil { + return graphql.Null + } + res := resTmp.({{$field.Signature}}) {{- end }} {{ $field.WriteJson }} {{- if $field.IsConcurrent }} diff --git a/example/starwars/server/server.go b/example/starwars/server/server.go index 01f591728a..dc2c57d2ae 100644 --- a/example/starwars/server/server.go +++ b/example/starwars/server/server.go @@ -1,16 +1,27 @@ package main import ( + "context" + "fmt" "log" "net/http" "github.com/vektah/gqlgen/example/starwars" + "github.com/vektah/gqlgen/graphql" "github.com/vektah/gqlgen/handler" ) func main() { http.Handle("/", handler.Playground("Starwars", "/query")) - http.Handle("/query", handler.GraphQL(starwars.MakeExecutableSchema(starwars.NewResolver()))) + http.Handle("/query", handler.GraphQL(starwars.MakeExecutableSchema(starwars.NewResolver()), + handler.Use(func(ctx context.Context, next graphql.Resolver) (res interface{}, err error) { + rc := graphql.GetResolverContext(ctx) + fmt.Println("Entered", rc.Object, rc.Field.Name) + res, err = next(ctx) + fmt.Println("Left", rc.Object, rc.Field.Name, "=>", res, err) + return res, err + }), + )) log.Fatal(http.ListenAndServe(":8080", nil)) } diff --git a/graphql/context.go b/graphql/context.go index a1abec5dc5..4f2dcfcef7 100644 --- a/graphql/context.go +++ b/graphql/context.go @@ -7,12 +7,16 @@ import ( "github.com/vektah/gqlgen/neelance/query" ) +type Resolver func(ctx context.Context) (res interface{}, err error) +type ResolverMiddleware func(ctx context.Context, next Resolver) (res interface{}, err error) + type RequestContext struct { errors.Builder - Variables map[string]interface{} - Doc *query.Document - Recover RecoverFunc + Variables map[string]interface{} + Doc *query.Document + Recover RecoverFunc + Middleware ResolverMiddleware } type key string @@ -36,6 +40,11 @@ func WithRequestContext(ctx context.Context, rc *RequestContext) context.Context } type ResolverContext struct { + // The name of the type this field belongs to + Object string + // These are the args after processing, they can be mutated in middleware to change what the resolver will get. + Args map[string]interface{} + // The raw field Field CollectedField } diff --git a/handler/graphql.go b/handler/graphql.go index af1f6117da..51c62b3f4b 100644 --- a/handler/graphql.go +++ b/handler/graphql.go @@ -7,6 +7,8 @@ import ( "strings" + "context" + "github.com/gorilla/websocket" "github.com/vektah/gqlgen/graphql" "github.com/vektah/gqlgen/neelance/errors" @@ -21,9 +23,10 @@ type params struct { } type Config struct { - upgrader websocket.Upgrader - recover graphql.RecoverFunc - formatError func(error) string + upgrader websocket.Upgrader + recover graphql.RecoverFunc + formatError func(error) string + resolverHook graphql.ResolverMiddleware } type Option func(cfg *Config) @@ -46,6 +49,22 @@ func FormatErrorFunc(f func(error) string) Option { } } +func Use(middleware graphql.ResolverMiddleware) Option { + return func(cfg *Config) { + if cfg.resolverHook == nil { + cfg.resolverHook = middleware + return + } + + lastResolve := cfg.resolverHook + cfg.resolverHook = func(ctx context.Context, next graphql.Resolver) (res interface{}, err error) { + return lastResolve(ctx, func(ctx context.Context) (res interface{}, err error) { + return middleware(ctx, next) + }) + } + } +} + func GraphQL(exec graphql.ExecutableSchema, options ...Option) http.HandlerFunc { cfg := Config{ recover: graphql.DefaultRecoverFunc, @@ -59,6 +78,12 @@ func GraphQL(exec graphql.ExecutableSchema, options ...Option) http.HandlerFunc option(&cfg) } + if cfg.resolverHook == nil { + cfg.resolverHook = func(ctx context.Context, next graphql.Resolver) (res interface{}, err error) { + return next(ctx) + } + } + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.Method == http.MethodOptions { w.Header().Set("Allow", "OPTIONS, GET, POST") @@ -67,7 +92,7 @@ func GraphQL(exec graphql.ExecutableSchema, options ...Option) http.HandlerFunc } if strings.Contains(r.Header.Get("Upgrade"), "websocket") { - connectWs(exec, w, r, cfg.upgrader, cfg.recover) + connectWs(exec, w, r, &cfg) return } @@ -113,9 +138,10 @@ func GraphQL(exec graphql.ExecutableSchema, options ...Option) http.HandlerFunc } ctx := graphql.WithRequestContext(r.Context(), &graphql.RequestContext{ - Doc: doc, - Variables: reqParams.Variables, - Recover: cfg.recover, + Doc: doc, + Variables: reqParams.Variables, + Recover: cfg.recover, + Middleware: cfg.resolverHook, Builder: errors.Builder{ ErrorMessageFn: cfg.formatError, }, diff --git a/handler/websocket.go b/handler/websocket.go index a5214aeb98..5275e0ba54 100644 --- a/handler/websocket.go +++ b/handler/websocket.go @@ -26,7 +26,7 @@ const ( dataMsg = "data" // Server -> Client errorMsg = "error" // Server -> Client completeMsg = "complete" // Server -> Client - //connectionKeepAliveMsg = "ka" // Server -> Client TODO: keepalives + //connectionKeepAliveMsg = "ka" // Server -> Client TODO: keepalives ) type operationMessage struct { @@ -36,17 +36,16 @@ type operationMessage struct { } type wsConnection struct { - ctx context.Context - conn *websocket.Conn - exec graphql.ExecutableSchema - active map[string]context.CancelFunc - mu sync.Mutex - recover graphql.RecoverFunc - formatError func(err error) string + ctx context.Context + conn *websocket.Conn + exec graphql.ExecutableSchema + active map[string]context.CancelFunc + mu sync.Mutex + cfg *Config } -func connectWs(exec graphql.ExecutableSchema, w http.ResponseWriter, r *http.Request, upgrader websocket.Upgrader, recover graphql.RecoverFunc) { - ws, err := upgrader.Upgrade(w, r, http.Header{ +func connectWs(exec graphql.ExecutableSchema, w http.ResponseWriter, r *http.Request, cfg *Config) { + ws, err := cfg.upgrader.Upgrade(w, r, http.Header{ "Sec-Websocket-Protocol": []string{"graphql-ws"}, }) if err != nil { @@ -56,11 +55,11 @@ func connectWs(exec graphql.ExecutableSchema, w http.ResponseWriter, r *http.Req } conn := wsConnection{ - active: map[string]context.CancelFunc{}, - exec: exec, - conn: ws, - ctx: r.Context(), - recover: recover, + active: map[string]context.CancelFunc{}, + exec: exec, + conn: ws, + ctx: r.Context(), + cfg: cfg, } if !conn.init() { @@ -157,11 +156,12 @@ func (c *wsConnection) subscribe(message *operationMessage) bool { } ctx := graphql.WithRequestContext(c.ctx, &graphql.RequestContext{ - Doc: doc, - Variables: reqParams.Variables, - Recover: c.recover, + Doc: doc, + Variables: reqParams.Variables, + Recover: c.cfg.recover, + Middleware: c.cfg.resolverHook, Builder: errors.Builder{ - ErrorMessageFn: c.formatError, + ErrorMessageFn: c.cfg.formatError, }, }) @@ -185,7 +185,7 @@ func (c *wsConnection) subscribe(message *operationMessage) bool { go func() { defer func() { if r := recover(); r != nil { - userErr := c.recover(r) + userErr := c.cfg.recover(r) c.sendError(message.ID, &errors.QueryError{Message: userErr.Error()}) } }() diff --git a/test/generated.go b/test/generated.go index c2427e537a..189fdb2bf4 100644 --- a/test/generated.go +++ b/test/generated.go @@ -214,8 +214,15 @@ func (ec *executionContext) _OuterObject_inner(ctx context.Context, field graphq ret = graphql.Null } }() - rctx := graphql.WithResolverContext(ctx, &graphql.ResolverContext{Field: field}) - res, err := ec.resolvers.OuterObject_inner(rctx, obj) + rctx := graphql.WithResolverContext(ctx, &graphql.ResolverContext{ + Object: "OuterObject", + Args: nil, + Field: field, + }) + resTmp, err := ec.Middleware(rctx, func(ctx context.Context) (interface{}, error) { + return ec.resolvers.OuterObject_inner(rctx, obj) + }) + res := resTmp.(models.InnerObject) if err != nil { ec.Error(err) return graphql.Null @@ -263,6 +270,7 @@ func (ec *executionContext) _Query(ctx context.Context, sel []query.Selection) g } func (ec *executionContext) _Query_nestedInputs(ctx context.Context, field graphql.CollectedField) graphql.Marshaler { + args := map[string]interface{}{} var arg0 [][]models.OuterInput if tmp, ok := field.Args["input"]; ok { var err error @@ -297,6 +305,7 @@ func (ec *executionContext) _Query_nestedInputs(ctx context.Context, field graph } } + args["input"] = arg0 return graphql.Defer(func() (ret graphql.Marshaler) { defer func() { if r := recover(); r != nil { @@ -305,8 +314,15 @@ func (ec *executionContext) _Query_nestedInputs(ctx context.Context, field graph ret = graphql.Null } }() - rctx := graphql.WithResolverContext(ctx, &graphql.ResolverContext{Field: field}) - res, err := ec.resolvers.Query_nestedInputs(rctx, arg0) + rctx := graphql.WithResolverContext(ctx, &graphql.ResolverContext{ + Object: "Query", + Args: args, + Field: field, + }) + resTmp, err := ec.Middleware(rctx, func(ctx context.Context) (interface{}, error) { + return ec.resolvers.Query_nestedInputs(rctx, args["input"].([][]models.OuterInput)) + }) + res := resTmp.(*bool) if err != nil { ec.Error(err) return graphql.Null @@ -327,8 +343,15 @@ func (ec *executionContext) _Query_nestedOutputs(ctx context.Context, field grap ret = graphql.Null } }() - rctx := graphql.WithResolverContext(ctx, &graphql.ResolverContext{Field: field}) - res, err := ec.resolvers.Query_nestedOutputs(rctx) + rctx := graphql.WithResolverContext(ctx, &graphql.ResolverContext{ + Object: "Query", + Args: nil, + Field: field, + }) + resTmp, err := ec.Middleware(rctx, func(ctx context.Context) (interface{}, error) { + return ec.resolvers.Query_nestedOutputs(rctx) + }) + res := resTmp.([][]models.OuterObject) if err != nil { ec.Error(err) return graphql.Null @@ -356,8 +379,15 @@ func (ec *executionContext) _Query_shapes(ctx context.Context, field graphql.Col ret = graphql.Null } }() - rctx := graphql.WithResolverContext(ctx, &graphql.ResolverContext{Field: field}) - res, err := ec.resolvers.Query_shapes(rctx) + rctx := graphql.WithResolverContext(ctx, &graphql.ResolverContext{ + Object: "Query", + Args: nil, + Field: field, + }) + resTmp, err := ec.Middleware(rctx, func(ctx context.Context) (interface{}, error) { + return ec.resolvers.Query_shapes(rctx) + }) + res := resTmp.([]Shape) if err != nil { ec.Error(err) return graphql.Null @@ -371,6 +401,7 @@ func (ec *executionContext) _Query_shapes(ctx context.Context, field graphql.Col } func (ec *executionContext) _Query_recursive(ctx context.Context, field graphql.CollectedField) graphql.Marshaler { + args := map[string]interface{}{} var arg0 *RecursiveInputSlice if tmp, ok := field.Args["input"]; ok { var err error @@ -385,6 +416,7 @@ func (ec *executionContext) _Query_recursive(ctx context.Context, field graphql. return graphql.Null } } + args["input"] = arg0 return graphql.Defer(func() (ret graphql.Marshaler) { defer func() { if r := recover(); r != nil { @@ -393,8 +425,15 @@ func (ec *executionContext) _Query_recursive(ctx context.Context, field graphql. ret = graphql.Null } }() - rctx := graphql.WithResolverContext(ctx, &graphql.ResolverContext{Field: field}) - res, err := ec.resolvers.Query_recursive(rctx, arg0) + rctx := graphql.WithResolverContext(ctx, &graphql.ResolverContext{ + Object: "Query", + Args: args, + Field: field, + }) + resTmp, err := ec.Middleware(rctx, func(ctx context.Context) (interface{}, error) { + return ec.resolvers.Query_recursive(rctx, args["input"].(*RecursiveInputSlice)) + }) + res := resTmp.(*bool) if err != nil { ec.Error(err) return graphql.Null @@ -407,6 +446,7 @@ func (ec *executionContext) _Query_recursive(ctx context.Context, field graphql. } func (ec *executionContext) _Query_mapInput(ctx context.Context, field graphql.CollectedField) graphql.Marshaler { + args := map[string]interface{}{} var arg0 *map[string]interface{} if tmp, ok := field.Args["input"]; ok { var err error @@ -421,6 +461,7 @@ func (ec *executionContext) _Query_mapInput(ctx context.Context, field graphql.C return graphql.Null } } + args["input"] = arg0 return graphql.Defer(func() (ret graphql.Marshaler) { defer func() { if r := recover(); r != nil { @@ -429,8 +470,15 @@ func (ec *executionContext) _Query_mapInput(ctx context.Context, field graphql.C ret = graphql.Null } }() - rctx := graphql.WithResolverContext(ctx, &graphql.ResolverContext{Field: field}) - res, err := ec.resolvers.Query_mapInput(rctx, arg0) + rctx := graphql.WithResolverContext(ctx, &graphql.ResolverContext{ + Object: "Query", + Args: args, + Field: field, + }) + resTmp, err := ec.Middleware(rctx, func(ctx context.Context) (interface{}, error) { + return ec.resolvers.Query_mapInput(rctx, args["input"].(*map[string]interface{})) + }) + res := resTmp.(*bool) if err != nil { ec.Error(err) return graphql.Null @@ -451,8 +499,15 @@ func (ec *executionContext) _Query_collision(ctx context.Context, field graphql. ret = graphql.Null } }() - rctx := graphql.WithResolverContext(ctx, &graphql.ResolverContext{Field: field}) - res, err := ec.resolvers.Query_collision(rctx) + rctx := graphql.WithResolverContext(ctx, &graphql.ResolverContext{ + Object: "Query", + Args: nil, + Field: field, + }) + resTmp, err := ec.Middleware(rctx, func(ctx context.Context) (interface{}, error) { + return ec.resolvers.Query_collision(rctx) + }) + res := resTmp.(*introspection1.It) if err != nil { ec.Error(err) return graphql.Null @@ -473,8 +528,15 @@ func (ec *executionContext) _Query_invalidIdentifier(ctx context.Context, field ret = graphql.Null } }() - rctx := graphql.WithResolverContext(ctx, &graphql.ResolverContext{Field: field}) - res, err := ec.resolvers.Query_invalidIdentifier(rctx) + rctx := graphql.WithResolverContext(ctx, &graphql.ResolverContext{ + Object: "Query", + Args: nil, + Field: field, + }) + resTmp, err := ec.Middleware(rctx, func(ctx context.Context) (interface{}, error) { + return ec.resolvers.Query_invalidIdentifier(rctx) + }) + res := resTmp.(*invalid_identifier.InvalidIdentifier) if err != nil { ec.Error(err) return graphql.Null @@ -495,6 +557,7 @@ func (ec *executionContext) _Query___schema(ctx context.Context, field graphql.C } func (ec *executionContext) _Query___type(ctx context.Context, field graphql.CollectedField) graphql.Marshaler { + args := map[string]interface{}{} var arg0 string if tmp, ok := field.Args["name"]; ok { var err error @@ -504,7 +567,8 @@ func (ec *executionContext) _Query___type(ctx context.Context, field graphql.Col return graphql.Null } } - res := ec.introspectType(arg0) + args["name"] = arg0 + res := ec.introspectType(args["name"].(string)) if res == nil { return graphql.Null } @@ -949,6 +1013,7 @@ func (ec *executionContext) ___Type_description(ctx context.Context, field graph } func (ec *executionContext) ___Type_fields(ctx context.Context, field graphql.CollectedField, obj *introspection.Type) graphql.Marshaler { + args := map[string]interface{}{} var arg0 bool if tmp, ok := field.Args["includeDeprecated"]; ok { var err error @@ -958,7 +1023,8 @@ func (ec *executionContext) ___Type_fields(ctx context.Context, field graphql.Co return graphql.Null } } - res := obj.Fields(arg0) + args["includeDeprecated"] = arg0 + res := obj.Fields(args["includeDeprecated"].(bool)) arr1 := graphql.Array{} for idx1 := range res { arr1 = append(arr1, func() graphql.Marshaler { @@ -1000,6 +1066,7 @@ func (ec *executionContext) ___Type_possibleTypes(ctx context.Context, field gra } func (ec *executionContext) ___Type_enumValues(ctx context.Context, field graphql.CollectedField, obj *introspection.Type) graphql.Marshaler { + args := map[string]interface{}{} var arg0 bool if tmp, ok := field.Args["includeDeprecated"]; ok { var err error @@ -1009,7 +1076,8 @@ func (ec *executionContext) ___Type_enumValues(ctx context.Context, field graphq return graphql.Null } } - res := obj.EnumValues(arg0) + args["includeDeprecated"] = arg0 + res := obj.EnumValues(args["includeDeprecated"].(bool)) arr1 := graphql.Array{} for idx1 := range res { arr1 = append(arr1, func() graphql.Marshaler {