diff --git a/codegen/generated!.gotpl b/codegen/generated!.gotpl index dce8ce977c..a1a1a76044 100644 --- a/codegen/generated!.gotpl +++ b/codegen/generated!.gotpl @@ -86,7 +86,7 @@ func (e *executableSchema) Complexity(typeName, field string, childComplexity in {{ if not $object.IsReserved }} {{ range $field := $object.UniqueFields }} {{ if not $field.IsReserved }} - case "{{$object.Name}}.{{$field.GoFieldName}}": + case "{{$object.Name}}.{{$field.Name}}": if e.complexity.{{$object.Name|go}}.{{$field.GoFieldName}} == nil { break } diff --git a/integration/generated.go b/integration/generated.go index 6b17e93743..a92cdbaee6 100644 --- a/integration/generated.go +++ b/integration/generated.go @@ -52,6 +52,7 @@ type ComplexityRoot struct { } Query struct { + Complexity func(childComplexity int, value int) int Date func(childComplexity int, filter models.DateFilter) int Error func(childComplexity int, typeArg *models.ErrorType) int JSONEncoding func(childComplexity int) int @@ -80,6 +81,7 @@ type QueryResolver interface { Viewer(ctx context.Context) (*models.Viewer, error) JSONEncoding(ctx context.Context) (string, error) Error(ctx context.Context, typeArg *models.ErrorType) (bool, error) + Complexity(ctx context.Context, value int) (bool, error) } type UserResolver interface { Likes(ctx context.Context, obj *remote_api.User) ([]string, error) @@ -100,28 +102,40 @@ func (e *executableSchema) Complexity(typeName, field string, childComplexity in _ = ec switch typeName + "." + field { - case "Element.Child": + case "Element.child": if e.complexity.Element.Child == nil { break } return e.complexity.Element.Child(childComplexity), true - case "Element.Error": + case "Element.error": if e.complexity.Element.Error == nil { break } return e.complexity.Element.Error(childComplexity), true - case "Element.Mismatched": + case "Element.mismatched": if e.complexity.Element.Mismatched == nil { break } return e.complexity.Element.Mismatched(childComplexity), true - case "Query.Date": + case "Query.complexity": + if e.complexity.Query.Complexity == nil { + break + } + + args, err := ec.field_Query_complexity_args(context.TODO(), rawArgs) + if err != nil { + return 0, false + } + + return e.complexity.Query.Complexity(childComplexity, args["value"].(int)), true + + case "Query.date": if e.complexity.Query.Date == nil { break } @@ -133,7 +147,7 @@ func (e *executableSchema) Complexity(typeName, field string, childComplexity in return e.complexity.Query.Date(childComplexity, args["filter"].(models.DateFilter)), true - case "Query.Error": + case "Query.error": if e.complexity.Query.Error == nil { break } @@ -145,42 +159,42 @@ func (e *executableSchema) Complexity(typeName, field string, childComplexity in return e.complexity.Query.Error(childComplexity, args["type"].(*models.ErrorType)), true - case "Query.JSONEncoding": + case "Query.jsonEncoding": if e.complexity.Query.JSONEncoding == nil { break } return e.complexity.Query.JSONEncoding(childComplexity), true - case "Query.Path": + case "Query.path": if e.complexity.Query.Path == nil { break } return e.complexity.Query.Path(childComplexity), true - case "Query.Viewer": + case "Query.viewer": if e.complexity.Query.Viewer == nil { break } return e.complexity.Query.Viewer(childComplexity), true - case "User.Likes": + case "User.likes": if e.complexity.User.Likes == nil { break } return e.complexity.User.Likes(childComplexity), true - case "User.Name": + case "User.name": if e.complexity.User.Name == nil { break } return e.complexity.User.Name(childComplexity), true - case "Viewer.User": + case "Viewer.user": if e.complexity.Viewer.User == nil { break } @@ -306,6 +320,7 @@ type Query { viewer: Viewer jsonEncoding: String! error(type: ErrorType = NORMAL): Boolean! + complexity(value: Int!): Boolean! } enum ErrorType { @@ -354,6 +369,20 @@ func (ec *executionContext) field_Query___type_args(ctx context.Context, rawArgs return args, nil } +func (ec *executionContext) field_Query_complexity_args(ctx context.Context, rawArgs map[string]interface{}) (map[string]interface{}, error) { + var err error + args := map[string]interface{}{} + var arg0 int + if tmp, ok := rawArgs["value"]; ok { + arg0, err = ec.unmarshalNInt2int(ctx, tmp) + if err != nil { + return nil, err + } + } + args["value"] = arg0 + return args, nil +} + func (ec *executionContext) field_Query_date_args(ctx context.Context, rawArgs map[string]interface{}) (map[string]interface{}, error) { var err error args := map[string]interface{}{} @@ -635,6 +664,40 @@ func (ec *executionContext) _Query_error(ctx context.Context, field graphql.Coll return ec.marshalNBoolean2bool(ctx, field.Selections, res) } +func (ec *executionContext) _Query_complexity(ctx context.Context, field graphql.CollectedField) graphql.Marshaler { + ctx = ec.Tracer.StartFieldExecution(ctx, field) + defer func() { ec.Tracer.EndFieldExecution(ctx) }() + rctx := &graphql.ResolverContext{ + Object: "Query", + Field: field, + Args: nil, + IsMethod: true, + } + ctx = graphql.WithResolverContext(ctx, rctx) + rawArgs := field.ArgumentMap(ec.Variables) + args, err := ec.field_Query_complexity_args(ctx, rawArgs) + if err != nil { + ec.Error(ctx, err) + return graphql.Null + } + rctx.Args = args + ctx = ec.Tracer.StartFieldResolverExecution(ctx, rctx) + resTmp := ec.FieldMiddleware(ctx, nil, func(rctx context.Context) (interface{}, error) { + ctx = rctx // use context from middleware stack in children + return ec.resolvers.Query().Complexity(rctx, args["value"].(int)) + }) + if resTmp == nil { + if !ec.HasError(rctx) { + ec.Errorf(ctx, "must not be null") + } + return graphql.Null + } + res := resTmp.(bool) + rctx.Result = res + ctx = ec.Tracer.StartFieldChildExecution(ctx) + return ec.marshalNBoolean2bool(ctx, field.Selections, res) +} + func (ec *executionContext) _Query___type(ctx context.Context, field graphql.CollectedField) graphql.Marshaler { ctx = ec.Tracer.StartFieldExecution(ctx, field) defer func() { ec.Tracer.EndFieldExecution(ctx) }() @@ -1784,6 +1847,20 @@ func (ec *executionContext) _Query(ctx context.Context, sel ast.SelectionSet) gr } return res }) + case "complexity": + field := field + out.Concurrently(i, func() (res graphql.Marshaler) { + defer func() { + if r := recover(); r != nil { + ec.Error(ctx, ec.Recover(ctx, r)) + } + }() + res = ec._Query_complexity(ctx, field) + if res == graphql.Null { + invalid = true + } + return res + }) case "__type": out.Values[i] = ec._Query___type(ctx, field) case "__schema": @@ -2135,6 +2212,14 @@ func (ec *executionContext) marshalNElement2ᚖgithubᚗcomᚋ99designsᚋgqlgen return ec._Element(ctx, sel, v) } +func (ec *executionContext) unmarshalNInt2int(ctx context.Context, v interface{}) (int, error) { + return graphql.UnmarshalInt(v) +} + +func (ec *executionContext) marshalNInt2int(ctx context.Context, sel ast.SelectionSet, v int) graphql.Marshaler { + return graphql.MarshalInt(v) +} + func (ec *executionContext) unmarshalNString2string(ctx context.Context, v interface{}) (string, error) { return graphql.UnmarshalString(v) } diff --git a/integration/integration-test.js b/integration/integration-test.js index 43414ec17d..e2af6486f3 100644 --- a/integration/integration-test.js +++ b/integration/integration-test.js @@ -43,6 +43,29 @@ describe('Input defaults', () => { }); }); +describe('Complexity', () => { + it('should fail when complexity is too high', async () => { + let res; + try { + res = await client.query({ + query: gql`{ complexity(value: 200) }`, + }); + } catch (err) { + expect(err.networkError.statusCode).toEqual(422); + } + expect(res).toBe(undefined); + }); + + it('should succeed when complexity is not too high', async () => { + let res = await client.query({ + query: gql`{ complexity(value: 100) }`, + }); + + expect(res.data.complexity).toBe(true); + expect(res.errors).toBe(undefined); + }); +}); + describe('Errors', () => { it('should respond with correct paths', async () => { let res = await client.query({ diff --git a/integration/resolver.go b/integration/resolver.go index 5e685ceb19..7ed7353c8d 100644 --- a/integration/resolver.go +++ b/integration/resolver.go @@ -91,6 +91,10 @@ func (r *queryResolver) JSONEncoding(ctx context.Context) (string, error) { return "\U000fe4ed", nil } +func (r *queryResolver) Complexity(ctx context.Context, value int) (bool, error) { + return true, nil +} + type userResolver struct{ *Resolver } func (r *userResolver) Likes(ctx context.Context, obj *remote_api.User) ([]string, error) { diff --git a/integration/schema.graphql b/integration/schema.graphql index ac2df27bed..73acb7278d 100644 --- a/integration/schema.graphql +++ b/integration/schema.graphql @@ -35,6 +35,7 @@ type Query { viewer: Viewer jsonEncoding: String! error(type: ErrorType = NORMAL): Boolean! + complexity(value: Int!): Boolean! } enum ErrorType { diff --git a/integration/server/server.go b/integration/server/server.go index f17ed2ada9..b0206fdfec 100644 --- a/integration/server/server.go +++ b/integration/server/server.go @@ -21,9 +21,15 @@ func main() { port = defaultPort } + cfg := integration.Config{Resolvers: &integration.Resolver{}} + cfg.Complexity.Query.Complexity = func(childComplexity, value int) int { + // Allow the integration client to dictate the complexity, to verify this + // function is executed. + return value + } http.Handle("/", handler.Playground("GraphQL playground", "/query")) http.Handle("/query", handler.GraphQL( - integration.NewExecutableSchema(integration.Config{Resolvers: &integration.Resolver{}}), + integration.NewExecutableSchema(cfg), handler.ErrorPresenter(func(ctx context.Context, e error) *gqlerror.Error { if e, ok := errors.Cause(e).(*integration.CustomError); ok { return &gqlerror.Error{ @@ -33,6 +39,7 @@ func main() { } return graphql.DefaultErrorPresenter(ctx, e) }), + handler.ComplexityLimit(100), )) log.Printf("connect to http://localhost:%s/ for GraphQL playground", port)