diff --git a/graphql/context.go b/graphql/context.go index adfe91e41e..9c39d909e3 100644 --- a/graphql/context.go +++ b/graphql/context.go @@ -21,6 +21,28 @@ type RequestContext struct { RequestMiddleware RequestMiddleware } +func DefaultResolverMiddleware(ctx context.Context, next Resolver) (res interface{}, err error) { + return next(ctx) +} + +func DefaultRequestMiddleware(ctx context.Context, next func(ctx context.Context) []byte) []byte { + return next(ctx) +} + +func NewRequestContext(doc *query.Document, query string, variables map[string]interface{}) *RequestContext { + return &RequestContext{ + Doc: doc, + RawQuery: query, + Variables: variables, + ResolverMiddleware: DefaultResolverMiddleware, + RequestMiddleware: DefaultRequestMiddleware, + Recover: DefaultRecover, + ErrorBuilder: ErrorBuilder{ + ErrorPresenter: DefaultErrorPresenter, + }, + } +} + type key string const ( diff --git a/graphql/recovery.go b/graphql/recovery.go index ef5c45c991..3aa032dc5a 100644 --- a/graphql/recovery.go +++ b/graphql/recovery.go @@ -10,7 +10,7 @@ import ( type RecoverFunc func(ctx context.Context, err interface{}) (userMessage error) -func DefaultRecoverFunc(ctx context.Context, err interface{}) error { +func DefaultRecover(ctx context.Context, err interface{}) error { fmt.Fprintln(os.Stderr, err) fmt.Fprintln(os.Stderr) debug.PrintStack() diff --git a/handler/graphql.go b/handler/graphql.go index ab79463202..98e1b90ecf 100644 --- a/handler/graphql.go +++ b/handler/graphql.go @@ -28,6 +28,27 @@ type Config struct { requestHook graphql.RequestMiddleware } +func (c *Config) newRequestContext(doc *query.Document, query string, variables map[string]interface{}) *graphql.RequestContext { + reqCtx := graphql.NewRequestContext(doc, query, variables) + if hook := c.recover; hook != nil { + reqCtx.Recover = hook + } + + if hook := c.errorPresenter; hook != nil { + reqCtx.ErrorPresenter = hook + } + + if hook := c.resolverHook; hook != nil { + reqCtx.ResolverMiddleware = hook + } + + if hook := c.requestHook; hook != nil { + reqCtx.RequestMiddleware = hook + } + + return reqCtx +} + type Option func(cfg *Config) func WebsocketUpgrader(upgrader websocket.Upgrader) Option { @@ -91,8 +112,6 @@ func RequestMiddleware(middleware graphql.RequestMiddleware) Option { func GraphQL(exec graphql.ExecutableSchema, options ...Option) http.HandlerFunc { cfg := Config{ - recover: graphql.DefaultRecoverFunc, - errorPresenter: graphql.DefaultErrorPresenter, upgrader: websocket.Upgrader{ ReadBufferSize: 1024, WriteBufferSize: 1024, @@ -103,18 +122,6 @@ func GraphQL(exec graphql.ExecutableSchema, options ...Option) http.HandlerFunc option(&cfg) } - if cfg.resolverHook == nil { - cfg.resolverHook = func(ctx context.Context, next graphql.Resolver) (res interface{}, err error) { - return next(ctx) - } - } - - if cfg.requestHook == nil { - cfg.requestHook = func(ctx context.Context, next func(ctx context.Context) []byte) []byte { - return next(ctx) - } - } - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.Method == http.MethodOptions { w.Header().Set("Allow", "OPTIONS, GET, POST") @@ -168,17 +175,7 @@ func GraphQL(exec graphql.ExecutableSchema, options ...Option) http.HandlerFunc return } - ctx := graphql.WithRequestContext(r.Context(), &graphql.RequestContext{ - Doc: doc, - RawQuery: reqParams.Query, - Variables: reqParams.Variables, - Recover: cfg.recover, - ResolverMiddleware: cfg.resolverHook, - RequestMiddleware: cfg.requestHook, - ErrorBuilder: graphql.ErrorBuilder{ - ErrorPresenter: cfg.errorPresenter, - }, - }) + ctx := graphql.WithRequestContext(r.Context(), cfg.newRequestContext(doc, reqParams.Query, reqParams.Variables)) defer func() { if err := recover(); err != nil { diff --git a/handler/websocket.go b/handler/websocket.go index 2b89d256ed..7430619d87 100644 --- a/handler/websocket.go +++ b/handler/websocket.go @@ -154,17 +154,7 @@ func (c *wsConnection) subscribe(message *operationMessage) bool { return true } - ctx := graphql.WithRequestContext(c.ctx, &graphql.RequestContext{ - Doc: doc, - Variables: reqParams.Variables, - RawQuery: reqParams.Query, - Recover: c.cfg.recover, - ResolverMiddleware: c.cfg.resolverHook, - RequestMiddleware: c.cfg.requestHook, - ErrorBuilder: graphql.ErrorBuilder{ - ErrorPresenter: c.cfg.errorPresenter, - }, - }) + ctx := graphql.WithRequestContext(c.ctx, c.cfg.newRequestContext(doc, reqParams.Query, reqParams.Variables)) if op.Type != query.Subscription { var result *graphql.Response