Skip to content

Commit

Permalink
Allow configuring the complexity limit dynamically per request
Browse files Browse the repository at this point in the history
  • Loading branch information
tgwizard committed Mar 12, 2019
1 parent 485ddf3 commit dd28814
Show file tree
Hide file tree
Showing 4 changed files with 66 additions and 2 deletions.
2 changes: 1 addition & 1 deletion docs/content/reference/complexity.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ func main() {
}
```

Now any query with complexity greater than 5 is rejected by the API. By default, each field and level of depth adds one to the overall query complexity.
Now any query with complexity greater than 5 is rejected by the API. By default, each field and level of depth adds one to the overall query complexity. You can also use `handler.ComplexityLimitFunc` to dynamically configure the complexity limit per request.

This helps, but we still have a problem: the `posts` and `related` fields, which return arrays, are much more expensive to resolve than the scalar `title` and `text` fields. However, the default complexity calculation weights them equally. It would make more sense to apply a higher cost to the array fields.

Expand Down
1 change: 1 addition & 0 deletions graphql/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
type Resolver func(ctx context.Context) (res interface{}, err error)
type FieldMiddleware func(ctx context.Context, next Resolver) (res interface{}, err error)
type RequestMiddleware func(ctx context.Context, next func(ctx context.Context) []byte) []byte
type ComplexityLimitFunc func(ctx context.Context) int

type RequestContext struct {
RawQuery string
Expand Down
16 changes: 15 additions & 1 deletion handler/graphql.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ type Config struct {
requestHook graphql.RequestMiddleware
tracer graphql.Tracer
complexityLimit int
complexityLimitFunc graphql.ComplexityLimitFunc
disableIntrospection bool
connectionKeepAlivePingInterval time.Duration
}
Expand Down Expand Up @@ -62,7 +63,7 @@ func (c *Config) newRequestContext(es graphql.ExecutableSchema, doc *ast.QueryDo
reqCtx.Tracer = hook
}

if c.complexityLimit > 0 {
if c.complexityLimit > 0 || c.complexityLimitFunc != nil {
reqCtx.ComplexityLimit = c.complexityLimit
operationComplexity := complexity.Calculate(es, op, variables)
reqCtx.OperationComplexity = operationComplexity
Expand Down Expand Up @@ -110,6 +111,15 @@ func ComplexityLimit(limit int) Option {
}
}

// ComplexityLimitFunc allows you to define a function to dynamically set the maximum query complexity that is allowed
// to be executed.
// If a query is submitted that exceeds the limit, a 422 status code will be returned.
func ComplexityLimitFunc(complexityLimitFunc graphql.ComplexityLimitFunc) Option {
return func(cfg *Config) {
cfg.complexityLimitFunc = complexityLimitFunc
}
}

// ResolverMiddleware allows you to define a function that will be called around every resolver,
// useful for logging.
func ResolverMiddleware(middleware graphql.FieldMiddleware) Option {
Expand Down Expand Up @@ -381,6 +391,10 @@ func (gh *graphqlHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
}
}()

if gh.cfg.complexityLimitFunc != nil {
reqCtx.ComplexityLimit = gh.cfg.complexityLimitFunc(ctx)
}

if reqCtx.ComplexityLimit > 0 && reqCtx.OperationComplexity > reqCtx.ComplexityLimit {
sendErrorf(w, http.StatusUnprocessableEntity, "operation has complexity %d, which exceeds the limit of %d", reqCtx.OperationComplexity, reqCtx.ComplexityLimit)
return
Expand Down
49 changes: 49 additions & 0 deletions handler/graphql_test.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
package handler

import (
"context"
"net/http"
"net/http/httptest"
"strings"
"testing"

"github.com/99designs/gqlgen/graphql"

"github.com/stretchr/testify/assert"
)

Expand Down Expand Up @@ -127,6 +130,52 @@ func TestHandlerHead(t *testing.T) {
assert.Equal(t, http.StatusMethodNotAllowed, resp.Code)
}

func TestHandlerComplexity(t *testing.T) {
t.Run("static complexity", func(t *testing.T) {
h := GraphQL(&executableSchemaStub{}, ComplexityLimit(2))

t.Run("below complexity limit", func(t *testing.T) {
resp := doRequest(h, "POST", "/graphql", `{"query":"{ me { name } }"}`)
assert.Equal(t, http.StatusOK, resp.Code)
assert.Equal(t, `{"data":{"name":"test"}}`, resp.Body.String())
})

t.Run("above complexity limit", func(t *testing.T) {
resp := doRequest(h, "POST", "/graphql", `{"query":"{ a: me { name } b: me { name } }"}`)
assert.Equal(t, http.StatusUnprocessableEntity, resp.Code)
assert.Equal(t, `{"errors":[{"message":"operation has complexity 4, which exceeds the limit of 2"}],"data":null}`, resp.Body.String())
})
})

t.Run("dynamic complexity", func(t *testing.T) {
h := GraphQL(&executableSchemaStub{}, ComplexityLimitFunc(func(ctx context.Context) int {
reqCtx := graphql.GetRequestContext(ctx)
if strings.Contains(reqCtx.RawQuery, "dummy") {
return 4
}
return 2
}))

t.Run("below complexity limit", func(t *testing.T) {
resp := doRequest(h, "POST", "/graphql", `{"query":"{ me { name } }"}`)
assert.Equal(t, http.StatusOK, resp.Code)
assert.Equal(t, `{"data":{"name":"test"}}`, resp.Body.String())
})

t.Run("above complexity limit", func(t *testing.T) {
resp := doRequest(h, "POST", "/graphql", `{"query":"{ a: me { name } b: me { name } }"}`)
assert.Equal(t, http.StatusUnprocessableEntity, resp.Code)
assert.Equal(t, `{"errors":[{"message":"operation has complexity 4, which exceeds the limit of 2"}],"data":null}`, resp.Body.String())
})

t.Run("within dynamic complexity limit", func(t *testing.T) {
resp := doRequest(h, "POST", "/graphql", `{"query":"{ a: me { name } dummy: me { name } }"}`)
assert.Equal(t, http.StatusOK, resp.Code)
assert.Equal(t, `{"data":{"name":"test"}}`, resp.Body.String())
})
})
}

func doRequest(handler http.Handler, method string, target string, body string) *httptest.ResponseRecorder {
r := httptest.NewRequest(method, target, strings.NewReader(body))
w := httptest.NewRecorder()
Expand Down

0 comments on commit dd28814

Please sign in to comment.