From dd2881455f20f37ec601791bf09e32db9e928790 Mon Sep 17 00:00:00 2001 From: Adam Renberg Tamm Date: Tue, 12 Mar 2019 11:17:58 +0100 Subject: [PATCH] Allow configuring the complexity limit dynamically per request --- docs/content/reference/complexity.md | 2 +- graphql/context.go | 1 + handler/graphql.go | 16 ++++++++- handler/graphql_test.go | 49 ++++++++++++++++++++++++++++ 4 files changed, 66 insertions(+), 2 deletions(-) diff --git a/docs/content/reference/complexity.md b/docs/content/reference/complexity.md index 70640351ba4..5f92f5d9f14 100644 --- a/docs/content/reference/complexity.md +++ b/docs/content/reference/complexity.md @@ -56,7 +56,7 @@ func main() { } ``` -Now any query with complexity greater than 5 is rejected by the API. By default, each field and level of depth adds one to the overall query complexity. +Now any query with complexity greater than 5 is rejected by the API. By default, each field and level of depth adds one to the overall query complexity. You can also use `handler.ComplexityLimitFunc` to dynamically configure the complexity limit per request. This helps, but we still have a problem: the `posts` and `related` fields, which return arrays, are much more expensive to resolve than the scalar `title` and `text` fields. However, the default complexity calculation weights them equally. It would make more sense to apply a higher cost to the array fields. diff --git a/graphql/context.go b/graphql/context.go index cc8d659b116..d6b28456cb4 100644 --- a/graphql/context.go +++ b/graphql/context.go @@ -12,6 +12,7 @@ import ( type Resolver func(ctx context.Context) (res interface{}, err error) type FieldMiddleware func(ctx context.Context, next Resolver) (res interface{}, err error) type RequestMiddleware func(ctx context.Context, next func(ctx context.Context) []byte) []byte +type ComplexityLimitFunc func(ctx context.Context) int type RequestContext struct { RawQuery string diff --git a/handler/graphql.go b/handler/graphql.go index 585897a9248..92a0471ce24 100644 --- a/handler/graphql.go +++ b/handler/graphql.go @@ -34,6 +34,7 @@ type Config struct { requestHook graphql.RequestMiddleware tracer graphql.Tracer complexityLimit int + complexityLimitFunc graphql.ComplexityLimitFunc disableIntrospection bool connectionKeepAlivePingInterval time.Duration } @@ -62,7 +63,7 @@ func (c *Config) newRequestContext(es graphql.ExecutableSchema, doc *ast.QueryDo reqCtx.Tracer = hook } - if c.complexityLimit > 0 { + if c.complexityLimit > 0 || c.complexityLimitFunc != nil { reqCtx.ComplexityLimit = c.complexityLimit operationComplexity := complexity.Calculate(es, op, variables) reqCtx.OperationComplexity = operationComplexity @@ -110,6 +111,15 @@ func ComplexityLimit(limit int) Option { } } +// ComplexityLimitFunc allows you to define a function to dynamically set the maximum query complexity that is allowed +// to be executed. +// If a query is submitted that exceeds the limit, a 422 status code will be returned. +func ComplexityLimitFunc(complexityLimitFunc graphql.ComplexityLimitFunc) Option { + return func(cfg *Config) { + cfg.complexityLimitFunc = complexityLimitFunc + } +} + // ResolverMiddleware allows you to define a function that will be called around every resolver, // useful for logging. func ResolverMiddleware(middleware graphql.FieldMiddleware) Option { @@ -381,6 +391,10 @@ func (gh *graphqlHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { } }() + if gh.cfg.complexityLimitFunc != nil { + reqCtx.ComplexityLimit = gh.cfg.complexityLimitFunc(ctx) + } + if reqCtx.ComplexityLimit > 0 && reqCtx.OperationComplexity > reqCtx.ComplexityLimit { sendErrorf(w, http.StatusUnprocessableEntity, "operation has complexity %d, which exceeds the limit of %d", reqCtx.OperationComplexity, reqCtx.ComplexityLimit) return diff --git a/handler/graphql_test.go b/handler/graphql_test.go index e6c2d9063e1..f938640716e 100644 --- a/handler/graphql_test.go +++ b/handler/graphql_test.go @@ -1,11 +1,14 @@ package handler import ( + "context" "net/http" "net/http/httptest" "strings" "testing" + "github.com/99designs/gqlgen/graphql" + "github.com/stretchr/testify/assert" ) @@ -127,6 +130,52 @@ func TestHandlerHead(t *testing.T) { assert.Equal(t, http.StatusMethodNotAllowed, resp.Code) } +func TestHandlerComplexity(t *testing.T) { + t.Run("static complexity", func(t *testing.T) { + h := GraphQL(&executableSchemaStub{}, ComplexityLimit(2)) + + t.Run("below complexity limit", func(t *testing.T) { + resp := doRequest(h, "POST", "/graphql", `{"query":"{ me { name } }"}`) + assert.Equal(t, http.StatusOK, resp.Code) + assert.Equal(t, `{"data":{"name":"test"}}`, resp.Body.String()) + }) + + t.Run("above complexity limit", func(t *testing.T) { + resp := doRequest(h, "POST", "/graphql", `{"query":"{ a: me { name } b: me { name } }"}`) + assert.Equal(t, http.StatusUnprocessableEntity, resp.Code) + assert.Equal(t, `{"errors":[{"message":"operation has complexity 4, which exceeds the limit of 2"}],"data":null}`, resp.Body.String()) + }) + }) + + t.Run("dynamic complexity", func(t *testing.T) { + h := GraphQL(&executableSchemaStub{}, ComplexityLimitFunc(func(ctx context.Context) int { + reqCtx := graphql.GetRequestContext(ctx) + if strings.Contains(reqCtx.RawQuery, "dummy") { + return 4 + } + return 2 + })) + + t.Run("below complexity limit", func(t *testing.T) { + resp := doRequest(h, "POST", "/graphql", `{"query":"{ me { name } }"}`) + assert.Equal(t, http.StatusOK, resp.Code) + assert.Equal(t, `{"data":{"name":"test"}}`, resp.Body.String()) + }) + + t.Run("above complexity limit", func(t *testing.T) { + resp := doRequest(h, "POST", "/graphql", `{"query":"{ a: me { name } b: me { name } }"}`) + assert.Equal(t, http.StatusUnprocessableEntity, resp.Code) + assert.Equal(t, `{"errors":[{"message":"operation has complexity 4, which exceeds the limit of 2"}],"data":null}`, resp.Body.String()) + }) + + t.Run("within dynamic complexity limit", func(t *testing.T) { + resp := doRequest(h, "POST", "/graphql", `{"query":"{ a: me { name } dummy: me { name } }"}`) + assert.Equal(t, http.StatusOK, resp.Code) + assert.Equal(t, `{"data":{"name":"test"}}`, resp.Body.String()) + }) + }) +} + func doRequest(handler http.Handler, method string, target string, body string) *httptest.ResponseRecorder { r := httptest.NewRequest(method, target, strings.NewReader(body)) w := httptest.NewRecorder()