diff --git a/codegen/config.go b/codegen/config.go index 1ded1c4175..db0e467b78 100644 --- a/codegen/config.go +++ b/codegen/config.go @@ -84,8 +84,8 @@ type TypeMapEntry struct { } type TypeMapField struct { - Resolver bool `yaml:"resolver"` - FieldName string `yaml:"fieldName"` + Resolver bool `yaml:"resolver"` + FieldName string `yaml:"fieldName"` } func (c *PackageConfig) normalize() error { diff --git a/codegen/object.go b/codegen/object.go index b5d137d11a..9675c3996e 100644 --- a/codegen/object.go +++ b/codegen/object.go @@ -81,10 +81,18 @@ func (o *Object) IsConcurrent() bool { return false } +func (o *Object) IsReserved() bool { + return strings.HasPrefix(o.GQLType, "__") +} + func (f *Field) IsResolver() bool { return f.GoFieldName == "" } +func (f *Field) IsReserved() bool { + return strings.HasPrefix(f.GQLName, "__") +} + func (f *Field) IsMethod() bool { return f.GoFieldType == GoFieldMethod } @@ -165,6 +173,24 @@ func (f *Field) ResolverDeclaration() string { return res } +func (f *Field) ComplexitySignature() string { + res := fmt.Sprintf("func(childComplexity int") + for _, arg := range f.Args { + res += fmt.Sprintf(", %s %s", arg.GoVarName, arg.Signature()) + } + res += ") int" + return res +} + +func (f *Field) ComplexityArgs() string { + var args []string + for _, arg := range f.Args { + args = append(args, "args["+strconv.Quote(arg.GQLName)+"].("+arg.Signature()+")") + } + + return strings.Join(args, ", ") +} + func (f *Field) CallArgs() string { var args []string @@ -198,7 +224,7 @@ func (f *Field) doWriteJson(val string, remainingMods []string, astType *ast.Typ ec.Errorf(ctx, "must not be null") } {{- end }} - return graphql.Null + return graphql.Null } {{.next }}`, map[string]interface{}{ "val": val, diff --git a/codegen/templates/data.go b/codegen/templates/data.go index 6f70f82ae4..8030bca457 100644 --- a/codegen/templates/data.go +++ b/codegen/templates/data.go @@ -3,7 +3,7 @@ package templates var data = map[string]string{ "args.gotpl": "\t{{- if . }}args := map[string]interface{}{} {{end}}\n\t{{- range $i, $arg := . }}\n\t\tvar arg{{$i}} {{$arg.Signature }}\n\t\tif tmp, ok := rawArgs[{{$arg.GQLName|quote}}]; ok {\n\t\t\tvar err error\n\t\t\t{{$arg.Unmarshal (print \"arg\" $i) \"tmp\" }}\n\t\t\tif err != nil {\n\t\t\t\tec.Error(ctx, err)\n\t\t\t\t{{- if $arg.Stream }}\n\t\t\t\t\treturn nil\n\t\t\t\t{{- else }}\n\t\t\t\t\treturn graphql.Null\n\t\t\t\t{{- end }}\n\t\t\t}\n\t\t}\n\t\targs[{{$arg.GQLName|quote}}] = arg{{$i}}\n\t{{- end -}}\n", "field.gotpl": "{{ $field := . }}\n{{ $object := $field.Object }}\n\n{{- if $object.Stream }}\n\tfunc (ec *executionContext) _{{$object.GQLType}}_{{$field.GQLName}}(ctx context.Context, field graphql.CollectedField) func() graphql.Marshaler {\n\t\t{{- if $field.Args }}\n\t\t\trawArgs := field.ArgumentMap(ec.Variables)\n\t\t\t{{ template \"args.gotpl\" $field.Args }}\n\t\t{{- end }}\n\t\tctx = graphql.WithResolverContext(ctx, &graphql.ResolverContext{\n\t\t\tField: field,\n\t\t})\n\t\tresults, err := ec.resolvers.{{ $field.ShortInvocation }}\n\t\tif err != nil {\n\t\t\tec.Error(ctx, err)\n\t\t\treturn nil\n\t\t}\n\t\treturn func() graphql.Marshaler {\n\t\t\tres, ok := <-results\n\t\t\tif !ok {\n\t\t\t\treturn nil\n\t\t\t}\n\t\t\tvar out graphql.OrderedMap\n\t\t\tout.Add(field.Alias, func() graphql.Marshaler { {{ $field.WriteJson }} }())\n\t\t\treturn &out\n\t\t}\n\t}\n{{ else }}\n\t// nolint: vetshadow\n\tfunc (ec *executionContext) _{{$object.GQLType}}_{{$field.GQLName}}(ctx context.Context, field graphql.CollectedField, {{if not $object.Root}}obj *{{$object.FullName}}{{end}}) graphql.Marshaler {\n\t\t{{- if $field.Args }}\n\t\t\trawArgs := field.ArgumentMap(ec.Variables)\n\t\t\t{{ template \"args.gotpl\" $field.Args }}\n\t\t{{- end }}\n\t\trctx := &graphql.ResolverContext{\n\t\t\tObject: {{$object.GQLType|quote}},\n\t\t\tArgs: {{if $field.Args }}args{{else}}nil{{end}},\n\t\t\tField: field,\n\t\t}\n\t\tctx = graphql.WithResolverContext(ctx, rctx)\n\t\tresTmp := ec.FieldMiddleware(ctx, {{if $object.Root}}nil{{else}}obj{{end}}, func(ctx context.Context) (interface{}, error) {\n\t\t\t{{- if $field.IsResolver }}\n\t\t\t\treturn ec.resolvers.{{ $field.ShortInvocation }}\n\t\t\t{{- else if $field.IsMethod }}\n\t\t\t\t{{- if $field.NoErr }}\n\t\t\t\t\treturn {{$field.GoReceiverName}}.{{$field.GoFieldName}}({{ $field.CallArgs }}), nil\n\t\t\t\t{{- else }}\n\t\t\t\t\treturn {{$field.GoReceiverName}}.{{$field.GoFieldName}}({{ $field.CallArgs }})\n\t\t\t\t{{- end }}\n\t\t\t{{- else if $field.IsVariable }}\n\t\t\t\treturn {{$field.GoReceiverName}}.{{$field.GoFieldName}}, nil\n\t\t\t{{- end }}\n\t\t})\n\t\tif resTmp == nil {\n\t\t\t{{- if $field.ASTType.NonNull }}\n\t\t\t\tif !ec.HasError(rctx) {\n\t\t\t\t\tec.Errorf(ctx, \"must not be null\")\n\t\t\t\t}\n\t\t\t{{- end }}\n\t\t\treturn graphql.Null\n\t\t}\n\t\tres := resTmp.({{$field.Signature}})\n\t\trctx.Result = res\n\t\t{{ $field.WriteJson }}\n\t}\n{{ end }}\n", - "generated.gotpl": "// Code generated by github.com/99designs/gqlgen, DO NOT EDIT.\n\npackage {{ .PackageName }}\n\nimport (\n{{- range $import := .Imports }}\n\t{{- $import.Write }}\n{{ end }}\n)\n\n// NewExecutableSchema creates an ExecutableSchema from the ResolverRoot interface.\nfunc NewExecutableSchema(cfg Config) graphql.ExecutableSchema {\n\treturn &executableSchema{\n\t\tresolvers: cfg.Resolvers,\n\t\tdirectives: cfg.Directives,\n\t}\n}\n\ntype Config struct {\n\tResolvers ResolverRoot\n\tDirectives DirectiveRoot\n}\n\ntype ResolverRoot interface {\n{{- range $object := .Objects -}}\n\t{{ if $object.HasResolvers -}}\n\t\t{{$object.GQLType}}() {{$object.GQLType}}Resolver\n\t{{ end }}\n{{- end }}\n}\n\ntype DirectiveRoot struct {\n{{ range $directive := .Directives }}\n\t{{ $directive.Declaration }}\n{{ end }}\n}\n\n{{- range $object := .Objects -}}\n\t{{ if $object.HasResolvers }}\n\t\ttype {{$object.GQLType}}Resolver interface {\n\t\t{{ range $field := $object.Fields -}}\n\t\t\t{{ $field.ShortResolverDeclaration }}\n\t\t{{ end }}\n\t\t}\n\t{{- end }}\n{{- end }}\n\ntype executableSchema struct {\n\tresolvers ResolverRoot\n\tdirectives DirectiveRoot\n}\n\nfunc (e *executableSchema) Schema() *ast.Schema {\n\treturn parsedSchema\n}\n\nfunc (e *executableSchema) Query(ctx context.Context, op *ast.OperationDefinition) *graphql.Response {\n\t{{- if .QueryRoot }}\n\t\tec := executionContext{graphql.GetRequestContext(ctx), e}\n\n\t\tbuf := ec.RequestMiddleware(ctx, func(ctx context.Context) []byte {\n\t\t\tdata := ec._{{.QueryRoot.GQLType}}(ctx, op.SelectionSet)\n\t\t\tvar buf bytes.Buffer\n\t\t\tdata.MarshalGQL(&buf)\n\t\t\treturn buf.Bytes()\n\t\t})\n\n\t\treturn &graphql.Response{\n\t\t\tData: buf,\n\t\t\tErrors: ec.Errors,\n\t\t}\n\t{{- else }}\n\t\treturn graphql.ErrorResponse(ctx, \"queries are not supported\")\n\t{{- end }}\n}\n\nfunc (e *executableSchema) Mutation(ctx context.Context, op *ast.OperationDefinition) *graphql.Response {\n\t{{- if .MutationRoot }}\n\t\tec := executionContext{graphql.GetRequestContext(ctx), e}\n\n\t\tbuf := ec.RequestMiddleware(ctx, func(ctx context.Context) []byte {\n\t\t\tdata := ec._{{.MutationRoot.GQLType}}(ctx, op.SelectionSet)\n\t\t\tvar buf bytes.Buffer\n\t\t\tdata.MarshalGQL(&buf)\n\t\t\treturn buf.Bytes()\n\t\t})\n\n\t\treturn &graphql.Response{\n\t\t\tData: buf,\n\t\t\tErrors: ec.Errors,\n\t\t}\n\t{{- else }}\n\t\treturn graphql.ErrorResponse(ctx, \"mutations are not supported\")\n\t{{- end }}\n}\n\nfunc (e *executableSchema) Subscription(ctx context.Context, op *ast.OperationDefinition) func() *graphql.Response {\n\t{{- if .SubscriptionRoot }}\n\t\tec := executionContext{graphql.GetRequestContext(ctx), e}\n\n\t\tnext := ec._{{.SubscriptionRoot.GQLType}}(ctx, op.SelectionSet)\n\t\tif ec.Errors != nil {\n\t\t\treturn graphql.OneShot(&graphql.Response{Data: []byte(\"null\"), Errors: ec.Errors})\n\t\t}\n\n\t\tvar buf bytes.Buffer\n\t\treturn func() *graphql.Response {\n\t\t\tbuf := ec.RequestMiddleware(ctx, func(ctx context.Context) []byte {\n\t\t\t\tbuf.Reset()\n\t\t\t\tdata := next()\n\n\t\t\t\tif data == nil {\n\t\t\t\t\treturn nil\n\t\t\t\t}\n\t\t\t\tdata.MarshalGQL(&buf)\n\t\t\t\treturn buf.Bytes()\n\t\t\t})\n\n\t\t\treturn &graphql.Response{\n\t\t\t\tData: buf,\n\t\t\t\tErrors: ec.Errors,\n\t\t\t}\n\t\t}\n\t{{- else }}\n\t\treturn graphql.OneShot(graphql.ErrorResponse(ctx, \"subscriptions are not supported\"))\n\t{{- end }}\n}\n\ntype executionContext struct {\n\t*graphql.RequestContext\n\t*executableSchema\n}\n\n{{- range $object := .Objects }}\n\t{{ template \"object.gotpl\" $object }}\n\n\t{{- range $field := $object.Fields }}\n\t\t{{ template \"field.gotpl\" $field }}\n\t{{ end }}\n{{- end}}\n\n{{- range $interface := .Interfaces }}\n\t{{ template \"interface.gotpl\" $interface }}\n{{- end }}\n\n{{- range $input := .Inputs }}\n\t{{ template \"input.gotpl\" $input }}\n{{- end }}\n\nfunc (ec *executionContext) FieldMiddleware(ctx context.Context, obj interface{}, next graphql.Resolver) (ret interface{}) {\n\tdefer func() {\n\t\tif r := recover(); r != nil {\n\t\t\tec.Error(ctx, ec.Recover(ctx, r))\n\t\t\tret = nil\n\t\t}\n\t}()\n\t{{- if .Directives }}\n\trctx := graphql.GetResolverContext(ctx)\n\tfor _, d := range rctx.Field.Definition.Directives {\n\t\tswitch d.Name {\n\t\t{{- range $directive := .Directives }}\n\t\tcase \"{{$directive.Name}}\":\n\t\t\tif ec.directives.{{$directive.Name|ucFirst}} != nil {\n\t\t\t\t{{- if $directive.Args }}\n\t\t\t\t\trawArgs := d.ArgumentMap(ec.Variables)\n\t\t\t\t\t{{ template \"args.gotpl\" $directive.Args }}\n\t\t\t\t{{- end }}\n\t\t\t\tn := next\n\t\t\t\tnext = func(ctx context.Context) (interface{}, error) {\n\t\t\t\t\treturn ec.directives.{{$directive.Name|ucFirst}}({{$directive.CallArgs}})\n\t\t\t\t}\n\t\t\t}\n\t\t{{- end }}\n\t\t}\n\t}\n\t{{- end }}\n\tres, err := ec.ResolverMiddleware(ctx, next)\n\tif err != nil {\n\t\tec.Error(ctx, err)\n\t\treturn nil\n\t}\n\treturn res\n}\n\nfunc (ec *executionContext) introspectSchema() *introspection.Schema {\n\treturn introspection.WrapSchema(parsedSchema)\n}\n\nfunc (ec *executionContext) introspectType(name string) *introspection.Type {\n\treturn introspection.WrapTypeFromDef(parsedSchema, parsedSchema.Types[name])\n}\n\nvar parsedSchema = gqlparser.MustLoadSchema(\n\t&ast.Source{Name: {{.SchemaFilename|quote}}, Input: {{.SchemaRaw|rawQuote}}},\n)\n", + "generated.gotpl": "// Code generated by github.com/99designs/gqlgen, DO NOT EDIT.\n\npackage {{ .PackageName }}\n\nimport (\n{{- range $import := .Imports }}\n\t{{- $import.Write }}\n{{ end }}\n)\n\n// NewExecutableSchema creates an ExecutableSchema from the ResolverRoot interface.\nfunc NewExecutableSchema(cfg Config) graphql.ExecutableSchema {\n\treturn &executableSchema{\n\t\tresolvers: cfg.Resolvers,\n\t\tdirectives: cfg.Directives,\n\t\tcomplexity: cfg.Complexity,\n\t}\n}\n\ntype Config struct {\n\tResolvers ResolverRoot\n\tDirectives DirectiveRoot\n\tComplexity ComplexityRoot\n}\n\ntype ResolverRoot interface {\n{{- range $object := .Objects -}}\n\t{{ if $object.HasResolvers -}}\n\t\t{{$object.GQLType}}() {{$object.GQLType}}Resolver\n\t{{ end }}\n{{- end }}\n}\n\ntype DirectiveRoot struct {\n{{ range $directive := .Directives }}\n\t{{ $directive.Declaration }}\n{{ end }}\n}\n\ntype ComplexityRoot struct {\n{{ range $object := .Objects }}\n\t{{ if $object.IsReserved }}{{ else -}}\n\t\t{{ $object.GQLType|toCamel }} struct {\n\t\t{{ range $field := $object.Fields -}}\n\t\t\t{{ if $field.IsReserved }}{{ else -}}\n\t\t\t\t{{ $field.GQLName|toCamel }} {{ $field.ComplexitySignature }}\n\t\t\t{{ end }}\n\t\t{{- end }}\n\t\t}\n\t{{- end }}\n{{ end }}\n}\n\n{{ range $object := .Objects -}}\n\t{{ if $object.HasResolvers }}\n\t\ttype {{$object.GQLType}}Resolver interface {\n\t\t{{ range $field := $object.Fields -}}\n\t\t\t{{ $field.ShortResolverDeclaration }}\n\t\t{{ end }}\n\t\t}\n\t{{- end }}\n{{- end }}\n\ntype executableSchema struct {\n\tresolvers ResolverRoot\n\tdirectives DirectiveRoot\n\tcomplexity ComplexityRoot\n}\n\nfunc (e *executableSchema) Schema() *ast.Schema {\n\treturn parsedSchema\n}\n\nfunc (e *executableSchema) Complexity(typeName, field string, childComplexity int, rawArgs map[string]interface{}) (int, bool) {\n\tswitch typeName + \".\" + field {\n\t{{ range $object := .Objects }}\n\t\t{{ if $object.IsReserved }}{{ else }}\n\t\t\t{{ range $field := $object.Fields }}\n\t\t\t\t{{ if $field.IsReserved }}{{ else }}\n\t\t\t\t\tcase \"{{$object.GQLType}}.{{$field.GQLName}}\":\n\t\t\t\t\t\tif e.complexity.{{$object.GQLType|toCamel}}.{{$field.GQLName|toCamel}} == nil {\n\t\t\t\t\t\t\tbreak\n\t\t\t\t\t\t}\n\t\t\t\t\t\t{{ if $field.Args }}args := map[string]interface{}{} {{end}}\n\t\t\t\t\t\t{{ range $i, $arg := $field.Args }}\n\t\t\t\t\t\t\tvar arg{{$i}} {{$arg.Signature }}\n\t\t\t\t\t\t\tif tmp, ok := rawArgs[{{$arg.GQLName|quote}}]; ok {\n\t\t\t\t\t\t\t\tvar err error\n\t\t\t\t\t\t\t\t{{$arg.Unmarshal (print \"arg\" $i) \"tmp\" }}\n\t\t\t\t\t\t\t\tif err != nil {\n\t\t\t\t\t\t\t\t\treturn 0, false\n\t\t\t\t\t\t\t\t}\n\t\t\t\t\t\t\t}\n\t\t\t\t\t\t\targs[{{$arg.GQLName|quote}}] = arg{{$i}}\n\t\t\t\t\t\t{{ end }}\n\t\t\t\t\t\treturn e.complexity.{{$object.GQLType|toCamel}}.{{$field.GQLName|toCamel}}(childComplexity{{if $field.Args}}, {{$field.ComplexityArgs}} {{end}}), true\n\t\t\t\t{{ end }}\n\t\t\t{{ end }}\n\t\t{{ end }}\n\t{{ end }}\n\t}\n\treturn 0, false\n}\n\nfunc (e *executableSchema) Query(ctx context.Context, op *ast.OperationDefinition) *graphql.Response {\n\t{{- if .QueryRoot }}\n\t\tec := executionContext{graphql.GetRequestContext(ctx), e}\n\n\t\tbuf := ec.RequestMiddleware(ctx, func(ctx context.Context) []byte {\n\t\t\tdata := ec._{{.QueryRoot.GQLType}}(ctx, op.SelectionSet)\n\t\t\tvar buf bytes.Buffer\n\t\t\tdata.MarshalGQL(&buf)\n\t\t\treturn buf.Bytes()\n\t\t})\n\n\t\treturn &graphql.Response{\n\t\t\tData: buf,\n\t\t\tErrors: ec.Errors,\n\t\t}\n\t{{- else }}\n\t\treturn graphql.ErrorResponse(ctx, \"queries are not supported\")\n\t{{- end }}\n}\n\nfunc (e *executableSchema) Mutation(ctx context.Context, op *ast.OperationDefinition) *graphql.Response {\n\t{{- if .MutationRoot }}\n\t\tec := executionContext{graphql.GetRequestContext(ctx), e}\n\n\t\tbuf := ec.RequestMiddleware(ctx, func(ctx context.Context) []byte {\n\t\t\tdata := ec._{{.MutationRoot.GQLType}}(ctx, op.SelectionSet)\n\t\t\tvar buf bytes.Buffer\n\t\t\tdata.MarshalGQL(&buf)\n\t\t\treturn buf.Bytes()\n\t\t})\n\n\t\treturn &graphql.Response{\n\t\t\tData: buf,\n\t\t\tErrors: ec.Errors,\n\t\t}\n\t{{- else }}\n\t\treturn graphql.ErrorResponse(ctx, \"mutations are not supported\")\n\t{{- end }}\n}\n\nfunc (e *executableSchema) Subscription(ctx context.Context, op *ast.OperationDefinition) func() *graphql.Response {\n\t{{- if .SubscriptionRoot }}\n\t\tec := executionContext{graphql.GetRequestContext(ctx), e}\n\n\t\tnext := ec._{{.SubscriptionRoot.GQLType}}(ctx, op.SelectionSet)\n\t\tif ec.Errors != nil {\n\t\t\treturn graphql.OneShot(&graphql.Response{Data: []byte(\"null\"), Errors: ec.Errors})\n\t\t}\n\n\t\tvar buf bytes.Buffer\n\t\treturn func() *graphql.Response {\n\t\t\tbuf := ec.RequestMiddleware(ctx, func(ctx context.Context) []byte {\n\t\t\t\tbuf.Reset()\n\t\t\t\tdata := next()\n\n\t\t\t\tif data == nil {\n\t\t\t\t\treturn nil\n\t\t\t\t}\n\t\t\t\tdata.MarshalGQL(&buf)\n\t\t\t\treturn buf.Bytes()\n\t\t\t})\n\n\t\t\treturn &graphql.Response{\n\t\t\t\tData: buf,\n\t\t\t\tErrors: ec.Errors,\n\t\t\t}\n\t\t}\n\t{{- else }}\n\t\treturn graphql.OneShot(graphql.ErrorResponse(ctx, \"subscriptions are not supported\"))\n\t{{- end }}\n}\n\ntype executionContext struct {\n\t*graphql.RequestContext\n\t*executableSchema\n}\n\n{{- range $object := .Objects }}\n\t{{ template \"object.gotpl\" $object }}\n\n\t{{- range $field := $object.Fields }}\n\t\t{{ template \"field.gotpl\" $field }}\n\t{{ end }}\n{{- end}}\n\n{{- range $interface := .Interfaces }}\n\t{{ template \"interface.gotpl\" $interface }}\n{{- end }}\n\n{{- range $input := .Inputs }}\n\t{{ template \"input.gotpl\" $input }}\n{{- end }}\n\nfunc (ec *executionContext) FieldMiddleware(ctx context.Context, obj interface{}, next graphql.Resolver) (ret interface{}) {\n\tdefer func() {\n\t\tif r := recover(); r != nil {\n\t\t\tec.Error(ctx, ec.Recover(ctx, r))\n\t\t\tret = nil\n\t\t}\n\t}()\n\t{{- if .Directives }}\n\trctx := graphql.GetResolverContext(ctx)\n\tfor _, d := range rctx.Field.Definition.Directives {\n\t\tswitch d.Name {\n\t\t{{- range $directive := .Directives }}\n\t\tcase \"{{$directive.Name}}\":\n\t\t\tif ec.directives.{{$directive.Name|ucFirst}} != nil {\n\t\t\t\t{{- if $directive.Args }}\n\t\t\t\t\trawArgs := d.ArgumentMap(ec.Variables)\n\t\t\t\t\t{{ template \"args.gotpl\" $directive.Args }}\n\t\t\t\t{{- end }}\n\t\t\t\tn := next\n\t\t\t\tnext = func(ctx context.Context) (interface{}, error) {\n\t\t\t\t\treturn ec.directives.{{$directive.Name|ucFirst}}({{$directive.CallArgs}})\n\t\t\t\t}\n\t\t\t}\n\t\t{{- end }}\n\t\t}\n\t}\n\t{{- end }}\n\tres, err := ec.ResolverMiddleware(ctx, next)\n\tif err != nil {\n\t\tec.Error(ctx, err)\n\t\treturn nil\n\t}\n\treturn res\n}\n\nfunc (ec *executionContext) introspectSchema() *introspection.Schema {\n\treturn introspection.WrapSchema(parsedSchema)\n}\n\nfunc (ec *executionContext) introspectType(name string) *introspection.Type {\n\treturn introspection.WrapTypeFromDef(parsedSchema, parsedSchema.Types[name])\n}\n\nvar parsedSchema = gqlparser.MustLoadSchema(\n\t&ast.Source{Name: {{.SchemaFilename|quote}}, Input: {{.SchemaRaw|rawQuote}}},\n)\n", "input.gotpl": "\t{{- if .IsMarshaled }}\n\tfunc Unmarshal{{ .GQLType }}(v interface{}) ({{.FullName}}, error) {\n\t\tvar it {{.FullName}}\n\t\tvar asMap = v.(map[string]interface{})\n\t\t{{ range $field := .Fields}}\n\t\t\t{{- if $field.Default}}\n\t\t\t\tif _, present := asMap[{{$field.GQLName|quote}}] ; !present {\n\t\t\t\t\tasMap[{{$field.GQLName|quote}}] = {{ $field.Default | dump }}\n\t\t\t\t}\n\t\t\t{{- end}}\n\t\t{{- end }}\n\n\t\tfor k, v := range asMap {\n\t\t\tswitch k {\n\t\t\t{{- range $field := .Fields }}\n\t\t\tcase {{$field.GQLName|quote}}:\n\t\t\t\tvar err error\n\t\t\t\t{{ $field.Unmarshal (print \"it.\" $field.GoFieldName) \"v\" }}\n\t\t\t\tif err != nil {\n\t\t\t\t\treturn it, err\n\t\t\t\t}\n\t\t\t{{- end }}\n\t\t\t}\n\t\t}\n\n\t\treturn it, nil\n\t}\n\t{{- end }}\n", "interface.gotpl": "{{- $interface := . }}\n\nfunc (ec *executionContext) _{{$interface.GQLType}}(ctx context.Context, sel ast.SelectionSet, obj *{{$interface.FullName}}) graphql.Marshaler {\n\tswitch obj := (*obj).(type) {\n\tcase nil:\n\t\treturn graphql.Null\n\t{{- range $implementor := $interface.Implementors }}\n\t\t{{- if $implementor.ValueReceiver }}\n\t\t\tcase {{$implementor.FullName}}:\n\t\t\t\treturn ec._{{$implementor.GQLType}}(ctx, sel, &obj)\n\t\t{{- end}}\n\t\tcase *{{$implementor.FullName}}:\n\t\t\treturn ec._{{$implementor.GQLType}}(ctx, sel, obj)\n\t{{- end }}\n\tdefault:\n\t\tpanic(fmt.Errorf(\"unexpected type %T\", obj))\n\t}\n}\n", "models.gotpl": "// Code generated by github.com/99designs/gqlgen, DO NOT EDIT.\n\npackage {{ .PackageName }}\n\nimport (\n{{- range $import := .Imports }}\n\t{{- $import.Write }}\n{{ end }}\n)\n\n{{ range $model := .Models }}\n\t{{with .Description}} {{.|prefixLines \"// \"}} {{end}}\n\t{{- if .IsInterface }}\n\t\ttype {{.GoType}} interface {}\n\t{{- else }}\n\t\ttype {{.GoType}} struct {\n\t\t\t{{- range $field := .Fields }}\n\t\t\t\t{{- with .Description}}\n\t\t\t\t\t{{.|prefixLines \"// \"}}\n\t\t\t\t{{- end}}\n\t\t\t\t{{- if $field.GoFieldName }}\n\t\t\t\t\t{{ $field.GoFieldName }} {{$field.Signature}} `json:\"{{$field.GQLName}}\"`\n\t\t\t\t{{- else }}\n\t\t\t\t\t{{ $field.GoFKName }} {{$field.GoFKType}}\n\t\t\t\t{{- end }}\n\t\t\t{{- end }}\n\t\t}\n\t{{- end }}\n{{- end}}\n\n{{ range $enum := .Enums }}\n\t{{with .Description}}{{.|prefixLines \"// \"}} {{end}}\n\ttype {{.GoType}} string\n\tconst (\n\t{{- range $value := .Values}}\n\t\t{{- with .Description}}\n\t\t\t{{.|prefixLines \"// \"}}\n\t\t{{- end}}\n\t\t{{$enum.GoType}}{{ .Name|toCamel }} {{$enum.GoType}} = {{.Name|quote}}\n\t{{- end }}\n\t)\n\n\tfunc (e {{.GoType}}) IsValid() bool {\n\t\tswitch e {\n\t\tcase {{ range $index, $element := .Values}}{{if $index}},{{end}}{{ $enum.GoType }}{{ $element.Name|toCamel }}{{end}}:\n\t\t\treturn true\n\t\t}\n\t\treturn false\n\t}\n\n\tfunc (e {{.GoType}}) String() string {\n\t\treturn string(e)\n\t}\n\n\tfunc (e *{{.GoType}}) UnmarshalGQL(v interface{}) error {\n\t\tstr, ok := v.(string)\n\t\tif !ok {\n\t\t\treturn fmt.Errorf(\"enums must be strings\")\n\t\t}\n\n\t\t*e = {{.GoType}}(str)\n\t\tif !e.IsValid() {\n\t\t\treturn fmt.Errorf(\"%s is not a valid {{.GQLType}}\", str)\n\t\t}\n\t\treturn nil\n\t}\n\n\tfunc (e {{.GoType}}) MarshalGQL(w io.Writer) {\n\t\tfmt.Fprint(w, strconv.Quote(e.String()))\n\t}\n\n{{- end }}\n", diff --git a/codegen/templates/generated.gotpl b/codegen/templates/generated.gotpl index 26c0f323e4..ccef008456 100644 --- a/codegen/templates/generated.gotpl +++ b/codegen/templates/generated.gotpl @@ -13,12 +13,14 @@ func NewExecutableSchema(cfg Config) graphql.ExecutableSchema { return &executableSchema{ resolvers: cfg.Resolvers, directives: cfg.Directives, + complexity: cfg.Complexity, } } type Config struct { Resolvers ResolverRoot Directives DirectiveRoot + Complexity ComplexityRoot } type ResolverRoot interface { @@ -35,7 +37,21 @@ type DirectiveRoot struct { {{ end }} } -{{- range $object := .Objects -}} +type ComplexityRoot struct { +{{ range $object := .Objects }} + {{ if $object.IsReserved }}{{ else -}} + {{ $object.GQLType|toCamel }} struct { + {{ range $field := $object.Fields -}} + {{ if $field.IsReserved }}{{ else -}} + {{ $field.GQLName|toCamel }} {{ $field.ComplexitySignature }} + {{ end }} + {{- end }} + } + {{- end }} +{{ end }} +} + +{{ range $object := .Objects -}} {{ if $object.HasResolvers }} type {{$object.GQLType}}Resolver interface { {{ range $field := $object.Fields -}} @@ -48,12 +64,44 @@ type DirectiveRoot struct { type executableSchema struct { resolvers ResolverRoot directives DirectiveRoot + complexity ComplexityRoot } func (e *executableSchema) Schema() *ast.Schema { return parsedSchema } +func (e *executableSchema) Complexity(typeName, field string, childComplexity int, rawArgs map[string]interface{}) (int, bool) { + switch typeName + "." + field { + {{ range $object := .Objects }} + {{ if $object.IsReserved }}{{ else }} + {{ range $field := $object.Fields }} + {{ if $field.IsReserved }}{{ else }} + case "{{$object.GQLType}}.{{$field.GQLName}}": + if e.complexity.{{$object.GQLType|toCamel}}.{{$field.GQLName|toCamel}} == nil { + break + } + {{ if $field.Args }}args := map[string]interface{}{} {{end}} + {{ range $i, $arg := $field.Args }} + var arg{{$i}} {{$arg.Signature }} + if tmp, ok := rawArgs[{{$arg.GQLName|quote}}]; ok { + var err error + {{$arg.Unmarshal (print "arg" $i) "tmp" }} + if err != nil { + return 0, false + } + } + args[{{$arg.GQLName|quote}}] = arg{{$i}} + {{ end }} + return e.complexity.{{$object.GQLType|toCamel}}.{{$field.GQLName|toCamel}}(childComplexity{{if $field.Args}}, {{$field.ComplexityArgs}} {{end}}), true + {{ end }} + {{ end }} + {{ end }} + {{ end }} + } + return 0, false +} + func (e *executableSchema) Query(ctx context.Context, op *ast.OperationDefinition) *graphql.Response { {{- if .QueryRoot }} ec := executionContext{graphql.GetRequestContext(ctx), e} diff --git a/codegen/testserver/generated.go b/codegen/testserver/generated.go index e300041d2b..6df8c0247f 100644 --- a/codegen/testserver/generated.go +++ b/codegen/testserver/generated.go @@ -22,12 +22,14 @@ func NewExecutableSchema(cfg Config) graphql.ExecutableSchema { return &executableSchema{ resolvers: cfg.Resolvers, directives: cfg.Directives, + complexity: cfg.Complexity, } } type Config struct { Resolvers ResolverRoot Directives DirectiveRoot + Complexity ComplexityRoot } type ResolverRoot interface { @@ -37,6 +39,61 @@ type ResolverRoot interface { type DirectiveRoot struct { } + +type ComplexityRoot struct { + Circle struct { + Radius func(childComplexity int) int + Area func(childComplexity int) int + } + + Error struct { + Id func(childComplexity int) int + ErrorOnNonRequiredField func(childComplexity int) int + ErrorOnRequiredField func(childComplexity int) int + NilOnRequiredField func(childComplexity int) int + } + + ForcedResolver struct { + Field func(childComplexity int) int + } + + InnerObject struct { + Id func(childComplexity int) int + } + + InvalidIdentifier struct { + Id func(childComplexity int) int + } + + It struct { + Id func(childComplexity int) int + } + + OuterObject struct { + Inner func(childComplexity int) int + } + + Query struct { + InvalidIdentifier func(childComplexity int) int + Collision func(childComplexity int) int + MapInput func(childComplexity int, input *map[string]interface{}) int + Recursive func(childComplexity int, input *RecursiveInputSlice) int + NestedInputs func(childComplexity int, input [][]*OuterInput) int + NestedOutputs func(childComplexity int) int + Keywords func(childComplexity int, input *Keywords) int + Shapes func(childComplexity int) int + ErrorBubble func(childComplexity int) int + Valid func(childComplexity int) int + KeywordArgs func(childComplexity int, breakArg string, defaultArg string, funcArg string, interfaceArg string, selectArg string, caseArg string, deferArg string, goArg string, mapArg string, structArg string, chanArg string, elseArg string, gotoArg string, packageArg string, switchArg string, constArg string, fallthroughArg string, ifArg string, rangeArg string, typeArg string, continueArg string, forArg string, importArg string, returnArg string, varArg string) int + } + + Rectangle struct { + Length func(childComplexity int) int + Width func(childComplexity int) int + Area func(childComplexity int) int + } +} + type ForcedResolverResolver interface { Field(ctx context.Context, obj *ForcedResolver) (*Circle, error) } @@ -57,12 +114,531 @@ type QueryResolver interface { type executableSchema struct { resolvers ResolverRoot directives DirectiveRoot + complexity ComplexityRoot } func (e *executableSchema) Schema() *ast.Schema { return parsedSchema } +func (e *executableSchema) Complexity(typeName, field string, childComplexity int, rawArgs map[string]interface{}) (int, bool) { + switch typeName + "." + field { + + case "Circle.radius": + if e.complexity.Circle.Radius == nil { + break + } + + return e.complexity.Circle.Radius(childComplexity), true + + case "Circle.area": + if e.complexity.Circle.Area == nil { + break + } + + return e.complexity.Circle.Area(childComplexity), true + + case "Error.id": + if e.complexity.Error.Id == nil { + break + } + + return e.complexity.Error.Id(childComplexity), true + + case "Error.errorOnNonRequiredField": + if e.complexity.Error.ErrorOnNonRequiredField == nil { + break + } + + return e.complexity.Error.ErrorOnNonRequiredField(childComplexity), true + + case "Error.errorOnRequiredField": + if e.complexity.Error.ErrorOnRequiredField == nil { + break + } + + return e.complexity.Error.ErrorOnRequiredField(childComplexity), true + + case "Error.nilOnRequiredField": + if e.complexity.Error.NilOnRequiredField == nil { + break + } + + return e.complexity.Error.NilOnRequiredField(childComplexity), true + + case "ForcedResolver.field": + if e.complexity.ForcedResolver.Field == nil { + break + } + + return e.complexity.ForcedResolver.Field(childComplexity), true + + case "InnerObject.id": + if e.complexity.InnerObject.Id == nil { + break + } + + return e.complexity.InnerObject.Id(childComplexity), true + + case "InvalidIdentifier.id": + if e.complexity.InvalidIdentifier.Id == nil { + break + } + + return e.complexity.InvalidIdentifier.Id(childComplexity), true + + case "It.id": + if e.complexity.It.Id == nil { + break + } + + return e.complexity.It.Id(childComplexity), true + + case "OuterObject.inner": + if e.complexity.OuterObject.Inner == nil { + break + } + + return e.complexity.OuterObject.Inner(childComplexity), true + + case "Query.invalidIdentifier": + if e.complexity.Query.InvalidIdentifier == nil { + break + } + + return e.complexity.Query.InvalidIdentifier(childComplexity), true + + case "Query.collision": + if e.complexity.Query.Collision == nil { + break + } + + return e.complexity.Query.Collision(childComplexity), true + + case "Query.mapInput": + if e.complexity.Query.MapInput == nil { + break + } + args := map[string]interface{}{} + + var arg0 *map[string]interface{} + if tmp, ok := rawArgs["input"]; ok { + var err error + var ptr1 map[string]interface{} + if tmp != nil { + ptr1 = tmp.(map[string]interface{}) + arg0 = &ptr1 + } + + if err != nil { + return 0, false + } + } + args["input"] = arg0 + + return e.complexity.Query.MapInput(childComplexity, args["input"].(*map[string]interface{})), true + + case "Query.recursive": + if e.complexity.Query.Recursive == nil { + break + } + args := map[string]interface{}{} + + var arg0 *RecursiveInputSlice + if tmp, ok := rawArgs["input"]; ok { + var err error + var ptr1 RecursiveInputSlice + if tmp != nil { + ptr1, err = UnmarshalRecursiveInputSlice(tmp) + arg0 = &ptr1 + } + + if err != nil { + return 0, false + } + } + args["input"] = arg0 + + return e.complexity.Query.Recursive(childComplexity, args["input"].(*RecursiveInputSlice)), true + + case "Query.nestedInputs": + if e.complexity.Query.NestedInputs == nil { + break + } + args := map[string]interface{}{} + + var arg0 [][]*OuterInput + if tmp, ok := rawArgs["input"]; ok { + var err error + var rawIf1 []interface{} + if tmp != nil { + if tmp1, ok := tmp.([]interface{}); ok { + rawIf1 = tmp1 + } else { + rawIf1 = []interface{}{tmp} + } + } + arg0 = make([][]*OuterInput, len(rawIf1)) + for idx1 := range rawIf1 { + var rawIf2 []interface{} + if rawIf1[idx1] != nil { + if tmp1, ok := rawIf1[idx1].([]interface{}); ok { + rawIf2 = tmp1 + } else { + rawIf2 = []interface{}{rawIf1[idx1]} + } + } + arg0[idx1] = make([]*OuterInput, len(rawIf2)) + for idx2 := range rawIf2 { + var ptr3 OuterInput + if rawIf2[idx2] != nil { + ptr3, err = UnmarshalOuterInput(rawIf2[idx2]) + arg0[idx1][idx2] = &ptr3 + } + } + } + if err != nil { + return 0, false + } + } + args["input"] = arg0 + + return e.complexity.Query.NestedInputs(childComplexity, args["input"].([][]*OuterInput)), true + + case "Query.nestedOutputs": + if e.complexity.Query.NestedOutputs == nil { + break + } + + return e.complexity.Query.NestedOutputs(childComplexity), true + + case "Query.keywords": + if e.complexity.Query.Keywords == nil { + break + } + args := map[string]interface{}{} + + var arg0 *Keywords + if tmp, ok := rawArgs["input"]; ok { + var err error + var ptr1 Keywords + if tmp != nil { + ptr1, err = UnmarshalKeywords(tmp) + arg0 = &ptr1 + } + + if err != nil { + return 0, false + } + } + args["input"] = arg0 + + return e.complexity.Query.Keywords(childComplexity, args["input"].(*Keywords)), true + + case "Query.shapes": + if e.complexity.Query.Shapes == nil { + break + } + + return e.complexity.Query.Shapes(childComplexity), true + + case "Query.errorBubble": + if e.complexity.Query.ErrorBubble == nil { + break + } + + return e.complexity.Query.ErrorBubble(childComplexity), true + + case "Query.valid": + if e.complexity.Query.Valid == nil { + break + } + + return e.complexity.Query.Valid(childComplexity), true + + case "Query.keywordArgs": + if e.complexity.Query.KeywordArgs == nil { + break + } + args := map[string]interface{}{} + + var arg0 string + if tmp, ok := rawArgs["break"]; ok { + var err error + arg0, err = graphql.UnmarshalString(tmp) + if err != nil { + return 0, false + } + } + args["break"] = arg0 + + var arg1 string + if tmp, ok := rawArgs["default"]; ok { + var err error + arg1, err = graphql.UnmarshalString(tmp) + if err != nil { + return 0, false + } + } + args["default"] = arg1 + + var arg2 string + if tmp, ok := rawArgs["func"]; ok { + var err error + arg2, err = graphql.UnmarshalString(tmp) + if err != nil { + return 0, false + } + } + args["func"] = arg2 + + var arg3 string + if tmp, ok := rawArgs["interface"]; ok { + var err error + arg3, err = graphql.UnmarshalString(tmp) + if err != nil { + return 0, false + } + } + args["interface"] = arg3 + + var arg4 string + if tmp, ok := rawArgs["select"]; ok { + var err error + arg4, err = graphql.UnmarshalString(tmp) + if err != nil { + return 0, false + } + } + args["select"] = arg4 + + var arg5 string + if tmp, ok := rawArgs["case"]; ok { + var err error + arg5, err = graphql.UnmarshalString(tmp) + if err != nil { + return 0, false + } + } + args["case"] = arg5 + + var arg6 string + if tmp, ok := rawArgs["defer"]; ok { + var err error + arg6, err = graphql.UnmarshalString(tmp) + if err != nil { + return 0, false + } + } + args["defer"] = arg6 + + var arg7 string + if tmp, ok := rawArgs["go"]; ok { + var err error + arg7, err = graphql.UnmarshalString(tmp) + if err != nil { + return 0, false + } + } + args["go"] = arg7 + + var arg8 string + if tmp, ok := rawArgs["map"]; ok { + var err error + arg8, err = graphql.UnmarshalString(tmp) + if err != nil { + return 0, false + } + } + args["map"] = arg8 + + var arg9 string + if tmp, ok := rawArgs["struct"]; ok { + var err error + arg9, err = graphql.UnmarshalString(tmp) + if err != nil { + return 0, false + } + } + args["struct"] = arg9 + + var arg10 string + if tmp, ok := rawArgs["chan"]; ok { + var err error + arg10, err = graphql.UnmarshalString(tmp) + if err != nil { + return 0, false + } + } + args["chan"] = arg10 + + var arg11 string + if tmp, ok := rawArgs["else"]; ok { + var err error + arg11, err = graphql.UnmarshalString(tmp) + if err != nil { + return 0, false + } + } + args["else"] = arg11 + + var arg12 string + if tmp, ok := rawArgs["goto"]; ok { + var err error + arg12, err = graphql.UnmarshalString(tmp) + if err != nil { + return 0, false + } + } + args["goto"] = arg12 + + var arg13 string + if tmp, ok := rawArgs["package"]; ok { + var err error + arg13, err = graphql.UnmarshalString(tmp) + if err != nil { + return 0, false + } + } + args["package"] = arg13 + + var arg14 string + if tmp, ok := rawArgs["switch"]; ok { + var err error + arg14, err = graphql.UnmarshalString(tmp) + if err != nil { + return 0, false + } + } + args["switch"] = arg14 + + var arg15 string + if tmp, ok := rawArgs["const"]; ok { + var err error + arg15, err = graphql.UnmarshalString(tmp) + if err != nil { + return 0, false + } + } + args["const"] = arg15 + + var arg16 string + if tmp, ok := rawArgs["fallthrough"]; ok { + var err error + arg16, err = graphql.UnmarshalString(tmp) + if err != nil { + return 0, false + } + } + args["fallthrough"] = arg16 + + var arg17 string + if tmp, ok := rawArgs["if"]; ok { + var err error + arg17, err = graphql.UnmarshalString(tmp) + if err != nil { + return 0, false + } + } + args["if"] = arg17 + + var arg18 string + if tmp, ok := rawArgs["range"]; ok { + var err error + arg18, err = graphql.UnmarshalString(tmp) + if err != nil { + return 0, false + } + } + args["range"] = arg18 + + var arg19 string + if tmp, ok := rawArgs["type"]; ok { + var err error + arg19, err = graphql.UnmarshalString(tmp) + if err != nil { + return 0, false + } + } + args["type"] = arg19 + + var arg20 string + if tmp, ok := rawArgs["continue"]; ok { + var err error + arg20, err = graphql.UnmarshalString(tmp) + if err != nil { + return 0, false + } + } + args["continue"] = arg20 + + var arg21 string + if tmp, ok := rawArgs["for"]; ok { + var err error + arg21, err = graphql.UnmarshalString(tmp) + if err != nil { + return 0, false + } + } + args["for"] = arg21 + + var arg22 string + if tmp, ok := rawArgs["import"]; ok { + var err error + arg22, err = graphql.UnmarshalString(tmp) + if err != nil { + return 0, false + } + } + args["import"] = arg22 + + var arg23 string + if tmp, ok := rawArgs["return"]; ok { + var err error + arg23, err = graphql.UnmarshalString(tmp) + if err != nil { + return 0, false + } + } + args["return"] = arg23 + + var arg24 string + if tmp, ok := rawArgs["var"]; ok { + var err error + arg24, err = graphql.UnmarshalString(tmp) + if err != nil { + return 0, false + } + } + args["var"] = arg24 + + return e.complexity.Query.KeywordArgs(childComplexity, args["break"].(string), args["default"].(string), args["func"].(string), args["interface"].(string), args["select"].(string), args["case"].(string), args["defer"].(string), args["go"].(string), args["map"].(string), args["struct"].(string), args["chan"].(string), args["else"].(string), args["goto"].(string), args["package"].(string), args["switch"].(string), args["const"].(string), args["fallthrough"].(string), args["if"].(string), args["range"].(string), args["type"].(string), args["continue"].(string), args["for"].(string), args["import"].(string), args["return"].(string), args["var"].(string)), true + + case "Rectangle.length": + if e.complexity.Rectangle.Length == nil { + break + } + + return e.complexity.Rectangle.Length(childComplexity), true + + case "Rectangle.width": + if e.complexity.Rectangle.Width == nil { + break + } + + return e.complexity.Rectangle.Width(childComplexity), true + + case "Rectangle.area": + if e.complexity.Rectangle.Area == nil { + break + } + + return e.complexity.Rectangle.Area(childComplexity), true + + } + return 0, false +} + func (e *executableSchema) Query(ctx context.Context, op *ast.OperationDefinition) *graphql.Response { ec := executionContext{graphql.GetRequestContext(ctx), e} diff --git a/complexity/complexity.go b/complexity/complexity.go new file mode 100644 index 0000000000..d5b46bf451 --- /dev/null +++ b/complexity/complexity.go @@ -0,0 +1,104 @@ +package complexity + +import ( + "github.com/99designs/gqlgen/graphql" + "github.com/vektah/gqlparser/ast" +) + +func Calculate(es graphql.ExecutableSchema, op *ast.OperationDefinition, vars map[string]interface{}) int { + walker := complexityWalker{ + es: es, + schema: es.Schema(), + vars: vars, + } + return walker.selectionSetComplexity(op.SelectionSet) +} + +type complexityWalker struct { + es graphql.ExecutableSchema + schema *ast.Schema + vars map[string]interface{} +} + +func (cw complexityWalker) selectionSetComplexity(selectionSet ast.SelectionSet) int { + var complexity int + for _, selection := range selectionSet { + switch s := selection.(type) { + case *ast.Field: + fieldDefinition := cw.schema.Types[s.Definition.Type.Name()] + var childComplexity int + switch fieldDefinition.Kind { + case ast.Object, ast.Interface, ast.Union: + childComplexity = cw.selectionSetComplexity(s.SelectionSet) + } + + args := s.ArgumentMap(cw.vars) + var fieldComplexity int + if s.ObjectDefinition.Kind == ast.Interface { + fieldComplexity = cw.interfaceFieldComplexity(s.ObjectDefinition, s.Name, childComplexity, args) + } else { + fieldComplexity = cw.fieldComplexity(s.ObjectDefinition.Name, s.Name, childComplexity, args) + } + complexity = safeAdd(complexity, fieldComplexity) + + case *ast.FragmentSpread: + complexity = safeAdd(complexity, cw.selectionSetComplexity(s.Definition.SelectionSet)) + + case *ast.InlineFragment: + complexity = safeAdd(complexity, cw.selectionSetComplexity(s.SelectionSet)) + } + } + return complexity +} + +func (cw complexityWalker) interfaceFieldComplexity(def *ast.Definition, field string, childComplexity int, args map[string]interface{}) int { + // Interfaces don't have their own separate field costs, so they have to assume the worst case. + // We iterate over all implementors and choose the most expensive one. + maxComplexity := 0 + implementors := cw.schema.GetPossibleTypes(def) + for _, t := range implementors { + fieldComplexity := cw.fieldComplexity(t.Name, field, childComplexity, args) + if fieldComplexity > maxComplexity { + maxComplexity = fieldComplexity + } + } + return maxComplexity +} + +func (cw complexityWalker) fieldComplexity(object, field string, childComplexity int, args map[string]interface{}) int { + if customComplexity, ok := cw.es.Complexity(object, field, childComplexity, args); ok && customComplexity >= childComplexity { + return customComplexity + } + // default complexity calculation + return safeAdd(1, childComplexity) +} + +const maxInt = int(^uint(0) >> 1) + +// safeAdd is a saturating add of a and b that ignores negative operands. +// If a + b would overflow through normal Go addition, +// it returns the maximum integer value instead. +// +// Adding complexities with this function prevents attackers from intentionally +// overflowing the complexity calculation to allow overly-complex queries. +// +// It also helps mitigate the impact of custom complexities that accidentally +// return negative values. +func safeAdd(a, b int) int { + // Ignore negative operands. + if a < 0 { + if b < 0 { + return 1 + } + return b + } else if b < 0 { + return a + } + + c := a + b + if c < a { + // Set c to maximum integer instead of overflowing. + c = maxInt + } + return c +} diff --git a/complexity/complexity_test.go b/complexity/complexity_test.go new file mode 100644 index 0000000000..c913ac2cd8 --- /dev/null +++ b/complexity/complexity_test.go @@ -0,0 +1,232 @@ +package complexity + +import ( + "context" + "math" + "testing" + + "github.com/99designs/gqlgen/graphql" + "github.com/stretchr/testify/require" + "github.com/vektah/gqlparser" + "github.com/vektah/gqlparser/ast" +) + +var schema = gqlparser.MustLoadSchema( + &ast.Source{ + Name: "test.graphql", + Input: ` + interface NameInterface { + name: String + } + + type Item implements NameInterface { + scalar: String + name: String + list(size: Int = 10): [Item] + } + + type ExpensiveItem implements NameInterface { + name: String + } + + type Named { + name: String + } + + union NameUnion = Item | Named + + type Query { + scalar: String + object: Item + interface: NameInterface + union: NameUnion + customObject: Item + list(size: Int = 10): [Item] + } + `, + }, +) + +func requireComplexity(t *testing.T, source string, vars map[string]interface{}, complexity int) { + t.Helper() + query := gqlparser.MustLoadQuery(schema, source) + es := &executableSchemaStub{} + actualComplexity := Calculate(es, query.Operations[0], vars) + require.Equal(t, complexity, actualComplexity) +} + +func TestCalculate(t *testing.T) { + t.Run("uses default cost", func(t *testing.T) { + const query = ` + { + scalar + } + ` + requireComplexity(t, query, nil, 1) + }) + + t.Run("adds together fields", func(t *testing.T) { + const query = ` + { + scalar1: scalar + scalar2: scalar + } + ` + requireComplexity(t, query, nil, 2) + }) + + t.Run("a level of nesting adds complexity", func(t *testing.T) { + const query = ` + { + object { + scalar + } + } + ` + requireComplexity(t, query, nil, 2) + }) + + t.Run("adds together children", func(t *testing.T) { + const query = ` + { + scalar + object { + scalar + } + } + ` + requireComplexity(t, query, nil, 3) + }) + + t.Run("adds inline fragments", func(t *testing.T) { + const query = ` + { + ... { + scalar + } + } + ` + requireComplexity(t, query, nil, 1) + }) + + t.Run("adds fragments", func(t *testing.T) { + const query = ` + { + ... Fragment + } + + fragment Fragment on Query { + scalar + } + ` + requireComplexity(t, query, nil, 1) + }) + + t.Run("uses custom complexity", func(t *testing.T) { + const query = ` + { + list { + scalar + } + } + ` + requireComplexity(t, query, nil, 10) + }) + + t.Run("ignores negative custom complexity values", func(t *testing.T) { + const query = ` + { + list(size: -100) { + scalar + } + } + ` + requireComplexity(t, query, nil, 2) + }) + + t.Run("custom complexity must be >= child complexity", func(t *testing.T) { + const query = ` + { + customObject { + list(size: 100) { + scalar + } + } + } + ` + requireComplexity(t, query, nil, 101) + }) + + t.Run("interfaces take max concrete cost", func(t *testing.T) { + const query = ` + { + interface { + name + } + } + ` + requireComplexity(t, query, nil, 6) + }) + + t.Run("guards against integer overflow", func(t *testing.T) { + if maxInt == math.MaxInt32 { + // this test is written assuming 64-bit ints + t.Skip() + } + const query = ` + { + list1: list(size: 2147483647) { + list(size: 2147483647) { + list(size: 2) { + scalar + } + } + } + # total cost so far: 2*0x7fffffff*0x7fffffff + # = 0x7ffffffe00000002 + # Adding the same again should cause overflow + list2: list(size: 2147483647) { + list(size: 2147483647) { + list(size: 2) { + scalar + } + } + } + } + ` + requireComplexity(t, query, nil, math.MaxInt64) + }) +} + +type executableSchemaStub struct { +} + +var _ graphql.ExecutableSchema = &executableSchemaStub{} + +func (e *executableSchemaStub) Schema() *ast.Schema { + return schema +} + +func (e *executableSchemaStub) Complexity(typeName, field string, childComplexity int, args map[string]interface{}) (int, bool) { + switch typeName + "." + field { + case "ExpensiveItem.name": + return 5, true + case "Query.list", "Item.list": + return int(args["size"].(int64)) * childComplexity, true + case "Query.customObject": + return 1, true + } + return 0, false +} + +func (e *executableSchemaStub) Query(ctx context.Context, op *ast.OperationDefinition) *graphql.Response { + panic("Query should never be called by complexity calculations") +} + +func (e *executableSchemaStub) Mutation(ctx context.Context, op *ast.OperationDefinition) *graphql.Response { + panic("Mutation should never be called by complexity calculations") +} + +func (e *executableSchemaStub) Subscription(ctx context.Context, op *ast.OperationDefinition) func() *graphql.Response { + panic("Subscription should never be called by complexity calculations") +} diff --git a/docs/content/reference/complexity.md b/docs/content/reference/complexity.md new file mode 100644 index 0000000000..70640351ba --- /dev/null +++ b/docs/content/reference/complexity.md @@ -0,0 +1,87 @@ +--- +title: 'Preventing overly complex queries' +description: Avoid denial of service attacks by calculating query costs and limiting complexity. +linkTitle: Query Complexity +menu: { main: { parent: 'reference' } } +--- + +GraphQL provides a powerful way to query your data, but putting great power in the hands of your API clients also exposes you to a risk of denial of service attacks. You can mitigate that risk with gqlgen by limiting the complexity of the queries you allow. + +## Expensive Queries + +Consider a schema that allows listing blog posts. Each blog post is also related to other posts. + +```graphql +type Query { + posts(count: Int = 10): [Post!]! +} + +type Post { + title: String! + text: String! + related(count: Int = 10): [Post!]! +} +``` + +It's not too hard to craft a query that will cause a very large response: + +```graphql +{ + posts(count: 100) { + related(count: 100) { + related(count: 100) { + related(count: 100) { + title + } + } + } + } +} +``` + +The size of the response grows exponentially with each additional level of the `related` field. Fortunately, gqlgen's `http.Handler` includes a way to guard against this type of query. + +## Limiting Query Complexity + +Limiting query complexity is as simple as adding a parameter to the `handler.GraphQL` function call: + +```go +func main() { + c := Config{ Resolvers: &resolvers{} } + gqlHandler := handler.GraphQL( + blog.NewExecutableSchema(c), + handler.ComplexityLimit(5), // This line is the key + ) + http.Handle("/query", gqlHandler) +} +``` + +Now any query with complexity greater than 5 is rejected by the API. By default, each field and level of depth adds one to the overall query complexity. + +This helps, but we still have a problem: the `posts` and `related` fields, which return arrays, are much more expensive to resolve than the scalar `title` and `text` fields. However, the default complexity calculation weights them equally. It would make more sense to apply a higher cost to the array fields. + +## Custom Complexity Calculation + +To apply higher costs to certain fields, we can use custom complexity functions. + +```go +func main() { + c := Config{ Resolvers: &resolvers{} } + + countComplexity := func(childComplexity, count int) int { + return count * childComplexity + } + c.Complexity.Query.Posts = countComplexity + c.Complexity.Post.Related = countComplexity + + gqlHandler := handler.GraphQL( + blog.NewExecutableSchema(c), + handler.ComplexityLimit(100), + ) + http.Handle("/query", gqlHandler) +} +``` + +When we assign a function to the appropriate `Complexity` field, that function is used in the complexity calculation. Here, the `posts` and `related` fields are weighted according to the value of their `count` parameter. This means that the more posts a client requests, the higher the query complexity. And just like the size of the response would increase exponentially in our original query, the complexity would also increase exponentially, so any client trying to abuse the API would run into the limit very quickly. + +By applying a query complexity limit and specifying custom complexity functions in the right places, you can easily prevent clients from using a disproportionate amount of resources and disrupting your service. diff --git a/example/chat/generated.go b/example/chat/generated.go index 588d5f884b..bf6fa37438 100644 --- a/example/chat/generated.go +++ b/example/chat/generated.go @@ -20,12 +20,14 @@ func NewExecutableSchema(cfg Config) graphql.ExecutableSchema { return &executableSchema{ resolvers: cfg.Resolvers, directives: cfg.Directives, + complexity: cfg.Complexity, } } type Config struct { Resolvers ResolverRoot Directives DirectiveRoot + Complexity ComplexityRoot } type ResolverRoot interface { @@ -36,6 +38,33 @@ type ResolverRoot interface { type DirectiveRoot struct { } + +type ComplexityRoot struct { + Chatroom struct { + Name func(childComplexity int) int + Messages func(childComplexity int) int + } + + Message struct { + Id func(childComplexity int) int + Text func(childComplexity int) int + CreatedBy func(childComplexity int) int + CreatedAt func(childComplexity int) int + } + + Mutation struct { + Post func(childComplexity int, text string, username string, roomName string) int + } + + Query struct { + Room func(childComplexity int, name string) int + } + + Subscription struct { + MessageAdded func(childComplexity int, roomName string) int + } +} + type MutationResolver interface { Post(ctx context.Context, text string, username string, roomName string) (Message, error) } @@ -49,12 +78,136 @@ type SubscriptionResolver interface { type executableSchema struct { resolvers ResolverRoot directives DirectiveRoot + complexity ComplexityRoot } func (e *executableSchema) Schema() *ast.Schema { return parsedSchema } +func (e *executableSchema) Complexity(typeName, field string, childComplexity int, rawArgs map[string]interface{}) (int, bool) { + switch typeName + "." + field { + + case "Chatroom.name": + if e.complexity.Chatroom.Name == nil { + break + } + + return e.complexity.Chatroom.Name(childComplexity), true + + case "Chatroom.messages": + if e.complexity.Chatroom.Messages == nil { + break + } + + return e.complexity.Chatroom.Messages(childComplexity), true + + case "Message.id": + if e.complexity.Message.Id == nil { + break + } + + return e.complexity.Message.Id(childComplexity), true + + case "Message.text": + if e.complexity.Message.Text == nil { + break + } + + return e.complexity.Message.Text(childComplexity), true + + case "Message.createdBy": + if e.complexity.Message.CreatedBy == nil { + break + } + + return e.complexity.Message.CreatedBy(childComplexity), true + + case "Message.createdAt": + if e.complexity.Message.CreatedAt == nil { + break + } + + return e.complexity.Message.CreatedAt(childComplexity), true + + case "Mutation.post": + if e.complexity.Mutation.Post == nil { + break + } + args := map[string]interface{}{} + + var arg0 string + if tmp, ok := rawArgs["text"]; ok { + var err error + arg0, err = graphql.UnmarshalString(tmp) + if err != nil { + return 0, false + } + } + args["text"] = arg0 + + var arg1 string + if tmp, ok := rawArgs["username"]; ok { + var err error + arg1, err = graphql.UnmarshalString(tmp) + if err != nil { + return 0, false + } + } + args["username"] = arg1 + + var arg2 string + if tmp, ok := rawArgs["roomName"]; ok { + var err error + arg2, err = graphql.UnmarshalString(tmp) + if err != nil { + return 0, false + } + } + args["roomName"] = arg2 + + return e.complexity.Mutation.Post(childComplexity, args["text"].(string), args["username"].(string), args["roomName"].(string)), true + + case "Query.room": + if e.complexity.Query.Room == nil { + break + } + args := map[string]interface{}{} + + var arg0 string + if tmp, ok := rawArgs["name"]; ok { + var err error + arg0, err = graphql.UnmarshalString(tmp) + if err != nil { + return 0, false + } + } + args["name"] = arg0 + + return e.complexity.Query.Room(childComplexity, args["name"].(string)), true + + case "Subscription.messageAdded": + if e.complexity.Subscription.MessageAdded == nil { + break + } + args := map[string]interface{}{} + + var arg0 string + if tmp, ok := rawArgs["roomName"]; ok { + var err error + arg0, err = graphql.UnmarshalString(tmp) + if err != nil { + return 0, false + } + } + args["roomName"] = arg0 + + return e.complexity.Subscription.MessageAdded(childComplexity, args["roomName"].(string)), true + + } + return 0, false +} + func (e *executableSchema) Query(ctx context.Context, op *ast.OperationDefinition) *graphql.Response { ec := executionContext{graphql.GetRequestContext(ctx), e} diff --git a/example/config/generated.go b/example/config/generated.go index 5f6f51403c..55bf813481 100644 --- a/example/config/generated.go +++ b/example/config/generated.go @@ -19,12 +19,14 @@ func NewExecutableSchema(cfg Config) graphql.ExecutableSchema { return &executableSchema{ resolvers: cfg.Resolvers, directives: cfg.Directives, + complexity: cfg.Complexity, } } type Config struct { Resolvers ResolverRoot Directives DirectiveRoot + Complexity ComplexityRoot } type ResolverRoot interface { @@ -35,6 +37,30 @@ type ResolverRoot interface { type DirectiveRoot struct { } + +type ComplexityRoot struct { + Mutation struct { + CreateTodo func(childComplexity int, input NewTodo) int + } + + Query struct { + Todos func(childComplexity int) int + } + + Todo struct { + Id func(childComplexity int) int + DatabaseId func(childComplexity int) int + Text func(childComplexity int) int + Done func(childComplexity int) int + User func(childComplexity int) int + } + + User struct { + Id func(childComplexity int) int + Name func(childComplexity int) int + } +} + type MutationResolver interface { CreateTodo(ctx context.Context, input NewTodo) (Todo, error) } @@ -48,12 +74,94 @@ type TodoResolver interface { type executableSchema struct { resolvers ResolverRoot directives DirectiveRoot + complexity ComplexityRoot } func (e *executableSchema) Schema() *ast.Schema { return parsedSchema } +func (e *executableSchema) Complexity(typeName, field string, childComplexity int, rawArgs map[string]interface{}) (int, bool) { + switch typeName + "." + field { + + case "Mutation.createTodo": + if e.complexity.Mutation.CreateTodo == nil { + break + } + args := map[string]interface{}{} + + var arg0 NewTodo + if tmp, ok := rawArgs["input"]; ok { + var err error + arg0, err = UnmarshalNewTodo(tmp) + if err != nil { + return 0, false + } + } + args["input"] = arg0 + + return e.complexity.Mutation.CreateTodo(childComplexity, args["input"].(NewTodo)), true + + case "Query.todos": + if e.complexity.Query.Todos == nil { + break + } + + return e.complexity.Query.Todos(childComplexity), true + + case "Todo.id": + if e.complexity.Todo.Id == nil { + break + } + + return e.complexity.Todo.Id(childComplexity), true + + case "Todo.databaseId": + if e.complexity.Todo.DatabaseId == nil { + break + } + + return e.complexity.Todo.DatabaseId(childComplexity), true + + case "Todo.text": + if e.complexity.Todo.Text == nil { + break + } + + return e.complexity.Todo.Text(childComplexity), true + + case "Todo.done": + if e.complexity.Todo.Done == nil { + break + } + + return e.complexity.Todo.Done(childComplexity), true + + case "Todo.user": + if e.complexity.Todo.User == nil { + break + } + + return e.complexity.Todo.User(childComplexity), true + + case "User.id": + if e.complexity.User.Id == nil { + break + } + + return e.complexity.User.Id(childComplexity), true + + case "User.name": + if e.complexity.User.Name == nil { + break + } + + return e.complexity.User.Name(childComplexity), true + + } + return 0, false +} + func (e *executableSchema) Query(ctx context.Context, op *ast.OperationDefinition) *graphql.Response { ec := executionContext{graphql.GetRequestContext(ctx), e} diff --git a/example/dataloader/generated.go b/example/dataloader/generated.go index 8ad2fc63ed..dec0733dd7 100644 --- a/example/dataloader/generated.go +++ b/example/dataloader/generated.go @@ -20,12 +20,14 @@ func NewExecutableSchema(cfg Config) graphql.ExecutableSchema { return &executableSchema{ resolvers: cfg.Resolvers, directives: cfg.Directives, + complexity: cfg.Complexity, } } type Config struct { Resolvers ResolverRoot Directives DirectiveRoot + Complexity ComplexityRoot } type ResolverRoot interface { @@ -36,6 +38,39 @@ type ResolverRoot interface { type DirectiveRoot struct { } + +type ComplexityRoot struct { + Address struct { + Id func(childComplexity int) int + Street func(childComplexity int) int + Country func(childComplexity int) int + } + + Customer struct { + Id func(childComplexity int) int + Name func(childComplexity int) int + Address func(childComplexity int) int + Orders func(childComplexity int) int + } + + Item struct { + Name func(childComplexity int) int + } + + Order struct { + Id func(childComplexity int) int + Date func(childComplexity int) int + Amount func(childComplexity int) int + Items func(childComplexity int) int + } + + Query struct { + Customers func(childComplexity int) int + Torture1d func(childComplexity int, customerIds []int) int + Torture2d func(childComplexity int, customerIds [][]int) int + } +} + type CustomerResolver interface { Address(ctx context.Context, obj *Customer) (*Address, error) Orders(ctx context.Context, obj *Customer) ([]Order, error) @@ -52,12 +87,180 @@ type QueryResolver interface { type executableSchema struct { resolvers ResolverRoot directives DirectiveRoot + complexity ComplexityRoot } func (e *executableSchema) Schema() *ast.Schema { return parsedSchema } +func (e *executableSchema) Complexity(typeName, field string, childComplexity int, rawArgs map[string]interface{}) (int, bool) { + switch typeName + "." + field { + + case "Address.id": + if e.complexity.Address.Id == nil { + break + } + + return e.complexity.Address.Id(childComplexity), true + + case "Address.street": + if e.complexity.Address.Street == nil { + break + } + + return e.complexity.Address.Street(childComplexity), true + + case "Address.country": + if e.complexity.Address.Country == nil { + break + } + + return e.complexity.Address.Country(childComplexity), true + + case "Customer.id": + if e.complexity.Customer.Id == nil { + break + } + + return e.complexity.Customer.Id(childComplexity), true + + case "Customer.name": + if e.complexity.Customer.Name == nil { + break + } + + return e.complexity.Customer.Name(childComplexity), true + + case "Customer.address": + if e.complexity.Customer.Address == nil { + break + } + + return e.complexity.Customer.Address(childComplexity), true + + case "Customer.orders": + if e.complexity.Customer.Orders == nil { + break + } + + return e.complexity.Customer.Orders(childComplexity), true + + case "Item.name": + if e.complexity.Item.Name == nil { + break + } + + return e.complexity.Item.Name(childComplexity), true + + case "Order.id": + if e.complexity.Order.Id == nil { + break + } + + return e.complexity.Order.Id(childComplexity), true + + case "Order.date": + if e.complexity.Order.Date == nil { + break + } + + return e.complexity.Order.Date(childComplexity), true + + case "Order.amount": + if e.complexity.Order.Amount == nil { + break + } + + return e.complexity.Order.Amount(childComplexity), true + + case "Order.items": + if e.complexity.Order.Items == nil { + break + } + + return e.complexity.Order.Items(childComplexity), true + + case "Query.customers": + if e.complexity.Query.Customers == nil { + break + } + + return e.complexity.Query.Customers(childComplexity), true + + case "Query.torture1d": + if e.complexity.Query.Torture1d == nil { + break + } + args := map[string]interface{}{} + + var arg0 []int + if tmp, ok := rawArgs["customerIds"]; ok { + var err error + var rawIf1 []interface{} + if tmp != nil { + if tmp1, ok := tmp.([]interface{}); ok { + rawIf1 = tmp1 + } else { + rawIf1 = []interface{}{tmp} + } + } + arg0 = make([]int, len(rawIf1)) + for idx1 := range rawIf1 { + arg0[idx1], err = graphql.UnmarshalInt(rawIf1[idx1]) + } + if err != nil { + return 0, false + } + } + args["customerIds"] = arg0 + + return e.complexity.Query.Torture1d(childComplexity, args["customerIds"].([]int)), true + + case "Query.torture2d": + if e.complexity.Query.Torture2d == nil { + break + } + args := map[string]interface{}{} + + var arg0 [][]int + if tmp, ok := rawArgs["customerIds"]; ok { + var err error + var rawIf1 []interface{} + if tmp != nil { + if tmp1, ok := tmp.([]interface{}); ok { + rawIf1 = tmp1 + } else { + rawIf1 = []interface{}{tmp} + } + } + arg0 = make([][]int, len(rawIf1)) + for idx1 := range rawIf1 { + var rawIf2 []interface{} + if rawIf1[idx1] != nil { + if tmp1, ok := rawIf1[idx1].([]interface{}); ok { + rawIf2 = tmp1 + } else { + rawIf2 = []interface{}{rawIf1[idx1]} + } + } + arg0[idx1] = make([]int, len(rawIf2)) + for idx2 := range rawIf2 { + arg0[idx1][idx2], err = graphql.UnmarshalInt(rawIf2[idx2]) + } + } + if err != nil { + return 0, false + } + } + args["customerIds"] = arg0 + + return e.complexity.Query.Torture2d(childComplexity, args["customerIds"].([][]int)), true + + } + return 0, false +} + func (e *executableSchema) Query(ctx context.Context, op *ast.OperationDefinition) *graphql.Response { ec := executionContext{graphql.GetRequestContext(ctx), e} diff --git a/example/scalars/generated.go b/example/scalars/generated.go index b44bc68e07..3728f5f46d 100644 --- a/example/scalars/generated.go +++ b/example/scalars/generated.go @@ -22,12 +22,14 @@ func NewExecutableSchema(cfg Config) graphql.ExecutableSchema { return &executableSchema{ resolvers: cfg.Resolvers, directives: cfg.Directives, + complexity: cfg.Complexity, } } type Config struct { Resolvers ResolverRoot Directives DirectiveRoot + Complexity ComplexityRoot } type ResolverRoot interface { @@ -37,6 +39,30 @@ type ResolverRoot interface { type DirectiveRoot struct { } + +type ComplexityRoot struct { + Address struct { + Id func(childComplexity int) int + Location func(childComplexity int) int + } + + Query struct { + User func(childComplexity int, id external.ObjectID) int + Search func(childComplexity int, input model.SearchArgs) int + } + + User struct { + Id func(childComplexity int) int + Name func(childComplexity int) int + Created func(childComplexity int) int + IsBanned func(childComplexity int) int + PrimitiveResolver func(childComplexity int) int + CustomResolver func(childComplexity int) int + Address func(childComplexity int) int + Tier func(childComplexity int) int + } +} + type QueryResolver interface { User(ctx context.Context, id external.ObjectID) (*model.User, error) Search(ctx context.Context, input model.SearchArgs) ([]model.User, error) @@ -49,12 +75,126 @@ type UserResolver interface { type executableSchema struct { resolvers ResolverRoot directives DirectiveRoot + complexity ComplexityRoot } func (e *executableSchema) Schema() *ast.Schema { return parsedSchema } +func (e *executableSchema) Complexity(typeName, field string, childComplexity int, rawArgs map[string]interface{}) (int, bool) { + switch typeName + "." + field { + + case "Address.id": + if e.complexity.Address.Id == nil { + break + } + + return e.complexity.Address.Id(childComplexity), true + + case "Address.location": + if e.complexity.Address.Location == nil { + break + } + + return e.complexity.Address.Location(childComplexity), true + + case "Query.user": + if e.complexity.Query.User == nil { + break + } + args := map[string]interface{}{} + + var arg0 external.ObjectID + if tmp, ok := rawArgs["id"]; ok { + var err error + arg0, err = model.UnmarshalID(tmp) + if err != nil { + return 0, false + } + } + args["id"] = arg0 + + return e.complexity.Query.User(childComplexity, args["id"].(external.ObjectID)), true + + case "Query.search": + if e.complexity.Query.Search == nil { + break + } + args := map[string]interface{}{} + + var arg0 model.SearchArgs + if tmp, ok := rawArgs["input"]; ok { + var err error + arg0, err = UnmarshalSearchArgs(tmp) + if err != nil { + return 0, false + } + } + args["input"] = arg0 + + return e.complexity.Query.Search(childComplexity, args["input"].(model.SearchArgs)), true + + case "User.id": + if e.complexity.User.Id == nil { + break + } + + return e.complexity.User.Id(childComplexity), true + + case "User.name": + if e.complexity.User.Name == nil { + break + } + + return e.complexity.User.Name(childComplexity), true + + case "User.created": + if e.complexity.User.Created == nil { + break + } + + return e.complexity.User.Created(childComplexity), true + + case "User.isBanned": + if e.complexity.User.IsBanned == nil { + break + } + + return e.complexity.User.IsBanned(childComplexity), true + + case "User.primitiveResolver": + if e.complexity.User.PrimitiveResolver == nil { + break + } + + return e.complexity.User.PrimitiveResolver(childComplexity), true + + case "User.customResolver": + if e.complexity.User.CustomResolver == nil { + break + } + + return e.complexity.User.CustomResolver(childComplexity), true + + case "User.address": + if e.complexity.User.Address == nil { + break + } + + return e.complexity.User.Address(childComplexity), true + + case "User.tier": + if e.complexity.User.Tier == nil { + break + } + + return e.complexity.User.Tier(childComplexity), true + + } + return 0, false +} + func (e *executableSchema) Query(ctx context.Context, op *ast.OperationDefinition) *graphql.Response { ec := executionContext{graphql.GetRequestContext(ctx), e} diff --git a/example/selection/generated.go b/example/selection/generated.go index f49b0a6e02..c1b243601e 100644 --- a/example/selection/generated.go +++ b/example/selection/generated.go @@ -21,12 +21,14 @@ func NewExecutableSchema(cfg Config) graphql.ExecutableSchema { return &executableSchema{ resolvers: cfg.Resolvers, directives: cfg.Directives, + complexity: cfg.Complexity, } } type Config struct { Resolvers ResolverRoot Directives DirectiveRoot + Complexity ComplexityRoot } type ResolverRoot interface { @@ -35,6 +37,27 @@ type ResolverRoot interface { type DirectiveRoot struct { } + +type ComplexityRoot struct { + Like struct { + Reaction func(childComplexity int) int + Sent func(childComplexity int) int + Selection func(childComplexity int) int + Collected func(childComplexity int) int + } + + Post struct { + Message func(childComplexity int) int + Sent func(childComplexity int) int + Selection func(childComplexity int) int + Collected func(childComplexity int) int + } + + Query struct { + Events func(childComplexity int) int + } +} + type QueryResolver interface { Events(ctx context.Context) ([]Event, error) } @@ -42,12 +65,83 @@ type QueryResolver interface { type executableSchema struct { resolvers ResolverRoot directives DirectiveRoot + complexity ComplexityRoot } func (e *executableSchema) Schema() *ast.Schema { return parsedSchema } +func (e *executableSchema) Complexity(typeName, field string, childComplexity int, rawArgs map[string]interface{}) (int, bool) { + switch typeName + "." + field { + + case "Like.reaction": + if e.complexity.Like.Reaction == nil { + break + } + + return e.complexity.Like.Reaction(childComplexity), true + + case "Like.sent": + if e.complexity.Like.Sent == nil { + break + } + + return e.complexity.Like.Sent(childComplexity), true + + case "Like.selection": + if e.complexity.Like.Selection == nil { + break + } + + return e.complexity.Like.Selection(childComplexity), true + + case "Like.collected": + if e.complexity.Like.Collected == nil { + break + } + + return e.complexity.Like.Collected(childComplexity), true + + case "Post.message": + if e.complexity.Post.Message == nil { + break + } + + return e.complexity.Post.Message(childComplexity), true + + case "Post.sent": + if e.complexity.Post.Sent == nil { + break + } + + return e.complexity.Post.Sent(childComplexity), true + + case "Post.selection": + if e.complexity.Post.Selection == nil { + break + } + + return e.complexity.Post.Selection(childComplexity), true + + case "Post.collected": + if e.complexity.Post.Collected == nil { + break + } + + return e.complexity.Post.Collected(childComplexity), true + + case "Query.events": + if e.complexity.Query.Events == nil { + break + } + + return e.complexity.Query.Events(childComplexity), true + + } + return 0, false +} + func (e *executableSchema) Query(ctx context.Context, op *ast.OperationDefinition) *graphql.Response { ec := executionContext{graphql.GetRequestContext(ctx), e} diff --git a/example/starwars/generated.go b/example/starwars/generated.go index 9213b5753f..a33e4c4f35 100644 --- a/example/starwars/generated.go +++ b/example/starwars/generated.go @@ -21,12 +21,14 @@ func NewExecutableSchema(cfg Config) graphql.ExecutableSchema { return &executableSchema{ resolvers: cfg.Resolvers, directives: cfg.Directives, + complexity: cfg.Complexity, } } type Config struct { Resolvers ResolverRoot Directives DirectiveRoot + Complexity ComplexityRoot } type ResolverRoot interface { @@ -40,6 +42,74 @@ type ResolverRoot interface { type DirectiveRoot struct { } + +type ComplexityRoot struct { + Droid struct { + Id func(childComplexity int) int + Name func(childComplexity int) int + Friends func(childComplexity int) int + FriendsConnection func(childComplexity int, first *int, after *string) int + AppearsIn func(childComplexity int) int + PrimaryFunction func(childComplexity int) int + } + + FriendsConnection struct { + TotalCount func(childComplexity int) int + Edges func(childComplexity int) int + Friends func(childComplexity int) int + PageInfo func(childComplexity int) int + } + + FriendsEdge struct { + Cursor func(childComplexity int) int + Node func(childComplexity int) int + } + + Human struct { + Id func(childComplexity int) int + Name func(childComplexity int) int + Height func(childComplexity int, unit LengthUnit) int + Mass func(childComplexity int) int + Friends func(childComplexity int) int + FriendsConnection func(childComplexity int, first *int, after *string) int + AppearsIn func(childComplexity int) int + Starships func(childComplexity int) int + } + + Mutation struct { + CreateReview func(childComplexity int, episode Episode, review Review) int + } + + PageInfo struct { + StartCursor func(childComplexity int) int + EndCursor func(childComplexity int) int + HasNextPage func(childComplexity int) int + } + + Query struct { + Hero func(childComplexity int, episode Episode) int + Reviews func(childComplexity int, episode Episode, since *time.Time) int + Search func(childComplexity int, text string) int + Character func(childComplexity int, id string) int + Droid func(childComplexity int, id string) int + Human func(childComplexity int, id string) int + Starship func(childComplexity int, id string) int + } + + Review struct { + Stars func(childComplexity int) int + Commentary func(childComplexity int) int + Time func(childComplexity int) int + } + + Starship struct { + Id func(childComplexity int) int + Name func(childComplexity int) int + Length func(childComplexity int, unit LengthUnit) int + History func(childComplexity int) int + } +} + type DroidResolver interface { Friends(ctx context.Context, obj *Droid) ([]Character, error) FriendsConnection(ctx context.Context, obj *Droid, first *int, after *string) (FriendsConnection, error) @@ -73,12 +143,483 @@ type StarshipResolver interface { type executableSchema struct { resolvers ResolverRoot directives DirectiveRoot + complexity ComplexityRoot } func (e *executableSchema) Schema() *ast.Schema { return parsedSchema } +func (e *executableSchema) Complexity(typeName, field string, childComplexity int, rawArgs map[string]interface{}) (int, bool) { + switch typeName + "." + field { + + case "Droid.id": + if e.complexity.Droid.Id == nil { + break + } + + return e.complexity.Droid.Id(childComplexity), true + + case "Droid.name": + if e.complexity.Droid.Name == nil { + break + } + + return e.complexity.Droid.Name(childComplexity), true + + case "Droid.friends": + if e.complexity.Droid.Friends == nil { + break + } + + return e.complexity.Droid.Friends(childComplexity), true + + case "Droid.friendsConnection": + if e.complexity.Droid.FriendsConnection == nil { + break + } + args := map[string]interface{}{} + + var arg0 *int + if tmp, ok := rawArgs["first"]; ok { + var err error + var ptr1 int + if tmp != nil { + ptr1, err = graphql.UnmarshalInt(tmp) + arg0 = &ptr1 + } + + if err != nil { + return 0, false + } + } + args["first"] = arg0 + + var arg1 *string + if tmp, ok := rawArgs["after"]; ok { + var err error + var ptr1 string + if tmp != nil { + ptr1, err = graphql.UnmarshalID(tmp) + arg1 = &ptr1 + } + + if err != nil { + return 0, false + } + } + args["after"] = arg1 + + return e.complexity.Droid.FriendsConnection(childComplexity, args["first"].(*int), args["after"].(*string)), true + + case "Droid.appearsIn": + if e.complexity.Droid.AppearsIn == nil { + break + } + + return e.complexity.Droid.AppearsIn(childComplexity), true + + case "Droid.primaryFunction": + if e.complexity.Droid.PrimaryFunction == nil { + break + } + + return e.complexity.Droid.PrimaryFunction(childComplexity), true + + case "FriendsConnection.totalCount": + if e.complexity.FriendsConnection.TotalCount == nil { + break + } + + return e.complexity.FriendsConnection.TotalCount(childComplexity), true + + case "FriendsConnection.edges": + if e.complexity.FriendsConnection.Edges == nil { + break + } + + return e.complexity.FriendsConnection.Edges(childComplexity), true + + case "FriendsConnection.friends": + if e.complexity.FriendsConnection.Friends == nil { + break + } + + return e.complexity.FriendsConnection.Friends(childComplexity), true + + case "FriendsConnection.pageInfo": + if e.complexity.FriendsConnection.PageInfo == nil { + break + } + + return e.complexity.FriendsConnection.PageInfo(childComplexity), true + + case "FriendsEdge.cursor": + if e.complexity.FriendsEdge.Cursor == nil { + break + } + + return e.complexity.FriendsEdge.Cursor(childComplexity), true + + case "FriendsEdge.node": + if e.complexity.FriendsEdge.Node == nil { + break + } + + return e.complexity.FriendsEdge.Node(childComplexity), true + + case "Human.id": + if e.complexity.Human.Id == nil { + break + } + + return e.complexity.Human.Id(childComplexity), true + + case "Human.name": + if e.complexity.Human.Name == nil { + break + } + + return e.complexity.Human.Name(childComplexity), true + + case "Human.height": + if e.complexity.Human.Height == nil { + break + } + args := map[string]interface{}{} + + var arg0 LengthUnit + if tmp, ok := rawArgs["unit"]; ok { + var err error + err = (&arg0).UnmarshalGQL(tmp) + if err != nil { + return 0, false + } + } + args["unit"] = arg0 + + return e.complexity.Human.Height(childComplexity, args["unit"].(LengthUnit)), true + + case "Human.mass": + if e.complexity.Human.Mass == nil { + break + } + + return e.complexity.Human.Mass(childComplexity), true + + case "Human.friends": + if e.complexity.Human.Friends == nil { + break + } + + return e.complexity.Human.Friends(childComplexity), true + + case "Human.friendsConnection": + if e.complexity.Human.FriendsConnection == nil { + break + } + args := map[string]interface{}{} + + var arg0 *int + if tmp, ok := rawArgs["first"]; ok { + var err error + var ptr1 int + if tmp != nil { + ptr1, err = graphql.UnmarshalInt(tmp) + arg0 = &ptr1 + } + + if err != nil { + return 0, false + } + } + args["first"] = arg0 + + var arg1 *string + if tmp, ok := rawArgs["after"]; ok { + var err error + var ptr1 string + if tmp != nil { + ptr1, err = graphql.UnmarshalID(tmp) + arg1 = &ptr1 + } + + if err != nil { + return 0, false + } + } + args["after"] = arg1 + + return e.complexity.Human.FriendsConnection(childComplexity, args["first"].(*int), args["after"].(*string)), true + + case "Human.appearsIn": + if e.complexity.Human.AppearsIn == nil { + break + } + + return e.complexity.Human.AppearsIn(childComplexity), true + + case "Human.starships": + if e.complexity.Human.Starships == nil { + break + } + + return e.complexity.Human.Starships(childComplexity), true + + case "Mutation.createReview": + if e.complexity.Mutation.CreateReview == nil { + break + } + args := map[string]interface{}{} + + var arg0 Episode + if tmp, ok := rawArgs["episode"]; ok { + var err error + err = (&arg0).UnmarshalGQL(tmp) + if err != nil { + return 0, false + } + } + args["episode"] = arg0 + + var arg1 Review + if tmp, ok := rawArgs["review"]; ok { + var err error + arg1, err = UnmarshalReviewInput(tmp) + if err != nil { + return 0, false + } + } + args["review"] = arg1 + + return e.complexity.Mutation.CreateReview(childComplexity, args["episode"].(Episode), args["review"].(Review)), true + + case "PageInfo.startCursor": + if e.complexity.PageInfo.StartCursor == nil { + break + } + + return e.complexity.PageInfo.StartCursor(childComplexity), true + + case "PageInfo.endCursor": + if e.complexity.PageInfo.EndCursor == nil { + break + } + + return e.complexity.PageInfo.EndCursor(childComplexity), true + + case "PageInfo.hasNextPage": + if e.complexity.PageInfo.HasNextPage == nil { + break + } + + return e.complexity.PageInfo.HasNextPage(childComplexity), true + + case "Query.hero": + if e.complexity.Query.Hero == nil { + break + } + args := map[string]interface{}{} + + var arg0 Episode + if tmp, ok := rawArgs["episode"]; ok { + var err error + err = (&arg0).UnmarshalGQL(tmp) + if err != nil { + return 0, false + } + } + args["episode"] = arg0 + + return e.complexity.Query.Hero(childComplexity, args["episode"].(Episode)), true + + case "Query.reviews": + if e.complexity.Query.Reviews == nil { + break + } + args := map[string]interface{}{} + + var arg0 Episode + if tmp, ok := rawArgs["episode"]; ok { + var err error + err = (&arg0).UnmarshalGQL(tmp) + if err != nil { + return 0, false + } + } + args["episode"] = arg0 + + var arg1 *time.Time + if tmp, ok := rawArgs["since"]; ok { + var err error + var ptr1 time.Time + if tmp != nil { + ptr1, err = graphql.UnmarshalTime(tmp) + arg1 = &ptr1 + } + + if err != nil { + return 0, false + } + } + args["since"] = arg1 + + return e.complexity.Query.Reviews(childComplexity, args["episode"].(Episode), args["since"].(*time.Time)), true + + case "Query.search": + if e.complexity.Query.Search == nil { + break + } + args := map[string]interface{}{} + + var arg0 string + if tmp, ok := rawArgs["text"]; ok { + var err error + arg0, err = graphql.UnmarshalString(tmp) + if err != nil { + return 0, false + } + } + args["text"] = arg0 + + return e.complexity.Query.Search(childComplexity, args["text"].(string)), true + + case "Query.character": + if e.complexity.Query.Character == nil { + break + } + args := map[string]interface{}{} + + var arg0 string + if tmp, ok := rawArgs["id"]; ok { + var err error + arg0, err = graphql.UnmarshalID(tmp) + if err != nil { + return 0, false + } + } + args["id"] = arg0 + + return e.complexity.Query.Character(childComplexity, args["id"].(string)), true + + case "Query.droid": + if e.complexity.Query.Droid == nil { + break + } + args := map[string]interface{}{} + + var arg0 string + if tmp, ok := rawArgs["id"]; ok { + var err error + arg0, err = graphql.UnmarshalID(tmp) + if err != nil { + return 0, false + } + } + args["id"] = arg0 + + return e.complexity.Query.Droid(childComplexity, args["id"].(string)), true + + case "Query.human": + if e.complexity.Query.Human == nil { + break + } + args := map[string]interface{}{} + + var arg0 string + if tmp, ok := rawArgs["id"]; ok { + var err error + arg0, err = graphql.UnmarshalID(tmp) + if err != nil { + return 0, false + } + } + args["id"] = arg0 + + return e.complexity.Query.Human(childComplexity, args["id"].(string)), true + + case "Query.starship": + if e.complexity.Query.Starship == nil { + break + } + args := map[string]interface{}{} + + var arg0 string + if tmp, ok := rawArgs["id"]; ok { + var err error + arg0, err = graphql.UnmarshalID(tmp) + if err != nil { + return 0, false + } + } + args["id"] = arg0 + + return e.complexity.Query.Starship(childComplexity, args["id"].(string)), true + + case "Review.stars": + if e.complexity.Review.Stars == nil { + break + } + + return e.complexity.Review.Stars(childComplexity), true + + case "Review.commentary": + if e.complexity.Review.Commentary == nil { + break + } + + return e.complexity.Review.Commentary(childComplexity), true + + case "Review.time": + if e.complexity.Review.Time == nil { + break + } + + return e.complexity.Review.Time(childComplexity), true + + case "Starship.id": + if e.complexity.Starship.Id == nil { + break + } + + return e.complexity.Starship.Id(childComplexity), true + + case "Starship.name": + if e.complexity.Starship.Name == nil { + break + } + + return e.complexity.Starship.Name(childComplexity), true + + case "Starship.length": + if e.complexity.Starship.Length == nil { + break + } + args := map[string]interface{}{} + + var arg0 LengthUnit + if tmp, ok := rawArgs["unit"]; ok { + var err error + err = (&arg0).UnmarshalGQL(tmp) + if err != nil { + return 0, false + } + } + args["unit"] = arg0 + + return e.complexity.Starship.Length(childComplexity, args["unit"].(LengthUnit)), true + + case "Starship.history": + if e.complexity.Starship.History == nil { + break + } + + return e.complexity.Starship.History(childComplexity), true + + } + return 0, false +} + func (e *executableSchema) Query(ctx context.Context, op *ast.OperationDefinition) *graphql.Response { ec := executionContext{graphql.GetRequestContext(ctx), e} diff --git a/example/todo/generated.go b/example/todo/generated.go index 49485cb0e4..e98ab8b48b 100644 --- a/example/todo/generated.go +++ b/example/todo/generated.go @@ -19,12 +19,14 @@ func NewExecutableSchema(cfg Config) graphql.ExecutableSchema { return &executableSchema{ resolvers: cfg.Resolvers, directives: cfg.Directives, + complexity: cfg.Complexity, } } type Config struct { Resolvers ResolverRoot Directives DirectiveRoot + Complexity ComplexityRoot } type ResolverRoot interface { @@ -35,6 +37,26 @@ type ResolverRoot interface { type DirectiveRoot struct { HasRole func(ctx context.Context, obj interface{}, next graphql.Resolver, role Role) (res interface{}, err error) } + +type ComplexityRoot struct { + MyMutation struct { + CreateTodo func(childComplexity int, todo TodoInput) int + UpdateTodo func(childComplexity int, id int, changes map[string]interface{}) int + } + + MyQuery struct { + Todo func(childComplexity int, id int) int + LastTodo func(childComplexity int) int + Todos func(childComplexity int) int + } + + Todo struct { + Id func(childComplexity int) int + Text func(childComplexity int) int + Done func(childComplexity int) int + } +} + type MyMutationResolver interface { CreateTodo(ctx context.Context, todo TodoInput) (Todo, error) UpdateTodo(ctx context.Context, id int, changes map[string]interface{}) (*Todo, error) @@ -48,12 +70,119 @@ type MyQueryResolver interface { type executableSchema struct { resolvers ResolverRoot directives DirectiveRoot + complexity ComplexityRoot } func (e *executableSchema) Schema() *ast.Schema { return parsedSchema } +func (e *executableSchema) Complexity(typeName, field string, childComplexity int, rawArgs map[string]interface{}) (int, bool) { + switch typeName + "." + field { + + case "MyMutation.createTodo": + if e.complexity.MyMutation.CreateTodo == nil { + break + } + args := map[string]interface{}{} + + var arg0 TodoInput + if tmp, ok := rawArgs["todo"]; ok { + var err error + arg0, err = UnmarshalTodoInput(tmp) + if err != nil { + return 0, false + } + } + args["todo"] = arg0 + + return e.complexity.MyMutation.CreateTodo(childComplexity, args["todo"].(TodoInput)), true + + case "MyMutation.updateTodo": + if e.complexity.MyMutation.UpdateTodo == nil { + break + } + args := map[string]interface{}{} + + var arg0 int + if tmp, ok := rawArgs["id"]; ok { + var err error + arg0, err = graphql.UnmarshalInt(tmp) + if err != nil { + return 0, false + } + } + args["id"] = arg0 + + var arg1 map[string]interface{} + if tmp, ok := rawArgs["changes"]; ok { + var err error + arg1 = tmp.(map[string]interface{}) + if err != nil { + return 0, false + } + } + args["changes"] = arg1 + + return e.complexity.MyMutation.UpdateTodo(childComplexity, args["id"].(int), args["changes"].(map[string]interface{})), true + + case "MyQuery.todo": + if e.complexity.MyQuery.Todo == nil { + break + } + args := map[string]interface{}{} + + var arg0 int + if tmp, ok := rawArgs["id"]; ok { + var err error + arg0, err = graphql.UnmarshalInt(tmp) + if err != nil { + return 0, false + } + } + args["id"] = arg0 + + return e.complexity.MyQuery.Todo(childComplexity, args["id"].(int)), true + + case "MyQuery.lastTodo": + if e.complexity.MyQuery.LastTodo == nil { + break + } + + return e.complexity.MyQuery.LastTodo(childComplexity), true + + case "MyQuery.todos": + if e.complexity.MyQuery.Todos == nil { + break + } + + return e.complexity.MyQuery.Todos(childComplexity), true + + case "Todo.id": + if e.complexity.Todo.Id == nil { + break + } + + return e.complexity.Todo.Id(childComplexity), true + + case "Todo.text": + if e.complexity.Todo.Text == nil { + break + } + + return e.complexity.Todo.Text(childComplexity), true + + case "Todo.done": + if e.complexity.Todo.Done == nil { + break + } + + return e.complexity.Todo.Done(childComplexity), true + + } + return 0, false +} + func (e *executableSchema) Query(ctx context.Context, op *ast.OperationDefinition) *graphql.Response { ec := executionContext{graphql.GetRequestContext(ctx), e} diff --git a/graphql/exec.go b/graphql/exec.go index 387c34cc09..9beb314907 100644 --- a/graphql/exec.go +++ b/graphql/exec.go @@ -10,6 +10,7 @@ import ( type ExecutableSchema interface { Schema() *ast.Schema + Complexity(typeName, fieldName string, childComplexity int, args map[string]interface{}) (int, bool) Query(ctx context.Context, op *ast.OperationDefinition) *Response Mutation(ctx context.Context, op *ast.OperationDefinition) *Response Subscription(ctx context.Context, op *ast.OperationDefinition) func() *Response diff --git a/handler/graphql.go b/handler/graphql.go index 87c3e66bb2..9d222826ab 100644 --- a/handler/graphql.go +++ b/handler/graphql.go @@ -8,6 +8,7 @@ import ( "net/http" "strings" + "github.com/99designs/gqlgen/complexity" "github.com/99designs/gqlgen/graphql" "github.com/gorilla/websocket" "github.com/hashicorp/golang-lru" @@ -24,12 +25,13 @@ type params struct { } type Config struct { - cacheSize int - upgrader websocket.Upgrader - recover graphql.RecoverFunc - errorPresenter graphql.ErrorPresenterFunc - resolverHook graphql.FieldMiddleware - requestHook graphql.RequestMiddleware + cacheSize int + upgrader websocket.Upgrader + recover graphql.RecoverFunc + errorPresenter graphql.ErrorPresenterFunc + resolverHook graphql.FieldMiddleware + requestHook graphql.RequestMiddleware + complexityLimit int } func (c *Config) newRequestContext(doc *ast.QueryDocument, query string, variables map[string]interface{}) *graphql.RequestContext { @@ -76,6 +78,14 @@ func ErrorPresenter(f graphql.ErrorPresenterFunc) Option { } } +// ComplexityLimit sets a maximum query complexity that is allowed to be executed. +// If a query is submitted that exceeds the limit, a 422 status code will be returned. +func ComplexityLimit(limit int) Option { + return func(cfg *Config) { + cfg.complexityLimit = limit + } +} + // ResolverMiddleware allows you to define a function that will be called around every resolver, // useful for tracing and logging. func ResolverMiddleware(middleware graphql.FieldMiddleware) Option { @@ -226,6 +236,14 @@ func GraphQL(exec graphql.ExecutableSchema, options ...Option) http.HandlerFunc } }() + if cfg.complexityLimit > 0 { + queryComplexity := complexity.Calculate(exec, op, vars) + if queryComplexity > cfg.complexityLimit { + sendErrorf(w, http.StatusUnprocessableEntity, "query has complexity %d, which exceeds the limit of %d", queryComplexity, cfg.complexityLimit) + return + } + } + switch op.Operation { case ast.Query: b, err := json.Marshal(exec.Query(ctx, op)) diff --git a/handler/stub.go b/handler/stub.go index a5b1a19b77..d237e18892 100644 --- a/handler/stub.go +++ b/handler/stub.go @@ -25,6 +25,10 @@ func (e *executableSchemaStub) Schema() *ast.Schema { `}) } +func (e *executableSchemaStub) Complexity(typeName, field string, childComplexity int, args map[string]interface{}) (int, bool) { + return 0, false +} + func (e *executableSchemaStub) Query(ctx context.Context, op *ast.OperationDefinition) *graphql.Response { return &graphql.Response{Data: []byte(`{"name":"test"}`)} } diff --git a/integration/generated.go b/integration/generated.go index bf648ad6f6..02581e1ddb 100644 --- a/integration/generated.go +++ b/integration/generated.go @@ -21,12 +21,14 @@ func NewExecutableSchema(cfg Config) graphql.ExecutableSchema { return &executableSchema{ resolvers: cfg.Resolvers, directives: cfg.Directives, + complexity: cfg.Complexity, } } type Config struct { Resolvers ResolverRoot Directives DirectiveRoot + Complexity ComplexityRoot } type ResolverRoot interface { @@ -38,6 +40,32 @@ type ResolverRoot interface { type DirectiveRoot struct { Magic func(ctx context.Context, obj interface{}, next graphql.Resolver, kind *int) (res interface{}, err error) } + +type ComplexityRoot struct { + Element struct { + Child func(childComplexity int) int + Error func(childComplexity int) int + Mismatched func(childComplexity int) int + } + + Query struct { + Path func(childComplexity int) int + Date func(childComplexity int, filter models.DateFilter) int + Viewer func(childComplexity int) int + JsonEncoding func(childComplexity int) int + Error func(childComplexity int, typeArg models.ErrorType) int + } + + User struct { + Name func(childComplexity int) int + Likes func(childComplexity int) int + } + + Viewer struct { + User func(childComplexity int) int + } +} + type ElementResolver interface { Child(ctx context.Context, obj *models.Element) (models.Element, error) Error(ctx context.Context, obj *models.Element) (bool, error) @@ -57,12 +85,119 @@ type UserResolver interface { type executableSchema struct { resolvers ResolverRoot directives DirectiveRoot + complexity ComplexityRoot } func (e *executableSchema) Schema() *ast.Schema { return parsedSchema } +func (e *executableSchema) Complexity(typeName, field string, childComplexity int, rawArgs map[string]interface{}) (int, bool) { + switch typeName + "." + field { + + case "Element.child": + if e.complexity.Element.Child == nil { + break + } + + return e.complexity.Element.Child(childComplexity), true + + case "Element.error": + if e.complexity.Element.Error == nil { + break + } + + return e.complexity.Element.Error(childComplexity), true + + case "Element.mismatched": + if e.complexity.Element.Mismatched == nil { + break + } + + return e.complexity.Element.Mismatched(childComplexity), true + + case "Query.path": + if e.complexity.Query.Path == nil { + break + } + + return e.complexity.Query.Path(childComplexity), true + + case "Query.date": + if e.complexity.Query.Date == nil { + break + } + args := map[string]interface{}{} + + var arg0 models.DateFilter + if tmp, ok := rawArgs["filter"]; ok { + var err error + arg0, err = UnmarshalDateFilter(tmp) + if err != nil { + return 0, false + } + } + args["filter"] = arg0 + + return e.complexity.Query.Date(childComplexity, args["filter"].(models.DateFilter)), true + + case "Query.viewer": + if e.complexity.Query.Viewer == nil { + break + } + + return e.complexity.Query.Viewer(childComplexity), true + + case "Query.jsonEncoding": + if e.complexity.Query.JsonEncoding == nil { + break + } + + return e.complexity.Query.JsonEncoding(childComplexity), true + + case "Query.error": + if e.complexity.Query.Error == nil { + break + } + args := map[string]interface{}{} + + var arg0 models.ErrorType + if tmp, ok := rawArgs["type"]; ok { + var err error + err = (&arg0).UnmarshalGQL(tmp) + if err != nil { + return 0, false + } + } + args["type"] = arg0 + + return e.complexity.Query.Error(childComplexity, args["type"].(models.ErrorType)), true + + case "User.name": + if e.complexity.User.Name == nil { + break + } + + return e.complexity.User.Name(childComplexity), true + + case "User.likes": + if e.complexity.User.Likes == nil { + break + } + + return e.complexity.User.Likes(childComplexity), true + + case "Viewer.user": + if e.complexity.Viewer.User == nil { + break + } + + return e.complexity.Viewer.User(childComplexity), true + + } + return 0, false +} + func (e *executableSchema) Query(ctx context.Context, op *ast.OperationDefinition) *graphql.Response { ec := executionContext{graphql.GetRequestContext(ctx), e}