Skip to content

Commit

Permalink
Merge pull request #828 from 99designs/feat-rc
Browse files Browse the repository at this point in the history
introduce RequestContext#Validate and use it instead of NewRequestContext function
  • Loading branch information
vektah authored Sep 24, 2019
2 parents 17f32d2 + 82758be commit cc64f33
Show file tree
Hide file tree
Showing 3 changed files with 82 additions and 44 deletions.
64 changes: 51 additions & 13 deletions graphql/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package graphql

import (
"context"
"errors"
"fmt"
"sync"

Expand All @@ -15,9 +16,10 @@ type RequestMiddleware func(ctx context.Context, next func(ctx context.Context)
type ComplexityLimitFunc func(ctx context.Context) int

type RequestContext struct {
RawQuery string
Variables map[string]interface{}
Doc *ast.QueryDocument
RawQuery string
Variables map[string]interface{}
OperationName string
Doc *ast.QueryDocument

ComplexityLimit int
OperationComplexity int
Expand All @@ -38,6 +40,41 @@ type RequestContext struct {
Extensions map[string]interface{}
}

func (rc *RequestContext) Validate(ctx context.Context) error {
if rc.Doc == nil {
return errors.New("field 'Doc' must be required")
}
if rc.RawQuery == "" {
return errors.New("field 'RawQuery' must be required")
}
if rc.Variables == nil {
rc.Variables = make(map[string]interface{})
}
if rc.ResolverMiddleware == nil {
rc.ResolverMiddleware = DefaultResolverMiddleware
}
if rc.DirectiveMiddleware == nil {
rc.DirectiveMiddleware = DefaultDirectiveMiddleware
}
if rc.RequestMiddleware == nil {
rc.RequestMiddleware = DefaultRequestMiddleware
}
if rc.Recover == nil {
rc.Recover = DefaultRecover
}
if rc.ErrorPresenter == nil {
rc.ErrorPresenter = DefaultErrorPresenter
}
if rc.Tracer == nil {
rc.Tracer = &NopTracer{}
}
if rc.ComplexityLimit < 0 {
return errors.New("field 'ComplexityLimit' value must be 0 or more")
}

return nil
}

func DefaultResolverMiddleware(ctx context.Context, next Resolver) (res interface{}, err error) {
return next(ctx)
}
Expand All @@ -50,18 +87,19 @@ func DefaultRequestMiddleware(ctx context.Context, next func(ctx context.Context
return next(ctx)
}

// Deprecated: construct RequestContext directly & call Validate method.
func NewRequestContext(doc *ast.QueryDocument, query string, variables map[string]interface{}) *RequestContext {
return &RequestContext{
Doc: doc,
RawQuery: query,
Variables: variables,
ResolverMiddleware: DefaultResolverMiddleware,
DirectiveMiddleware: DefaultDirectiveMiddleware,
RequestMiddleware: DefaultRequestMiddleware,
Recover: DefaultRecover,
ErrorPresenter: DefaultErrorPresenter,
Tracer: &NopTracer{},
rc := &RequestContext{
Doc: doc,
RawQuery: query,
Variables: variables,
}
err := rc.Validate(context.Background())
if err != nil {
panic(err)
}

return rc
}

type key string
Expand Down
56 changes: 26 additions & 30 deletions handler/graphql.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,37 +72,29 @@ type Config struct {
apqCache PersistedQueryCache
}

func (c *Config) newRequestContext(es graphql.ExecutableSchema, doc *ast.QueryDocument, op *ast.OperationDefinition, query string, variables map[string]interface{}) *graphql.RequestContext {
reqCtx := graphql.NewRequestContext(doc, query, variables)
reqCtx.DisableIntrospection = c.disableIntrospection

if hook := c.recover; hook != nil {
reqCtx.Recover = hook
}

if hook := c.errorPresenter; hook != nil {
reqCtx.ErrorPresenter = hook
}

if hook := c.resolverHook; hook != nil {
reqCtx.ResolverMiddleware = hook
}

if hook := c.requestHook; hook != nil {
reqCtx.RequestMiddleware = hook
}

if hook := c.tracer; hook != nil {
reqCtx.Tracer = hook
}

if c.complexityLimit > 0 || c.complexityLimitFunc != nil {
reqCtx.ComplexityLimit = c.complexityLimit
operationComplexity := complexity.Calculate(es, op, variables)
reqCtx.OperationComplexity = operationComplexity
func (c *Config) newRequestContext(ctx context.Context, es graphql.ExecutableSchema, doc *ast.QueryDocument, op *ast.OperationDefinition, operationName, query string, variables map[string]interface{}) (*graphql.RequestContext, error) {
reqCtx := &graphql.RequestContext{
Doc: doc,
RawQuery: query,
Variables: variables,
OperationName: operationName,
DisableIntrospection: c.disableIntrospection,
Recover: c.recover,
ErrorPresenter: c.errorPresenter,
ResolverMiddleware: c.resolverHook,
RequestMiddleware: c.requestHook,
Tracer: c.tracer,
ComplexityLimit: c.complexityLimit,
}
if reqCtx.ComplexityLimit > 0 || c.complexityLimitFunc != nil {
reqCtx.OperationComplexity = complexity.Calculate(es, op, variables)
}
err := reqCtx.Validate(ctx)
if err != nil {
return nil, err
}

return reqCtx
return reqCtx, nil
}

type Option func(cfg *Config)
Expand Down Expand Up @@ -532,7 +524,11 @@ func (gh *graphqlHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
gh.cache.Add(reqParams.Query, doc)
}

reqCtx := gh.cfg.newRequestContext(gh.exec, doc, op, reqParams.Query, vars)
reqCtx, err := gh.cfg.newRequestContext(ctx, gh.exec, doc, op, reqParams.OperationName, reqParams.Query, vars)
if err != nil {
sendErrorf(w, http.StatusBadRequest, "invalid RequestContext was generated: %s", err.Error())
return
}
ctx = graphql.WithRequestContext(ctx, reqCtx)

defer func() {
Expand Down
6 changes: 5 additions & 1 deletion handler/websocket.go
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,11 @@ func (c *wsConnection) subscribe(message *operationMessage) bool {
c.sendError(message.ID, err)
return true
}
reqCtx := c.cfg.newRequestContext(c.exec, doc, op, reqParams.Query, vars)
reqCtx, err2 := c.cfg.newRequestContext(c.ctx, c.exec, doc, op, reqParams.OperationName, reqParams.Query, vars)
if err2 != nil {
c.sendError(message.ID, gqlerror.Errorf(err2.Error()))
return true
}
ctx := graphql.WithRequestContext(c.ctx, reqCtx)

if c.initPayload != nil {
Expand Down

0 comments on commit cc64f33

Please sign in to comment.