diff --git a/graphql/context.go b/graphql/context.go index cee727e5eaf..91778a5dd87 100644 --- a/graphql/context.go +++ b/graphql/context.go @@ -12,7 +12,6 @@ import ( type Resolver func(ctx context.Context) (res interface{}, err error) type FieldMiddleware func(ctx context.Context, next Resolver) (res interface{}, err error) -type RequestMiddleware func(ctx context.Context, next func(ctx context.Context) []byte) []byte type ComplexityLimitFunc func(ctx context.Context) int type RequestContext struct { @@ -56,9 +55,6 @@ func (rc *RequestContext) Validate(ctx context.Context) error { if rc.DirectiveMiddleware == nil { rc.DirectiveMiddleware = DefaultDirectiveMiddleware } - if rc.RequestMiddleware == nil { - rc.RequestMiddleware = DefaultRequestMiddleware - } if rc.Recover == nil { rc.Recover = DefaultRecover } @@ -75,22 +71,6 @@ func (rc *RequestContext) Validate(ctx context.Context) error { return nil } -// AddRequestMiddleware allows you to define a function that will be called around the root request, -// after the query has been parsed. This is useful for logging -func (cfg *RequestContext) AddRequestMiddleware(middleware RequestMiddleware) { - if cfg.RequestMiddleware == nil { - cfg.RequestMiddleware = middleware - return - } - - lastResolve := cfg.RequestMiddleware - cfg.RequestMiddleware = func(ctx context.Context, next func(ctx context.Context) []byte) []byte { - return lastResolve(ctx, func(ctx context.Context) []byte { - return middleware(ctx, next) - }) - } -} - func (cfg *RequestContext) AddTracer(tracer Tracer) { if cfg.Tracer == nil { cfg.Tracer = tracer @@ -338,3 +318,10 @@ func ChainFieldMiddleware(handleFunc ...FieldMiddleware) FieldMiddleware { return next(ctx) } } + +var _ RequestContextMutator = ComplexityLimitFunc(nil) + +func (c ComplexityLimitFunc) MutateRequestContext(ctx context.Context, rc *RequestContext) *gqlerror.Error { + rc.ComplexityLimit = c(ctx) + return nil +} diff --git a/graphql/error.go b/graphql/error.go index af8b4ce4088..7d00ade2a5b 100644 --- a/graphql/error.go +++ b/graphql/error.go @@ -31,3 +31,10 @@ func DefaultErrorPresenter(ctx context.Context, err error) *gqlerror.Error { Extensions: extensions, } } + +var _ RequestContextMutator = ErrorPresenterFunc(nil) + +func (f ErrorPresenterFunc) MutateRequestContext(ctx context.Context, rc *RequestContext) *gqlerror.Error { + rc.ErrorPresenter = f + return nil +} diff --git a/graphql/handler.go b/graphql/handler.go index dbe6a3e1220..1f1c052a16e 100644 --- a/graphql/handler.go +++ b/graphql/handler.go @@ -10,14 +10,40 @@ import ( type ( Handler func(ctx context.Context, writer Writer) - Middleware func(next Handler) Handler ResponseStream func() *Response Writer func(Status, *Response) Status int + RawParams struct { + Query string `json:"query"` + OperationName string `json:"operationName"` + Variables map[string]interface{} `json:"variables"` + Extensions map[string]interface{} `json:"extensions"` + } + + GraphExecutor interface { + CreateRequestContext(ctx context.Context, params *RawParams) (*RequestContext, gqlerror.List) + DispatchRequest(ctx context.Context, writer Writer) + } + + // HandlerPlugin interface is entirely optional, see the list of possible hook points below + HandlerPlugin interface{} + + RequestMutator interface { + MutateRequest(ctx context.Context, request *RawParams) *gqlerror.Error + } + + RequestContextMutator interface { + MutateRequestContext(ctx context.Context, rc *RequestContext) *gqlerror.Error + } + + RequestMiddleware interface { + InterceptRequest(next Handler) Handler + } + Transport interface { Supports(r *http.Request) bool - Do(w http.ResponseWriter, r *http.Request, handler Handler) + Do(w http.ResponseWriter, r *http.Request, exec GraphExecutor) } ) @@ -39,3 +65,9 @@ func (w Writer) Error(msg string) { Errors: gqlerror.List{{Message: msg}}, }) } + +func (w Writer) GraphqlErr(err ...*gqlerror.Error) { + w(StatusResolverError, &Response{ + Errors: err, + }) +} diff --git a/graphql/handler/executor.go b/graphql/handler/executor.go new file mode 100644 index 00000000000..273bffe2d56 --- /dev/null +++ b/graphql/handler/executor.go @@ -0,0 +1,152 @@ +package handler + +import ( + "context" + + "github.com/99designs/gqlgen/graphql" + "github.com/vektah/gqlparser/ast" + "github.com/vektah/gqlparser/gqlerror" + "github.com/vektah/gqlparser/parser" + "github.com/vektah/gqlparser/validator" +) + +type executor struct { + handler graphql.Handler + es graphql.ExecutableSchema + requestMutators []graphql.RequestMutator + requestContextMutators []graphql.RequestContextMutator +} + +var _ graphql.GraphExecutor = executor{} + +func newExecutor(es graphql.ExecutableSchema, plugins []graphql.HandlerPlugin) executor { + e := executor{ + es: es, + } + handler := e.executableSchemaHandler + // this loop goes backwards so the first plugin is the outer most middleware and runs first. + for i := len(plugins) - 1; i >= 0; i-- { + p := plugins[i] + if p, ok := p.(graphql.RequestMiddleware); ok { + handler = p.InterceptRequest(handler) + } + } + + for _, p := range plugins { + + if p, ok := p.(graphql.RequestMutator); ok { + e.requestMutators = append(e.requestMutators, p) + } + + if p, ok := p.(graphql.RequestContextMutator); ok { + e.requestContextMutators = append(e.requestContextMutators, p) + } + } + + e.handler = handler + + return e +} + +func (e executor) DispatchRequest(ctx context.Context, writer graphql.Writer) { + e.handler(ctx, writer) +} + +func (e executor) CreateRequestContext(ctx context.Context, params *graphql.RawParams) (*graphql.RequestContext, gqlerror.List) { + for _, p := range e.requestMutators { + if err := p.MutateRequest(ctx, params); err != nil { + return nil, gqlerror.List{err} + } + } + + var gerr *gqlerror.Error + + rc := &graphql.RequestContext{ + DisableIntrospection: true, + Recover: graphql.DefaultRecover, + ErrorPresenter: graphql.DefaultErrorPresenter, + ResolverMiddleware: nil, + RequestMiddleware: nil, + Tracer: graphql.NopTracer{}, + ComplexityLimit: 0, + RawQuery: params.Query, + OperationName: params.OperationName, + Variables: params.Variables, + Extensions: params.Extensions, + } + + rc.Doc, gerr = e.parseOperation(ctx, rc) + if gerr != nil { + return nil, []*gqlerror.Error{gerr} + } + + ctx, op, listErr := e.validateOperation(ctx, rc) + if len(listErr) != 0 { + return nil, listErr + } + + vars, err := validator.VariableValues(e.es.Schema(), op, rc.Variables) + if err != nil { + return nil, gqlerror.List{err} + } + + rc.Variables = vars + + for _, p := range e.requestContextMutators { + if err := p.MutateRequestContext(ctx, rc); err != nil { + return nil, gqlerror.List{err} + } + } + + return rc, nil +} + +// executableSchemaHandler is the inner most handler, it invokes the graph directly after all middleware +// and sends responses to the transport so it can be returned to the client +func (e *executor) executableSchemaHandler(ctx context.Context, write graphql.Writer) { + rc := graphql.GetRequestContext(ctx) + + op := rc.Doc.Operations.ForName(rc.OperationName) + + switch op.Operation { + case ast.Query: + resp := e.es.Query(ctx, op) + + write(getStatus(resp), resp) + case ast.Mutation: + resp := e.es.Mutation(ctx, op) + write(getStatus(resp), resp) + case ast.Subscription: + resp := e.es.Subscription(ctx, op) + + for w := resp(); w != nil; w = resp() { + write(getStatus(w), w) + } + default: + write(graphql.StatusValidationError, graphql.ErrorResponse(ctx, "unsupported GraphQL operation")) + } +} + +func (e executor) parseOperation(ctx context.Context, rc *graphql.RequestContext) (*ast.QueryDocument, *gqlerror.Error) { + ctx = rc.Tracer.StartOperationValidation(ctx) + defer func() { rc.Tracer.EndOperationValidation(ctx) }() + + return parser.ParseQuery(&ast.Source{Input: rc.RawQuery}) +} + +func (e executor) validateOperation(ctx context.Context, rc *graphql.RequestContext) (context.Context, *ast.OperationDefinition, gqlerror.List) { + ctx = rc.Tracer.StartOperationValidation(ctx) + defer func() { rc.Tracer.EndOperationValidation(ctx) }() + + listErr := validator.Validate(e.es.Schema(), rc.Doc) + if len(listErr) != 0 { + return ctx, nil, listErr + } + + op := rc.Doc.Operations.ForName(rc.OperationName) + if op == nil { + return ctx, nil, gqlerror.List{gqlerror.Errorf("operation %s not found", rc.OperationName)} + } + + return ctx, op, nil +} diff --git a/graphql/handler/middleware/apq.go b/graphql/handler/middleware/apq.go index 62315c9b49c..3658b2915e9 100644 --- a/graphql/handler/middleware/apq.go +++ b/graphql/handler/middleware/apq.go @@ -4,6 +4,7 @@ import ( "context" "crypto/sha256" "encoding/hex" + "errors" "github.com/99designs/gqlgen/graphql" "github.com/mitchellh/mapstructure" @@ -18,50 +19,44 @@ const ( // does not yet know what the query is for the hash it will respond telling the client to send the query along with the // hash in the next request. // see https://github.com/apollographql/apollo-link-persisted-queries -func AutomaticPersistedQuery(cache graphql.Cache) graphql.Middleware { - return func(next graphql.Handler) graphql.Handler { - return func(ctx context.Context, writer graphql.Writer) { - rc := graphql.GetRequestContext(ctx) +type AutomaticPersistedQuery struct { + Cache graphql.Cache +} - if rc.Extensions["persistedQuery"] == nil { - next(ctx, writer) - return - } +func (a AutomaticPersistedQuery) MutateRequest(ctx context.Context, rawParams *graphql.RawParams) error { + if rawParams.Extensions["persistedQuery"] == nil { + return nil + } - var extension struct { - Sha256 string `json:"sha256Hash"` - Version int64 `json:"version"` - } + var extension struct { + Sha256 string `json:"sha256Hash"` + Version int64 `json:"version"` + } - if err := mapstructure.Decode(rc.Extensions["persistedQuery"], &extension); err != nil { - writer.Error("Invalid APQ extension data") - return - } + if err := mapstructure.Decode(rawParams.Extensions["persistedQuery"], &extension); err != nil { + return errors.New("Invalid APQ extension data") + } - if extension.Version != 1 { - writer.Error("Unsupported APQ version") - return - } + if extension.Version != 1 { + return errors.New("Unsupported APQ version") + } - if rc.RawQuery == "" { - // client sent optimistic query hash without query string, get it from the cache - query, ok := cache.Get(extension.Sha256) - if !ok { - writer.Error(errPersistedQueryNotFound) - return - } - rc.RawQuery = query.(string) - } else { - // client sent optimistic query hash with query string, verify and store it - if computeQueryHash(rc.RawQuery) != extension.Sha256 { - writer.Error("Provided APQ hash does not match query") - return - } - cache.Add(extension.Sha256, rc.RawQuery) - } - next(ctx, writer) + if rawParams.Query == "" { + // client sent optimistic query hash without query string, get it from the cache + query, ok := a.Cache.Get(extension.Sha256) + if !ok { + return errors.New(errPersistedQueryNotFound) + } + rawParams.Query = query.(string) + } else { + // client sent optimistic query hash with query string, verify and store it + if computeQueryHash(rawParams.Query) != extension.Sha256 { + return errors.New("Provided APQ hash does not match query") } + a.Cache.Add(extension.Sha256, rawParams.Query) } + + return nil } func computeQueryHash(query string) string { diff --git a/graphql/handler/middleware/apq_test.go b/graphql/handler/middleware/apq_test.go index d1a872cf04b..e3c7fc09d32 100644 --- a/graphql/handler/middleware/apq_test.go +++ b/graphql/handler/middleware/apq_test.go @@ -1,6 +1,7 @@ package middleware import ( + "context" "testing" "github.com/99designs/gqlgen/graphql" @@ -12,117 +13,118 @@ func TestAPQ(t *testing.T) { const hash = "b8d9506e34c83b0e53c2aa463624fcea354713bc38f95276e6f0bd893ffb5b88" t.Run("with query and no hash", func(t *testing.T) { - rc := testMiddleware(AutomaticPersistedQuery(graphql.MapCache{}), graphql.RequestContext{ - RawQuery: "original query", - }) + params := &graphql.RawParams{ + Query: "original query", + } + err := AutomaticPersistedQuery{graphql.MapCache{}}.MutateRequest(context.Background(), params) + require.NoError(t, err) - require.True(t, rc.InvokedNext) - require.Equal(t, "original query", rc.ResultContext.RawQuery) + require.Equal(t, "original query", params.Query) }) t.Run("with hash miss and no query", func(t *testing.T) { - rc := testMiddleware(AutomaticPersistedQuery(graphql.MapCache{}), graphql.RequestContext{ - RawQuery: "", + params := &graphql.RawParams{ Extensions: map[string]interface{}{ "persistedQuery": map[string]interface{}{ "sha256": hash, "version": 1, }, }, - }) + } - require.False(t, rc.InvokedNext) - require.Equal(t, "PersistedQueryNotFound", rc.Response.Errors[0].Message) + err := AutomaticPersistedQuery{graphql.MapCache{}}.MutateRequest(context.Background(), params) + require.EqualError(t, err, "PersistedQueryNotFound") }) t.Run("with hash miss and query", func(t *testing.T) { - cache := graphql.MapCache{} - rc := testMiddleware(AutomaticPersistedQuery(cache), graphql.RequestContext{ - RawQuery: query, + params := &graphql.RawParams{ + Query: query, Extensions: map[string]interface{}{ "persistedQuery": map[string]interface{}{ "sha256": hash, "version": 1, }, }, - }) + } + cache := graphql.MapCache{} + err := AutomaticPersistedQuery{cache}.MutateRequest(context.Background(), params) + require.NoError(t, err) - require.True(t, rc.InvokedNext, rc.Response.Errors) - require.Equal(t, "{ me { name } }", rc.ResultContext.RawQuery) + require.Equal(t, "{ me { name } }", params.Query) require.Equal(t, "{ me { name } }", cache[hash]) }) t.Run("with hash miss and query", func(t *testing.T) { - cache := graphql.MapCache{} - rc := testMiddleware(AutomaticPersistedQuery(cache), graphql.RequestContext{ - RawQuery: query, + params := &graphql.RawParams{ + Query: query, Extensions: map[string]interface{}{ "persistedQuery": map[string]interface{}{ "sha256": hash, "version": 1, }, }, - }) + } + cache := graphql.MapCache{} + err := AutomaticPersistedQuery{cache}.MutateRequest(context.Background(), params) + require.NoError(t, err) - require.True(t, rc.InvokedNext, rc.Response.Errors) - require.Equal(t, "{ me { name } }", rc.ResultContext.RawQuery) + require.Equal(t, "{ me { name } }", params.Query) require.Equal(t, "{ me { name } }", cache[hash]) }) t.Run("with hash hit and no query", func(t *testing.T) { - cache := graphql.MapCache{ - hash: query, - } - rc := testMiddleware(AutomaticPersistedQuery(cache), graphql.RequestContext{ - RawQuery: "", + params := &graphql.RawParams{ Extensions: map[string]interface{}{ "persistedQuery": map[string]interface{}{ "sha256": hash, "version": 1, }, }, - }) + } + cache := graphql.MapCache{ + hash: query, + } + err := AutomaticPersistedQuery{cache}.MutateRequest(context.Background(), params) + require.NoError(t, err) - require.True(t, rc.InvokedNext, rc.Response.Errors) - require.Equal(t, "{ me { name } }", rc.ResultContext.RawQuery) + require.Equal(t, "{ me { name } }", params.Query) }) t.Run("with malformed extension payload", func(t *testing.T) { - rc := testMiddleware(AutomaticPersistedQuery(graphql.MapCache{}), graphql.RequestContext{ + params := &graphql.RawParams{ Extensions: map[string]interface{}{ "persistedQuery": "asdf", }, - }) + } - require.False(t, rc.InvokedNext) - require.Equal(t, "Invalid APQ extension data", rc.Response.Errors[0].Message) + err := AutomaticPersistedQuery{graphql.MapCache{}}.MutateRequest(context.Background(), params) + require.EqualError(t, err, "Invalid APQ extension data") }) t.Run("with invalid extension version", func(t *testing.T) { - rc := testMiddleware(AutomaticPersistedQuery(graphql.MapCache{}), graphql.RequestContext{ + params := &graphql.RawParams{ Extensions: map[string]interface{}{ "persistedQuery": map[string]interface{}{ "version": 2, }, }, - }) - - require.False(t, rc.InvokedNext) - require.Equal(t, "Unsupported APQ version", rc.Response.Errors[0].Message) + } + err := AutomaticPersistedQuery{graphql.MapCache{}}.MutateRequest(context.Background(), params) + require.EqualError(t, err, "Unsupported APQ version") }) t.Run("with hash mismatch", func(t *testing.T) { - rc := testMiddleware(AutomaticPersistedQuery(graphql.MapCache{}), graphql.RequestContext{ - RawQuery: query, + params := &graphql.RawParams{ + Query: query, Extensions: map[string]interface{}{ "persistedQuery": map[string]interface{}{ "sha256": "badhash", "version": 1, }, }, - }) + } - require.False(t, rc.InvokedNext) - require.Equal(t, "Provided APQ hash does not match query", rc.Response.Errors[0].Message) + err := AutomaticPersistedQuery{graphql.MapCache{}}.MutateRequest(context.Background(), params) + require.EqualError(t, err, "Provided APQ hash does not match query") }) } diff --git a/graphql/handler/middleware/complexity.go b/graphql/handler/middleware/complexity.go index fc65e3bebcd..846227d9437 100644 --- a/graphql/handler/middleware/complexity.go +++ b/graphql/handler/middleware/complexity.go @@ -4,30 +4,17 @@ import ( "context" "github.com/99designs/gqlgen/graphql" + "github.com/vektah/gqlparser/gqlerror" ) // ComplexityLimit sets a maximum query complexity that is allowed to be executed. // // If a query is submitted that exceeds the limit, a 422 status code will be returned. -func ComplexityLimit(limit int) graphql.Middleware { - return func(next graphql.Handler) graphql.Handler { - return func(ctx context.Context, writer graphql.Writer) { - graphql.GetRequestContext(ctx).ComplexityLimit = limit - next(ctx, writer) - } - } -} +type ComplexityLimit int -// ComplexityLimitFunc allows you to define a function to dynamically set the maximum query complexity that is allowed -// to be executed. This is mostly just a wrapper to preserve the old interface, consider writing your own middleware -// instead. -// -// If a query is submitted that exceeds the limit, a 422 status code will be returned. -func ComplexityLimitFunc(f graphql.ComplexityLimitFunc) graphql.Middleware { - return func(next graphql.Handler) graphql.Handler { - return func(ctx context.Context, writer graphql.Writer) { - graphql.GetRequestContext(ctx).ComplexityLimit = f(ctx) - next(ctx, writer) - } - } +var _ graphql.RequestContextMutator = ComplexityLimit(0) + +func (c ComplexityLimit) MutateRequestContext(ctx context.Context, rc *graphql.RequestContext) *gqlerror.Error { + rc.ComplexityLimit = int(c) + return nil } diff --git a/graphql/handler/middleware/complexity_test.go b/graphql/handler/middleware/complexity_test.go index 0e141826a3a..016cfae0d0a 100644 --- a/graphql/handler/middleware/complexity_test.go +++ b/graphql/handler/middleware/complexity_test.go @@ -4,23 +4,12 @@ import ( "context" "testing" + "github.com/99designs/gqlgen/graphql" "github.com/stretchr/testify/require" ) func TestComplexityLimit(t *testing.T) { - rc := testMiddleware(ComplexityLimitFunc(func(ctx context.Context) int { - return 10 - })) - - require.True(t, rc.InvokedNext) - require.Equal(t, 10, rc.ResultContext.ComplexityLimit) -} - -func TestComplexityLimitFunc(t *testing.T) { - rc := testMiddleware(ComplexityLimitFunc(func(ctx context.Context) int { - return 22 - })) - - require.True(t, rc.InvokedNext) - require.Equal(t, 22, rc.ResultContext.ComplexityLimit) + rc := &graphql.RequestContext{} + ComplexityLimit(10).MutateRequestContext(context.Background(), rc) + require.Equal(t, 10, rc.ComplexityLimit) } diff --git a/graphql/handler/middleware/errors.go b/graphql/handler/middleware/errors.go deleted file mode 100644 index 584f6524b57..00000000000 --- a/graphql/handler/middleware/errors.go +++ /dev/null @@ -1,30 +0,0 @@ -package middleware - -import ( - "context" - - "github.com/99designs/gqlgen/graphql" -) - -// ErrorPresenter transforms errors found while resolving into errors that will be returned to the user. It provides -// a good place to add any extra fields, like error.type, that might be desired by your frontend. Check the default -// implementation in graphql.DefaultErrorPresenter for an example. -func ErrorPresenter(ep graphql.ErrorPresenterFunc) graphql.Middleware { - return func(next graphql.Handler) graphql.Handler { - return func(ctx context.Context, writer graphql.Writer) { - graphql.GetRequestContext(ctx).ErrorPresenter = ep - next(ctx, writer) - } - } -} - -// RecoverFunc is called to recover from panics inside goroutines. It can be used to send errors to error trackers -// and hide internal error types from clients. -func RecoverFunc(recover graphql.RecoverFunc) graphql.Middleware { - return func(next graphql.Handler) graphql.Handler { - return func(ctx context.Context, writer graphql.Writer) { - graphql.GetRequestContext(ctx).Recover = recover - next(ctx, writer) - } - } -} diff --git a/graphql/handler/middleware/errors_test.go b/graphql/handler/middleware/errors_test.go deleted file mode 100644 index 9471b6ae3bc..00000000000 --- a/graphql/handler/middleware/errors_test.go +++ /dev/null @@ -1,32 +0,0 @@ -package middleware - -import ( - "context" - "fmt" - "testing" - - "github.com/stretchr/testify/require" - - "github.com/stretchr/testify/assert" - "github.com/vektah/gqlparser/gqlerror" -) - -func TestErrorPresenter(t *testing.T) { - rc := testMiddleware(ErrorPresenter(func(i context.Context, e error) *gqlerror.Error { - return &gqlerror.Error{Message: "boom"} - })) - - require.True(t, rc.InvokedNext) - // cant test for function equality in go, so testing the return type instead - require.Equal(t, "boom", rc.ResultContext.ErrorPresenter(nil, nil).Message) -} - -func TestRecoverFunc(t *testing.T) { - rc := testMiddleware(RecoverFunc(func(ctx context.Context, err interface{}) (userMessage error) { - return fmt.Errorf("boom") - })) - - require.True(t, rc.InvokedNext) - // cant test for function equality in go, so testing the return type instead - assert.Equal(t, "boom", rc.ResultContext.Recover(nil, nil).Error()) -} diff --git a/graphql/handler/middleware/introspection.go b/graphql/handler/middleware/introspection.go index 681df097012..d53b7745662 100644 --- a/graphql/handler/middleware/introspection.go +++ b/graphql/handler/middleware/introspection.go @@ -4,14 +4,15 @@ import ( "context" "github.com/99designs/gqlgen/graphql" + "github.com/vektah/gqlparser/gqlerror" ) -// Introspection enables clients to reflect all of the types available on the graph. -func Introspection() graphql.Middleware { - return func(next graphql.Handler) graphql.Handler { - return func(ctx context.Context, writer graphql.Writer) { - graphql.GetRequestContext(ctx).DisableIntrospection = false - next(ctx, writer) - } - } +// EnableIntrospection enables clients to reflect all of the types available on the graph. +type Introspection struct{} + +var _ graphql.RequestContextMutator = Introspection{} + +func (c Introspection) MutateRequestContext(ctx context.Context, rc *graphql.RequestContext) *gqlerror.Error { + rc.DisableIntrospection = false + return nil } diff --git a/graphql/handler/middleware/introspection_test.go b/graphql/handler/middleware/introspection_test.go index 475c908a2d7..0bda353f843 100644 --- a/graphql/handler/middleware/introspection_test.go +++ b/graphql/handler/middleware/introspection_test.go @@ -1,19 +1,17 @@ package middleware import ( + "context" "testing" "github.com/99designs/gqlgen/graphql" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestIntrospection(t *testing.T) { - rc := testMiddleware(Introspection(), graphql.RequestContext{ + rc := &graphql.RequestContext{ DisableIntrospection: true, - }) - - require.True(t, rc.InvokedNext) - // cant test for function equality in go, so testing the return type instead - assert.False(t, rc.ResultContext.DisableIntrospection) + } + Introspection{}.MutateRequestContext(context.Background(), rc) + require.Equal(t, false, rc.DisableIntrospection) } diff --git a/graphql/handler/middleware/tracer.go b/graphql/handler/middleware/tracer.go index 983f03ea02e..4c14b99d9aa 100644 --- a/graphql/handler/middleware/tracer.go +++ b/graphql/handler/middleware/tracer.go @@ -1,26 +1,21 @@ package middleware -import ( - "context" - - "github.com/99designs/gqlgen/graphql" -) - +// todo fixme // Tracer allows you to add a request/resolver tracer that will be called around the root request, // calling resolver. This is useful for tracing -func Tracer(tracer graphql.Tracer) graphql.Middleware { - return func(next graphql.Handler) graphql.Handler { - return func(ctx context.Context, writer graphql.Writer) { - rc := graphql.GetRequestContext(ctx) - rc.AddTracer(tracer) - rc.AddRequestMiddleware(func(ctx context.Context, next func(ctx context.Context) []byte) []byte { - ctx = tracer.StartOperationExecution(ctx) - resp := next(ctx) - tracer.EndOperationExecution(ctx) - - return resp - }) - next(ctx, writer) - } - } -} +//func Tracer(tracer graphql.Tracer) graphql.Middleware { +// return func(next graphql.Handler) graphql.Handler { +// return func(ctx context.Context, writer graphql.Writer) { +// rc := graphql.GetRequestContext(ctx) +// rc.AddTracer(tracer) +// rc.AddRequestMiddleware(func(ctx context.Context, next func(ctx context.Context) []byte) []byte { +// ctx = tracer.StartOperationExecution(ctx) +// resp := next(ctx) +// tracer.EndOperationExecution(ctx) +// +// return resp +// }) +// next(ctx, writer) +// } +// } +//} diff --git a/graphql/handler/middleware/tracer_test.go b/graphql/handler/middleware/tracer_test.go deleted file mode 100644 index bdaf821511d..00000000000 --- a/graphql/handler/middleware/tracer_test.go +++ /dev/null @@ -1,17 +0,0 @@ -package middleware - -import ( - "testing" - - "github.com/99designs/gqlgen/graphql" - "github.com/stretchr/testify/require" -) - -func TestTracer(t *testing.T) { - tracer := &graphql.NopTracer{} - rc := testMiddleware(Tracer(tracer)) - - require.True(t, rc.InvokedNext) - require.Equal(t, tracer, rc.ResultContext.Tracer) - require.NotNil(t, tracer, rc.ResultContext.RequestMiddleware) -} diff --git a/graphql/handler/middleware/utils_test.go b/graphql/handler/middleware/utils_test.go deleted file mode 100644 index af194c479e5..00000000000 --- a/graphql/handler/middleware/utils_test.go +++ /dev/null @@ -1,30 +0,0 @@ -package middleware - -import ( - "context" - - "github.com/99designs/gqlgen/graphql" -) - -type middlewareContext struct { - InvokedNext bool - ResultContext graphql.RequestContext - Response graphql.Response -} - -func testMiddleware(m graphql.Middleware, initialContexts ...graphql.RequestContext) middlewareContext { - var c middlewareContext - initial := &graphql.RequestContext{} - if len(initialContexts) > 0 { - initial = &initialContexts[0] - } - - m(func(ctx context.Context, writer graphql.Writer) { - c.ResultContext = *graphql.GetRequestContext(ctx) - c.InvokedNext = true - })(graphql.WithRequestContext(context.Background(), initial), func(status graphql.Status, response *graphql.Response) { - c.Response = *response - }) - - return c -} diff --git a/graphql/handler/server.go b/graphql/handler/server.go index cd617e73624..dfe98bb57c0 100644 --- a/graphql/handler/server.go +++ b/graphql/handler/server.go @@ -1,56 +1,44 @@ package handler import ( - "context" "encoding/json" "fmt" "net/http" "github.com/99designs/gqlgen/graphql" - "github.com/vektah/gqlparser/ast" "github.com/vektah/gqlparser/gqlerror" - "github.com/vektah/gqlparser/parser" - "github.com/vektah/gqlparser/validator" ) type ( Server struct { - es graphql.ExecutableSchema - transports []graphql.Transport - middlewares []graphql.Middleware - - handler graphql.Handler + es graphql.ExecutableSchema + transports []graphql.Transport + plugins []graphql.HandlerPlugin + exec executor } - - Option func(Server) ) -func (s *Server) AddTransport(transport graphql.Transport) { - s.transports = append(s.transports, transport) +func New(es graphql.ExecutableSchema) *Server { + s := &Server{ + es: es, + } + s.exec = newExecutor(s.es, s.plugins) + return s } -func (s *Server) Use(middleware graphql.Middleware) { - s.middlewares = append(s.middlewares, middleware) - s.buildHandler() +func (s *Server) AddTransport(transport graphql.Transport) { + s.transports = append(s.transports, transport) } -// This is fairly expensive, so we dont want to do it in every request. May be called multiple times while creating -// the server. -func (s *Server) buildHandler() { - handler := s.executableSchemaHandler - for i := len(s.middlewares) - 1; i >= 0; i-- { - handler = s.middlewares[i](handler) - } +func (s *Server) Use(plugin graphql.HandlerPlugin) { + switch plugin.(type) { + case graphql.RequestMutator, graphql.RequestContextMutator, graphql.RequestMiddleware: + s.plugins = append(s.plugins, plugin) + s.exec = newExecutor(s.es, s.plugins) - s.handler = handler -} - -func New(es graphql.ExecutableSchema) *Server { - s := &Server{ - es: es, + default: + panic(fmt.Errorf("cannot Use %T as a gqlgen handler plugin because it does not implement any plugin hooks", plugin)) } - s.handler = s.executableSchemaHandler - return s } func (s *Server) getTransport(r *http.Request) graphql.Transport { @@ -69,60 +57,7 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { return } - transport.Do(w, r, s.handler) -} - -// executableSchemaHandler is the inner most handler, it invokes the graph directly after all middleware -// and sends responses to the transport so it can be returned to the client -func (s *Server) executableSchemaHandler(ctx context.Context, write graphql.Writer) { - rc := graphql.GetRequestContext(ctx) - - var gerr *gqlerror.Error - - // todo: hmm... how should this work? - if rc.Doc == nil { - rc.Doc, gerr = s.parseOperation(ctx, rc) - if gerr != nil { - write(graphql.StatusParseError, &graphql.Response{Errors: []*gqlerror.Error{gerr}}) - return - } - } - - ctx, op, listErr := s.validateOperation(ctx, rc) - if len(listErr) != 0 { - write(graphql.StatusValidationError, &graphql.Response{ - Errors: listErr, - }) - return - } - - vars, err := validator.VariableValues(s.es.Schema(), op, rc.Variables) - if err != nil { - write(graphql.StatusValidationError, &graphql.Response{ - Errors: gqlerror.List{err}, - }) - return - } - - rc.Variables = vars - - switch op.Operation { - case ast.Query: - resp := s.es.Query(ctx, op) - - write(getStatus(resp), resp) - case ast.Mutation: - resp := s.es.Mutation(ctx, op) - write(getStatus(resp), resp) - case ast.Subscription: - resp := s.es.Subscription(ctx, op) - - for w := resp(); w != nil; w = resp() { - write(getStatus(w), w) - } - default: - write(graphql.StatusValidationError, graphql.ErrorResponse(ctx, "unsupported GraphQL operation")) - } + transport.Do(w, r, s.exec) } func getStatus(resp *graphql.Response) graphql.Status { @@ -132,30 +67,6 @@ func getStatus(resp *graphql.Response) graphql.Status { return graphql.StatusOk } -func (s *Server) parseOperation(ctx context.Context, rc *graphql.RequestContext) (*ast.QueryDocument, *gqlerror.Error) { - ctx = rc.Tracer.StartOperationValidation(ctx) - defer func() { rc.Tracer.EndOperationValidation(ctx) }() - - return parser.ParseQuery(&ast.Source{Input: rc.RawQuery}) -} - -func (gh *Server) validateOperation(ctx context.Context, rc *graphql.RequestContext) (context.Context, *ast.OperationDefinition, gqlerror.List) { - ctx = rc.Tracer.StartOperationValidation(ctx) - defer func() { rc.Tracer.EndOperationValidation(ctx) }() - - listErr := validator.Validate(gh.es.Schema(), rc.Doc) - if len(listErr) != 0 { - return ctx, nil, listErr - } - - op := rc.Doc.Operations.ForName(rc.OperationName) - if op == nil { - return ctx, nil, gqlerror.List{gqlerror.Errorf("operation %s not found", rc.OperationName)} - } - - return ctx, op, nil -} - func sendError(w http.ResponseWriter, code int, errors ...*gqlerror.Error) { w.WriteHeader(code) b, err := json.Marshal(&graphql.Response{Errors: errors}) diff --git a/graphql/handler/server_test.go b/graphql/handler/server_test.go index feee8c9d6fc..27ebb048559 100644 --- a/graphql/handler/server_test.go +++ b/graphql/handler/server_test.go @@ -38,11 +38,11 @@ func TestServer(t *testing.T) { } srv := New(es) srv.AddTransport(&transport.HTTPGet{}) - srv.Use(func(next graphql.Handler) graphql.Handler { + srv.Use(middlewareFunc(func(next graphql.Handler) graphql.Handler { return func(ctx context.Context, writer graphql.Writer) { next(ctx, writer) } - }) + })) t.Run("returns an error if no transport matches", func(t *testing.T) { resp := post(srv, "/foo", "application/json") @@ -56,32 +56,32 @@ func TestServer(t *testing.T) { assert.Equal(t, `{"data":"query resp"}`, resp.Body.String()) }) - t.Run("calls mutation on executable schema", func(t *testing.T) { + t.Run("mutations are forbidden", func(t *testing.T) { resp := get(srv, "/foo?query=mutation{a}") assert.Equal(t, http.StatusOK, resp.Code) - assert.Equal(t, `{"data":"mutation resp"}`, resp.Body.String()) + assert.Equal(t, `{"errors":[{"message":"GET requests only allow query operations"}],"data":null}`, resp.Body.String()) }) - t.Run("calls subscription repeatedly on executable schema", func(t *testing.T) { + t.Run("subscriptions are forbidden", func(t *testing.T) { resp := get(srv, "/foo?query=subscription{a}") assert.Equal(t, http.StatusOK, resp.Code) - assert.Equal(t, `{"data":"subscription resp"}{"data":"subscription resp"}`, resp.Body.String()) + assert.Equal(t, `{"errors":[{"message":"GET requests only allow query operations"}],"data":null}`, resp.Body.String()) }) t.Run("invokes middleware in order", func(t *testing.T) { var calls []string - srv.Use(func(next graphql.Handler) graphql.Handler { + srv.Use(middlewareFunc(func(next graphql.Handler) graphql.Handler { return func(ctx context.Context, writer graphql.Writer) { calls = append(calls, "first") next(ctx, writer) } - }) - srv.Use(func(next graphql.Handler) graphql.Handler { + })) + srv.Use(middlewareFunc(func(next graphql.Handler) graphql.Handler { return func(ctx context.Context, writer graphql.Writer) { calls = append(calls, "second") next(ctx, writer) } - }) + })) resp := get(srv, "/foo?query={a}") assert.Equal(t, http.StatusOK, resp.Code) @@ -89,6 +89,12 @@ func TestServer(t *testing.T) { }) } +type middlewareFunc func(next graphql.Handler) graphql.Handler + +func (r middlewareFunc) InterceptRequest(next graphql.Handler) graphql.Handler { + return r(next) +} + func get(handler http.Handler, target string) *httptest.ResponseRecorder { r := httptest.NewRequest("GET", target, nil) w := httptest.NewRecorder() diff --git a/graphql/handler/transport/http_get.go b/graphql/handler/transport/http_get.go index 26e8a5cd92f..08004dcb380 100644 --- a/graphql/handler/transport/http_get.go +++ b/graphql/handler/transport/http_get.go @@ -6,6 +6,8 @@ import ( "net/http" "strings" + "github.com/vektah/gqlparser/ast" + "github.com/99designs/gqlgen/graphql" ) @@ -21,10 +23,11 @@ func (H HTTPGet) Supports(r *http.Request) bool { return r.Method == "GET" } -func (H HTTPGet) Do(w http.ResponseWriter, r *http.Request, handler graphql.Handler) { - rc := newRequestContext() - rc.RawQuery = r.URL.Query().Get("query") - rc.OperationName = r.URL.Query().Get("operationName") +func (H HTTPGet) Do(w http.ResponseWriter, r *http.Request, exec graphql.GraphExecutor) { + raw := &graphql.RawParams{ + Query: r.URL.Query().Get("query"), + OperationName: r.URL.Query().Get("operationName"), + } writer := graphql.Writer(func(status graphql.Status, response *graphql.Response) { switch status { @@ -41,26 +44,30 @@ func (H HTTPGet) Do(w http.ResponseWriter, r *http.Request, handler graphql.Hand }) if variables := r.URL.Query().Get("variables"); variables != "" { - if err := jsonDecode(strings.NewReader(variables), &rc.Variables); err != nil { + if err := jsonDecode(strings.NewReader(variables), &raw.Variables); err != nil { writer.Errorf("variables could not be decoded") return } } if extensions := r.URL.Query().Get("extensions"); extensions != "" { - if err := jsonDecode(strings.NewReader(extensions), &rc.Extensions); err != nil { + if err := jsonDecode(strings.NewReader(extensions), &raw.Extensions); err != nil { writer.Errorf("extensions could not be decoded") return } } - // TODO: FIXME - //if op.Operation != ast.Query && args.R.Method == http.MethodGet { - // return ctx, nil, nil, gqlerror.List{gqlerror.Errorf("GET requests only allow query operations")} - //} - + rc, err := exec.CreateRequestContext(r.Context(), raw) + if err != nil { + writer.GraphqlErr(err...) + } + op := rc.Doc.Operations.ForName(rc.OperationName) + if op.Operation != ast.Query { + writer.Errorf("GET requests only allow query operations") + return + } ctx := graphql.WithRequestContext(r.Context(), rc) - handler(ctx, writer) + exec.DispatchRequest(ctx, writer) } func jsonDecode(r io.Reader, val interface{}) error { diff --git a/graphql/handler/transport/jsonpost.go b/graphql/handler/transport/jsonpost.go index d2764050e9c..bda13b0de44 100644 --- a/graphql/handler/transport/jsonpost.go +++ b/graphql/handler/transport/jsonpost.go @@ -25,7 +25,7 @@ func (H JsonPostTransport) Supports(r *http.Request) bool { return r.Method == "POST" && mediaType == "application/json" } -func (H JsonPostTransport) Do(w http.ResponseWriter, r *http.Request, handler graphql.Handler) { +func (H JsonPostTransport) Do(w http.ResponseWriter, r *http.Request, exec graphql.GraphExecutor) { w.Header().Set("Content-Type", "application/json") write := graphql.Writer(func(status graphql.Status, response *graphql.Response) { @@ -43,19 +43,19 @@ func (H JsonPostTransport) Do(w http.ResponseWriter, r *http.Request, handler gr w.Write(b) }) - var params rawParams + var params *graphql.RawParams if err := jsonDecode(r.Body, ¶ms); err != nil { w.WriteHeader(http.StatusBadRequest) write.Errorf("json body could not be decoded: " + err.Error()) return } - rc := newRequestContext() - rc.RawQuery = params.Query - rc.OperationName = params.OperationName - rc.Variables = params.Variables - rc.Extensions = params.Extensions - + rc, err := exec.CreateRequestContext(r.Context(), params) + if err != nil { + w.WriteHeader(http.StatusBadRequest) + write.GraphqlErr(err...) + return + } ctx := graphql.WithRequestContext(r.Context(), rc) - handler(ctx, write) + exec.DispatchRequest(ctx, write) } diff --git a/graphql/handler/transport/raw.go b/graphql/handler/transport/raw.go deleted file mode 100644 index 64b0ea8eff2..00000000000 --- a/graphql/handler/transport/raw.go +++ /dev/null @@ -1,8 +0,0 @@ -package transport - -type rawParams struct { - Query string `json:"query"` - OperationName string `json:"operationName"` - Variables map[string]interface{} `json:"variables"` - Extensions map[string]interface{} `json:"extensions"` -} diff --git a/graphql/handler/transport/requestcontext.go b/graphql/handler/transport/requestcontext.go deleted file mode 100644 index b5e443d0681..00000000000 --- a/graphql/handler/transport/requestcontext.go +++ /dev/null @@ -1,15 +0,0 @@ -package transport - -import "github.com/99designs/gqlgen/graphql" - -func newRequestContext() *graphql.RequestContext { - return &graphql.RequestContext{ - DisableIntrospection: true, - Recover: graphql.DefaultRecover, - ErrorPresenter: graphql.DefaultErrorPresenter, - ResolverMiddleware: nil, - RequestMiddleware: nil, - Tracer: graphql.NopTracer{}, - ComplexityLimit: 0, - } -} diff --git a/graphql/handler/transport/websocket.go b/graphql/handler/transport/websocket.go index 153fe38f92b..0f37b3335ed 100644 --- a/graphql/handler/transport/websocket.go +++ b/graphql/handler/transport/websocket.go @@ -41,7 +41,7 @@ type ( active map[string]context.CancelFunc mu sync.Mutex keepAliveTicker *time.Ticker - handler graphql.Handler + exec graphql.GraphExecutor initPayload InitPayload } @@ -59,7 +59,7 @@ func (t WebsocketTransport) Supports(r *http.Request) bool { return r.Header.Get("Upgrade") != "" } -func (t WebsocketTransport) Do(w http.ResponseWriter, r *http.Request, handler graphql.Handler) { +func (t WebsocketTransport) Do(w http.ResponseWriter, r *http.Request, exec graphql.GraphExecutor) { ws, err := t.Upgrader.Upgrade(w, r, http.Header{ "Sec-Websocket-Protocol": []string{"graphql-ws"}, }) @@ -73,7 +73,7 @@ func (t WebsocketTransport) Do(w http.ResponseWriter, r *http.Request, handler g active: map[string]context.CancelFunc{}, conn: ws, ctx: r.Context(), - handler: handler, + exec: exec, WebsocketTransport: t, } @@ -191,17 +191,17 @@ func (c *wsConnection) keepAlive(ctx context.Context) { } func (c *wsConnection) subscribe(message *operationMessage) bool { - var params rawParams + var params *graphql.RawParams if err := jsonDecode(bytes.NewReader(message.Payload), ¶ms); err != nil { c.sendConnectionError("invalid json") return false } - rc := newRequestContext() - rc.RawQuery = params.Query - rc.OperationName = params.OperationName - rc.Variables = params.Variables - rc.Extensions = params.Extensions + rc, err := c.exec.CreateRequestContext(c.ctx, params) + if err != nil { + c.sendError(message.ID, err...) + return false + } ctx := graphql.WithRequestContext(c.ctx, rc) @@ -220,7 +220,7 @@ func (c *wsConnection) subscribe(message *operationMessage) bool { c.sendError(message.ID, &gqlerror.Error{Message: userErr.Error()}) } }() - c.handler(ctx, func(status graphql.Status, response *graphql.Response) { + c.exec.DispatchRequest(ctx, func(status graphql.Status, response *graphql.Response) { msgType := dataMsg switch status { case graphql.StatusOk, graphql.StatusResolverError: diff --git a/graphql/recovery.go b/graphql/recovery.go index 3aa032dc5aa..130d661b373 100644 --- a/graphql/recovery.go +++ b/graphql/recovery.go @@ -6,6 +6,8 @@ import ( "fmt" "os" "runtime/debug" + + "github.com/vektah/gqlparser/gqlerror" ) type RecoverFunc func(ctx context.Context, err interface{}) (userMessage error) @@ -17,3 +19,10 @@ func DefaultRecover(ctx context.Context, err interface{}) error { return errors.New("internal system error") } + +var _ RequestContextMutator = RecoverFunc(nil) + +func (f RecoverFunc) MutateRequestContext(ctx context.Context, rc *RequestContext) *gqlerror.Error { + rc.Recover = f + return nil +} diff --git a/handler/graphql.go b/handler/graphql.go index 3c3f6e7ecfd..8ab912eb573 100644 --- a/handler/graphql.go +++ b/handler/graphql.go @@ -48,7 +48,6 @@ type Config struct { recover graphql.RecoverFunc errorPresenter graphql.ErrorPresenterFunc resolverHook graphql.FieldMiddleware - requestHook graphql.RequestMiddleware tracer graphql.Tracer complexityLimit int complexityLimitFunc graphql.ComplexityLimitFunc