From 03e92a9d08443bea92c7a7c68a605521ee483310 Mon Sep 17 00:00:00 2001 From: Andrew Wilkins Date: Thu, 25 Mar 2021 10:41:38 +0800 Subject: [PATCH] module/apmgrpc: add stream interceptors (#918) * module/apmgrpc: add stream interceptors --- docs/instrumenting.asciidoc | 17 +- docs/supported-tech.asciidoc | 6 +- module/apmgrpc/client.go | 119 ++++++++- module/apmgrpc/client_test.go | 80 +++++- module/apmgrpc/go.mod | 1 + module/apmgrpc/ignorer.go | 18 +- .../apmgrpc/internal/testservice/generate.go | 20 ++ .../internal/testservice/testservice.pb.go | 250 ++++++++++++++++++ .../internal/testservice/testservice.proto | 17 ++ module/apmgrpc/server.go | 82 +++++- module/apmgrpc/server_test.go | 113 +++++++- 11 files changed, 689 insertions(+), 34 deletions(-) create mode 100644 module/apmgrpc/internal/testservice/generate.go create mode 100644 module/apmgrpc/internal/testservice/testservice.pb.go create mode 100644 module/apmgrpc/internal/testservice/testservice.proto diff --git a/docs/instrumenting.asciidoc b/docs/instrumenting.asciidoc index 27fecb198..e0ab924c6 100644 --- a/docs/instrumenting.asciidoc +++ b/docs/instrumenting.asciidoc @@ -150,9 +150,15 @@ import ( ) func main() { - server := grpc.NewServer(grpc.UnaryInterceptor(apmgrpc.NewUnaryServerInterceptor())) + server := grpc.NewServer( + grpc.UnaryInterceptor(apmgrpc.NewUnaryServerInterceptor()), + grpc.StreamInterceptor(apmgrpc.NewStreamServerInterceptor()), + ) ... - conn, err := grpc.Dial(addr, grpc.WithUnaryInterceptor(apmgrpc.NewUnaryClientInterceptor())) + conn, err := grpc.Dial(addr, + grpc.WithUnaryInterceptor(apmgrpc.NewUnaryClientInterceptor()), + gprc.WithStreamInterceptor(apmgrpc.NewStreamClientInterceptor()), + ) ... } ---- @@ -172,8 +178,11 @@ server := grpc.NewServer(grpc.UnaryInterceptor( ... ---- -There is currently no support for intercepting at the stream level. Please file an issue and/or -send a pull request if this is something you need. +Stream interceptors emit transactions and spans that represent the entire stream, +and not individual messages. For client streams, spans will be ended when the request +fails; when any of `grpc.ClientStream.RecvMsg`, `grpc.ClientStream.SendMsg`, or +`grpc.ClientStream.Header` return with an error; or when `grpc.ClientStream.RecvMsg` +returns for a non-streaming server method. [[builtin-modules-apmhttp]] ==== module/apmhttp diff --git a/docs/supported-tech.asciidoc b/docs/supported-tech.asciidoc index 2f1726cb3..ea3c2bfd7 100644 --- a/docs/supported-tech.asciidoc +++ b/docs/supported-tech.asciidoc @@ -223,9 +223,9 @@ the MongoDB Go Driver instrumentation. We support https://grpc.io/[gRPC] https://github.com/grpc/grpc-go/releases/tag/v1.3.0[v1.3.0] and greater. -We provide unary interceptors for both the client and server. The server -interceptor will create a transaction for each incoming request, and -the client interceptor will create a span for each outgoing request. +We provide unary and stream interceptors for both the client and server. +The server interceptor will create a transaction for each incoming request, +and the client interceptor will create a span for each outgoing request. See <> for more information about gRPC instrumentation. diff --git a/module/apmgrpc/client.go b/module/apmgrpc/client.go index e0fcc88f9..a1d9c73d6 100644 --- a/module/apmgrpc/client.go +++ b/module/apmgrpc/client.go @@ -21,6 +21,7 @@ package apmgrpc // import "go.elastic.co/apm/module/apmgrpc" import ( "net" + "sync" "golang.org/x/net/context" "google.golang.org/grpc" @@ -35,9 +36,9 @@ import ( // NewUnaryClientInterceptor returns a grpc.UnaryClientInterceptor that // traces gRPC requests with the given options. // -// The interceptor will trace spans with the "grpc" type for each request -// made, for any client method presented with a context containing a sampled -// apm.Transaction. +// The interceptor will trace spans with the "external.grpc" type for each +// request made, for any client method presented with a context containing +// a sampled apm.Transaction. func NewUnaryClientInterceptor(o ...ClientOption) grpc.UnaryClientInterceptor { opts := clientOptions{} for _, o := range o { @@ -75,6 +76,105 @@ func NewUnaryClientInterceptor(o ...ClientOption) grpc.UnaryClientInterceptor { } } +// NewStreamClientInterceptor returns a grpc.UnaryClientInterceptor that +// traces gRPC requests with the given options. +// +// The interceptor will trace spans with the "external.grpc" type for each +// stream request made, for any client method presented with a context +// containing a sampled apm.Transaction. +// +// Spans are ended when the stream is closed, which can happen in various +// ways: the initial stream setup request fails, Header, SendMsg or RecvMsg +// return with an error, or RecvMsg returns for a non-streaming server. +func NewStreamClientInterceptor(o ...ClientOption) grpc.StreamClientInterceptor { + opts := clientOptions{} + for _, o := range o { + o(&opts) + } + return func( + ctx context.Context, + desc *grpc.StreamDesc, + cc *grpc.ClientConn, + method string, + streamer grpc.Streamer, + opts ...grpc.CallOption, + ) (grpc.ClientStream, error) { + var peer peer.Peer + span, ctx := startSpan(ctx, method) + if span != nil { + opts = append(opts, grpc.Peer(&peer)) + } + stream, err := streamer(ctx, desc, cc, method, opts...) + if span != nil { + if err != nil { + setSpanOutcome(span, err) + setSpanContext(span, peer) + span.End() + } else if stream != nil { + wrapped := &clientStream{ClientStream: stream} + go func() { + defer span.End() + // Header blocks until headers are available + // or the stream is ended. Either way, after + // Header returns, it is safe to call Context(). + stream.Header() + <-stream.Context().Done() + err := wrapped.getError() + setSpanOutcome(span, err) + setSpanContext(span, peer) + }() + stream = wrapped + } + } + return stream, err + } +} + +// clientStream wraps grpc.ClientStream to intercept errors. +type clientStream struct { + grpc.ClientStream + mu sync.RWMutex + err error +} + +func (s *clientStream) CloseSend() error { + err := s.ClientStream.CloseSend() + s.setError(err) + return err +} + +func (s *clientStream) Header() (metadata.MD, error) { + md, err := s.ClientStream.Header() + s.setError(err) + return md, err +} + +func (s *clientStream) SendMsg(m interface{}) error { + err := s.ClientStream.SendMsg(m) + s.setError(err) + return err +} + +func (s *clientStream) RecvMsg(m interface{}) error { + err := s.ClientStream.RecvMsg(m) + s.setError(err) + return err +} + +func (s *clientStream) getError() error { + s.mu.RLock() + defer s.mu.RUnlock() + return s.err +} + +func (s *clientStream) setError(err error) { + if err != nil { + s.mu.Lock() + s.err = err + s.mu.Unlock() + } +} + func startSpan(ctx context.Context, name string) (*apm.Span, context.Context) { tx := apm.TransactionFromContext(ctx) if tx == nil { @@ -93,6 +193,19 @@ func startSpan(ctx context.Context, name string) (*apm.Span, context.Context) { return span, outgoingContextWithTraceContext(ctx, traceContext, propagateLegacyHeader) } +func setSpanContext(span *apm.Span, peer peer.Peer) { + if peer.Addr != nil { + if tcpAddr, ok := peer.Addr.(*net.TCPAddr); ok { + span.Context.SetDestinationAddress(tcpAddr.IP.String(), tcpAddr.Port) + } + addrString := peer.Addr.String() + span.Context.SetDestinationService(apm.DestinationServiceSpanContext{ + Name: addrString, + Resource: addrString, + }) + } +} + func outgoingContextWithTraceContext( ctx context.Context, traceContext apm.TraceContext, diff --git a/module/apmgrpc/client_test.go b/module/apmgrpc/client_test.go index d1712f81b..ed54c4725 100644 --- a/module/apmgrpc/client_test.go +++ b/module/apmgrpc/client_test.go @@ -20,9 +20,11 @@ package apmgrpc_test import ( + "io" "net" "os" "testing" + "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -34,6 +36,7 @@ import ( "go.elastic.co/apm" "go.elastic.co/apm/apmtest" "go.elastic.co/apm/model" + "go.elastic.co/apm/module/apmgrpc/internal/testservice" "go.elastic.co/apm/module/apmhttp" "go.elastic.co/apm/transport/transporttest" ) @@ -55,11 +58,11 @@ func testClientSpan(t *testing.T, traceparentHeaders ...string) { serverTracer, serverTransport := transporttest.NewRecorderTracer() defer serverTracer.Close() - s, _, addr := newServer(t, serverTracer) + s, _, addr := newGreeterServer(t, serverTracer) defer s.GracefulStop() tcpAddr := addr.(*net.TCPAddr) - conn, client := newClient(t, addr) + conn, client := newGreeterClient(t, addr) defer conn.Close() resp, err := client.SayHello(context.Background(), &pb.HelloRequest{Name: "birita"}) require.NoError(t, err) @@ -131,10 +134,10 @@ func testClientSpan(t *testing.T, traceparentHeaders ...string) { func TestClientSpanDropped(t *testing.T) { serverTracer := apmtest.NewRecordingTracer() defer serverTracer.Close() - s, _, addr := newServer(t, serverTracer.Tracer) + s, _, addr := newGreeterServer(t, serverTracer.Tracer) defer s.GracefulStop() - conn, client := newClient(t, addr) + conn, client := newGreeterClient(t, addr) defer conn.Close() clientTracer := apmtest.NewRecordingTracer() @@ -162,10 +165,10 @@ func TestClientSpanDropped(t *testing.T) { func TestClientTransactionUnsampled(t *testing.T) { serverTracer := apmtest.NewRecordingTracer() defer serverTracer.Close() - s, _, addr := newServer(t, serverTracer.Tracer) + s, _, addr := newGreeterServer(t, serverTracer.Tracer) defer s.GracefulStop() - conn, client := newClient(t, addr) + conn, client := newGreeterClient(t, addr) defer conn.Close() clientTracer := apmtest.NewRecordingTracer() @@ -186,10 +189,10 @@ func TestClientTransactionUnsampled(t *testing.T) { } func TestClientOutcome(t *testing.T) { - s, helloworldServer, addr := newServer(t, apmtest.DiscardTracer) + s, helloworldServer, addr := newGreeterServer(t, apmtest.DiscardTracer) defer s.GracefulStop() - conn, client := newClient(t, addr) + conn, client := newGreeterClient(t, addr) defer conn.Close() clientTracer := apmtest.NewRecordingTracer() @@ -210,3 +213,64 @@ func TestClientOutcome(t *testing.T) { assert.Equal(t, "failure", spans[1].Outcome) // unknown error assert.Equal(t, "failure", spans[2].Outcome) } + +func TestStreamClientSpan(t *testing.T) { + clientTracer, clientTransport := transporttest.NewRecorderTracer() + defer clientTracer.Close() + + serverTracer, serverTransport := transporttest.NewRecorderTracer() + defer serverTracer.Close() + s, _, addr := newAccumulatorServer(t, serverTracer) + defer s.GracefulStop() + + conn, client := newAccumulatorClient(t, addr) + defer conn.Close() + + clientTransaction := clientTracer.StartTransaction("name", "type") + ctx := apm.ContextWithTransaction(context.Background(), clientTransaction) + + stream, err := client.Accumulate(ctx) + require.NoError(t, err) + err = stream.Send(&testservice.AccumulateRequest{Value: 123}) + require.NoError(t, err) + reply, err := stream.Recv() + require.NoError(t, err) + assert.Equal(t, int64(123), reply.Value) + + err = stream.CloseSend() + require.NoError(t, err) + _, err = stream.Recv() + assert.Equal(t, io.EOF, err) + + timeout := time.NewTimer(10 * time.Second) + ticker := time.NewTicker(100 * time.Millisecond) + defer timeout.Stop() + defer ticker.Stop() + var done bool + for !done { + select { + case <-ticker.C: + clientTracer.Flush(nil) + if len(clientTransport.Payloads().Spans) > 0 { + done = true + } + case <-timeout.C: + t.Fatal("timed out waiting for client span to end") + } + } + clientTransaction.End() + + clientTracer.Flush(nil) + clientPayloads := clientTransport.Payloads() + require.Len(t, clientPayloads.Transactions, 1) + require.Len(t, clientPayloads.Spans, 1) + assert.Equal(t, "/go.elastic.co.apm.module.apmgrpc.testservice.Accumulator/Accumulate", clientPayloads.Spans[0].Name) + assert.Equal(t, "external", clientPayloads.Spans[0].Type) + assert.Equal(t, "grpc", clientPayloads.Spans[0].Subtype) + + serverTracer.Flush(nil) + serverPayloads := serverTransport.Payloads() + require.Len(t, serverPayloads.Transactions, 1) + assert.Equal(t, clientPayloads.Spans[0].ID, serverPayloads.Transactions[0].ParentID) + assert.Equal(t, clientPayloads.Spans[0].TraceID, serverPayloads.Transactions[0].TraceID) +} diff --git a/module/apmgrpc/go.mod b/module/apmgrpc/go.mod index 40da63c43..3db893fa0 100644 --- a/module/apmgrpc/go.mod +++ b/module/apmgrpc/go.mod @@ -1,6 +1,7 @@ module go.elastic.co/apm/module/apmgrpc require ( + github.com/golang/protobuf v1.2.0 github.com/grpc-ecosystem/go-grpc-middleware v1.0.0 github.com/stretchr/testify v1.4.0 go.elastic.co/apm v1.11.0 diff --git a/module/apmgrpc/ignorer.go b/module/apmgrpc/ignorer.go index 030176040..b3b96fdce 100644 --- a/module/apmgrpc/ignorer.go +++ b/module/apmgrpc/ignorer.go @@ -21,22 +21,27 @@ package apmgrpc // import "go.elastic.co/apm/module/apmgrpc" import ( "regexp" - "sync" "google.golang.org/grpc" ) var ( - defaultServerRequestIgnorerOnce sync.Once - defaultServerRequestIgnorer RequestIgnorerFunc = IgnoreNone + defaultServerRequestIgnorer RequestIgnorerFunc = IgnoreNone + defaultServerStreamIgnorer StreamIgnorerFunc = IgnoreNoneStream ) -// DefaultServerRequestIgnorer returns the default RequestIgnorer to use in +// DefaultServerRequestIgnorer returns the default RequestIgnorerFunc to use in // handlers. func DefaultServerRequestIgnorer() RequestIgnorerFunc { return defaultServerRequestIgnorer } +// DefaultServerStreamIgnorer returns the default StreamIgnorerFunc to use in +// handlers. +func DefaultServerStreamIgnorer() StreamIgnorerFunc { + return defaultServerStreamIgnorer +} + // NewRegexpRequestIgnorer returns a RequestIgnorerFunc which matches requests' // URLs against re. Note that for server requests, typically only Path and // possibly RawQuery will be set, so the regular expression should take this @@ -54,3 +59,8 @@ func NewRegexpRequestIgnorer(re *regexp.Regexp) RequestIgnorerFunc { func IgnoreNone(*grpc.UnaryServerInfo) bool { return false } + +// IgnoreNoneStream is a StreamIgnorerFunc which ignores no stream requests. +func IgnoreNoneStream(*grpc.StreamServerInfo) bool { + return false +} diff --git a/module/apmgrpc/internal/testservice/generate.go b/module/apmgrpc/internal/testservice/generate.go new file mode 100644 index 000000000..ae1809db5 --- /dev/null +++ b/module/apmgrpc/internal/testservice/generate.go @@ -0,0 +1,20 @@ +// Licensed to Elasticsearch B.V. under one or more contributor +// license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright +// ownership. Elasticsearch B.V. licenses this file to you under +// the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//go:generate protoc --go_out=plugins=grpc:. --go_opt=paths=source_relative testservice.proto + +package testservice diff --git a/module/apmgrpc/internal/testservice/testservice.pb.go b/module/apmgrpc/internal/testservice/testservice.pb.go new file mode 100644 index 000000000..a7d5c7dbe --- /dev/null +++ b/module/apmgrpc/internal/testservice/testservice.pb.go @@ -0,0 +1,250 @@ +// Licensed to Elasticsearch B.V. under one or more contributor +// license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright +// ownership. Elasticsearch B.V. licenses this file to you under +// the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// Code generated by protoc-gen-go. DO NOT EDIT. +// source: testservice.proto + +// +build go1.9 + +package testservice // import "go.elastic.co/apm/module/apmgrpc/internal/testservice" + +import ( + fmt "fmt" + + proto "github.com/golang/protobuf/proto" + + math "math" + + context "golang.org/x/net/context" + + grpc "google.golang.org/grpc" +) + +// Reference imports to suppress errors if they are not otherwise used. +var _ = proto.Marshal +var _ = fmt.Errorf +var _ = math.Inf + +// This is a compile-time assertion to ensure that this generated file +// is compatible with the proto package it is being compiled against. +// A compilation error at this line likely means your copy of the +// proto package needs to be updated. +const _ = proto.ProtoPackageIsVersion2 // please upgrade the proto package + +type AccumulateRequest struct { + Value int64 `protobuf:"varint,1,opt,name=value,proto3" json:"value,omitempty"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *AccumulateRequest) Reset() { *m = AccumulateRequest{} } +func (m *AccumulateRequest) String() string { return proto.CompactTextString(m) } +func (*AccumulateRequest) ProtoMessage() {} +func (*AccumulateRequest) Descriptor() ([]byte, []int) { + return fileDescriptor_testservice_6101db31074a6f55, []int{0} +} +func (m *AccumulateRequest) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_AccumulateRequest.Unmarshal(m, b) +} +func (m *AccumulateRequest) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_AccumulateRequest.Marshal(b, m, deterministic) +} +func (dst *AccumulateRequest) XXX_Merge(src proto.Message) { + xxx_messageInfo_AccumulateRequest.Merge(dst, src) +} +func (m *AccumulateRequest) XXX_Size() int { + return xxx_messageInfo_AccumulateRequest.Size(m) +} +func (m *AccumulateRequest) XXX_DiscardUnknown() { + xxx_messageInfo_AccumulateRequest.DiscardUnknown(m) +} + +var xxx_messageInfo_AccumulateRequest proto.InternalMessageInfo + +func (m *AccumulateRequest) GetValue() int64 { + if m != nil { + return m.Value + } + return 0 +} + +type AccumulateReply struct { + Value int64 `protobuf:"varint,1,opt,name=value,proto3" json:"value,omitempty"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *AccumulateReply) Reset() { *m = AccumulateReply{} } +func (m *AccumulateReply) String() string { return proto.CompactTextString(m) } +func (*AccumulateReply) ProtoMessage() {} +func (*AccumulateReply) Descriptor() ([]byte, []int) { + return fileDescriptor_testservice_6101db31074a6f55, []int{1} +} +func (m *AccumulateReply) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_AccumulateReply.Unmarshal(m, b) +} +func (m *AccumulateReply) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_AccumulateReply.Marshal(b, m, deterministic) +} +func (dst *AccumulateReply) XXX_Merge(src proto.Message) { + xxx_messageInfo_AccumulateReply.Merge(dst, src) +} +func (m *AccumulateReply) XXX_Size() int { + return xxx_messageInfo_AccumulateReply.Size(m) +} +func (m *AccumulateReply) XXX_DiscardUnknown() { + xxx_messageInfo_AccumulateReply.DiscardUnknown(m) +} + +var xxx_messageInfo_AccumulateReply proto.InternalMessageInfo + +func (m *AccumulateReply) GetValue() int64 { + if m != nil { + return m.Value + } + return 0 +} + +func init() { + proto.RegisterType((*AccumulateRequest)(nil), "go.elastic.co.apm.module.apmgrpc.testservice.AccumulateRequest") + proto.RegisterType((*AccumulateReply)(nil), "go.elastic.co.apm.module.apmgrpc.testservice.AccumulateReply") +} + +// Reference imports to suppress errors if they are not otherwise used. +var _ context.Context +var _ grpc.ClientConn + +// This is a compile-time assertion to ensure that this generated file +// is compatible with the grpc package it is being compiled against. +const _ = grpc.SupportPackageIsVersion4 + +// AccumulatorClient is the client API for Accumulator service. +// +// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://godoc.org/google.golang.org/grpc#ClientConn.NewStream. +type AccumulatorClient interface { + Accumulate(ctx context.Context, opts ...grpc.CallOption) (Accumulator_AccumulateClient, error) +} + +type accumulatorClient struct { + cc *grpc.ClientConn +} + +func NewAccumulatorClient(cc *grpc.ClientConn) AccumulatorClient { + return &accumulatorClient{cc} +} + +func (c *accumulatorClient) Accumulate(ctx context.Context, opts ...grpc.CallOption) (Accumulator_AccumulateClient, error) { + stream, err := c.cc.NewStream(ctx, &_Accumulator_serviceDesc.Streams[0], "/go.elastic.co.apm.module.apmgrpc.testservice.Accumulator/Accumulate", opts...) + if err != nil { + return nil, err + } + x := &accumulatorAccumulateClient{stream} + return x, nil +} + +type Accumulator_AccumulateClient interface { + Send(*AccumulateRequest) error + Recv() (*AccumulateReply, error) + grpc.ClientStream +} + +type accumulatorAccumulateClient struct { + grpc.ClientStream +} + +func (x *accumulatorAccumulateClient) Send(m *AccumulateRequest) error { + return x.ClientStream.SendMsg(m) +} + +func (x *accumulatorAccumulateClient) Recv() (*AccumulateReply, error) { + m := new(AccumulateReply) + if err := x.ClientStream.RecvMsg(m); err != nil { + return nil, err + } + return m, nil +} + +// AccumulatorServer is the server API for Accumulator service. +type AccumulatorServer interface { + Accumulate(Accumulator_AccumulateServer) error +} + +func RegisterAccumulatorServer(s *grpc.Server, srv AccumulatorServer) { + s.RegisterService(&_Accumulator_serviceDesc, srv) +} + +func _Accumulator_Accumulate_Handler(srv interface{}, stream grpc.ServerStream) error { + return srv.(AccumulatorServer).Accumulate(&accumulatorAccumulateServer{stream}) +} + +type Accumulator_AccumulateServer interface { + Send(*AccumulateReply) error + Recv() (*AccumulateRequest, error) + grpc.ServerStream +} + +type accumulatorAccumulateServer struct { + grpc.ServerStream +} + +func (x *accumulatorAccumulateServer) Send(m *AccumulateReply) error { + return x.ServerStream.SendMsg(m) +} + +func (x *accumulatorAccumulateServer) Recv() (*AccumulateRequest, error) { + m := new(AccumulateRequest) + if err := x.ServerStream.RecvMsg(m); err != nil { + return nil, err + } + return m, nil +} + +var _Accumulator_serviceDesc = grpc.ServiceDesc{ + ServiceName: "go.elastic.co.apm.module.apmgrpc.testservice.Accumulator", + HandlerType: (*AccumulatorServer)(nil), + Methods: []grpc.MethodDesc{}, + Streams: []grpc.StreamDesc{ + { + StreamName: "Accumulate", + Handler: _Accumulator_Accumulate_Handler, + ServerStreams: true, + ClientStreams: true, + }, + }, + Metadata: "testservice.proto", +} + +func init() { proto.RegisterFile("testservice.proto", fileDescriptor_testservice_6101db31074a6f55) } + +var fileDescriptor_testservice_6101db31074a6f55 = []byte{ + // 201 bytes of a gzipped FileDescriptorProto + 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xe2, 0x12, 0x2c, 0x49, 0x2d, 0x2e, + 0x29, 0x4e, 0x2d, 0x2a, 0xcb, 0x4c, 0x4e, 0xd5, 0x2b, 0x28, 0xca, 0x2f, 0xc9, 0x17, 0xd2, 0x49, + 0xcf, 0xd7, 0x4b, 0xcd, 0x49, 0x2c, 0x2e, 0xc9, 0x4c, 0xd6, 0x4b, 0xce, 0xd7, 0x4b, 0x2c, 0xc8, + 0xd5, 0xcb, 0xcd, 0x4f, 0x29, 0xcd, 0x49, 0x05, 0x31, 0xd3, 0x8b, 0x0a, 0x92, 0xf5, 0x90, 0xf4, + 0x28, 0x69, 0x72, 0x09, 0x3a, 0x26, 0x27, 0x97, 0xe6, 0x96, 0xe6, 0x24, 0x96, 0xa4, 0x06, 0xa5, + 0x16, 0x96, 0xa6, 0x16, 0x97, 0x08, 0x89, 0x70, 0xb1, 0x96, 0x25, 0xe6, 0x94, 0xa6, 0x4a, 0x30, + 0x2a, 0x30, 0x6a, 0x30, 0x07, 0x41, 0x38, 0x4a, 0xea, 0x5c, 0xfc, 0xc8, 0x4a, 0x0b, 0x72, 0x2a, + 0xb1, 0x2b, 0x34, 0x5a, 0xc4, 0xc8, 0xc5, 0x0d, 0x57, 0x99, 0x5f, 0x24, 0x34, 0x89, 0x91, 0x8b, + 0x0b, 0xa1, 0x53, 0xc8, 0x5e, 0x8f, 0x14, 0x17, 0xea, 0x61, 0x38, 0x4f, 0xca, 0x96, 0x7c, 0x03, + 0x0a, 0x72, 0x2a, 0x95, 0x18, 0x34, 0x18, 0x0d, 0x18, 0x9d, 0xcc, 0xa3, 0x4c, 0x51, 0x4c, 0xd1, + 0x4f, 0x2c, 0xc8, 0xd5, 0x87, 0x98, 0xa2, 0x0f, 0x35, 0x45, 0x3f, 0x33, 0xaf, 0x24, 0xb5, 0x28, + 0x2f, 0x31, 0x47, 0x1f, 0xc9, 0xb8, 0x24, 0x36, 0x70, 0x30, 0x1b, 0x03, 0x02, 0x00, 0x00, 0xff, + 0xff, 0xb0, 0xb9, 0xd9, 0xf0, 0x7b, 0x01, 0x00, 0x00, +} diff --git a/module/apmgrpc/internal/testservice/testservice.proto b/module/apmgrpc/internal/testservice/testservice.proto new file mode 100644 index 000000000..27b1503b7 --- /dev/null +++ b/module/apmgrpc/internal/testservice/testservice.proto @@ -0,0 +1,17 @@ +syntax = "proto3"; + +package go.elastic.co.apm.module.apmgrpc.testservice; + +option go_package = "go.elastic.co/apm/module/apmgrpc/internal/testservice"; + +service Accumulator { + rpc Accumulate(stream AccumulateRequest) returns (stream AccumulateReply) {} +} + +message AccumulateRequest { + int64 value = 1; +} + +message AccumulateReply { + int64 value = 1; +} diff --git a/module/apmgrpc/server.go b/module/apmgrpc/server.go index b423d2bfb..cd6a94c6c 100644 --- a/module/apmgrpc/server.go +++ b/module/apmgrpc/server.go @@ -41,9 +41,9 @@ var ( // NewUnaryServerInterceptor returns a grpc.UnaryServerInterceptor that // traces gRPC requests with the given options. // -// The interceptor will trace transactions with the "grpc" type for each -// incoming request. The transaction will be added to the context, so -// server methods can use apm.StartSpan with the provided context. +// The interceptor will trace transactions with the "request" type for +// each incoming request. The transaction will be added to the context, +// so server methods can use apm.StartSpan with the provided context. // // By default, the interceptor will trace with apm.DefaultTracer, // and will not recover any panics. Use WithTracer to specify an @@ -53,6 +53,7 @@ func NewUnaryServerInterceptor(o ...ServerOption) grpc.UnaryServerInterceptor { tracer: apm.DefaultTracer, recover: false, requestIgnorer: DefaultServerRequestIgnorer(), + streamIgnorer: DefaultServerStreamIgnorer(), } for _, o := range o { o(&opts) @@ -69,7 +70,7 @@ func NewUnaryServerInterceptor(o ...ServerOption) grpc.UnaryServerInterceptor { tx, ctx := startTransaction(ctx, opts.tracer, info.FullMethod) defer tx.End() - // TODO(axw) define context schema for RPC, + // TODO(axw) define span context schema for RPC, // including at least the peer address. defer func() { @@ -94,6 +95,62 @@ func NewUnaryServerInterceptor(o ...ServerOption) grpc.UnaryServerInterceptor { } } +// NewStreamServerInterceptor returns a grpc.StreamServerInterceptor that +// traces gRPC stream requests with the given options. +// +// The interceptor will trace transactions with the "request" type for each +// incoming stream request. The transaction will be added to the context, so +// server methods can use apm.StartSpan with the provided context. +// +// By default, the interceptor will trace with apm.DefaultTracer, and will +// not recover any panics. Use WithTracer to specify an alternative tracer, +// and WithRecovery to enable panic recovery. +func NewStreamServerInterceptor(o ...ServerOption) grpc.StreamServerInterceptor { + opts := serverOptions{ + tracer: apm.DefaultTracer, + recover: false, + streamIgnorer: DefaultServerStreamIgnorer(), + } + for _, o := range o { + o(&opts) + } + + return func( + srv interface{}, + stream grpc.ServerStream, + info *grpc.StreamServerInfo, + handler grpc.StreamHandler, + ) (err error) { + if !opts.tracer.Recording() || opts.streamIgnorer(info) { + return handler(srv, stream) + } + ctx := stream.Context() + tx, ctx := startTransaction(ctx, opts.tracer, info.FullMethod) + defer tx.End() + + // TODO(axw) define span context schema for RPC, + // including at least the peer address. + + defer func() { + r := recover() + if r != nil { + e := opts.tracer.Recovered(r) + e.SetTransaction(tx) + e.Context.SetFramework("grpc", grpc.Version) + e.Handled = opts.recover + e.Send() + if opts.recover { + err = status.Errorf(codes.Internal, "%s", r) + } else { + panic(r) + } + } + setTransactionResult(tx, err) + }() + return handler(srv, stream) + } +} + func startTransaction(ctx context.Context, tracer *apm.Tracer, name string) (*apm.Transaction, context.Context) { var opts apm.TransactionOptions if md, ok := metadata.FromIncomingContext(ctx); ok { @@ -157,6 +214,7 @@ type serverOptions struct { tracer *apm.Tracer recover bool requestIgnorer RequestIgnorerFunc + streamIgnorer StreamIgnorerFunc } // ServerOption sets options for server-side tracing. @@ -201,3 +259,19 @@ func WithServerRequestIgnorer(r RequestIgnorerFunc) ServerOption { o.requestIgnorer = r } } + +// StreamIgnorerFunc is the type of a function for use in +// WithServerStreamIgnorer. +type StreamIgnorerFunc func(*grpc.StreamServerInfo) bool + +// WithServerStreamIgnorer returns a ServerOption which sets s as the +// function to use to determine whether or not a server stream request +// should be ignored. If s is nil, all stream requests will be reported. +func WithServerStreamIgnorer(s StreamIgnorerFunc) ServerOption { + if s == nil { + s = IgnoreNoneStream + } + return func(o *serverOptions) { + o.streamIgnorer = s + } +} diff --git a/module/apmgrpc/server_test.go b/module/apmgrpc/server_test.go index be1a87d38..df13f054c 100644 --- a/module/apmgrpc/server_test.go +++ b/module/apmgrpc/server_test.go @@ -22,6 +22,7 @@ package apmgrpc_test import ( "errors" "fmt" + "io" "net" "reflect" "strings" @@ -41,6 +42,7 @@ import ( "go.elastic.co/apm" "go.elastic.co/apm/model" "go.elastic.co/apm/module/apmgrpc" + "go.elastic.co/apm/module/apmgrpc/internal/testservice" "go.elastic.co/apm/module/apmhttp" "go.elastic.co/apm/stacktrace" "go.elastic.co/apm/transport/transporttest" @@ -59,10 +61,10 @@ func TestServerTransaction(t *testing.T) { tracer, transport := transporttest.NewRecorderTracer() defer tracer.Close() - s, server, addr := newServer(t, tracer) + s, server, addr := newGreeterServer(t, tracer) defer s.GracefulStop() - conn, client := newClient(t, addr) + conn, client := newGreeterClient(t, addr) defer conn.Close() f(t, testParams{ @@ -174,10 +176,10 @@ func TestServerRecovery(t *testing.T) { tracer, transport := transporttest.NewRecorderTracer() defer tracer.Close() - s, server, addr := newServer(t, tracer, apmgrpc.WithRecovery()) + s, server, addr := newGreeterServer(t, tracer, apmgrpc.WithRecovery()) defer s.GracefulStop() - conn, client := newClient(t, addr) + conn, client := newGreeterClient(t, addr) defer conn.Close() server.panic = true @@ -201,12 +203,12 @@ func TestServerIgnorer(t *testing.T) { tracer, transport := transporttest.NewRecorderTracer() defer tracer.Close() - s, _, addr := newServer(t, tracer, apmgrpc.WithRecovery(), apmgrpc.WithServerRequestIgnorer(func(*grpc.UnaryServerInfo) bool { + s, _, addr := newGreeterServer(t, tracer, apmgrpc.WithRecovery(), apmgrpc.WithServerRequestIgnorer(func(*grpc.UnaryServerInfo) bool { return true })) defer s.GracefulStop() - conn, client := newClient(t, addr) + conn, client := newGreeterClient(t, addr) defer conn.Close() resp, err := client.SayHello(context.Background(), &pb.HelloRequest{Name: "birita"}) @@ -217,7 +219,42 @@ func TestServerIgnorer(t *testing.T) { assert.Empty(t, transport.Payloads()) } -func newServer(t *testing.T, tracer *apm.Tracer, opts ...apmgrpc.ServerOption) (*grpc.Server, *helloworldServer, net.Addr) { +func TestServerStream(t *testing.T) { + tracer, transport := transporttest.NewRecorderTracer() + defer tracer.Close() + + s, _, addr := newAccumulatorServer(t, tracer, apmgrpc.WithRecovery()) + defer s.GracefulStop() + + conn, client := newAccumulatorClient(t, addr) + defer conn.Close() + + accumulator, err := client.Accumulate(context.Background()) + require.NoError(t, err) + + var expected int64 + for i := 0; i < 10; i++ { + expected += int64(i) + err = accumulator.Send(&testservice.AccumulateRequest{Value: int64(i)}) + require.NoError(t, err) + reply, err := accumulator.Recv() + require.NoError(t, err) + assert.Equal(t, expected, reply.Value) + } + err = accumulator.CloseSend() + assert.NoError(t, err) + + // Wait for the server to close, ending its transaction. + _, err = accumulator.Recv() + assert.Equal(t, io.EOF, err) + + // There should be just one transaction for the entire stream. + tracer.Flush(nil) + transactions := transport.Payloads().Transactions + require.Len(t, transactions, 1) +} + +func newGreeterServer(t *testing.T, tracer *apm.Tracer, opts ...apmgrpc.ServerOption) (*grpc.Server, *helloworldServer, net.Addr) { // We always install grpc_recovery first to avoid panics // aborting the test process. We install it before the // apmgrpc interceptor so that apmgrpc can recover panics @@ -239,7 +276,7 @@ func newServer(t *testing.T, tracer *apm.Tracer, opts ...apmgrpc.ServerOption) ( return s, server, lis.Addr() } -func newClient(t *testing.T, addr net.Addr) (*grpc.ClientConn, pb.GreeterClient) { +func newGreeterClient(t *testing.T, addr net.Addr) (*grpc.ClientConn, pb.GreeterClient) { conn, err := grpc.Dial( addr.String(), grpc.WithInsecure(), grpc.WithUnaryInterceptor(apmgrpc.NewUnaryClientInterceptor()), @@ -248,6 +285,39 @@ func newClient(t *testing.T, addr net.Addr) (*grpc.ClientConn, pb.GreeterClient) return conn, pb.NewGreeterClient(conn) } +func newAccumulatorServer(t *testing.T, tracer *apm.Tracer, opts ...apmgrpc.ServerOption) (*grpc.Server, *accumulator, net.Addr) { + // We always install grpc_recovery first to avoid panics + // aborting the test process. We install it before the + // apmgrpc interceptor so that apmgrpc can recover panics + // itself if configured to do so. + interceptors := []grpc.StreamServerInterceptor{ + grpc_recovery.StreamServerInterceptor(), + } + serverOpts := []grpc.ServerOption{} + if tracer != nil { + opts = append(opts, apmgrpc.WithTracer(tracer)) + interceptors = append(interceptors, apmgrpc.NewStreamServerInterceptor(opts...)) + } + serverOpts = append(serverOpts, grpc_middleware.WithStreamServerChain(interceptors...)) + + s := grpc.NewServer(serverOpts...) + accumulator := &accumulator{} + testservice.RegisterAccumulatorServer(s, accumulator) + lis, err := net.Listen("tcp", "localhost:0") + require.NoError(t, err) + go s.Serve(lis) + return s, accumulator, lis.Addr() +} + +func newAccumulatorClient(t *testing.T, addr net.Addr) (*grpc.ClientConn, testservice.AccumulatorClient) { + conn, err := grpc.Dial( + addr.String(), grpc.WithInsecure(), + grpc.WithStreamInterceptor(apmgrpc.NewStreamClientInterceptor()), + ) + require.NoError(t, err) + return conn, testservice.NewAccumulatorClient(conn) +} + type helloworldServer struct { panic bool err error @@ -276,3 +346,30 @@ func (s *helloworldServer) SayHello(ctx context.Context, req *pb.HelloRequest) ( } return &pb.HelloReply{Message: "hello, " + req.Name}, nil } + +type accumulator struct { + panic bool + err error +} + +func (a *accumulator) Accumulate(srv testservice.Accumulator_AccumulateServer) error { + if a.panic { + panic(a.err) + } + if a.err != nil { + return a.err + } + var reply testservice.AccumulateReply + for { + req, err := srv.Recv() + if err == io.EOF { + return nil + } else if err != nil { + return err + } + reply.Value += req.Value + if err := srv.Send(&reply); err != nil { + return err + } + } +}