Skip to content

Commit

Permalink
Fix complexity case selection
Browse files Browse the repository at this point in the history
Use the GraphQL field name rather than the Go field name in the generated
`Complexity` func.

Before this patch, overloading complexity funcs was ineffective because they
were never executed.

It also ensures that overlapping fields are now generated; mapping all possible
field names to the associated complexity func.
  • Loading branch information
mbranch committed May 1, 2019
1 parent 5ff6092 commit 02e9dd8
Show file tree
Hide file tree
Showing 18 changed files with 417 additions and 219 deletions.
6 changes: 3 additions & 3 deletions codegen/complexity.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
package codegen

func (o *Object) UniqueFields() map[string]*Field {
m := map[string]*Field{}
func (o *Object) UniqueFields() map[string][]*Field {
m := map[string][]*Field{}

for _, f := range o.Fields {
m[f.GoFieldName] = f
m[f.GoFieldName] = append(m[f.GoFieldName], f)
}

return m
Expand Down
32 changes: 19 additions & 13 deletions codegen/generated!.gotpl
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@ type ComplexityRoot struct {
{{ range $object := .Objects }}
{{ if not $object.IsReserved -}}
{{ $object.Name|go }} struct {
{{ range $field := $object.UniqueFields -}}
{{ range $_, $fields := $object.UniqueFields }}
{{- $field := index $fields 0 -}}
{{ if not $field.IsReserved -}}
{{ $field.GoFieldName }} {{ $field.ComplexitySignature }}
{{ end }}
Expand Down Expand Up @@ -84,20 +85,25 @@ func (e *executableSchema) Complexity(typeName, field string, childComplexity in
switch typeName + "." + field {
{{ range $object := .Objects }}
{{ if not $object.IsReserved }}
{{ range $field := $object.UniqueFields }}
{{ if not $field.IsReserved }}
case "{{$object.Name}}.{{$field.GoFieldName}}":
if e.complexity.{{$object.Name|go}}.{{$field.GoFieldName}} == nil {
break
}
{{ if $field.Args }}
args, err := ec.{{ $field.ArgsFunc }}(context.TODO(),rawArgs)
if err != nil {
return 0, false
{{ range $_, $fields := $object.UniqueFields }}
{{- $len := len $fields }}
{{- range $i, $field := $fields }}
{{- $last := eq (add $i 1) $len }}
{{- if not $field.IsReserved }}
{{- if eq $i 0 }}case {{ end }}"{{$object.Name}}.{{$field.Name}}"{{ if not $last }},{{ else }}:
if e.complexity.{{$object.Name|go}}.{{$field.GoFieldName}} == nil {
break
}
{{ if $field.Args }}
args, err := ec.{{ $field.ArgsFunc }}(context.TODO(),rawArgs)
if err != nil {
return 0, false
}
{{ end }}
return e.complexity.{{$object.Name|go}}.{{$field.GoFieldName}}(childComplexity{{if $field.Args}}, {{$field.ComplexityArgs}} {{ end }}), true
{{ end }}
return e.complexity.{{$object.Name|go}}.{{$field.GoFieldName}}(childComplexity{{if $field.Args}}, {{$field.ComplexityArgs}} {{end}}), true
{{ end }}
{{- end }}
{{- end }}
{{ end }}
{{ end }}
{{ end }}
Expand Down
71 changes: 71 additions & 0 deletions codegen/testserver/complexity_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,5 +42,76 @@ func TestComplexityCollisions(t *testing.T) {
require.Equal(t, 2, resp.Overlapping.OldFoo)
require.Equal(t, 3, resp.Overlapping.NewFoo)
require.Equal(t, 3, resp.Overlapping.New_foo)
}

func TestComplexityFuncs(t *testing.T) {
resolvers := &Stub{}
cfg := Config{Resolvers: resolvers}
cfg.Complexity.OverlappingFields.Foo = func(childComplexity int) int { return 1000 }
cfg.Complexity.OverlappingFields.NewFoo = func(childComplexity int) int { return 5 }

srv := httptest.NewServer(handler.GraphQL(NewExecutableSchema(cfg), handler.ComplexityLimit(10)))
c := client.New(srv.URL)

resolvers.QueryResolver.Overlapping = func(ctx context.Context) (fields *OverlappingFields, e error) {
return &OverlappingFields{
Foo: 2,
NewFoo: 3,
}, nil
}

t.Run("with high complexity limit will not run", func(t *testing.T) {
ran := false
resolvers.OverlappingFieldsResolver.OldFoo = func(ctx context.Context, obj *OverlappingFields) (i int, e error) {
ran = true
return obj.Foo, nil
}

var resp struct {
Overlapping interface{}
}
err := c.Post(`query { overlapping { oneFoo, twoFoo, oldFoo, newFoo, new_foo } }`, &resp)

require.EqualError(t, err, `http 422: {"errors":[{"message":"operation has complexity 2012, which exceeds the limit of 10"}],"data":null}`)
require.False(t, ran)
})

t.Run("with low complexity will run", func(t *testing.T) {
ran := false
resolvers.QueryResolver.Overlapping = func(ctx context.Context) (fields *OverlappingFields, e error) {
ran = true
return &OverlappingFields{
Foo: 2,
NewFoo: 3,
}, nil
}

var resp struct {
Overlapping interface{}
}
c.MustPost(`query { overlapping { newFoo } }`, &resp)

require.True(t, ran)
})

t.Run("with multiple low complexity will not run", func(t *testing.T) {
ran := false
resolvers.QueryResolver.Overlapping = func(ctx context.Context) (fields *OverlappingFields, e error) {
ran = true
return &OverlappingFields{
Foo: 2,
NewFoo: 3,
}, nil
}

var resp interface{}
err := c.Post(`query {
a: overlapping { newFoo },
b: overlapping { newFoo },
c: overlapping { newFoo },
}`, &resp)

require.EqualError(t, err, `http 422: {"errors":[{"message":"operation has complexity 18, which exceeds the limit of 10"}],"data":null}`)
require.False(t, ran)
})
}
Loading

0 comments on commit 02e9dd8

Please sign in to comment.