Skip to content

Commit

Permalink
Add resolver middleware
Browse files Browse the repository at this point in the history
  • Loading branch information
Adam Scarr committed Apr 5, 2018
1 parent 28d0c81 commit 57adb24
Show file tree
Hide file tree
Showing 6 changed files with 176 additions and 52 deletions.
14 changes: 12 additions & 2 deletions codegen/templates/field.gotpl
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,22 @@
}
{{- end }}
{{- else }}
rctx := graphql.WithResolverContext(ctx, &graphql.ResolverContext{Field: field})
res, err := ec.resolvers.{{ $object.GQLType }}_{{ $field.GQLName }}({{ $field.CallArgs }})
rctx := graphql.WithResolverContext(ctx, &graphql.ResolverContext{
Object: {{$object.GQLType|quote}},
Args: {{if $field.Args }}args{{else}}nil{{end}},
Field: field,
})
resTmp, err := ec.Middleware(rctx, func(rctx context.Context) (interface{}, error) {
return ec.resolvers.{{ $object.GQLType }}_{{ $field.GQLName }}({{ $field.CallArgs }})
})
if err != nil {
ec.Error(err)
return graphql.Null
}
if resTmp == nil {
return graphql.Null
}
res := resTmp.({{$field.Signature}})
{{- end }}
{{ $field.WriteJson }}
{{- if $field.IsConcurrent }}
Expand Down
13 changes: 12 additions & 1 deletion example/starwars/server/server.go
Original file line number Diff line number Diff line change
@@ -1,16 +1,27 @@
package main

import (
"context"
"fmt"
"log"
"net/http"

"github.com/vektah/gqlgen/example/starwars"
"github.com/vektah/gqlgen/graphql"
"github.com/vektah/gqlgen/handler"
)

func main() {
http.Handle("/", handler.Playground("Starwars", "/query"))
http.Handle("/query", handler.GraphQL(starwars.MakeExecutableSchema(starwars.NewResolver())))
http.Handle("/query", handler.GraphQL(starwars.MakeExecutableSchema(starwars.NewResolver()),
handler.Use(func(ctx context.Context, next graphql.Resolver) (res interface{}, err error) {
rc := graphql.GetResolverContext(ctx)
fmt.Println("Entered", rc.Object, rc.Field.Name)
res, err = next(ctx)
fmt.Println("Left", rc.Object, rc.Field.Name, "=>", res, err)
return res, err
}),
))

log.Fatal(http.ListenAndServe(":8080", nil))
}
15 changes: 12 additions & 3 deletions graphql/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,16 @@ import (
"github.com/vektah/gqlgen/neelance/query"
)

type Resolver func(ctx context.Context) (res interface{}, err error)
type ResolverMiddleware func(ctx context.Context, next Resolver) (res interface{}, err error)

type RequestContext struct {
errors.Builder

Variables map[string]interface{}
Doc *query.Document
Recover RecoverFunc
Variables map[string]interface{}
Doc *query.Document
Recover RecoverFunc
Middleware ResolverMiddleware
}

type key string
Expand All @@ -36,6 +40,11 @@ func WithRequestContext(ctx context.Context, rc *RequestContext) context.Context
}

type ResolverContext struct {
// The name of the type this field belongs to
Object string
// These are the args after processing, they can be mutated in middleware to change what the resolver will get.
Args map[string]interface{}
// The raw field
Field CollectedField
}

Expand Down
40 changes: 33 additions & 7 deletions handler/graphql.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ import (

"strings"

"context"

"github.com/gorilla/websocket"
"github.com/vektah/gqlgen/graphql"
"github.com/vektah/gqlgen/neelance/errors"
Expand All @@ -21,9 +23,10 @@ type params struct {
}

type Config struct {
upgrader websocket.Upgrader
recover graphql.RecoverFunc
formatError func(error) string
upgrader websocket.Upgrader
recover graphql.RecoverFunc
formatError func(error) string
resolverHook graphql.ResolverMiddleware
}

type Option func(cfg *Config)
Expand All @@ -46,6 +49,22 @@ func FormatErrorFunc(f func(error) string) Option {
}
}

func Use(middleware graphql.ResolverMiddleware) Option {
return func(cfg *Config) {
if cfg.resolverHook == nil {
cfg.resolverHook = middleware
return
}

lastResolve := cfg.resolverHook
cfg.resolverHook = func(ctx context.Context, next graphql.Resolver) (res interface{}, err error) {
return lastResolve(ctx, func(ctx context.Context) (res interface{}, err error) {
return middleware(ctx, next)
})
}
}
}

func GraphQL(exec graphql.ExecutableSchema, options ...Option) http.HandlerFunc {
cfg := Config{
recover: graphql.DefaultRecoverFunc,
Expand All @@ -59,6 +78,12 @@ func GraphQL(exec graphql.ExecutableSchema, options ...Option) http.HandlerFunc
option(&cfg)
}

if cfg.resolverHook == nil {
cfg.resolverHook = func(ctx context.Context, next graphql.Resolver) (res interface{}, err error) {
return next(ctx)
}
}

return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method == http.MethodOptions {
w.Header().Set("Allow", "OPTIONS, GET, POST")
Expand All @@ -67,7 +92,7 @@ func GraphQL(exec graphql.ExecutableSchema, options ...Option) http.HandlerFunc
}

if strings.Contains(r.Header.Get("Upgrade"), "websocket") {
connectWs(exec, w, r, cfg.upgrader, cfg.recover)
connectWs(exec, w, r, &cfg)
return
}

Expand Down Expand Up @@ -113,9 +138,10 @@ func GraphQL(exec graphql.ExecutableSchema, options ...Option) http.HandlerFunc
}

ctx := graphql.WithRequestContext(r.Context(), &graphql.RequestContext{
Doc: doc,
Variables: reqParams.Variables,
Recover: cfg.recover,
Doc: doc,
Variables: reqParams.Variables,
Recover: cfg.recover,
Middleware: cfg.resolverHook,
Builder: errors.Builder{
ErrorMessageFn: cfg.formatError,
},
Expand Down
40 changes: 20 additions & 20 deletions handler/websocket.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ const (
dataMsg = "data" // Server -> Client
errorMsg = "error" // Server -> Client
completeMsg = "complete" // Server -> Client
//connectionKeepAliveMsg = "ka" // Server -> Client TODO: keepalives
//connectionKeepAliveMsg = "ka" // Server -> Client TODO: keepalives
)

type operationMessage struct {
Expand All @@ -36,17 +36,16 @@ type operationMessage struct {
}

type wsConnection struct {
ctx context.Context
conn *websocket.Conn
exec graphql.ExecutableSchema
active map[string]context.CancelFunc
mu sync.Mutex
recover graphql.RecoverFunc
formatError func(err error) string
ctx context.Context
conn *websocket.Conn
exec graphql.ExecutableSchema
active map[string]context.CancelFunc
mu sync.Mutex
cfg *Config
}

func connectWs(exec graphql.ExecutableSchema, w http.ResponseWriter, r *http.Request, upgrader websocket.Upgrader, recover graphql.RecoverFunc) {
ws, err := upgrader.Upgrade(w, r, http.Header{
func connectWs(exec graphql.ExecutableSchema, w http.ResponseWriter, r *http.Request, cfg *Config) {
ws, err := cfg.upgrader.Upgrade(w, r, http.Header{
"Sec-Websocket-Protocol": []string{"graphql-ws"},
})
if err != nil {
Expand All @@ -56,11 +55,11 @@ func connectWs(exec graphql.ExecutableSchema, w http.ResponseWriter, r *http.Req
}

conn := wsConnection{
active: map[string]context.CancelFunc{},
exec: exec,
conn: ws,
ctx: r.Context(),
recover: recover,
active: map[string]context.CancelFunc{},
exec: exec,
conn: ws,
ctx: r.Context(),
cfg: cfg,
}

if !conn.init() {
Expand Down Expand Up @@ -157,11 +156,12 @@ func (c *wsConnection) subscribe(message *operationMessage) bool {
}

ctx := graphql.WithRequestContext(c.ctx, &graphql.RequestContext{
Doc: doc,
Variables: reqParams.Variables,
Recover: c.recover,
Doc: doc,
Variables: reqParams.Variables,
Recover: c.cfg.recover,
Middleware: c.cfg.resolverHook,
Builder: errors.Builder{
ErrorMessageFn: c.formatError,
ErrorMessageFn: c.cfg.formatError,
},
})

Expand All @@ -185,7 +185,7 @@ func (c *wsConnection) subscribe(message *operationMessage) bool {
go func() {
defer func() {
if r := recover(); r != nil {
userErr := c.recover(r)
userErr := c.cfg.recover(r)
c.sendError(message.ID, &errors.QueryError{Message: userErr.Error()})
}
}()
Expand Down
Loading

0 comments on commit 57adb24

Please sign in to comment.