Skip to content

Commit

Permalink
contrib/google.golang.org/grpc: fix client stream context propagation (
Browse files Browse the repository at this point in the history
…#919)

Fix the client stream context returned by method Context() on the traced
client stream. It now correctly returns the context containing the
grpc.client span instead of its parent context.

This was leading to misconnected or orphan grpc.message spans.

Fixes: APMS-5154
  • Loading branch information
Julio-Guerra authored Jun 29, 2021
1 parent ae6b43f commit bfaa50c
Show file tree
Hide file tree
Showing 2 changed files with 129 additions and 46 deletions.
22 changes: 16 additions & 6 deletions contrib/google.golang.org/grpc/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,15 @@ import (

type clientStream struct {
grpc.ClientStream
ctx context.Context
cfg *config
method string
}

func (cs *clientStream) Context() context.Context {
return cs.ctx
}

func (cs *clientStream) RecvMsg(m interface{}) (err error) {
if cs.cfg.traceStreamMessages {
span, _ := startSpanFromContext(
Expand Down Expand Up @@ -86,7 +91,11 @@ func StreamClientInterceptor(opts ...Option) grpc.StreamClientInterceptor {
}
var stream grpc.ClientStream
if cfg.traceStreamCalls {
span, err := doClientRequest(ctx, cfg, method, methodKind, opts,
var (
span tracer.Span
err error
)
span, ctx, err = doClientRequest(ctx, cfg, method, methodKind, opts,
func(ctx context.Context, opts []grpc.CallOption) error {
var err error
stream, err = streamer(ctx, desc, cc, method, opts...)
Expand Down Expand Up @@ -125,6 +134,7 @@ func StreamClientInterceptor(opts ...Option) grpc.StreamClientInterceptor {
ClientStream: stream,
cfg: cfg,
method: method,
ctx: ctx,
}, nil
}
}
Expand All @@ -139,7 +149,7 @@ func UnaryClientInterceptor(opts ...Option) grpc.UnaryClientInterceptor {
}
log.Debug("contrib/google.golang.org/grpc: Configuring UnaryClientInterceptor: %#v", cfg)
return func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
span, err := doClientRequest(ctx, cfg, method, methodKindUnary, opts,
span, _, err := doClientRequest(ctx, cfg, method, methodKindUnary, opts,
func(ctx context.Context, opts []grpc.CallOption) error {
return invoker(ctx, method, req, reply, cc, opts...)
})
Expand All @@ -153,7 +163,7 @@ func UnaryClientInterceptor(opts ...Option) grpc.UnaryClientInterceptor {
func doClientRequest(
ctx context.Context, cfg *config, method string, methodKind string, opts []grpc.CallOption,
handler func(ctx context.Context, opts []grpc.CallOption) error,
) (ddtrace.Span, error) {
) (ddtrace.Span, context.Context, error) {
// inject the trace id into the metadata
span, ctx := startSpanFromContext(
ctx,
Expand All @@ -165,17 +175,17 @@ func doClientRequest(
if methodKind != "" {
span.SetTag(tagMethodKind, methodKind)
}
ctx = injectSpanIntoContext(ctx)

// fill in the peer so we can add it to the tags
var p peer.Peer
opts = append(opts, grpc.Peer(&p))

err := handler(ctx, opts)
handlerCtx := injectSpanIntoContext(ctx)
err := handler(handlerCtx, opts)

setSpanTargetFromPeer(span, p)

return span, err
return span, ctx, err
}

// setSpanTargetFromPeer sets the target tags in a span based on the gRPC peer.
Expand Down
153 changes: 113 additions & 40 deletions contrib/google.golang.org/grpc/grpc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -262,48 +262,117 @@ func TestStreaming(t *testing.T) {
})
}

func TestChild(t *testing.T) {
assert := assert.New(t)
mt := mocktracer.Start()
defer mt.Stop()

rig, err := newRig(false)
if err != nil {
t.Fatalf("error setting up rig: %s", err)
func TestSpanTree(t *testing.T) {
assertSpan := func(t *testing.T, span, parent mocktracer.Span, operationName, resourceName string) {
require.NotNil(t, span)
assert.Nil(t, span.Tag(ext.Error))
assert.Equal(t, operationName, span.OperationName())
assert.Equal(t, "grpc", span.Tag(ext.ServiceName))
assert.Equal(t, span.Tag(ext.ResourceName), resourceName)
assert.True(t, span.FinishTime().Sub(span.StartTime()) > 0)

if parent == nil {
return
}
assert.Equal(t, parent.SpanID(), span.ParentID(), "unexpected parent id")
}
defer rig.Close()

client := rig.client
resp, err := client.Ping(context.Background(), &FixtureRequest{Name: "child"})
assert.Nil(err)
assert.Equal(resp.Message, "child")
t.Run("unary", func(t *testing.T) {
assert := assert.New(t)
mt := mocktracer.Start()
defer mt.Stop()

spans := mt.FinishedSpans()
assert.Len(spans, 2)
rig, err := newRig(true)
if err != nil {
t.Fatalf("error setting up rig: %s", err)
}
defer rig.Close()

var serverSpan, clientSpan mocktracer.Span
{
// Unary Ping rpc leading to trace:
// root span -> client Ping span -> server Ping span -> child span
rootSpan, ctx := tracer.StartSpanFromContext(context.Background(), "root")
client := rig.client
resp, err := client.Ping(ctx, &FixtureRequest{Name: "child"})
assert.NoError(err)
assert.Equal("child", resp.Message)
rootSpan.Finish()
}

for _, s := range spans {
// order of traces in buffer is not garanteed
switch s.OperationName() {
case "grpc.server":
serverSpan = s
case "child":
clientSpan = s
assert.Empty(mt.OpenSpans())
spans := mt.FinishedSpans()
assert.Len(spans, 4)

rootSpan := spans[3]
clientPingSpan := spans[2]
serverPingSpan := spans[1]
serverPingChildSpan := spans[0]

assert.Zero(0, rootSpan.ParentID())
assertSpan(t, serverPingChildSpan, serverPingSpan, "child", "child")
assertSpan(t, serverPingSpan, clientPingSpan, "grpc.server", "/grpc.Fixture/Ping")
assertSpan(t, clientPingSpan, rootSpan, "grpc.client", "/grpc.Fixture/Ping")
})

t.Run("stream", func(t *testing.T) {
assert := assert.New(t)
mt := mocktracer.Start()
defer mt.Stop()

rig, err := newRig(true)
if err != nil {
t.Fatalf("error setting up rig: %s", err)
}
}
defer rig.Close()
client := rig.client

{
rootSpan, ctx := tracer.StartSpanFromContext(context.Background(), "root")

// Streaming RPC leading to trace:
// root -> client stream -> client send message -> server stream
// -> server receive message -> server send message
// -> client receive message
ctx, cancel := context.WithCancel(ctx)
stream, err := client.StreamPing(ctx)
assert.NoError(err)
err = stream.SendMsg(&FixtureRequest{Name: "break"})
assert.NoError(err)
resp, err := stream.Recv()
assert.Nil(err)
assert.Equal(resp.Message, "passed")
err = stream.CloseSend()
assert.NoError(err)
cancel()

// Wait until the client stream tracer goroutine gets awoken by the context
// cancellation and finishes its span
waitForSpans(mt, 6, time.Second)

assert.NotNil(clientSpan)
assert.Nil(clientSpan.Tag(ext.Error))
assert.Equal(clientSpan.Tag(ext.ServiceName), "grpc")
assert.Equal(clientSpan.Tag(ext.ResourceName), "child")
assert.True(clientSpan.FinishTime().Sub(clientSpan.StartTime()) > 0)

assert.NotNil(serverSpan)
assert.Nil(serverSpan.Tag(ext.Error))
assert.Equal(serverSpan.Tag(ext.ServiceName), "grpc")
assert.Equal(serverSpan.Tag(ext.ResourceName), "/grpc.Fixture/Ping")
assert.True(serverSpan.FinishTime().Sub(serverSpan.StartTime()) > 0)
rootSpan.Finish()
}

assert.Empty(mt.OpenSpans())
spans := mt.FinishedSpans()
assert.Len(spans, 7)

// Ping spans
rootSpan := spans[6]
clientStreamSpan := spans[5]
clientStreamSendMsgSpan := spans[4]
serverStreamSpan := spans[3]
serverStreamRecvMsgSpan := spans[2]
serverStreamSendMsgSpan := spans[1]
clientStreamRecvMsgSpan := spans[0]

assert.Zero(rootSpan.ParentID())
assertSpan(t, clientStreamSpan, rootSpan, "grpc.client", "/grpc.Fixture/StreamPing")
assertSpan(t, clientStreamSendMsgSpan, clientStreamSpan, "grpc.message", "/grpc.Fixture/StreamPing")
assertSpan(t, serverStreamSpan, clientStreamSpan, "grpc.server", "/grpc.Fixture/StreamPing")
assertSpan(t, serverStreamRecvMsgSpan, serverStreamSpan, "grpc.message", "/grpc.Fixture/StreamPing")
assertSpan(t, serverStreamSendMsgSpan, serverStreamSpan, "grpc.message", "/grpc.Fixture/StreamPing")
assertSpan(t, clientStreamRecvMsgSpan, clientStreamSpan, "grpc.message", "/grpc.Fixture/StreamPing")
})
}

func TestPass(t *testing.T) {
Expand Down Expand Up @@ -411,27 +480,31 @@ func TestStreamSendsErrorCode(t *testing.T) {
assert.Equal(t, gotLastSpanCode, wantCode, "last span should contain error code")
}

// fixtureServer a dummy implemenation of our grpc fixtureServer.
// fixtureServer a dummy implementation of our grpc fixtureServer.
type fixtureServer struct {
lastRequestMetadata atomic.Value
}

func (s *fixtureServer) StreamPing(srv Fixture_StreamPingServer) error {
func (s *fixtureServer) StreamPing(stream Fixture_StreamPingServer) (err error) {
for {
msg, err := srv.Recv()
msg, err := stream.Recv()
if err != nil {
return err
}

reply, err := s.Ping(srv.Context(), msg)
reply, err := s.Ping(stream.Context(), msg)
if err != nil {
return err
}

err = srv.Send(reply)
err = stream.Send(reply)
if err != nil {
return err
}

if msg.Name == "break" {
return nil
}
}
}

Expand Down

0 comments on commit bfaa50c

Please sign in to comment.