From 98763b80a0bbc3cd393cb1a40aff2129b4a89590 Mon Sep 17 00:00:00 2001 From: bendiktv2 <86664407+bendiktv2@users.noreply.github.com> Date: Tue, 16 Nov 2021 02:11:49 +0100 Subject: [PATCH] feat(apmgrpc): wrap the server-stream with transaction-ctx (#1151) * feat(apmgrpc): wrap the server-stream with transaction-ctx This allows the handler to retrieve the transaction from it's context, to e.g. create spans. * define wrappedServerStream instead of importing from go-grpc-middleware --- module/apmgrpc/go.sum | 2 -- module/apmgrpc/server.go | 22 +++++++++++++++++++++- module/apmgrpc/server_test.go | 12 +++++++++++- 3 files changed, 32 insertions(+), 4 deletions(-) diff --git a/module/apmgrpc/go.sum b/module/apmgrpc/go.sum index 2abe9f7e4..fd870e045 100644 --- a/module/apmgrpc/go.sum +++ b/module/apmgrpc/go.sum @@ -48,8 +48,6 @@ github.com/santhosh-tekuri/jsonschema v1.2.4/go.mod h1:TEAUOeZSmIxTTuHatJzrvARHi github.com/stretchr/objx v0.1.0 h1:4G4v2dO3VZwixGIRoQ5Lfboy6nUhCyYzaqnIAPPhYs4= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= -github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk= -github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= github.com/stretchr/testify v1.6.1 h1:hDPOHmpOpP40lSULcqw7IrRb/u7w6RpDC9399XyoNd0= github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/yuin/goldmark v1.1.27 h1:nqDD4MMMQA0lmWq03Z2/myGPYLQoXtmi0rGVs95ntbo= diff --git a/module/apmgrpc/server.go b/module/apmgrpc/server.go index b7a5d4856..21d234c46 100644 --- a/module/apmgrpc/server.go +++ b/module/apmgrpc/server.go @@ -134,6 +134,9 @@ func NewStreamServerInterceptor(o ...ServerOption) grpc.StreamServerInterceptor tx, ctx := startTransaction(ctx, opts.tracer, info.FullMethod) defer tx.End() + wrapped := wrapServerStream(stream) + wrapped.wrappedContext = ctx + // TODO(axw) define span context schema for RPC, // including at least the peer address. @@ -153,7 +156,7 @@ func NewStreamServerInterceptor(o ...ServerOption) grpc.StreamServerInterceptor } setTransactionResult(tx, err) }() - return handler(srv, stream) + return handler(srv, wrapped) } } @@ -311,3 +314,20 @@ func WithServerStreamIgnorer(s StreamIgnorerFunc) ServerOption { o.streamIgnorer = s } } + +// wrappedServerStream is a thin wrapper around grpc.ServerStream that allows modifying context. +type wrappedServerStream struct { + grpc.ServerStream + // wrappedContext is the wrapper's own Context. You can assign it. + wrappedContext context.Context +} + +// Context returns the wrapper's WrappedContext, overwriting the nested grpc.ServerStream.Context() +func (w *wrappedServerStream) Context() context.Context { + return w.wrappedContext +} + +// wrapServerStream returns a ServerStream that has the ability to overwrite context. +func wrapServerStream(stream grpc.ServerStream) *wrappedServerStream { + return &wrappedServerStream{ServerStream: stream, wrappedContext: stream.Context()} +} diff --git a/module/apmgrpc/server_test.go b/module/apmgrpc/server_test.go index 117bb897e..958483227 100644 --- a/module/apmgrpc/server_test.go +++ b/module/apmgrpc/server_test.go @@ -261,7 +261,7 @@ func TestServerStream(t *testing.T) { tracer, transport := transporttest.NewRecorderTracer() defer tracer.Close() - s, _, addr := newAccumulatorServer(t, tracer, apmgrpc.WithRecovery()) + s, accumulatorServer, addr := newAccumulatorServer(t, tracer, apmgrpc.WithRecovery()) defer s.GracefulStop() conn, client := newAccumulatorClient(t, addr) @@ -290,6 +290,13 @@ func TestServerStream(t *testing.T) { tracer.Flush(nil) transactions := transport.Payloads().Transactions require.Len(t, transactions, 1) + + // The transaction should have propagated into the accumulatorServer + require.NotNil(t, accumulatorServer.transactionFromContext) + expectedTraceID := fmt.Sprintf("%x", transactions[0].TraceID) + actualTraceID := accumulatorServer.transactionFromContext.TraceContext().Trace.String() + require.NotEmpty(t, expectedTraceID) + require.Equal(t, expectedTraceID, actualTraceID) } func TestServerTLS(t *testing.T) { @@ -457,9 +464,12 @@ func (s *helloworldServer) SayHello(ctx context.Context, req *pb.HelloRequest) ( type accumulator struct { panic bool err error + + transactionFromContext *apm.Transaction } func (a *accumulator) Accumulate(srv testservice.Accumulator_AccumulateServer) error { + a.transactionFromContext = apm.TransactionFromContext(srv.Context()) if a.panic { panic(a.err) }