From bc98156929b06eeca1d95f475501470bd9034c2d Mon Sep 17 00:00:00 2001 From: Adam Date: Tue, 5 Nov 2019 15:18:04 +1100 Subject: [PATCH] Combine root handlers in ExecutableSchema into a single Exec method --- codegen/field.gotpl | 7 +- codegen/generated!.gotpl | 122 +++--- codegen/testserver/introspection_test.go | 2 +- codegen/testserver/panics_test.go | 3 +- codegen/testserver/response_extension_test.go | 8 +- codegen/testserver/tracer_test.go | 346 ------------------ codegen/type.gotpl | 4 +- complexity/complexity_test.go | 53 +-- graphql/handler/executor.go | 12 +- graphql/handler/server.go | 10 + graphql/handler/transport/websocket.go | 7 +- handler/handler.go | 7 +- 12 files changed, 98 insertions(+), 483 deletions(-) delete mode 100644 codegen/testserver/tracer_test.go diff --git a/codegen/field.gotpl b/codegen/field.gotpl index f0bf5db7629..97ec20a684c 100644 --- a/codegen/field.gotpl +++ b/codegen/field.gotpl @@ -5,13 +5,12 @@ func (ec *executionContext) _{{$object.Name}}_{{$field.Name}}(ctx context.Contex {{- if $object.Stream }} {{- $null = "nil" }} {{- end }} - ctx = ec.Tracer.StartFieldExecution(ctx, field) + defer func () { if r := recover(); r != nil { ec.Error(ctx, ec.Recover(ctx, r)) ret = {{ $null }} } - ec.Tracer.EndFieldExecution(ctx) }() rctx := &graphql.ResolverContext{ Object: {{$object.Name|quote}}, @@ -29,7 +28,6 @@ func (ec *executionContext) _{{$object.Name}}_{{$field.Name}}(ctx context.Contex } rctx.Args = args {{- end }} - ctx = ec.Tracer.StartFieldResolverExecution(ctx, rctx) {{- if $.Directives.LocationDirectives "FIELD" }} resTmp := ec._fieldMiddleware(ctx, {{if $object.Root}}nil{{else}}obj{{end}}, func(rctx context.Context) (interface{}, error) { {{ template "field" $field }} @@ -45,7 +43,7 @@ func (ec *executionContext) _{{$object.Name}}_{{$field.Name}}(ctx context.Contex {{- end }} if resTmp == nil { {{- if $field.TypeReference.GQL.NonNull }} - if !ec.HasError(rctx) { + if !graphql.HasFieldError(ctx, rctx) { ec.Errorf(ctx, "must not be null") } {{- end }} @@ -68,7 +66,6 @@ func (ec *executionContext) _{{$object.Name}}_{{$field.Name}}(ctx context.Contex {{- else }} res := resTmp.({{$field.TypeReference.GO | ref}}) rctx.Result = res - ctx = ec.Tracer.StartFieldChildExecution(ctx) return ec.{{ $field.TypeReference.MarshalFunc }}(ctx, field.Selections, res) {{- end }} } diff --git a/codegen/generated!.gotpl b/codegen/generated!.gotpl index a95e57b625c..59bf9173ca4 100644 --- a/codegen/generated!.gotpl +++ b/codegen/generated!.gotpl @@ -112,101 +112,79 @@ func (e *executableSchema) Complexity(typeName, field string, childComplexity in return 0, false } -func (e *executableSchema) Query(ctx context.Context, op *ast.OperationDefinition) *graphql.Response { - {{- if .QueryRoot }} - ec := executionContext{graphql.GetRequestContext(ctx), e} - - buf := ec.RequestMiddleware(ctx, func(ctx context.Context) []byte { - {{ if .Directives.LocationDirectives "QUERY" -}} - data := ec._queryMiddleware(ctx, op, func(ctx context.Context) (interface{}, error){ - return ec._{{.QueryRoot.Name}}(ctx, op.SelectionSet), nil - }) - {{- else -}} - data := ec._{{.QueryRoot.Name}}(ctx, op.SelectionSet) - {{- end }} +func (e *executableSchema) Exec(ctx context.Context) graphql.ResponseHandler { + rc := graphql.GetRequestContext(ctx) + ec := executionContext{rc, e} + first := true + + switch rc.Operation.Operation { + {{- if .QueryRoot }} case ast.Query: + return func(ctx context.Context) *graphql.Response { + if !first { return nil } + first = false + {{ if .Directives.LocationDirectives "QUERY" -}} + data := ec._queryMiddleware(ctx, rc.Operation, func(ctx context.Context) (interface{}, error){ + return ec._{{.QueryRoot.Name}}(ctx, rc.Operation.SelectionSet), nil + }) + {{- else -}} + data := ec._{{.QueryRoot.Name}}(ctx, rc.Operation.SelectionSet) + {{- end }} var buf bytes.Buffer data.MarshalGQL(&buf) - return buf.Bytes() - }) - return &graphql.Response{ - Data: buf, - Errors: ec.Errors, - Extensions: ec.Extensions, + return &graphql.Response{ + Data: buf.Bytes(), + } } - {{- else }} - return graphql.ErrorResponse(ctx, "queries are not supported") - {{- end }} -} - -func (e *executableSchema) Mutation(ctx context.Context, op *ast.OperationDefinition) *graphql.Response { - {{- if .MutationRoot }} - ec := executionContext{graphql.GetRequestContext(ctx), e} + {{ end }} - buf := ec.RequestMiddleware(ctx, func(ctx context.Context) []byte { - {{ if .Directives.LocationDirectives "MUTATION" -}} - data := ec._mutationMiddleware(ctx, op, func(ctx context.Context) (interface{}, error){ - return ec._{{.MutationRoot.Name}}(ctx, op.SelectionSet), nil - }) - {{- else -}} - data := ec._{{.MutationRoot.Name}}(ctx, op.SelectionSet) - {{- end }} + {{- if .MutationRoot }} case ast.Mutation: + return func(ctx context.Context) *graphql.Response { + if !first { return nil } + first = false + {{ if .Directives.LocationDirectives "MUTATION" -}} + data := ec._mutationMiddleware(ctx, rc.Operation, func(ctx context.Context) (interface{}, error){ + return ec._{{.MutationRoot.Name}}(ctx, rc.Operation.SelectionSet), nil + }) + {{- else -}} + data := ec._{{.MutationRoot.Name}}(ctx, rc.Operation.SelectionSet) + {{- end }} var buf bytes.Buffer data.MarshalGQL(&buf) - return buf.Bytes() - }) - return &graphql.Response{ - Data: buf, - Errors: ec.Errors, - Extensions: ec.Extensions, + return &graphql.Response{ + Data: buf.Bytes(), + } } - {{- else }} - return graphql.ErrorResponse(ctx, "mutations are not supported") - {{- end }} -} - -func (e *executableSchema) Subscription(ctx context.Context, op *ast.OperationDefinition) func() *graphql.Response { - {{- if .SubscriptionRoot }} - ec := executionContext{graphql.GetRequestContext(ctx), e} + {{ end }} + {{- if .SubscriptionRoot }} case ast.Subscription: {{ if .Directives.LocationDirectives "SUBSCRIPTION" -}} - next := ec._subscriptionMiddleware(ctx, op, func(ctx context.Context) (interface{}, error){ - return ec._{{.SubscriptionRoot.Name}}(ctx, op.SelectionSet),nil + next := ec._subscriptionMiddleware(ctx, rc.Operation, func(ctx context.Context) (interface{}, error){ + return ec._{{.SubscriptionRoot.Name}}(ctx, rc.Operation.SelectionSet),nil }) {{- else -}} - next := ec._{{.SubscriptionRoot.Name}}(ctx, op.SelectionSet) + next := ec._{{.SubscriptionRoot.Name}}(ctx, rc.Operation.SelectionSet) {{- end }} - if ec.Errors != nil { - return graphql.OneShot(&graphql.Response{Data: []byte("null"), Errors: ec.Errors}) - } var buf bytes.Buffer - return func() *graphql.Response { - buf := ec.RequestMiddleware(ctx, func(ctx context.Context) []byte { - buf.Reset() - data := next() - - if data == nil { - return nil - } - data.MarshalGQL(&buf) - return buf.Bytes() - }) + return func(ctx context.Context) *graphql.Response { + buf.Reset() + data := next() - if buf == nil { + if data == nil { return nil } + data.MarshalGQL(&buf) return &graphql.Response{ - Data: buf, - Errors: ec.Errors, - Extensions: ec.Extensions, + Data: buf.Bytes(), } } - {{- else }} - return graphql.OneShot(graphql.ErrorResponse(ctx, "subscriptions are not supported")) - {{- end }} + {{ end }} + default: + return graphql.OneShot(graphql.ErrorResponse(ctx, "unsupported GraphQL operation")) + } } type executionContext struct { diff --git a/codegen/testserver/introspection_test.go b/codegen/testserver/introspection_test.go index 97178798665..c33ea6e5bdd 100644 --- a/codegen/testserver/introspection_test.go +++ b/codegen/testserver/introspection_test.go @@ -67,7 +67,7 @@ func TestIntrospection(t *testing.T) { c := client.New(handler.GraphQL( NewExecutableSchema(Config{Resolvers: resolvers}), - handler.RequestMiddleware(func(ctx context.Context, next func(ctx context.Context) []byte) []byte { + handler.RequestMiddleware(func(ctx context.Context, next graphql.ResponseHandler) *graphql.Response { graphql.GetRequestContext(ctx).DisableIntrospection = true return next(ctx) diff --git a/codegen/testserver/panics_test.go b/codegen/testserver/panics_test.go index 92319ed4abb..f5de5ab4715 100644 --- a/codegen/testserver/panics_test.go +++ b/codegen/testserver/panics_test.go @@ -4,10 +4,9 @@ import ( "context" "testing" - "github.com/stretchr/testify/require" - "github.com/99designs/gqlgen/client" "github.com/99designs/gqlgen/handler" + "github.com/stretchr/testify/require" ) func TestPanics(t *testing.T) { diff --git a/codegen/testserver/response_extension_test.go b/codegen/testserver/response_extension_test.go index 985c7aa09e4..198ea08e561 100644 --- a/codegen/testserver/response_extension_test.go +++ b/codegen/testserver/response_extension_test.go @@ -18,11 +18,9 @@ func TestResponseExtension(t *testing.T) { srv := handler.GraphQL( NewExecutableSchema(Config{Resolvers: resolvers}), - handler.RequestMiddleware(func(ctx context.Context, next func(ctx context.Context) []byte) []byte { - rctx := graphql.GetRequestContext(ctx) - if err := rctx.RegisterExtension("example", "value"); err != nil { - panic(err) - } + handler.RequestMiddleware(func(ctx context.Context, next graphql.ResponseHandler) *graphql.Response { + graphql.RegisterExtension(ctx, "example", "value") + return next(ctx) }), ) diff --git a/codegen/testserver/tracer_test.go b/codegen/testserver/tracer_test.go deleted file mode 100644 index 07c39b080f5..00000000000 --- a/codegen/testserver/tracer_test.go +++ /dev/null @@ -1,346 +0,0 @@ -package testserver - -import ( - "context" - "fmt" - "sync" - "testing" - - "github.com/99designs/gqlgen/client" - "github.com/99designs/gqlgen/graphql" - "github.com/99designs/gqlgen/handler" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestTracer(t *testing.T) { - resolvers := &Stub{} - resolvers.QueryResolver.User = func(ctx context.Context, id int) (user *User, e error) { - return &User{ID: 1}, nil - } - t.Run("called in the correct order", func(t *testing.T) { - var tracerLog []string - var mu sync.Mutex - - srv := handler.GraphQL( - NewExecutableSchema(Config{Resolvers: resolvers}), - handler.ResolverMiddleware(func(ctx context.Context, next graphql.Resolver) (res interface{}, err error) { - path, _ := ctx.Value("path").([]int) - return next(context.WithValue(ctx, "path", append(path, 1))) - }), - handler.ResolverMiddleware(func(ctx context.Context, next graphql.Resolver) (res interface{}, err error) { - path, _ := ctx.Value("path").([]int) - return next(context.WithValue(ctx, "path", append(path, 2))) - }), - handler.Tracer(&testTracer{ - id: 1, - append: func(s string) { - mu.Lock() - defer mu.Unlock() - tracerLog = append(tracerLog, s) - }, - }), - handler.Tracer(&testTracer{ - id: 2, - append: func(s string) { - mu.Lock() - defer mu.Unlock() - tracerLog = append(tracerLog, s) - }, - }), - ) - c := client.New(srv) - - var resp struct { - User struct { - ID int - Friends []struct { - ID int - } - } - } - - called := false - resolvers.UserResolver.Friends = func(ctx context.Context, obj *User) ([]*User, error) { - assert.Equal(t, []string{ - "op:p:start:1", "op:p:start:2", - "op:v:start:1", "op:v:start:2", - "op:e:start:1", "op:e:start:2", - "field'a:e:start:1:user", "field'a:e:start:2:user", - "field'b:e:start:1:[user]", "field'b:e:start:2:[user]", - "field'c:e:start:1", "field'c:e:start:2", - "field'a:e:start:1:friends", "field'a:e:start:2:friends", - "field'b:e:start:1:[user friends]", "field'b:e:start:2:[user friends]", - }, ctx.Value("tracer")) - called = true - return []*User{}, nil - } - - err := c.Post(`query { user(id: 1) { id, friends { id } } }`, &resp) - - require.NoError(t, err) - require.True(t, called) - mu.Lock() - defer mu.Unlock() - assert.Equal(t, []string{ - "op:p:start:1", "op:p:start:2", - "op:p:end:2", "op:p:end:1", - - "op:v:start:1", "op:v:start:2", - "op:v:end:2", "op:v:end:1", - - "op:e:start:1", "op:e:start:2", - - "field'a:e:start:1:user", "field'a:e:start:2:user", - "field'b:e:start:1:[user]", "field'b:e:start:2:[user]", - "field'c:e:start:1", "field'c:e:start:2", - "field'a:e:start:1:id", "field'a:e:start:2:id", - "field'b:e:start:1:[user id]", "field'b:e:start:2:[user id]", - "field'c:e:start:1", "field'c:e:start:2", - "field:e:end:2", "field:e:end:1", - "field'a:e:start:1:friends", "field'a:e:start:2:friends", - "field'b:e:start:1:[user friends]", "field'b:e:start:2:[user friends]", - "field'c:e:start:1", "field'c:e:start:2", - "field:e:end:2", "field:e:end:1", - "field:e:end:2", "field:e:end:1", - - "op:e:end:2", "op:e:end:1", - }, tracerLog) - }) - - t.Run("take ctx over from prev step", func(t *testing.T) { - - configurableTracer := &configurableTracer{ - StartOperationParsingCallback: func(ctx context.Context) context.Context { - return context.WithValue(ctx, "StartOperationParsing", true) - }, - EndOperationParsingCallback: func(ctx context.Context) { - assert.NotNil(t, ctx.Value("StartOperationParsing")) - }, - - StartOperationValidationCallback: func(ctx context.Context) context.Context { - return context.WithValue(ctx, "StartOperationValidation", true) - }, - EndOperationValidationCallback: func(ctx context.Context) { - assert.NotNil(t, ctx.Value("StartOperationParsing")) - assert.NotNil(t, ctx.Value("StartOperationValidation")) - }, - - StartOperationExecutionCallback: func(ctx context.Context) context.Context { - return context.WithValue(ctx, "StartOperationExecution", true) - }, - StartFieldExecutionCallback: func(ctx context.Context, field graphql.CollectedField) context.Context { - return context.WithValue(ctx, "StartFieldExecution", true) - }, - StartFieldResolverExecutionCallback: func(ctx context.Context, rc *graphql.ResolverContext) context.Context { - return context.WithValue(ctx, "StartFieldResolverExecution", true) - }, - StartFieldChildExecutionCallback: func(ctx context.Context) context.Context { - return context.WithValue(ctx, "StartFieldChildExecution", true) - }, - EndFieldExecutionCallback: func(ctx context.Context) { - assert.NotNil(t, ctx.Value("StartOperationParsing")) - assert.NotNil(t, ctx.Value("StartOperationValidation")) - assert.NotNil(t, ctx.Value("StartOperationExecution")) - assert.NotNil(t, ctx.Value("StartFieldExecution")) - assert.NotNil(t, ctx.Value("StartFieldResolverExecution")) - assert.NotNil(t, ctx.Value("StartFieldChildExecution")) - }, - - EndOperationExecutionCallback: func(ctx context.Context) { - assert.NotNil(t, ctx.Value("StartOperationParsing")) - assert.NotNil(t, ctx.Value("StartOperationValidation")) - assert.NotNil(t, ctx.Value("StartOperationExecution")) - }, - } - - c := client.New(handler.GraphQL( - NewExecutableSchema(Config{Resolvers: resolvers}), - handler.Tracer(configurableTracer), - )) - - var resp struct { - User struct { - ID int - Friends []struct { - ID int - } - } - } - - called := false - resolvers.UserResolver.Friends = func(ctx context.Context, obj *User) ([]*User, error) { - called = true - return []*User{}, nil - } - - err := c.Post(`query { user(id: 1) { id, friends { id } } }`, &resp) - - require.NoError(t, err) - require.True(t, called) - }) -} - -var _ graphql.Tracer = (*configurableTracer)(nil) - -type configurableTracer struct { - StartOperationParsingCallback func(ctx context.Context) context.Context - EndOperationParsingCallback func(ctx context.Context) - StartOperationValidationCallback func(ctx context.Context) context.Context - EndOperationValidationCallback func(ctx context.Context) - StartOperationExecutionCallback func(ctx context.Context) context.Context - StartFieldExecutionCallback func(ctx context.Context, field graphql.CollectedField) context.Context - StartFieldResolverExecutionCallback func(ctx context.Context, rc *graphql.ResolverContext) context.Context - StartFieldChildExecutionCallback func(ctx context.Context) context.Context - EndFieldExecutionCallback func(ctx context.Context) - EndOperationExecutionCallback func(ctx context.Context) -} - -func (ct *configurableTracer) StartOperationParsing(ctx context.Context) context.Context { - if f := ct.StartOperationParsingCallback; f != nil { - ctx = f(ctx) - } - return ctx -} - -func (ct *configurableTracer) EndOperationParsing(ctx context.Context) { - if f := ct.EndOperationParsingCallback; f != nil { - f(ctx) - } -} - -func (ct *configurableTracer) StartOperationValidation(ctx context.Context) context.Context { - if f := ct.StartOperationValidationCallback; f != nil { - ctx = f(ctx) - } - return ctx -} - -func (ct *configurableTracer) EndOperationValidation(ctx context.Context) { - if f := ct.EndOperationValidationCallback; f != nil { - f(ctx) - } -} - -func (ct *configurableTracer) StartOperationExecution(ctx context.Context) context.Context { - if f := ct.StartOperationExecutionCallback; f != nil { - ctx = f(ctx) - } - return ctx -} - -func (ct *configurableTracer) StartFieldExecution(ctx context.Context, field graphql.CollectedField) context.Context { - if f := ct.StartFieldExecutionCallback; f != nil { - ctx = f(ctx, field) - } - return ctx -} - -func (ct *configurableTracer) StartFieldResolverExecution(ctx context.Context, rc *graphql.ResolverContext) context.Context { - if f := ct.StartFieldResolverExecutionCallback; f != nil { - ctx = f(ctx, rc) - } - return ctx -} - -func (ct *configurableTracer) StartFieldChildExecution(ctx context.Context) context.Context { - if f := ct.StartFieldChildExecutionCallback; f != nil { - ctx = f(ctx) - } - return ctx -} - -func (ct *configurableTracer) EndFieldExecution(ctx context.Context) { - if f := ct.EndFieldExecutionCallback; f != nil { - f(ctx) - } -} - -func (ct *configurableTracer) EndOperationExecution(ctx context.Context) { - if f := ct.EndOperationExecutionCallback; f != nil { - f(ctx) - } -} - -var _ graphql.Tracer = (*testTracer)(nil) - -type testTracer struct { - id int - append func(string) -} - -func (tt *testTracer) StartOperationParsing(ctx context.Context) context.Context { - line := fmt.Sprintf("op:p:start:%d", tt.id) - - tracerLogs, _ := ctx.Value("tracer").([]string) - ctx = context.WithValue(ctx, "tracer", append(append([]string{}, tracerLogs...), line)) - tt.append(line) - - return ctx -} - -func (tt *testTracer) EndOperationParsing(ctx context.Context) { - tt.append(fmt.Sprintf("op:p:end:%d", tt.id)) -} - -func (tt *testTracer) StartOperationValidation(ctx context.Context) context.Context { - line := fmt.Sprintf("op:v:start:%d", tt.id) - - tracerLogs, _ := ctx.Value("tracer").([]string) - ctx = context.WithValue(ctx, "tracer", append(append([]string{}, tracerLogs...), line)) - tt.append(line) - - return ctx -} - -func (tt *testTracer) EndOperationValidation(ctx context.Context) { - tt.append(fmt.Sprintf("op:v:end:%d", tt.id)) -} - -func (tt *testTracer) StartOperationExecution(ctx context.Context) context.Context { - line := fmt.Sprintf("op:e:start:%d", tt.id) - - tracerLogs, _ := ctx.Value("tracer").([]string) - ctx = context.WithValue(ctx, "tracer", append(append([]string{}, tracerLogs...), line)) - tt.append(line) - - return ctx -} - -func (tt *testTracer) StartFieldExecution(ctx context.Context, field graphql.CollectedField) context.Context { - line := fmt.Sprintf("field'a:e:start:%d:%s", tt.id, field.Name) - - tracerLogs, _ := ctx.Value("tracer").([]string) - ctx = context.WithValue(ctx, "tracer", append(append([]string{}, tracerLogs...), line)) - tt.append(line) - - return ctx -} - -func (tt *testTracer) StartFieldResolverExecution(ctx context.Context, rc *graphql.ResolverContext) context.Context { - line := fmt.Sprintf("field'b:e:start:%d:%v", tt.id, rc.Path()) - - tracerLogs, _ := ctx.Value("tracer").([]string) - ctx = context.WithValue(ctx, "tracer", append(append([]string{}, tracerLogs...), line)) - tt.append(line) - - return ctx -} - -func (tt *testTracer) StartFieldChildExecution(ctx context.Context) context.Context { - line := fmt.Sprintf("field'c:e:start:%d", tt.id) - - tracerLogs, _ := ctx.Value("tracer").([]string) - ctx = context.WithValue(ctx, "tracer", append(append([]string{}, tracerLogs...), line)) - tt.append(line) - - return ctx -} - -func (tt *testTracer) EndFieldExecution(ctx context.Context) { - tt.append(fmt.Sprintf("field:e:end:%d", tt.id)) -} - -func (tt *testTracer) EndOperationExecution(ctx context.Context) { - tt.append(fmt.Sprintf("op:e:end:%d", tt.id)) -} diff --git a/codegen/type.gotpl b/codegen/type.gotpl index cb2782c39ed..11425c30f2e 100644 --- a/codegen/type.gotpl +++ b/codegen/type.gotpl @@ -50,7 +50,7 @@ {{- if $type.IsNilable }} if v == nil { {{- if $type.GQL.NonNull }} - if !ec.HasError(graphql.GetResolverContext(ctx)) { + if !graphql.HasFieldError(ctx, graphql.GetResolverContext(ctx)) { ec.Errorf(ctx, "must not be null") } {{- end }} @@ -113,7 +113,7 @@ {{- else if $type.GQL.NonNull }} res := {{ $type.Marshaler | call }}({{- if $type.CastType }}{{ $type.CastType | ref }}(v){{else}}v{{- end }}) if res == graphql.Null { - if !ec.HasError(graphql.GetResolverContext(ctx)) { + if !graphql.HasFieldError(ctx, graphql.GetResolverContext(ctx)) { ec.Errorf(ctx, "must not be null") } } diff --git a/complexity/complexity_test.go b/complexity/complexity_test.go index 93b61980992..a7a2f25565b 100644 --- a/complexity/complexity_test.go +++ b/complexity/complexity_test.go @@ -1,7 +1,6 @@ package complexity import ( - "context" "math" "testing" @@ -50,7 +49,24 @@ var schema = gqlparser.MustLoadSchema( func requireComplexity(t *testing.T, source string, complexity int) { t.Helper() query := gqlparser.MustLoadQuery(schema, source) - es := &executableSchemaStub{} + + es := &graphql.ExecutableSchemaMock{ + ComplexityFunc: func(typeName, field string, childComplexity int, args map[string]interface{}) (int, bool) { + switch typeName + "." + field { + case "ExpensiveItem.name": + return 5, true + case "Query.list", "Item.list": + return int(args["size"].(int64)) * childComplexity, true + case "Query.customObject": + return 1, true + } + return 0, false + }, + SchemaFunc: func() *ast.Schema { + return schema + }, + } + actualComplexity := Calculate(es, query.Operations[0], nil) require.Equal(t, complexity, actualComplexity) } @@ -197,36 +213,3 @@ func TestCalculate(t *testing.T) { requireComplexity(t, query, math.MaxInt64) }) } - -type executableSchemaStub struct { -} - -var _ graphql.ExecutableSchema = &executableSchemaStub{} - -func (e *executableSchemaStub) Schema() *ast.Schema { - return schema -} - -func (e *executableSchemaStub) Complexity(typeName, field string, childComplexity int, args map[string]interface{}) (int, bool) { - switch typeName + "." + field { - case "ExpensiveItem.name": - return 5, true - case "Query.list", "Item.list": - return int(args["size"].(int64)) * childComplexity, true - case "Query.customObject": - return 1, true - } - return 0, false -} - -func (e *executableSchemaStub) Query(ctx context.Context, op *ast.OperationDefinition) *graphql.Response { - panic("Query should never be called by complexity calculations") -} - -func (e *executableSchemaStub) Mutation(ctx context.Context, op *ast.OperationDefinition) *graphql.Response { - panic("Mutation should never be called by complexity calculations") -} - -func (e *executableSchemaStub) Subscription(ctx context.Context, op *ast.OperationDefinition) func() *graphql.Response { - panic("Subscription should never be called by complexity calculations") -} diff --git a/graphql/handler/executor.go b/graphql/handler/executor.go index bdb639026db..dca3ecf92d1 100644 --- a/graphql/handler/executor.go +++ b/graphql/handler/executor.go @@ -80,12 +80,18 @@ func newExecutor(s *Server) executor { return e } -func (e executor) DispatchRequest(ctx context.Context, rc *graphql.RequestContext) (graphql.ResponseHandler, context.Context) { +func (e executor) DispatchRequest(ctx context.Context, rc *graphql.RequestContext) (h graphql.ResponseHandler, resctx context.Context) { + ctx = graphql.WithRequestContext(ctx, rc) + var innerCtx context.Context - res := e.operationMiddleware(graphql.WithRequestContext(ctx, rc), func(ctx context.Context) graphql.ResponseHandler { + res := e.operationMiddleware(ctx, func(ctx context.Context) graphql.ResponseHandler { innerCtx = ctx - responses := e.server.es.Exec(ctx) + tmpResponseContext := graphql.WithResponseContext(ctx, e.server.errorPresenter, e.server.recoverFunc) + responses := e.server.es.Exec(tmpResponseContext) + if errs := graphql.GetErrors(tmpResponseContext); errs != nil { + return graphql.OneShot(&graphql.Response{Errors: errs}) + } return func(ctx context.Context) *graphql.Response { ctx = graphql.WithResponseContext(ctx, e.server.errorPresenter, e.server.recoverFunc) diff --git a/graphql/handler/server.go b/graphql/handler/server.go index 0e3d44fb0a5..7c6517e605d 100644 --- a/graphql/handler/server.go +++ b/graphql/handler/server.go @@ -90,6 +90,16 @@ func (s *Server) getTransport(r *http.Request) graphql.Transport { } func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { + defer func() { + if err := recover(); err != nil { + err := s.errorPresenter(r.Context(), s.recoverFunc(r.Context(), err)) + resp := &graphql.Response{Errors: []*gqlerror.Error{err}} + b, _ := json.Marshal(resp) + w.WriteHeader(http.StatusUnprocessableEntity) + w.Write(b) + } + }() + r = r.WithContext(graphql.StartOperationTrace(r.Context())) transport := s.getTransport(r) diff --git a/graphql/handler/transport/websocket.go b/graphql/handler/transport/websocket.go index a5fba9ff685..4a7764f86d6 100644 --- a/graphql/handler/transport/websocket.go +++ b/graphql/handler/transport/websocket.go @@ -229,11 +229,6 @@ func (c *wsConnection) subscribe(message *operationMessage) bool { break } - msgType := dataMsg - if len(response.Errors) > 0 { - msgType = errorMsg - } - b, err := json.Marshal(response) if err != nil { panic(err) @@ -241,7 +236,7 @@ func (c *wsConnection) subscribe(message *operationMessage) bool { c.write(&operationMessage{ Payload: b, ID: message.ID, - Type: msgType, + Type: dataMsg, }) } c.write(&operationMessage{ID: message.ID, Type: completeMsg}) diff --git a/handler/handler.go b/handler/handler.go index 47b26b69f99..13720835f8a 100644 --- a/handler/handler.go +++ b/handler/handler.go @@ -45,12 +45,7 @@ func GraphQL(exec graphql.ExecutableSchema, options ...Option) http.HandlerFunc srv.SetErrorPresenter(cfg.errorPresenter) } for _, hook := range cfg.fieldHooks { - srv.AroundFields(func(ctx context.Context, next graphql.Resolver) (res interface{}, err error) { - if !graphql.GetResolverContext(ctx).IsMethod { - return next(ctx) - } - return hook(ctx, next) - }) + srv.AroundFields(hook) } for _, hook := range cfg.requestHooks { srv.AroundResponses(hook)