Skip to content

Commit

Permalink
fix: apollo federation tracer was race prone (#2366)
Browse files Browse the repository at this point in the history
The tracer was using a global state across different goroutines
Added req headers to operation context to allow it to be fetched in InterceptOperation
  • Loading branch information
roeest authored Sep 13, 2022
1 parent fc01855 commit 59493af
Show file tree
Hide file tree
Showing 4 changed files with 73 additions and 51 deletions.
2 changes: 2 additions & 0 deletions graphql/context_operation.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package graphql
import (
"context"
"errors"
"net/http"

"github.com/vektah/gqlparser/v2/ast"
)
Expand All @@ -15,6 +16,7 @@ type OperationContext struct {
Variables map[string]interface{}
OperationName string
Doc *ast.QueryDocument
Headers http.Header

Operation *ast.OperationDefinition
DisableIntrospection bool
Expand Down
1 change: 1 addition & 0 deletions graphql/executor/executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ func (e *Executor) CreateOperationContext(ctx context.Context, params *graphql.R

rc.RawQuery = params.Query
rc.OperationName = params.OperationName
rc.Headers = params.Headers

var listErr gqlerror.List
rc.Doc, listErr = e.parseQuery(ctx, &rc.Stats, params.Query)
Expand Down
68 changes: 39 additions & 29 deletions graphql/handler/apollofederatedtracingv1/tracing.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,26 +6,28 @@ import (
"fmt"

"github.com/99designs/gqlgen/graphql"
"github.com/vektah/gqlparser/v2/gqlerror"
"google.golang.org/protobuf/proto"
)

type (
Tracer struct {
ClientName string
Version string
Hostname string
TreeBuilder *TreeBuilder
ShouldTrace bool
ClientName string
Version string
Hostname string
}

treeBuilderKey string
)

const (
key = treeBuilderKey("treeBuilder")
)

var _ interface {
graphql.HandlerExtension
graphql.ResponseInterceptor
graphql.FieldInterceptor
graphql.OperationInterceptor
graphql.OperationParameterMutator
} = &Tracer{}

// ExtensionName returns the name of the extension
Expand All @@ -38,61 +40,69 @@ func (Tracer) Validate(graphql.ExecutableSchema) error {
return nil
}

func (t *Tracer) MutateOperationParameters(ctx context.Context, request *graphql.RawParams) *gqlerror.Error {
t.ShouldTrace = request.Headers.Get("apollo-federation-include-trace") == "ftv1" // check for header
func (t *Tracer) shouldTrace(ctx context.Context) bool {
return graphql.GetOperationContext(ctx).Headers.Get("apollo-federation-include-trace") == "ftv1"
}

func (t *Tracer) getTreeBuilder(ctx context.Context) *TreeBuilder {
val := ctx.Value(key)
if val == nil {
return nil
}
if tb, ok := val.(*TreeBuilder); ok {
return tb
}
return nil
}

// InterceptOperation acts on each Graph operation; on each operation, start a tree builder and start the tree's timer for tracing
func (t *Tracer) InterceptOperation(ctx context.Context, next graphql.OperationHandler) graphql.ResponseHandler {
if !t.ShouldTrace {
if !t.shouldTrace(ctx) {
return next(ctx)
}

t.TreeBuilder = NewTreeBuilder()

return next(ctx)
return next(context.WithValue(ctx, key, NewTreeBuilder()))
}

// InterceptField is called on each field's resolution, including information about the path and parent node.
// This information is then used to build the relevant Node Tree used in the FTV1 tracing format
func (t *Tracer) InterceptField(ctx context.Context, next graphql.Resolver) (interface{}, error) {
if !t.ShouldTrace {
if !t.shouldTrace(ctx) {
return next(ctx)
}

t.TreeBuilder.WillResolveField(ctx)
if tb := t.getTreeBuilder(ctx); tb != nil {
tb.WillResolveField(ctx)
}

return next(ctx)
}

// InterceptResponse is called before the overall response is sent, but before each field resolves; as a result
// the final marshaling is deferred to happen at the end of the operation
func (t *Tracer) InterceptResponse(ctx context.Context, next graphql.ResponseHandler) *graphql.Response {
if !t.ShouldTrace {
if !t.shouldTrace(ctx) {
return next(ctx)
}
t.TreeBuilder.StartTimer(ctx)
tb := t.getTreeBuilder(ctx)
if tb != nil {
tb.StartTimer(ctx)
}

// because we need to update the ftv1 string at a later time (as fields resolve before the response is sent),
// we instantiate the string and use a pointer to be able to update later
var ftv1 string
graphql.RegisterExtension(ctx, "ftv1", &ftv1)
val := new(string)
graphql.RegisterExtension(ctx, "ftv1", val)

// now that fields have finished resolving, it stops the timer to calculate trace duration
defer func() {
t.TreeBuilder.StopTimer(ctx)
defer func(val *string) {
tb.StopTimer(ctx)

// marshal the protobuf ...
p, err := proto.Marshal(t.TreeBuilder.Trace)
p, err := proto.Marshal(tb.Trace)
if err != nil {
fmt.Print(err)
}

// ... then set the previously instantiated string as the base64 formatted string as required
ftv1 = base64.StdEncoding.EncodeToString(p)
}()

*val = base64.StdEncoding.EncodeToString(p)
}(val)
resp := next(ctx)
return resp
}
53 changes: 31 additions & 22 deletions graphql/handler/apollofederatedtracingv1/tracing_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,7 @@ import (
"net/http/httptest"
"strings"
"testing"
"time"

"github.com/99designs/gqlgen/graphql"
"github.com/99designs/gqlgen/graphql/handler/apollofederatedtracingv1"
"github.com/99designs/gqlgen/graphql/handler/apollofederatedtracingv1/generated"
"github.com/99designs/gqlgen/graphql/handler/apollotracing"
Expand All @@ -25,15 +23,6 @@ import (
)

func TestApolloTracing(t *testing.T) {
now := time.Unix(0, 0)

graphql.Now = func() time.Time {
defer func() {
now = now.Add(100 * time.Nanosecond)
}()
return now
}

h := testserver.New()
h.AddTransport(transport.POST{})
h.Use(&apollofederatedtracingv1.Tracer{})
Expand All @@ -48,31 +37,51 @@ func TestApolloTracing(t *testing.T) {
require.NoError(t, json.Unmarshal(resp.Body.Bytes(), &respData))

tracing := respData.Extensions.FTV1
pbuf, _ := base64.StdEncoding.DecodeString(tracing)
pbuf, err := base64.StdEncoding.DecodeString(tracing)
require.Nil(t, err)

ftv1 := &generated.Trace{}
err := proto.Unmarshal(pbuf, ftv1)
err = proto.Unmarshal(pbuf, ftv1)
require.Nil(t, err)

require.Zero(t, ftv1.StartTime.Nanos, ftv1.StartTime.Nanos)
require.EqualValues(t, 900, ftv1.EndTime.Nanos)
require.EqualValues(t, 900, ftv1.DurationNs)
require.NotZero(t, ftv1.StartTime.Nanos)
require.Less(t, ftv1.StartTime.Nanos, ftv1.EndTime.Nanos)
require.EqualValues(t, ftv1.EndTime.Nanos-ftv1.StartTime.Nanos, ftv1.DurationNs)

fmt.Printf("%#v\n", resp.Body.String())
require.Equal(t, "Query", ftv1.Root.Child[0].ParentType)
require.Equal(t, "name", ftv1.Root.Child[0].GetResponseName())
require.Equal(t, "String!", ftv1.Root.Child[0].Type)
}

func TestApolloTracing_withFail(t *testing.T) {
now := time.Unix(0, 0)
func TestApolloTracing_Concurrent(t *testing.T) {
h := testserver.New()
h.AddTransport(transport.POST{})
h.Use(&apollofederatedtracingv1.Tracer{})
for i := 0; i < 2; i++ {
go func() {
resp := doRequest(h, http.MethodPost, "/graphql", `{"query":"{ name }"}`)
assert.Equal(t, http.StatusOK, resp.Code, resp.Body.String())
var respData struct {
Extensions struct {
FTV1 string `json:"ftv1"`
} `json:"extensions"`
}
require.NoError(t, json.Unmarshal(resp.Body.Bytes(), &respData))

graphql.Now = func() time.Time {
defer func() {
now = now.Add(100 * time.Nanosecond)
tracing := respData.Extensions.FTV1
pbuf, err := base64.StdEncoding.DecodeString(tracing)
require.Nil(t, err)

ftv1 := &generated.Trace{}
err = proto.Unmarshal(pbuf, ftv1)
require.Nil(t, err)
require.NotZero(t, ftv1.StartTime.Nanos)
}()
return now
}
}

func TestApolloTracing_withFail(t *testing.T) {
h := testserver.New()
h.AddTransport(transport.POST{})
h.Use(extension.AutomaticPersistedQuery{Cache: lru.New(100)})
Expand Down

0 comments on commit 59493af

Please sign in to comment.