diff --git a/graphql/context.go b/graphql/context.go index 875963daa84..b3c7cc73e1c 100644 --- a/graphql/context.go +++ b/graphql/context.go @@ -3,14 +3,13 @@ package graphql import ( "context" "errors" - "fmt" - "sync" "github.com/vektah/gqlparser/ast" "github.com/vektah/gqlparser/gqlerror" ) type Resolver func(ctx context.Context) (res interface{}, err error) +type ResultMiddleware func(ctx context.Context, next ResultHandler) *Response type FieldMiddleware func(ctx context.Context, next Resolver) (res interface{}, err error) type ComplexityLimitFunc func(ctx context.Context) int @@ -30,15 +29,7 @@ type RequestContext struct { Recover RecoverFunc ResolverMiddleware FieldMiddleware DirectiveMiddleware FieldMiddleware - RequestMiddleware ResponseInterceptor - - errorsMu sync.Mutex - Errors gqlerror.List - extensionsMu sync.Mutex - - // @deprecated use ResponseContext instead, in the case of subscriptions there are many responses returned - // and each can have its own set of extensions - Extensions map[string]interface{} + RequestMiddleware OperationInterceptor Stats Stats } @@ -80,41 +71,22 @@ func DefaultDirectiveMiddleware(ctx context.Context, next Resolver) (res interfa return next(ctx) } -func DefaultRequestMiddleware(ctx context.Context, next func(ctx context.Context) []byte) []byte { - return next(ctx) -} - -// Deprecated: construct RequestContext directly & call Validate method. -func NewRequestContext(doc *ast.QueryDocument, query string, variables map[string]interface{}) *RequestContext { - rc := &RequestContext{ - Doc: doc, - RawQuery: query, - Variables: variables, - } - err := rc.Validate(context.Background()) - if err != nil { - panic(err) - } - - return rc -} - type key string const ( - request key = "request_context" - resolver key = "resolver_context" + requestCtx key = "request_context" + resolverCtx key = "resolver_context" ) func GetRequestContext(ctx context.Context) *RequestContext { - if val, ok := ctx.Value(request).(*RequestContext); ok { + if val, ok := ctx.Value(requestCtx).(*RequestContext); ok { return val } return nil } func WithRequestContext(ctx context.Context, rc *RequestContext) context.Context { - return context.WithValue(ctx, request, rc) + return context.WithValue(ctx, requestCtx, rc) } type ResolverContext struct { @@ -153,7 +125,7 @@ func (r *ResolverContext) Path() []interface{} { } func GetResolverContext(ctx context.Context) *ResolverContext { - if val, ok := ctx.Value(resolver).(*ResolverContext); ok { + if val, ok := ctx.Value(resolverCtx).(*ResolverContext); ok { return val } return nil @@ -161,7 +133,7 @@ func GetResolverContext(ctx context.Context) *ResolverContext { func WithResolverContext(ctx context.Context, rc *ResolverContext) context.Context { rc.Parent = GetResolverContext(ctx) - return context.WithValue(ctx, resolver, rc) + return context.WithValue(ctx, resolverCtx, rc) } // This is just a convenient wrapper method for CollectFields @@ -189,48 +161,15 @@ Next: } // Errorf sends an error string to the client, passing it through the formatter. +// Deprecated: use graphql.AddErrorf(ctx, err) instead func (c *RequestContext) Errorf(ctx context.Context, format string, args ...interface{}) { - c.errorsMu.Lock() - defer c.errorsMu.Unlock() - - c.Errors = append(c.Errors, c.ErrorPresenter(ctx, fmt.Errorf(format, args...))) + AddErrorf(ctx, format, args...) } // Error sends an error to the client, passing it through the formatter. +// Deprecated: use graphql.AddError(ctx, err) instead func (c *RequestContext) Error(ctx context.Context, err error) { - c.errorsMu.Lock() - defer c.errorsMu.Unlock() - - c.Errors = append(c.Errors, c.ErrorPresenter(ctx, err)) -} - -// HasError returns true if the current field has already errored -func (c *RequestContext) HasError(rctx *ResolverContext) bool { - c.errorsMu.Lock() - defer c.errorsMu.Unlock() - path := rctx.Path() - - for _, err := range c.Errors { - if equalPath(err.Path, path) { - return true - } - } - return false -} - -// GetErrors returns a list of errors that occurred in the current field -func (c *RequestContext) GetErrors(rctx *ResolverContext) gqlerror.List { - c.errorsMu.Lock() - defer c.errorsMu.Unlock() - path := rctx.Path() - - var errs gqlerror.List - for _, err := range c.Errors { - if equalPath(err.Path, path) { - errs = append(errs, err) - } - } - return errs + AddError(ctx, err) } func equalPath(a []interface{}, b []interface{}) bool { @@ -247,67 +186,6 @@ func equalPath(a []interface{}, b []interface{}) bool { return true } -// AddError is a convenience method for adding an error to the current response -func AddError(ctx context.Context, err error) { - GetRequestContext(ctx).Error(ctx, err) -} - -// AddErrorf is a convenience method for adding an error to the current response -func AddErrorf(ctx context.Context, format string, args ...interface{}) { - GetRequestContext(ctx).Errorf(ctx, format, args...) -} - -// RegisterExtension registers an extension, returns error if extension has already been registered -func (c *RequestContext) RegisterExtension(key string, value interface{}) error { - c.extensionsMu.Lock() - defer c.extensionsMu.Unlock() - - if c.Extensions == nil { - c.Extensions = make(map[string]interface{}) - } - - if _, ok := c.Extensions[key]; ok { - return fmt.Errorf("extension already registered for key %s", key) - } - - c.Extensions[key] = value - return nil -} - -// ChainFieldMiddleware add chain by FieldMiddleware -func ChainFieldMiddleware(handleFunc ...FieldMiddleware) FieldMiddleware { - n := len(handleFunc) - - if n > 1 { - lastI := n - 1 - return func(ctx context.Context, next Resolver) (interface{}, error) { - var ( - chainHandler Resolver - curI int - ) - chainHandler = func(currentCtx context.Context) (interface{}, error) { - if curI == lastI { - return next(currentCtx) - } - curI++ - res, err := handleFunc[curI](currentCtx, chainHandler) - curI-- - return res, err - - } - return handleFunc[0](ctx, chainHandler) - } - } - - if n == 1 { - return handleFunc[0] - } - - return func(ctx context.Context, next Resolver) (interface{}, error) { - return next(ctx) - } -} - var _ RequestContextMutator = ComplexityLimitFunc(nil) func (c ComplexityLimitFunc) MutateRequestContext(ctx context.Context, rc *RequestContext) *gqlerror.Error { diff --git a/graphql/extensions.go b/graphql/extensions.go deleted file mode 100644 index 28c60bb7fdd..00000000000 --- a/graphql/extensions.go +++ /dev/null @@ -1,12 +0,0 @@ -package graphql - -import "context" - -func GetExtensions(ctx context.Context) map[string]interface{} { - ext := GetRequestContext(ctx).Extensions - if ext == nil { - return map[string]interface{}{} - } - - return ext -} diff --git a/graphql/handler.go b/graphql/handler.go index 83567fe63c6..c7fafe9af30 100644 --- a/graphql/handler.go +++ b/graphql/handler.go @@ -9,10 +9,11 @@ import ( ) type ( - Handler func(ctx context.Context, writer Writer) - ResponseStream func() *Response - Writer func(Status, *Response) - Status int + OperationHandler func(ctx context.Context, writer Writer) + ResultHandler func(ctx context.Context) *Response + ResponseStream func() *Response + Writer func(Status, *Response) + Status int RawParams struct { Query string `json:"query"` @@ -56,10 +57,14 @@ type ( MutateRequestContext(ctx context.Context, rc *RequestContext) *gqlerror.Error } - // ResponseInterceptor is called around each graphql operation. This can be called many times in the case of - // batching and subscriptions. - ResponseInterceptor interface { - InterceptResponse(next Handler) Handler + OperationInterceptor interface { + InterceptOperation(next OperationHandler) OperationHandler + } + + // ResultInterceptor is called around each graphql operation result. This can be called many times for a single + // operation the case of subscriptions. + ResultInterceptor interface { + InterceptResult(ctx context.Context, next ResultHandler) *Response } // FieldInterceptor called around each field @@ -67,6 +72,7 @@ type ( InterceptField(ctx context.Context, next Resolver) (res interface{}, err error) } + // Transport provides support for different wire level encodings of graphql requests, eg Form, Get, Post, Websocket Transport interface { Supports(r *http.Request) bool Do(w http.ResponseWriter, r *http.Request, exec GraphExecutor) diff --git a/graphql/handler/apollotracing/tracer.go b/graphql/handler/apollotracing/tracer.go index 6895a70dfa1..2dafd618a0f 100644 --- a/graphql/handler/apollotracing/tracer.go +++ b/graphql/handler/apollotracing/tracer.go @@ -2,7 +2,6 @@ package apollotracing import ( "context" - "fmt" "sync" "time" @@ -40,7 +39,7 @@ type ( } ) -var _ graphql.ResponseInterceptor = ApolloTracing{} +var _ graphql.ResultInterceptor = ApolloTracing{} var _ graphql.FieldInterceptor = ApolloTracing{} func New() graphql.HandlerPlugin { @@ -49,7 +48,7 @@ func New() graphql.HandlerPlugin { func (a ApolloTracing) InterceptField(ctx context.Context, next graphql.Resolver) (res interface{}, err error) { rc := graphql.GetRequestContext(ctx) - td, ok := rc.Extensions["tracing"].(*TracingExtension) + td, ok := graphql.GetExtension(ctx, "tracing").(*TracingExtension) if !ok { panic("missing tracing extension") } @@ -76,39 +75,36 @@ func (a ApolloTracing) InterceptField(ctx context.Context, next graphql.Resolver return next(ctx) } -func (a ApolloTracing) InterceptResponse(next graphql.Handler) graphql.Handler { - return func(ctx context.Context, writer graphql.Writer) { - rc := graphql.GetRequestContext(ctx) - - start := rc.Stats.OperationStart - - td := &TracingExtension{ - Version: 1, - StartTime: start, - Parsing: Span{ - StartOffset: rc.Stats.Parsing.Start.Sub(start), - Duration: rc.Stats.Parsing.End.Sub(rc.Stats.Parsing.Start), - }, - - Validation: Span{ - StartOffset: rc.Stats.Validation.Start.Sub(start), - Duration: rc.Stats.Validation.End.Sub(rc.Stats.Validation.Start), - }, - - Execution: struct { - Resolvers []ResolverExecution `json:"resolvers"` - }{}, - } - - if err := rc.RegisterExtension("tracing", td); err != nil { - panic(fmt.Errorf("unable to register extension: %s", err.Error())) - } - - next(ctx, func(status graphql.Status, response *graphql.Response) { - end := graphql.Now() - td.EndTime = end - td.Duration = end.Sub(start) - writer(status, response) - }) +func (a ApolloTracing) InterceptResult(ctx context.Context, next graphql.ResultHandler) *graphql.Response { + rc := graphql.GetRequestContext(ctx) + + start := rc.Stats.OperationStart + + td := &TracingExtension{ + Version: 1, + StartTime: start, + Parsing: Span{ + StartOffset: rc.Stats.Parsing.Start.Sub(start), + Duration: rc.Stats.Parsing.End.Sub(rc.Stats.Parsing.Start), + }, + + Validation: Span{ + StartOffset: rc.Stats.Validation.Start.Sub(start), + Duration: rc.Stats.Validation.End.Sub(rc.Stats.Validation.Start), + }, + + Execution: struct { + Resolvers []ResolverExecution `json:"resolvers"` + }{}, } + + graphql.RegisterExtension(ctx, "tracing", td) + + resp := next(ctx) + + end := graphql.Now() + td.EndTime = end + td.Duration = end.Sub(start) + + return resp } diff --git a/graphql/handler/executor.go b/graphql/handler/executor.go index 5fcb1bcf464..a6db6d5426e 100644 --- a/graphql/handler/executor.go +++ b/graphql/handler/executor.go @@ -11,7 +11,8 @@ import ( ) type executor struct { - handler graphql.Handler + operationHandler graphql.OperationHandler + resultHandler graphql.ResultMiddleware responseMiddleware graphql.FieldMiddleware es graphql.ExecutableSchema requestParamMutators []graphql.RequestParameterMutator @@ -24,7 +25,10 @@ func newExecutor(es graphql.ExecutableSchema, plugins []graphql.HandlerPlugin) e e := executor{ es: es, } - e.handler = e.executableSchemaHandler + e.operationHandler = e.executableSchemaHandler + e.resultHandler = func(ctx context.Context, next graphql.ResultHandler) *graphql.Response { + return next(ctx) + } e.responseMiddleware = func(ctx context.Context, next graphql.Resolver) (res interface{}, err error) { return next(ctx) } @@ -32,9 +36,18 @@ func newExecutor(es graphql.ExecutableSchema, plugins []graphql.HandlerPlugin) e // this loop goes backwards so the first plugin is the outer most middleware and runs first. for i := len(plugins) - 1; i >= 0; i-- { p := plugins[i] - if p, ok := p.(graphql.ResponseInterceptor); ok { - previous := e.handler - e.handler = p.InterceptResponse(previous) + if p, ok := p.(graphql.OperationInterceptor); ok { + previous := e.operationHandler + e.operationHandler = p.InterceptOperation(previous) + } + + if p, ok := p.(graphql.ResultInterceptor); ok { + previous := e.resultHandler + e.resultHandler = func(ctx context.Context, next graphql.ResultHandler) *graphql.Response { + return p.InterceptResult(ctx, func(ctx context.Context) *graphql.Response { + return previous(ctx, next) + }) + } } if p, ok := p.(graphql.FieldInterceptor); ok { @@ -62,7 +75,7 @@ func newExecutor(es graphql.ExecutableSchema, plugins []graphql.HandlerPlugin) e } func (e executor) DispatchRequest(ctx context.Context, writer graphql.Writer) { - e.handler(ctx, writer) + e.operationHandler(ctx, writer) } func (e executor) CreateRequestContext(ctx context.Context, params *graphql.RawParams) (*graphql.RequestContext, gqlerror.List) { @@ -84,7 +97,6 @@ func (e executor) CreateRequestContext(ctx context.Context, params *graphql.RawP RawQuery: params.Query, OperationName: params.OperationName, Variables: params.Variables, - Extensions: params.Extensions, } rc.Stats.OperationStart = graphql.GetStartTime(ctx) @@ -114,7 +126,16 @@ func (e executor) CreateRequestContext(ctx context.Context, params *graphql.RawP return rc, nil } -// executableSchemaHandler is the inner most handler, it invokes the graph directly after all middleware +func (e *executor) write(ctx context.Context, resp *graphql.Response, writer graphql.Writer) { + resp.Extensions = graphql.GetExtensions(ctx) + + for _, err := range graphql.GetErrors(ctx) { + resp.Errors = append(resp.Errors, err) + } + writer(getStatus(resp), resp) +} + +// executableSchemaHandler is the inner most operation handler, it invokes the graph directly after all middleware // and sends responses to the transport so it can be returned to the client func (e *executor) executableSchemaHandler(ctx context.Context, write graphql.Writer) { rc := graphql.GetRequestContext(ctx) @@ -123,21 +144,37 @@ func (e *executor) executableSchemaHandler(ctx context.Context, write graphql.Wr switch op.Operation { case ast.Query: - resp := e.es.Query(ctx, op) - resp.Extensions = graphql.GetExtensions(ctx) - write(getStatus(resp), resp) - case ast.Mutation: - resp := e.es.Mutation(ctx, op) - resp.Extensions = graphql.GetExtensions(ctx) - write(getStatus(resp), resp) - case ast.Subscription: - resp := e.es.Subscription(ctx, op) + resCtx := graphql.WithResultContext(ctx) + resp := e.resultHandler(resCtx, func(ctx context.Context) *graphql.Response { + return e.es.Query(ctx, op) + }) + e.write(resCtx, resp, write) - for w := resp(); w != nil; w = resp() { - w.Extensions = graphql.GetExtensions(ctx) + case ast.Mutation: + resCtx := graphql.WithResultContext(ctx) + resp := e.resultHandler(resCtx, func(ctx context.Context) *graphql.Response { + return e.es.Mutation(ctx, op) + }) + e.write(resCtx, resp, write) - write(getStatus(w), w) + case ast.Subscription: + responses := e.es.Subscription(ctx, op) + for { + resCtx := graphql.WithResultContext(ctx) + resp := e.resultHandler(resCtx, func(ctx context.Context) *graphql.Response { + resp := responses() + if resp == nil { + return nil + } + resp.Extensions = graphql.GetExtensions(ctx) + return resp + }) + if resp == nil { + break + } + e.write(resCtx, resp, write) } + default: write(graphql.StatusValidationError, graphql.ErrorResponse(ctx, "unsupported GraphQL operation")) } diff --git a/graphql/handler/server.go b/graphql/handler/server.go index 1d15d1ef82c..79d83e76365 100644 --- a/graphql/handler/server.go +++ b/graphql/handler/server.go @@ -34,8 +34,9 @@ func (s *Server) Use(plugin graphql.HandlerPlugin) { switch plugin.(type) { case graphql.RequestParameterMutator, graphql.RequestContextMutator, - graphql.ResponseInterceptor, - graphql.FieldInterceptor: + graphql.OperationInterceptor, + graphql.FieldInterceptor, + graphql.ResultInterceptor: s.plugins = append(s.plugins, plugin) s.exec = newExecutor(s.es, s.plugins) diff --git a/graphql/handler/server_test.go b/graphql/handler/server_test.go index 72ab650f947..1df1b8ca048 100644 --- a/graphql/handler/server_test.go +++ b/graphql/handler/server_test.go @@ -71,13 +71,13 @@ func TestServer(t *testing.T) { t.Run("invokes operation middleware in order", func(t *testing.T) { var calls []string - srv.Use(opFunc(func(next graphql.Handler) graphql.Handler { + srv.Use(opFunc(func(next graphql.OperationHandler) graphql.OperationHandler { return func(ctx context.Context, writer graphql.Writer) { calls = append(calls, "first") next(ctx, writer) } })) - srv.Use(opFunc(func(next graphql.Handler) graphql.Handler { + srv.Use(opFunc(func(next graphql.OperationHandler) graphql.OperationHandler { return func(ctx context.Context, writer graphql.Writer) { calls = append(calls, "second") next(ctx, writer) @@ -108,9 +108,9 @@ func TestServer(t *testing.T) { }) } -type opFunc func(next graphql.Handler) graphql.Handler +type opFunc func(next graphql.OperationHandler) graphql.OperationHandler -func (r opFunc) InterceptResponse(next graphql.Handler) graphql.Handler { +func (r opFunc) InterceptOperation(next graphql.OperationHandler) graphql.OperationHandler { return r(next) } diff --git a/graphql/result_context.go b/graphql/result_context.go new file mode 100644 index 00000000000..a7d56f07008 --- /dev/null +++ b/graphql/result_context.go @@ -0,0 +1,135 @@ +package graphql + +import ( + "context" + "fmt" + "sync" + + "github.com/vektah/gqlparser/gqlerror" +) + +type resultContext struct { + errors gqlerror.List + errorsMu sync.Mutex + + extensions map[string]interface{} + extensionsMu sync.Mutex +} + +var resultCtx key = "result_context" + +func getResultContext(ctx context.Context) *resultContext { + val, _ := ctx.Value(resultCtx).(*resultContext) + return val +} + +func WithResultContext(ctx context.Context) context.Context { + return context.WithValue(ctx, resultCtx, &resultContext{}) +} + +// AddErrorf writes a formatted error to the client, first passing it through the error presenter. +func AddErrorf(ctx context.Context, format string, args ...interface{}) { + c := getResultContext(ctx) + + c.errorsMu.Lock() + defer c.errorsMu.Unlock() + + c.errors = append(c.errors, GetRequestContext(ctx).ErrorPresenter(ctx, fmt.Errorf(format, args...))) +} + +// AddError sends an error to the client, first passing it through the error presenter. +func AddError(ctx context.Context, err error) { + c := getResultContext(ctx) + + c.errorsMu.Lock() + defer c.errorsMu.Unlock() + + c.errors = append(c.errors, GetRequestContext(ctx).ErrorPresenter(ctx, err)) +} + +// HasFieldError returns true if the given field has already errored +func HasFieldError(ctx context.Context, rctx *ResolverContext) bool { + c := getResultContext(ctx) + + c.errorsMu.Lock() + defer c.errorsMu.Unlock() + path := rctx.Path() + + for _, err := range c.errors { + if equalPath(err.Path, path) { + return true + } + } + return false +} + +// GetFieldErrors returns a list of errors that occurred in the given field +func GetFieldErrors(ctx context.Context, rctx *ResolverContext) gqlerror.List { + c := getResultContext(ctx) + + c.errorsMu.Lock() + defer c.errorsMu.Unlock() + path := rctx.Path() + + var errs gqlerror.List + for _, err := range c.errors { + if equalPath(err.Path, path) { + errs = append(errs, err) + } + } + return errs +} + +func GetErrors(ctx context.Context) gqlerror.List { + resCtx := getResultContext(ctx) + resCtx.errorsMu.Lock() + defer resCtx.errorsMu.Unlock() + + if len(resCtx.errors) == 0 { + return nil + } + + errs := resCtx.errors + cpy := make(gqlerror.List, len(errs)) + for i := range errs { + errCpy := *errs[i] + cpy[i] = &errCpy + } + return cpy +} + +// RegisterExtension allows you to add a new extension into the graphql response +func RegisterExtension(ctx context.Context, key string, value interface{}) { + c := getResultContext(ctx) + c.extensionsMu.Lock() + defer c.extensionsMu.Unlock() + + if c.extensions == nil { + c.extensions = make(map[string]interface{}) + } + + if _, ok := c.extensions[key]; ok { + panic(fmt.Errorf("extension already registered for key %s", key)) + } + + c.extensions[key] = value +} + +// GetExtensions returns any extensions registered in the current result context +func GetExtensions(ctx context.Context) map[string]interface{} { + ext := getResultContext(ctx).extensions + if ext == nil { + return map[string]interface{}{} + } + + return ext +} + +func GetExtension(ctx context.Context, name string) interface{} { + ext := getResultContext(ctx).extensions + if ext == nil { + return nil + } + + return ext[name] +}