From 59493aff86020d170e58900654d334f5ebc2ceee Mon Sep 17 00:00:00 2001 From: roeest Date: Mon, 12 Sep 2022 20:39:30 -0500 Subject: [PATCH] fix: apollo federation tracer was race prone (#2366) The tracer was using a global state across different goroutines Added req headers to operation context to allow it to be fetched in InterceptOperation --- graphql/context_operation.go | 2 + graphql/executor/executor.go | 1 + .../apollofederatedtracingv1/tracing.go | 68 +++++++++++-------- .../apollofederatedtracingv1/tracing_test.go | 53 +++++++++------ 4 files changed, 73 insertions(+), 51 deletions(-) diff --git a/graphql/context_operation.go b/graphql/context_operation.go index bfbbc5c002f..0518ecc6ba2 100644 --- a/graphql/context_operation.go +++ b/graphql/context_operation.go @@ -3,6 +3,7 @@ package graphql import ( "context" "errors" + "net/http" "github.com/vektah/gqlparser/v2/ast" ) @@ -15,6 +16,7 @@ type OperationContext struct { Variables map[string]interface{} OperationName string Doc *ast.QueryDocument + Headers http.Header Operation *ast.OperationDefinition DisableIntrospection bool diff --git a/graphql/executor/executor.go b/graphql/executor/executor.go index d036ea6ff62..6bfb698f5fd 100644 --- a/graphql/executor/executor.go +++ b/graphql/executor/executor.go @@ -58,6 +58,7 @@ func (e *Executor) CreateOperationContext(ctx context.Context, params *graphql.R rc.RawQuery = params.Query rc.OperationName = params.OperationName + rc.Headers = params.Headers var listErr gqlerror.List rc.Doc, listErr = e.parseQuery(ctx, &rc.Stats, params.Query) diff --git a/graphql/handler/apollofederatedtracingv1/tracing.go b/graphql/handler/apollofederatedtracingv1/tracing.go index 46be2a0ec2b..689c3c0640a 100644 --- a/graphql/handler/apollofederatedtracingv1/tracing.go +++ b/graphql/handler/apollofederatedtracingv1/tracing.go @@ -6,18 +6,21 @@ import ( "fmt" "github.com/99designs/gqlgen/graphql" - "github.com/vektah/gqlparser/v2/gqlerror" "google.golang.org/protobuf/proto" ) type ( Tracer struct { - ClientName string - Version string - Hostname string - TreeBuilder *TreeBuilder - ShouldTrace bool + ClientName string + Version string + Hostname string } + + treeBuilderKey string +) + +const ( + key = treeBuilderKey("treeBuilder") ) var _ interface { @@ -25,7 +28,6 @@ var _ interface { graphql.ResponseInterceptor graphql.FieldInterceptor graphql.OperationInterceptor - graphql.OperationParameterMutator } = &Tracer{} // ExtensionName returns the name of the extension @@ -38,30 +40,38 @@ func (Tracer) Validate(graphql.ExecutableSchema) error { return nil } -func (t *Tracer) MutateOperationParameters(ctx context.Context, request *graphql.RawParams) *gqlerror.Error { - t.ShouldTrace = request.Headers.Get("apollo-federation-include-trace") == "ftv1" // check for header +func (t *Tracer) shouldTrace(ctx context.Context) bool { + return graphql.GetOperationContext(ctx).Headers.Get("apollo-federation-include-trace") == "ftv1" +} + +func (t *Tracer) getTreeBuilder(ctx context.Context) *TreeBuilder { + val := ctx.Value(key) + if val == nil { + return nil + } + if tb, ok := val.(*TreeBuilder); ok { + return tb + } return nil } // InterceptOperation acts on each Graph operation; on each operation, start a tree builder and start the tree's timer for tracing func (t *Tracer) InterceptOperation(ctx context.Context, next graphql.OperationHandler) graphql.ResponseHandler { - if !t.ShouldTrace { + if !t.shouldTrace(ctx) { return next(ctx) } - - t.TreeBuilder = NewTreeBuilder() - - return next(ctx) + return next(context.WithValue(ctx, key, NewTreeBuilder())) } // InterceptField is called on each field's resolution, including information about the path and parent node. // This information is then used to build the relevant Node Tree used in the FTV1 tracing format func (t *Tracer) InterceptField(ctx context.Context, next graphql.Resolver) (interface{}, error) { - if !t.ShouldTrace { + if !t.shouldTrace(ctx) { return next(ctx) } - - t.TreeBuilder.WillResolveField(ctx) + if tb := t.getTreeBuilder(ctx); tb != nil { + tb.WillResolveField(ctx) + } return next(ctx) } @@ -69,30 +79,30 @@ func (t *Tracer) InterceptField(ctx context.Context, next graphql.Resolver) (int // InterceptResponse is called before the overall response is sent, but before each field resolves; as a result // the final marshaling is deferred to happen at the end of the operation func (t *Tracer) InterceptResponse(ctx context.Context, next graphql.ResponseHandler) *graphql.Response { - if !t.ShouldTrace { + if !t.shouldTrace(ctx) { return next(ctx) } - t.TreeBuilder.StartTimer(ctx) + tb := t.getTreeBuilder(ctx) + if tb != nil { + tb.StartTimer(ctx) + } - // because we need to update the ftv1 string at a later time (as fields resolve before the response is sent), - // we instantiate the string and use a pointer to be able to update later - var ftv1 string - graphql.RegisterExtension(ctx, "ftv1", &ftv1) + val := new(string) + graphql.RegisterExtension(ctx, "ftv1", val) // now that fields have finished resolving, it stops the timer to calculate trace duration - defer func() { - t.TreeBuilder.StopTimer(ctx) + defer func(val *string) { + tb.StopTimer(ctx) // marshal the protobuf ... - p, err := proto.Marshal(t.TreeBuilder.Trace) + p, err := proto.Marshal(tb.Trace) if err != nil { fmt.Print(err) } // ... then set the previously instantiated string as the base64 formatted string as required - ftv1 = base64.StdEncoding.EncodeToString(p) - }() - + *val = base64.StdEncoding.EncodeToString(p) + }(val) resp := next(ctx) return resp } diff --git a/graphql/handler/apollofederatedtracingv1/tracing_test.go b/graphql/handler/apollofederatedtracingv1/tracing_test.go index 625392c4241..61b8b37c4ec 100644 --- a/graphql/handler/apollofederatedtracingv1/tracing_test.go +++ b/graphql/handler/apollofederatedtracingv1/tracing_test.go @@ -8,9 +8,7 @@ import ( "net/http/httptest" "strings" "testing" - "time" - "github.com/99designs/gqlgen/graphql" "github.com/99designs/gqlgen/graphql/handler/apollofederatedtracingv1" "github.com/99designs/gqlgen/graphql/handler/apollofederatedtracingv1/generated" "github.com/99designs/gqlgen/graphql/handler/apollotracing" @@ -25,15 +23,6 @@ import ( ) func TestApolloTracing(t *testing.T) { - now := time.Unix(0, 0) - - graphql.Now = func() time.Time { - defer func() { - now = now.Add(100 * time.Nanosecond) - }() - return now - } - h := testserver.New() h.AddTransport(transport.POST{}) h.Use(&apollofederatedtracingv1.Tracer{}) @@ -48,14 +37,16 @@ func TestApolloTracing(t *testing.T) { require.NoError(t, json.Unmarshal(resp.Body.Bytes(), &respData)) tracing := respData.Extensions.FTV1 - pbuf, _ := base64.StdEncoding.DecodeString(tracing) + pbuf, err := base64.StdEncoding.DecodeString(tracing) + require.Nil(t, err) + ftv1 := &generated.Trace{} - err := proto.Unmarshal(pbuf, ftv1) + err = proto.Unmarshal(pbuf, ftv1) require.Nil(t, err) - require.Zero(t, ftv1.StartTime.Nanos, ftv1.StartTime.Nanos) - require.EqualValues(t, 900, ftv1.EndTime.Nanos) - require.EqualValues(t, 900, ftv1.DurationNs) + require.NotZero(t, ftv1.StartTime.Nanos) + require.Less(t, ftv1.StartTime.Nanos, ftv1.EndTime.Nanos) + require.EqualValues(t, ftv1.EndTime.Nanos-ftv1.StartTime.Nanos, ftv1.DurationNs) fmt.Printf("%#v\n", resp.Body.String()) require.Equal(t, "Query", ftv1.Root.Child[0].ParentType) @@ -63,16 +54,34 @@ func TestApolloTracing(t *testing.T) { require.Equal(t, "String!", ftv1.Root.Child[0].Type) } -func TestApolloTracing_withFail(t *testing.T) { - now := time.Unix(0, 0) +func TestApolloTracing_Concurrent(t *testing.T) { + h := testserver.New() + h.AddTransport(transport.POST{}) + h.Use(&apollofederatedtracingv1.Tracer{}) + for i := 0; i < 2; i++ { + go func() { + resp := doRequest(h, http.MethodPost, "/graphql", `{"query":"{ name }"}`) + assert.Equal(t, http.StatusOK, resp.Code, resp.Body.String()) + var respData struct { + Extensions struct { + FTV1 string `json:"ftv1"` + } `json:"extensions"` + } + require.NoError(t, json.Unmarshal(resp.Body.Bytes(), &respData)) - graphql.Now = func() time.Time { - defer func() { - now = now.Add(100 * time.Nanosecond) + tracing := respData.Extensions.FTV1 + pbuf, err := base64.StdEncoding.DecodeString(tracing) + require.Nil(t, err) + + ftv1 := &generated.Trace{} + err = proto.Unmarshal(pbuf, ftv1) + require.Nil(t, err) + require.NotZero(t, ftv1.StartTime.Nanos) }() - return now } +} +func TestApolloTracing_withFail(t *testing.T) { h := testserver.New() h.AddTransport(transport.POST{}) h.Use(extension.AutomaticPersistedQuery{Cache: lru.New(100)})