Skip to content

Commit

Permalink
Add result context
Browse files Browse the repository at this point in the history
  • Loading branch information
vektah committed Oct 30, 2019
1 parent c3dbcf8 commit ab5665a
Show file tree
Hide file tree
Showing 8 changed files with 258 additions and 217 deletions.
146 changes: 12 additions & 134 deletions graphql/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,13 @@ package graphql
import (
"context"
"errors"
"fmt"
"sync"

"github.com/vektah/gqlparser/ast"
"github.com/vektah/gqlparser/gqlerror"
)

type Resolver func(ctx context.Context) (res interface{}, err error)
type ResultMiddleware func(ctx context.Context, next ResultHandler) *Response
type FieldMiddleware func(ctx context.Context, next Resolver) (res interface{}, err error)
type ComplexityLimitFunc func(ctx context.Context) int

Expand All @@ -30,15 +29,7 @@ type RequestContext struct {
Recover RecoverFunc
ResolverMiddleware FieldMiddleware
DirectiveMiddleware FieldMiddleware
RequestMiddleware ResponseInterceptor

errorsMu sync.Mutex
Errors gqlerror.List
extensionsMu sync.Mutex

// @deprecated use ResponseContext instead, in the case of subscriptions there are many responses returned
// and each can have its own set of extensions
Extensions map[string]interface{}
RequestMiddleware OperationInterceptor

Stats Stats
}
Expand Down Expand Up @@ -80,41 +71,22 @@ func DefaultDirectiveMiddleware(ctx context.Context, next Resolver) (res interfa
return next(ctx)
}

func DefaultRequestMiddleware(ctx context.Context, next func(ctx context.Context) []byte) []byte {
return next(ctx)
}

// Deprecated: construct RequestContext directly & call Validate method.
func NewRequestContext(doc *ast.QueryDocument, query string, variables map[string]interface{}) *RequestContext {
rc := &RequestContext{
Doc: doc,
RawQuery: query,
Variables: variables,
}
err := rc.Validate(context.Background())
if err != nil {
panic(err)
}

return rc
}

type key string

const (
request key = "request_context"
resolver key = "resolver_context"
requestCtx key = "request_context"
resolverCtx key = "resolver_context"
)

func GetRequestContext(ctx context.Context) *RequestContext {
if val, ok := ctx.Value(request).(*RequestContext); ok {
if val, ok := ctx.Value(requestCtx).(*RequestContext); ok {
return val
}
return nil
}

func WithRequestContext(ctx context.Context, rc *RequestContext) context.Context {
return context.WithValue(ctx, request, rc)
return context.WithValue(ctx, requestCtx, rc)
}

type ResolverContext struct {
Expand Down Expand Up @@ -153,15 +125,15 @@ func (r *ResolverContext) Path() []interface{} {
}

func GetResolverContext(ctx context.Context) *ResolverContext {
if val, ok := ctx.Value(resolver).(*ResolverContext); ok {
if val, ok := ctx.Value(resolverCtx).(*ResolverContext); ok {
return val
}
return nil
}

func WithResolverContext(ctx context.Context, rc *ResolverContext) context.Context {
rc.Parent = GetResolverContext(ctx)
return context.WithValue(ctx, resolver, rc)
return context.WithValue(ctx, resolverCtx, rc)
}

// This is just a convenient wrapper method for CollectFields
Expand Down Expand Up @@ -189,48 +161,15 @@ Next:
}

// Errorf sends an error string to the client, passing it through the formatter.
// Deprecated: use graphql.AddErrorf(ctx, err) instead
func (c *RequestContext) Errorf(ctx context.Context, format string, args ...interface{}) {
c.errorsMu.Lock()
defer c.errorsMu.Unlock()

c.Errors = append(c.Errors, c.ErrorPresenter(ctx, fmt.Errorf(format, args...)))
AddErrorf(ctx, format, args...)
}

// Error sends an error to the client, passing it through the formatter.
// Deprecated: use graphql.AddError(ctx, err) instead
func (c *RequestContext) Error(ctx context.Context, err error) {
c.errorsMu.Lock()
defer c.errorsMu.Unlock()

c.Errors = append(c.Errors, c.ErrorPresenter(ctx, err))
}

// HasError returns true if the current field has already errored
func (c *RequestContext) HasError(rctx *ResolverContext) bool {
c.errorsMu.Lock()
defer c.errorsMu.Unlock()
path := rctx.Path()

for _, err := range c.Errors {
if equalPath(err.Path, path) {
return true
}
}
return false
}

// GetErrors returns a list of errors that occurred in the current field
func (c *RequestContext) GetErrors(rctx *ResolverContext) gqlerror.List {
c.errorsMu.Lock()
defer c.errorsMu.Unlock()
path := rctx.Path()

var errs gqlerror.List
for _, err := range c.Errors {
if equalPath(err.Path, path) {
errs = append(errs, err)
}
}
return errs
AddError(ctx, err)
}

func equalPath(a []interface{}, b []interface{}) bool {
Expand All @@ -247,67 +186,6 @@ func equalPath(a []interface{}, b []interface{}) bool {
return true
}

// AddError is a convenience method for adding an error to the current response
func AddError(ctx context.Context, err error) {
GetRequestContext(ctx).Error(ctx, err)
}

// AddErrorf is a convenience method for adding an error to the current response
func AddErrorf(ctx context.Context, format string, args ...interface{}) {
GetRequestContext(ctx).Errorf(ctx, format, args...)
}

// RegisterExtension registers an extension, returns error if extension has already been registered
func (c *RequestContext) RegisterExtension(key string, value interface{}) error {
c.extensionsMu.Lock()
defer c.extensionsMu.Unlock()

if c.Extensions == nil {
c.Extensions = make(map[string]interface{})
}

if _, ok := c.Extensions[key]; ok {
return fmt.Errorf("extension already registered for key %s", key)
}

c.Extensions[key] = value
return nil
}

// ChainFieldMiddleware add chain by FieldMiddleware
func ChainFieldMiddleware(handleFunc ...FieldMiddleware) FieldMiddleware {
n := len(handleFunc)

if n > 1 {
lastI := n - 1
return func(ctx context.Context, next Resolver) (interface{}, error) {
var (
chainHandler Resolver
curI int
)
chainHandler = func(currentCtx context.Context) (interface{}, error) {
if curI == lastI {
return next(currentCtx)
}
curI++
res, err := handleFunc[curI](currentCtx, chainHandler)
curI--
return res, err

}
return handleFunc[0](ctx, chainHandler)
}
}

if n == 1 {
return handleFunc[0]
}

return func(ctx context.Context, next Resolver) (interface{}, error) {
return next(ctx)
}
}

var _ RequestContextMutator = ComplexityLimitFunc(nil)

func (c ComplexityLimitFunc) MutateRequestContext(ctx context.Context, rc *RequestContext) *gqlerror.Error {
Expand Down
12 changes: 0 additions & 12 deletions graphql/extensions.go

This file was deleted.

22 changes: 14 additions & 8 deletions graphql/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,11 @@ import (
)

type (
Handler func(ctx context.Context, writer Writer)
ResponseStream func() *Response
Writer func(Status, *Response)
Status int
OperationHandler func(ctx context.Context, writer Writer)
ResultHandler func(ctx context.Context) *Response
ResponseStream func() *Response
Writer func(Status, *Response)
Status int

RawParams struct {
Query string `json:"query"`
Expand Down Expand Up @@ -56,17 +57,22 @@ type (
MutateRequestContext(ctx context.Context, rc *RequestContext) *gqlerror.Error
}

// ResponseInterceptor is called around each graphql operation. This can be called many times in the case of
// batching and subscriptions.
ResponseInterceptor interface {
InterceptResponse(next Handler) Handler
OperationInterceptor interface {
InterceptOperation(next OperationHandler) OperationHandler
}

// ResultInterceptor is called around each graphql operation result. This can be called many times for a single
// operation the case of subscriptions.
ResultInterceptor interface {
InterceptResult(ctx context.Context, next ResultHandler) *Response
}

// FieldInterceptor called around each field
FieldInterceptor interface {
InterceptField(ctx context.Context, next Resolver) (res interface{}, err error)
}

// Transport provides support for different wire level encodings of graphql requests, eg Form, Get, Post, Websocket
Transport interface {
Supports(r *http.Request) bool
Do(w http.ResponseWriter, r *http.Request, exec GraphExecutor)
Expand Down
70 changes: 33 additions & 37 deletions graphql/handler/apollotracing/tracer.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package apollotracing

import (
"context"
"fmt"
"sync"
"time"

Expand Down Expand Up @@ -40,7 +39,7 @@ type (
}
)

var _ graphql.ResponseInterceptor = ApolloTracing{}
var _ graphql.ResultInterceptor = ApolloTracing{}
var _ graphql.FieldInterceptor = ApolloTracing{}

func New() graphql.HandlerPlugin {
Expand All @@ -49,7 +48,7 @@ func New() graphql.HandlerPlugin {

func (a ApolloTracing) InterceptField(ctx context.Context, next graphql.Resolver) (res interface{}, err error) {
rc := graphql.GetRequestContext(ctx)
td, ok := rc.Extensions["tracing"].(*TracingExtension)
td, ok := graphql.GetExtension(ctx, "tracing").(*TracingExtension)
if !ok {
panic("missing tracing extension")
}
Expand All @@ -76,39 +75,36 @@ func (a ApolloTracing) InterceptField(ctx context.Context, next graphql.Resolver
return next(ctx)
}

func (a ApolloTracing) InterceptResponse(next graphql.Handler) graphql.Handler {
return func(ctx context.Context, writer graphql.Writer) {
rc := graphql.GetRequestContext(ctx)

start := rc.Stats.OperationStart

td := &TracingExtension{
Version: 1,
StartTime: start,
Parsing: Span{
StartOffset: rc.Stats.Parsing.Start.Sub(start),
Duration: rc.Stats.Parsing.End.Sub(rc.Stats.Parsing.Start),
},

Validation: Span{
StartOffset: rc.Stats.Validation.Start.Sub(start),
Duration: rc.Stats.Validation.End.Sub(rc.Stats.Validation.Start),
},

Execution: struct {
Resolvers []ResolverExecution `json:"resolvers"`
}{},
}

if err := rc.RegisterExtension("tracing", td); err != nil {
panic(fmt.Errorf("unable to register extension: %s", err.Error()))
}

next(ctx, func(status graphql.Status, response *graphql.Response) {
end := graphql.Now()
td.EndTime = end
td.Duration = end.Sub(start)
writer(status, response)
})
func (a ApolloTracing) InterceptResult(ctx context.Context, next graphql.ResultHandler) *graphql.Response {
rc := graphql.GetRequestContext(ctx)

start := rc.Stats.OperationStart

td := &TracingExtension{
Version: 1,
StartTime: start,
Parsing: Span{
StartOffset: rc.Stats.Parsing.Start.Sub(start),
Duration: rc.Stats.Parsing.End.Sub(rc.Stats.Parsing.Start),
},

Validation: Span{
StartOffset: rc.Stats.Validation.Start.Sub(start),
Duration: rc.Stats.Validation.End.Sub(rc.Stats.Validation.Start),
},

Execution: struct {
Resolvers []ResolverExecution `json:"resolvers"`
}{},
}

graphql.RegisterExtension(ctx, "tracing", td)

resp := next(ctx)

end := graphql.Now()
td.EndTime = end
td.Duration = end.Sub(start)

return resp
}
Loading

0 comments on commit ab5665a

Please sign in to comment.