Skip to content

Commit

Permalink
use plugins instead of middleware so multiple hooks can be configured
Browse files Browse the repository at this point in the history
  • Loading branch information
vektah committed Oct 29, 2019
1 parent a7c5e66 commit f00e5fa
Show file tree
Hide file tree
Showing 24 changed files with 403 additions and 458 deletions.
27 changes: 7 additions & 20 deletions graphql/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ import (

type Resolver func(ctx context.Context) (res interface{}, err error)
type FieldMiddleware func(ctx context.Context, next Resolver) (res interface{}, err error)
type RequestMiddleware func(ctx context.Context, next func(ctx context.Context) []byte) []byte
type ComplexityLimitFunc func(ctx context.Context) int

type RequestContext struct {
Expand Down Expand Up @@ -56,9 +55,6 @@ func (rc *RequestContext) Validate(ctx context.Context) error {
if rc.DirectiveMiddleware == nil {
rc.DirectiveMiddleware = DefaultDirectiveMiddleware
}
if rc.RequestMiddleware == nil {
rc.RequestMiddleware = DefaultRequestMiddleware
}
if rc.Recover == nil {
rc.Recover = DefaultRecover
}
Expand All @@ -75,22 +71,6 @@ func (rc *RequestContext) Validate(ctx context.Context) error {
return nil
}

// AddRequestMiddleware allows you to define a function that will be called around the root request,
// after the query has been parsed. This is useful for logging
func (cfg *RequestContext) AddRequestMiddleware(middleware RequestMiddleware) {
if cfg.RequestMiddleware == nil {
cfg.RequestMiddleware = middleware
return
}

lastResolve := cfg.RequestMiddleware
cfg.RequestMiddleware = func(ctx context.Context, next func(ctx context.Context) []byte) []byte {
return lastResolve(ctx, func(ctx context.Context) []byte {
return middleware(ctx, next)
})
}
}

func (cfg *RequestContext) AddTracer(tracer Tracer) {
if cfg.Tracer == nil {
cfg.Tracer = tracer
Expand Down Expand Up @@ -338,3 +318,10 @@ func ChainFieldMiddleware(handleFunc ...FieldMiddleware) FieldMiddleware {
return next(ctx)
}
}

var _ RequestContextMutator = ComplexityLimitFunc(nil)

func (c ComplexityLimitFunc) MutateRequestContext(ctx context.Context, rc *RequestContext) *gqlerror.Error {
rc.ComplexityLimit = c(ctx)
return nil
}
7 changes: 7 additions & 0 deletions graphql/error.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,10 @@ func DefaultErrorPresenter(ctx context.Context, err error) *gqlerror.Error {
Extensions: extensions,
}
}

var _ RequestContextMutator = ErrorPresenterFunc(nil)

func (f ErrorPresenterFunc) MutateRequestContext(ctx context.Context, rc *RequestContext) *gqlerror.Error {
rc.ErrorPresenter = f
return nil
}
36 changes: 34 additions & 2 deletions graphql/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,40 @@ import (

type (
Handler func(ctx context.Context, writer Writer)
Middleware func(next Handler) Handler
ResponseStream func() *Response
Writer func(Status, *Response)
Status int

RawParams struct {
Query string `json:"query"`
OperationName string `json:"operationName"`
Variables map[string]interface{} `json:"variables"`
Extensions map[string]interface{} `json:"extensions"`
}

GraphExecutor interface {
CreateRequestContext(ctx context.Context, params *RawParams) (*RequestContext, gqlerror.List)
DispatchRequest(ctx context.Context, writer Writer)
}

// HandlerPlugin interface is entirely optional, see the list of possible hook points below
HandlerPlugin interface{}

RequestMutator interface {
MutateRequest(ctx context.Context, request *RawParams) *gqlerror.Error
}

RequestContextMutator interface {
MutateRequestContext(ctx context.Context, rc *RequestContext) *gqlerror.Error
}

RequestMiddleware interface {
InterceptRequest(next Handler) Handler
}

Transport interface {
Supports(r *http.Request) bool
Do(w http.ResponseWriter, r *http.Request, handler Handler)
Do(w http.ResponseWriter, r *http.Request, exec GraphExecutor)
}
)

Expand All @@ -39,3 +65,9 @@ func (w Writer) Error(msg string) {
Errors: gqlerror.List{{Message: msg}},
})
}

func (w Writer) GraphqlErr(err ...*gqlerror.Error) {
w(StatusResolverError, &Response{
Errors: err,
})
}
152 changes: 152 additions & 0 deletions graphql/handler/executor.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
package handler

import (
"context"

"github.com/99designs/gqlgen/graphql"
"github.com/vektah/gqlparser/ast"
"github.com/vektah/gqlparser/gqlerror"
"github.com/vektah/gqlparser/parser"
"github.com/vektah/gqlparser/validator"
)

type executor struct {
handler graphql.Handler
es graphql.ExecutableSchema
requestMutators []graphql.RequestMutator
requestContextMutators []graphql.RequestContextMutator
}

var _ graphql.GraphExecutor = executor{}

func newExecutor(es graphql.ExecutableSchema, plugins []graphql.HandlerPlugin) executor {
e := executor{
es: es,
}
handler := e.executableSchemaHandler
// this loop goes backwards so the first plugin is the outer most middleware and runs first.
for i := len(plugins) - 1; i >= 0; i-- {
p := plugins[i]
if p, ok := p.(graphql.RequestMiddleware); ok {
handler = p.InterceptRequest(handler)
}
}

for _, p := range plugins {

if p, ok := p.(graphql.RequestMutator); ok {
e.requestMutators = append(e.requestMutators, p)
}

if p, ok := p.(graphql.RequestContextMutator); ok {
e.requestContextMutators = append(e.requestContextMutators, p)
}
}

e.handler = handler

return e
}

func (e executor) DispatchRequest(ctx context.Context, writer graphql.Writer) {
e.handler(ctx, writer)
}

func (e executor) CreateRequestContext(ctx context.Context, params *graphql.RawParams) (*graphql.RequestContext, gqlerror.List) {
for _, p := range e.requestMutators {
if err := p.MutateRequest(ctx, params); err != nil {
return nil, gqlerror.List{err}
}
}

var gerr *gqlerror.Error

rc := &graphql.RequestContext{
DisableIntrospection: true,
Recover: graphql.DefaultRecover,
ErrorPresenter: graphql.DefaultErrorPresenter,
ResolverMiddleware: nil,
RequestMiddleware: nil,
Tracer: graphql.NopTracer{},
ComplexityLimit: 0,
RawQuery: params.Query,
OperationName: params.OperationName,
Variables: params.Variables,
Extensions: params.Extensions,
}

rc.Doc, gerr = e.parseOperation(ctx, rc)
if gerr != nil {
return nil, []*gqlerror.Error{gerr}
}

ctx, op, listErr := e.validateOperation(ctx, rc)
if len(listErr) != 0 {
return nil, listErr
}

vars, err := validator.VariableValues(e.es.Schema(), op, rc.Variables)
if err != nil {
return nil, gqlerror.List{err}
}

rc.Variables = vars

for _, p := range e.requestContextMutators {
if err := p.MutateRequestContext(ctx, rc); err != nil {
return nil, gqlerror.List{err}
}
}

return rc, nil
}

// executableSchemaHandler is the inner most handler, it invokes the graph directly after all middleware
// and sends responses to the transport so it can be returned to the client
func (e *executor) executableSchemaHandler(ctx context.Context, write graphql.Writer) {
rc := graphql.GetRequestContext(ctx)

op := rc.Doc.Operations.ForName(rc.OperationName)

switch op.Operation {
case ast.Query:
resp := e.es.Query(ctx, op)

write(getStatus(resp), resp)
case ast.Mutation:
resp := e.es.Mutation(ctx, op)
write(getStatus(resp), resp)
case ast.Subscription:
resp := e.es.Subscription(ctx, op)

for w := resp(); w != nil; w = resp() {
write(getStatus(w), w)
}
default:
write(graphql.StatusValidationError, graphql.ErrorResponse(ctx, "unsupported GraphQL operation"))
}
}

func (e executor) parseOperation(ctx context.Context, rc *graphql.RequestContext) (*ast.QueryDocument, *gqlerror.Error) {
ctx = rc.Tracer.StartOperationValidation(ctx)
defer func() { rc.Tracer.EndOperationValidation(ctx) }()

return parser.ParseQuery(&ast.Source{Input: rc.RawQuery})
}

func (e executor) validateOperation(ctx context.Context, rc *graphql.RequestContext) (context.Context, *ast.OperationDefinition, gqlerror.List) {
ctx = rc.Tracer.StartOperationValidation(ctx)
defer func() { rc.Tracer.EndOperationValidation(ctx) }()

listErr := validator.Validate(e.es.Schema(), rc.Doc)
if len(listErr) != 0 {
return ctx, nil, listErr
}

op := rc.Doc.Operations.ForName(rc.OperationName)
if op == nil {
return ctx, nil, gqlerror.List{gqlerror.Errorf("operation %s not found", rc.OperationName)}
}

return ctx, op, nil
}
69 changes: 32 additions & 37 deletions graphql/handler/middleware/apq.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"crypto/sha256"
"encoding/hex"
"errors"

"github.com/99designs/gqlgen/graphql"
"github.com/mitchellh/mapstructure"
Expand All @@ -18,50 +19,44 @@ const (
// does not yet know what the query is for the hash it will respond telling the client to send the query along with the
// hash in the next request.
// see https://github.com/apollographql/apollo-link-persisted-queries
func AutomaticPersistedQuery(cache graphql.Cache) graphql.Middleware {
return func(next graphql.Handler) graphql.Handler {
return func(ctx context.Context, writer graphql.Writer) {
rc := graphql.GetRequestContext(ctx)
type AutomaticPersistedQuery struct {
Cache graphql.Cache
}

if rc.Extensions["persistedQuery"] == nil {
next(ctx, writer)
return
}
func (a AutomaticPersistedQuery) MutateRequest(ctx context.Context, rawParams *graphql.RawParams) error {
if rawParams.Extensions["persistedQuery"] == nil {
return nil
}

var extension struct {
Sha256 string `json:"sha256Hash"`
Version int64 `json:"version"`
}
var extension struct {
Sha256 string `json:"sha256Hash"`
Version int64 `json:"version"`
}

if err := mapstructure.Decode(rc.Extensions["persistedQuery"], &extension); err != nil {
writer.Error("Invalid APQ extension data")
return
}
if err := mapstructure.Decode(rawParams.Extensions["persistedQuery"], &extension); err != nil {
return errors.New("Invalid APQ extension data")
}

if extension.Version != 1 {
writer.Error("Unsupported APQ version")
return
}
if extension.Version != 1 {
return errors.New("Unsupported APQ version")
}

if rc.RawQuery == "" {
// client sent optimistic query hash without query string, get it from the cache
query, ok := cache.Get(extension.Sha256)
if !ok {
writer.Error(errPersistedQueryNotFound)
return
}
rc.RawQuery = query.(string)
} else {
// client sent optimistic query hash with query string, verify and store it
if computeQueryHash(rc.RawQuery) != extension.Sha256 {
writer.Error("Provided APQ hash does not match query")
return
}
cache.Add(extension.Sha256, rc.RawQuery)
}
next(ctx, writer)
if rawParams.Query == "" {
// client sent optimistic query hash without query string, get it from the cache
query, ok := a.Cache.Get(extension.Sha256)
if !ok {
return errors.New(errPersistedQueryNotFound)
}
rawParams.Query = query.(string)
} else {
// client sent optimistic query hash with query string, verify and store it
if computeQueryHash(rawParams.Query) != extension.Sha256 {
return errors.New("Provided APQ hash does not match query")
}
a.Cache.Add(extension.Sha256, rawParams.Query)
}

return nil
}

func computeQueryHash(query string) string {
Expand Down
Loading

0 comments on commit f00e5fa

Please sign in to comment.