From d8acf16ca662d8603e6798eba2f6b9225ba479c1 Mon Sep 17 00:00:00 2001 From: Julio Guerra Date: Tue, 4 Jun 2024 15:21:11 +0200 Subject: [PATCH 1/2] appsec/grpc: fix rpc message blocking (#2723) Co-authored-by: Julio Guerra Co-authored-by: Eliott Bouhana <47679741+eliottness@users.noreply.github.com> --- contrib/google.golang.org/grpc/appsec.go | 99 +++--- contrib/google.golang.org/grpc/appsec_test.go | 302 ++++++------------ internal/appsec/emitter/sharedsec/actions.go | 24 +- internal/appsec/listener/grpcsec/grpc.go | 53 +-- 4 files changed, 191 insertions(+), 287 deletions(-) diff --git a/contrib/google.golang.org/grpc/appsec.go b/contrib/google.golang.org/grpc/appsec.go index 108d900838..6330cd8f4f 100644 --- a/contrib/google.golang.org/grpc/appsec.go +++ b/contrib/google.golang.org/grpc/appsec.go @@ -29,9 +29,8 @@ import ( // UnaryHandler wrapper to use when AppSec is enabled to monitor its execution. func appsecUnaryHandlerMiddleware(method string, span ddtrace.Span, handler grpc.UnaryHandler) grpc.UnaryHandler { trace.SetAppSecEnabledTags(span) - return func(ctx context.Context, req interface{}) (interface{}, error) { - var err error - var blocked bool + return func(ctx context.Context, req any) (res any, rpcErr error) { + var blockedErr error md, _ := metadata.FromIncomingContext(ctx) clientIP := setClientIP(ctx, span, md) args := types.HandlerOperationArgs{ @@ -41,48 +40,56 @@ func appsecUnaryHandlerMiddleware(method string, span ddtrace.Span, handler grpc } ctx, op := grpcsec.StartHandlerOperation(ctx, args, nil, func(op *types.HandlerOperation) { dyngo.OnData(op, func(a *sharedsec.GRPCAction) { - code, e := a.GRPCWrapper(md) - blocked = a.Blocking() - err = status.Error(codes.Code(code), e.Error()) + code, err := a.GRPCWrapper() + blockedErr = status.Error(codes.Code(code), err.Error()) }) }) defer func() { events := op.Finish(types.HandlerOperationRes{}) - if blocked { + if len(events) > 0 { + grpctrace.SetSecurityEventsTags(span, events) + } + if blockedErr != nil { op.SetTag(trace.BlockedRequestTag, true) + rpcErr = blockedErr } grpctrace.SetRequestMetadataTags(span, md) trace.SetTags(span, op.Tags()) - if len(events) > 0 { - grpctrace.SetSecurityEventsTags(span, events) - } }() - if err != nil { - return nil, err + // Check if a blocking condition was detected so far with the start operation event (ip blocking, metadata blocking, etc.) + if blockedErr != nil { + return nil, blockedErr } - defer grpcsec.StartReceiveOperation(types.ReceiveOperationArgs{}, op).Finish(types.ReceiveOperationRes{Message: req}) - rv, downstreamErr := handler(ctx, req) - if blocked { - return nil, err + // As of our gRPC abstract operation definition, we must fake a receive operation for unary RPCs (the same model fits both unary and streaming RPCs) + grpcsec.StartReceiveOperation(types.ReceiveOperationArgs{}, op).Finish(types.ReceiveOperationRes{Message: req}) + // Check if a blocking condition was detected so far with the receive operation events + if blockedErr != nil { + return nil, blockedErr } - return rv, downstreamErr + // Call the original handler - let the deferred function above handle the blocking condition and return error + return handler(ctx, req) } } // StreamHandler wrapper to use when AppSec is enabled to monitor its execution. func appsecStreamHandlerMiddleware(method string, span ddtrace.Span, handler grpc.StreamHandler) grpc.StreamHandler { trace.SetAppSecEnabledTags(span) - return func(srv interface{}, stream grpc.ServerStream) error { - var err error - var blocked bool + return func(srv any, stream grpc.ServerStream) (rpcErr error) { + // Create a ServerStream wrapper with appsec RPC handler operation and the Go context (to implement the ServerStream interface) + appsecStream := &appsecServerStream{ + ServerStream: stream, + // note: the blockedErr field is captured by the RPC handler's OnData closure below + } + ctx := stream.Context() md, _ := metadata.FromIncomingContext(ctx) clientIP := setClientIP(ctx, span, md) grpctrace.SetRequestMetadataTags(span, md) + // Create the handler operation and listen to blocking gRPC actions to detect a blocking condition args := types.HandlerOperationArgs{ Method: method, Metadata: md, @@ -90,37 +97,38 @@ func appsecStreamHandlerMiddleware(method string, span ddtrace.Span, handler grp } ctx, op := grpcsec.StartHandlerOperation(ctx, args, nil, func(op *types.HandlerOperation) { dyngo.OnData(op, func(a *sharedsec.GRPCAction) { - code, e := a.GRPCWrapper(md) - blocked = a.Blocking() - err = status.Error(codes.Code(code), e.Error()) + code, e := a.GRPCWrapper() + appsecStream.blockedErr = status.Error(codes.Code(code), e.Error()) }) }) - stream = appsecServerStream{ - ServerStream: stream, - handlerOperation: op, - ctx: ctx, - } + + // Finish constructing the appsec stream wrapper and replace the original one + appsecStream.handlerOperation = op + appsecStream.ctx = ctx + defer func() { events := op.Finish(types.HandlerOperationRes{}) - if blocked { - op.SetTag(trace.BlockedRequestTag, true) - } - trace.SetTags(span, op.Tags()) + if len(events) > 0 { grpctrace.SetSecurityEventsTags(span, events) } - }() - if err != nil { - return err - } + if appsecStream.blockedErr != nil { + op.SetTag(trace.BlockedRequestTag, true) + // Change the RPC return error with appsec's + rpcErr = appsecStream.blockedErr + } + + trace.SetTags(span, op.Tags()) + }() - downstreamErr := handler(srv, stream) - if blocked { - return err + // Check if a blocking condition was detected so far with the start operation event (ip blocking, metadata blocking, etc.) + if appsecStream.blockedErr != nil { + return appsecStream.blockedErr } - return downstreamErr + // Call the original handler - let the deferred function above handle the blocking condition and return error + return handler(srv, appsecStream) } } @@ -128,19 +136,26 @@ type appsecServerStream struct { grpc.ServerStream handlerOperation *types.HandlerOperation ctx context.Context + + // blockedErr is used to store the error to return when a blocking sec event is detected. + blockedErr error } // RecvMsg implements grpc.ServerStream interface method to monitor its // execution with AppSec. -func (ss appsecServerStream) RecvMsg(m interface{}) error { +func (ss *appsecServerStream) RecvMsg(m interface{}) (err error) { op := grpcsec.StartReceiveOperation(types.ReceiveOperationArgs{}, ss.handlerOperation) defer func() { op.Finish(types.ReceiveOperationRes{Message: m}) + if ss.blockedErr != nil { + // Change the function call return error with appsec's + err = ss.blockedErr + } }() return ss.ServerStream.RecvMsg(m) } -func (ss appsecServerStream) Context() context.Context { +func (ss *appsecServerStream) Context() context.Context { return ss.ctx } diff --git a/contrib/google.golang.org/grpc/appsec_test.go b/contrib/google.golang.org/grpc/appsec_test.go index a0de9b998e..911efa7c5d 100644 --- a/contrib/google.golang.org/grpc/appsec_test.go +++ b/contrib/google.golang.org/grpc/appsec_test.go @@ -10,7 +10,6 @@ import ( "encoding/json" "fmt" "net" - "strings" "testing" pappsec "gopkg.in/DataDog/dd-trace-go.v1/appsec" @@ -32,7 +31,7 @@ func TestAppSec(t *testing.T) { } setup := func() (FixtureClient, mocktracer.Tracer, func()) { - rig, err := newRig(false) + rig, err := newAppsecRig(false) require.NoError(t, err) mt := mocktracer.Start() @@ -139,99 +138,6 @@ func TestBlocking(t *testing.T) { t.Skip("appsec disabled") } - setup := func() (FixtureClient, mocktracer.Tracer, func()) { - rig, err := newRig(false) - require.NoError(t, err) - - mt := mocktracer.Start() - - return rig.client, mt, func() { - rig.Close() - mt.Stop() - } - } - - t.Run("unary-block", func(t *testing.T) { - client, mt, cleanup := setup() - defer cleanup() - - // Send a XSS attack in the payload along with the canary value in the RPC metadata - ctx := metadata.NewOutgoingContext(context.Background(), metadata.Pairs("dd-canary", "dd-test-scanner-log", "x-client-ip", "1.2.3.4")) - reply, err := client.Ping(ctx, &FixtureRequest{Name: ""}) - - require.Nil(t, reply) - require.Equal(t, codes.Aborted, status.Code(err)) - - finished := mt.FinishedSpans() - require.Len(t, finished, 1) - // The request should have the attack attempts - event, _ := finished[0].Tag("_dd.appsec.json").(string) - require.NotNil(t, event) - require.True(t, strings.Contains(event, "blk-001-001")) - }) - - t.Run("unary-no-block", func(t *testing.T) { - client, _, cleanup := setup() - defer cleanup() - - // Send a XSS attack in the payload along with the canary value in the RPC metadata - ctx := metadata.NewOutgoingContext(context.Background(), metadata.Pairs("dd-canary", "dd-test-scanner-log", "x-client-ip", "1.2.3.5")) - reply, err := client.Ping(ctx, &FixtureRequest{Name: ""}) - - require.Equal(t, "passed", reply.Message) - require.Equal(t, codes.OK, status.Code(err)) - }) - - t.Run("stream-block", func(t *testing.T) { - client, mt, cleanup := setup() - defer cleanup() - - ctx := metadata.NewOutgoingContext(context.Background(), metadata.Pairs("dd-canary", "dd-test-scanner-log", "x-client-ip", "1.2.3.4")) - stream, err := client.StreamPing(ctx) - require.NoError(t, err) - reply, err := stream.Recv() - - require.Equal(t, codes.Aborted, status.Code(err)) - require.Nil(t, reply) - - finished := mt.FinishedSpans() - require.Len(t, finished, 1) - // The request should have the attack attempts - event, _ := finished[0].Tag("_dd.appsec.json").(string) - require.NotNil(t, event) - require.True(t, strings.Contains(event, "blk-001-001")) - }) - - t.Run("stream-no-block", func(t *testing.T) { - client, _, cleanup := setup() - defer cleanup() - - ctx := metadata.NewOutgoingContext(context.Background(), metadata.Pairs("dd-canary", "dd-test-scanner-log", "x-client-ip", "1.2.3.5")) - stream, err := client.StreamPing(ctx) - require.NoError(t, err) - - // Send a XSS attack - err = stream.Send(&FixtureRequest{Name: ""}) - require.NoError(t, err) - reply, err := stream.Recv() - require.Equal(t, codes.OK, status.Code(err)) - require.Equal(t, "passed", reply.Message) - - err = stream.CloseSend() - require.NoError(t, err) - }) - -} - -// Test that user blocking works by using custom rules/rules data -func TestUserBlocking(t *testing.T) { - t.Setenv("DD_APPSEC_RULES", "../../../internal/appsec/testdata/blocking.json") - appsec.Start() - defer appsec.Stop() - if !appsec.Enabled() { - t.Skip("appsec disabled") - } - setup := func() (FixtureClient, mocktracer.Tracer, func()) { rig, err := newAppsecRig(false) require.NoError(t, err) @@ -245,113 +151,90 @@ func TestUserBlocking(t *testing.T) { } t.Run("unary-block", func(t *testing.T) { - client, mt, cleanup := setup() - defer cleanup() - - // Send a XSS attack in the payload along with the canary value in the RPC metadata - ctx := metadata.NewOutgoingContext(context.Background(), metadata.Pairs("user-id", "blocked-user-1")) - reply, err := client.Ping(ctx, &FixtureRequest{Name: ""}) - - require.Nil(t, reply) - require.Equal(t, codes.Aborted, status.Code(err)) - - finished := mt.FinishedSpans() - require.Len(t, finished, 1) - // The request should have the XSS and user ID attack attempts - event, _ := finished[0].Tag("_dd.appsec.json").(string) - require.NotNil(t, event) - require.True(t, strings.Contains(event, "blk-001-002")) - require.True(t, strings.Contains(event, "crs-941-110")) - }) - - t.Run("unary-no-block", func(t *testing.T) { - client, _, cleanup := setup() - defer cleanup() - - ctx := metadata.NewOutgoingContext(context.Background(), metadata.Pairs("user-id", "legit user")) - reply, err := client.Ping(ctx, &FixtureRequest{Name: ""}) - - require.Equal(t, "passed", reply.Message) - require.Equal(t, codes.OK, status.Code(err)) - }) - - // This test checks that IP blocking happens BEFORE user blocking, since user blocking needs the request handler - // to be invoked while IP blocking doesn't - t.Run("unary-mixed-block", func(t *testing.T) { - client, mt, cleanup := setup() - defer cleanup() - - ctx := metadata.NewOutgoingContext(context.Background(), metadata.Pairs("user-id", "blocked-user-1", "x-forwarded-for", "1.2.3.4")) - reply, err := client.Ping(ctx, &FixtureRequest{}) - - require.Nil(t, reply) - require.Equal(t, codes.Aborted, status.Code(err)) - - finished := mt.FinishedSpans() - require.Len(t, finished, 1) - event, _ := finished[0].Tag("_dd.appsec.json").(string) - require.NotNil(t, event) - require.True(t, strings.Contains(event, "blk-001-001")) - }) - - t.Run("stream-block", func(t *testing.T) { - client, mt, cleanup := setup() - defer cleanup() - - ctx := metadata.NewOutgoingContext(context.Background(), metadata.Pairs("user-id", "blocked-user-1")) - stream, err := client.StreamPing(ctx) - require.NoError(t, err) - reply, err := stream.Recv() - - require.Equal(t, codes.Aborted, status.Code(err)) - require.Nil(t, reply) - - finished := mt.FinishedSpans() - require.Len(t, finished, 1) - // The request should have the attack attempts - event, _ := finished[0].Tag("_dd.appsec.json").(string) - require.NotNil(t, event) - require.True(t, strings.Contains(event, "blk-001-002")) - }) - - t.Run("stream-no-block", func(t *testing.T) { - client, _, cleanup := setup() - defer cleanup() - - ctx := metadata.NewOutgoingContext(context.Background(), metadata.Pairs("user-id", "legit user")) - stream, err := client.StreamPing(ctx) - require.NoError(t, err) - - // Send a XSS attack - err = stream.Send(&FixtureRequest{Name: ""}) - require.NoError(t, err) - reply, err := stream.Recv() - require.Equal(t, codes.OK, status.Code(err)) - require.Equal(t, "passed", reply.Message) - - err = stream.CloseSend() - require.NoError(t, err) - }) - // This test checks that IP blocking happens BEFORE user blocking, since user blocking needs the request handler - // to be invoked while IP blocking doesn't - t.Run("stream-mixed-block", func(t *testing.T) { - client, mt, cleanup := setup() - defer cleanup() - - ctx := metadata.NewOutgoingContext(context.Background(), metadata.Pairs("user-id", "blocked-user-1", "x-forwarded-for", "1.2.3.4")) - stream, err := client.StreamPing(ctx) - require.NoError(t, err) - reply, err := stream.Recv() - - require.Equal(t, codes.Aborted, status.Code(err)) - require.Nil(t, reply) - - finished := mt.FinishedSpans() - require.Len(t, finished, 1) - // The request should have IP related the attack attempts - event, _ := finished[0].Tag("_dd.appsec.json").(string) - require.NotNil(t, event) - require.True(t, strings.Contains(event, "blk-001-001")) + for _, tc := range []struct { + name string + md metadata.MD + message string + expectedBlocked bool + expectedMatchedRules []string + expectedNotMatchedRules []string + }{ + { + name: "ip blocking", + md: metadata.Pairs("m1", "v1", "x-client-ip", "1.2.3.4", "user-id", "blocked-user-1"), + message: "$globals", + expectedMatchedRules: []string{"blk-001-001"}, // ip blocking alone as it comes first + expectedNotMatchedRules: []string{"crs-933-130-block", "blk-001-002"}, // no user blocking or message blocking + }, + { + name: "message blocking", + md: metadata.Pairs("m1", "v1", "x-client-ip", "1.2.3.5", "user-id", "legit-user-1"), + message: "$globals", + expectedMatchedRules: []string{"crs-933-130-block"}, // message blocking alone as it comes before user blocking + expectedNotMatchedRules: []string{"blk-001-002"}, // no user blocking + }, + { + name: "user blocking", + md: metadata.Pairs("m1", "v1", "x-client-ip", "1.2.3.5", "user-id", "blocked-user-1"), + message: "", + expectedMatchedRules: []string{"blk-001-002"}, // user blocking alone as it comes first in our test handler + expectedNotMatchedRules: []string{"crs-933-130-block"}, // message blocking alone as it comes before user blocking + }, + } { + t.Run(tc.name, func(t *testing.T) { + // Helper assertion function to run for the unary and stream tests + assert := func(t *testing.T, do func(client FixtureClient)) { + client, mt, cleanup := setup() + defer cleanup() + + do(client) + + finished := mt.FinishedSpans() + require.True(t, len(finished) >= 1) // streaming RPCs will have two spans, unary RPCs will have one + + // The request should have the security events + events, _ := finished[len(finished)-1 /* root span */].Tag("_dd.appsec.json").(string) + require.NotEmpty(t, events) + for _, rule := range tc.expectedMatchedRules { + require.Contains(t, events, rule) + } + for _, rule := range tc.expectedNotMatchedRules { + require.NotContains(t, events, rule) + } + } + + t.Run("unary", func(t *testing.T) { + assert(t, func(client FixtureClient) { + ctx := metadata.NewOutgoingContext(context.Background(), tc.md) + reply, err := client.Ping(ctx, &FixtureRequest{Name: tc.message}) + require.Nil(t, reply) + require.Equal(t, codes.Aborted, status.Code(err)) + }) + }) + + t.Run("stream", func(t *testing.T) { + assert(t, func(client FixtureClient) { + ctx := metadata.NewOutgoingContext(context.Background(), tc.md) + + // Open the stream + stream, err := client.StreamPing(ctx) + require.NoError(t, err) + defer func() { + require.NoError(t, stream.CloseSend()) + }() + + // Send a message + err = stream.Send(&FixtureRequest{Name: tc.message}) + require.NoError(t, err) + + // Receive a message + reply, err := stream.Recv() + require.Equal(t, codes.Aborted, status.Code(err)) + require.Nil(t, reply) + }) + }) + }) + } }) } @@ -367,7 +250,7 @@ func TestPasslist(t *testing.T) { } setup := func() (FixtureClient, mocktracer.Tracer, func()) { - rig, err := newRig(false) + rig, err := newAppsecRig(false) require.NoError(t, err) mt := mocktracer.Start() @@ -501,17 +384,20 @@ func (s *appsecFixtureServer) StreamPing(stream Fixture_StreamPingServer) (err e ctx := stream.Context() md, _ := metadata.FromIncomingContext(ctx) ids := md.Get("user-id") - if err := pappsec.SetUser(ctx, ids[0]); err != nil { - return err + if len(ids) > 0 { + if err := pappsec.SetUser(ctx, ids[0]); err != nil { + return err + } } return s.s.StreamPing(stream) } func (s *appsecFixtureServer) Ping(ctx context.Context, in *FixtureRequest) (*FixtureReply, error) { md, _ := metadata.FromIncomingContext(ctx) ids := md.Get("user-id") - if err := pappsec.SetUser(ctx, ids[0]); err != nil { - return nil, err + if len(ids) > 0 { + if err := pappsec.SetUser(ctx, ids[0]); err != nil { + return nil, err + } } - return s.s.Ping(ctx, in) } diff --git a/internal/appsec/emitter/sharedsec/actions.go b/internal/appsec/emitter/sharedsec/actions.go index 0a9c0cd3b8..49cd65a5d4 100644 --- a/internal/appsec/emitter/sharedsec/actions.go +++ b/internal/appsec/emitter/sharedsec/actions.go @@ -68,16 +68,18 @@ type ( } // GRPCWrapper is an opaque prototype abstraction for a gRPC handler (to avoid importing grpc) - // that takes metadata as input and returns a status code and an error + // that returns a status code and an error // TODO: rely on strongly typed actions (with the actual grpc types) by introducing WAF constructors // living in the contrib packages, along with their dependencies - something like `appsec.RegisterWAFConstructor(newGRPCWAF)` // Such constructors would receive the full appsec config and rules, so that they would be able to build // specific blocking actions. - GRPCWrapper func(map[string][]string) (uint32, error) + GRPCWrapper func() (uint32, error) // blockActionParams are the dynamic parameters to be provided to a "block_request" // action type upon invocation blockActionParams struct { + // GRPCStatusCode is the gRPC status code to be returned. Since 0 is the OK status, the value is nullable to + // be able to distinguish between unset and defaulting to Abort (10), or set to OK (0). GRPCStatusCode *int `mapstructure:"grpc_status_code,omitempty"` StatusCode int `mapstructure:"status_code"` Type string `mapstructure:"type,omitempty"` @@ -196,27 +198,27 @@ func newBlockRequestHandler(status int, ct string, payload []byte) http.Handler } func newGRPCBlockHandler(status int) GRPCWrapper { - return func(_ map[string][]string) (uint32, error) { + return func() (uint32, error) { return uint32(status), &events.BlockingSecurityEvent{} } } func blockParamsFromMap(params map[string]any) (blockActionParams, error) { - var ( - err error - ) + grpcCode := 10 p := blockActionParams{ - StatusCode: 403, - Type: "auto", + Type: "auto", + StatusCode: 403, + GRPCStatusCode: &grpcCode, } - mapstructure.WeakDecode(params, &p) + if err := mapstructure.WeakDecode(params, &p); err != nil { + return p, err + } - grpcCode := 10 if p.GRPCStatusCode == nil { p.GRPCStatusCode = &grpcCode } - return p, err + return p, nil } diff --git a/internal/appsec/listener/grpcsec/grpc.go b/internal/appsec/listener/grpcsec/grpc.go index 3dc4e16154..5b3326bddd 100644 --- a/internal/appsec/listener/grpcsec/grpc.go +++ b/internal/appsec/listener/grpcsec/grpc.go @@ -85,14 +85,21 @@ func newWafEventListener(wafHandle *waf.Handle, cfg *config.Config, limiter limi func (l *wafEventListener) onEvent(op *types.HandlerOperation, handlerArgs types.HandlerOperationArgs) { // Limit the maximum number of security events, as a streaming RPC could // receive unlimited number of messages where we could find security events - const maxWAFEventsPerRequest = 10 var ( nbEvents atomic.Uint32 logOnce sync.Once // per request - - events []any - mu sync.Mutex // events mutex ) + addEvents := func(events []any) { + const maxWAFEventsPerRequest = 10 + if nbEvents.Load() >= maxWAFEventsPerRequest { + logOnce.Do(func() { + log.Debug("appsec: ignoring new WAF event due to the maximum number of security events per grpc call reached") + }) + return + } + nbEvents.Add(uint32(len(events))) + shared.AddSecurityEvents(&op.SecurityEventsHolder, l.limiter, events) + } wafCtx, err := l.wafHandle.NewContextWithBudget(l.config.WAFTimeout) if err != nil { @@ -114,16 +121,18 @@ func (l *wafEventListener) onEvent(op *types.HandlerOperation, handlerArgs types // UserIDOperation happens when appsec.SetUser() is called. We run the WAF and apply actions to // see if the associated user should be blocked. Since we don't control the execution flow in this case // (SetUser is SDK), we delegate the responsibility of interrupting the handler to the user. - dyngo.On(op, func(userIDOp *sharedsec.UserIDOperation, args sharedsec.UserIDOperationArgs) { + dyngo.On(op, func(op *sharedsec.UserIDOperation, args sharedsec.UserIDOperationArgs) { values := map[string]any{ httpsec.UserIDAddr: args.UserID, } wafResult := shared.RunWAF(wafCtx, waf.RunAddressData{Persistent: values}) - if wafResult.HasActions() || wafResult.HasEvents() { - shared.ProcessActions(userIDOp, wafResult.Actions) - shared.AddSecurityEvents(&op.SecurityEventsHolder, l.limiter, wafResult.Events) + if wafResult.HasEvents() { + addEvents(wafResult.Events) log.Debug("appsec: WAF detected an authenticated user attack: %s", args.UserID) } + if wafResult.HasActions() { + shared.ProcessActions(op, wafResult.Actions) + } }) } @@ -137,10 +146,12 @@ func (l *wafEventListener) onEvent(op *types.HandlerOperation, handlerArgs types } wafResult := shared.RunWAF(wafCtx, waf.RunAddressData{Persistent: values}) - if wafResult.HasActions() || wafResult.HasEvents() { - interrupt := shared.ProcessActions(op, wafResult.Actions) - shared.AddSecurityEvents(&op.SecurityEventsHolder, l.limiter, wafResult.Events) + if wafResult.HasEvents() { + addEvents(wafResult.Events) log.Debug("appsec: WAF detected an attack before executing the request") + } + if wafResult.HasActions() { + interrupt := shared.ProcessActions(op, wafResult.Actions) if interrupt { wafCtx.Close() return @@ -149,13 +160,6 @@ func (l *wafEventListener) onEvent(op *types.HandlerOperation, handlerArgs types // When the gRPC handler receives a message dyngo.OnFinish(op, func(_ types.ReceiveOperation, res types.ReceiveOperationRes) { - if nbEvents.Load() == maxWAFEventsPerRequest { - logOnce.Do(func() { - log.Debug("appsec: ignoring the rpc message due to the maximum number of security events per grpc call reached") - }) - return - } - // Run the WAF on the rule addresses available and listened to by the sec rules var values waf.RunAddressData // Add the gRPC message to the values if the WAF rules are using it. @@ -174,28 +178,25 @@ func (l *wafEventListener) onEvent(op *types.HandlerOperation, handlerArgs types // Run the WAF, ignoring the returned actions - if any - since blocking after the request handler's // response is not supported at the moment. wafResult := shared.RunWAF(wafCtx, values) - if wafResult.HasEvents() { log.Debug("appsec: attack detected by the grpc waf") - nbEvents.Inc() - mu.Lock() - defer mu.Unlock() - events = append(events, wafResult.Events...) + addEvents(wafResult.Events) + } + if wafResult.HasActions() { + shared.ProcessActions(op, wafResult.Actions) } }) // When the gRPC handler finishes dyngo.OnFinish(op, func(op *types.HandlerOperation, _ types.HandlerOperationRes) { defer wafCtx.Close() - shared.AddWAFMonitoringTags(op, l.wafDiags.Version, wafCtx.Stats().Metrics()) + shared.AddWAFMonitoringTags(op, l.wafDiags.Version, wafCtx.Stats().Metrics()) // Log the following metrics once per instantiation of a WAF handle l.once.Do(func() { shared.AddRulesMonitoringTags(op, &l.wafDiags) op.SetTag(ext.ManualKeep, samplernames.AppSec) }) - - shared.AddSecurityEvents(&op.SecurityEventsHolder, l.limiter, events) }) } From eeaff7c2a73156f930187f7b6c795a1947b448b0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dario=20Casta=C3=B1=C3=A9?= Date: Tue, 4 Jun 2024 16:15:12 +0200 Subject: [PATCH 2/2] ddtrace/tracer: fix concurrent map writes when applying trace sampling rules and setting tags concurrently (#2727) --- ddtrace/tracer/rules_sampler.go | 11 +++++++---- ddtrace/tracer/span_test.go | 32 ++++++++++++++++++++++++++++++++ 2 files changed, 39 insertions(+), 4 deletions(-) diff --git a/ddtrace/tracer/rules_sampler.go b/ddtrace/tracer/rules_sampler.go index d0b61f8454..2cd911e3f7 100644 --- a/ddtrace/tracer/rules_sampler.go +++ b/ddtrace/tracer/rules_sampler.go @@ -498,20 +498,23 @@ func (rs *traceRulesSampler) sampleRules(span *span) bool { } func (rs *traceRulesSampler) applyRate(span *span, rate float64, now time.Time, sampler samplernames.SamplerName) { + span.Lock() + defer span.Unlock() + span.setMetric(keyRulesSamplerAppliedRate, rate) delete(span.Metrics, keySamplingPriorityRate) if !sampledByRate(span.TraceID, rate) { - span.setSamplingPriority(ext.PriorityUserReject, sampler) + span.setSamplingPriorityLocked(ext.PriorityUserReject, sampler) return } sampled, rate := rs.limiter.allowOne(now) if sampled { - span.setSamplingPriority(ext.PriorityUserKeep, sampler) + span.setSamplingPriorityLocked(ext.PriorityUserKeep, sampler) } else { - span.setSamplingPriority(ext.PriorityUserReject, sampler) + span.setSamplingPriorityLocked(ext.PriorityUserReject, sampler) } - span.SetTag(keyRulesSamplerLimiterRate, rate) + span.setMetric(keyRulesSamplerLimiterRate, rate) } // limit returns the rate limit set in the rules sampler, controlled by DD_TRACE_RATE_LIMIT, and diff --git a/ddtrace/tracer/span_test.go b/ddtrace/tracer/span_test.go index 3aa012e1cf..c0a0b12ca6 100644 --- a/ddtrace/tracer/span_test.go +++ b/ddtrace/tracer/span_test.go @@ -1076,3 +1076,35 @@ type stringer struct{} func (s *stringer) String() string { return "string" } + +// TestConcurrentSpanSetTag tests that setting tags concurrently on a span directly or +// not (through tracer.Inject when trace sampling rules are in place) does not cause +// concurrent map writes. It seems to only be consistently reproduced with the -count=100 +// flag when running go test, but it's a good test to have. +func TestConcurrentSpanSetTag(t *testing.T) { + testConcurrentSpanSetTag(t) + testConcurrentSpanSetTag(t) +} + +func testConcurrentSpanSetTag(t *testing.T) { + tracer, _, _, stop := startTestTracer(t, WithSamplingRules([]SamplingRule{NameRule("root", 1.0)})) + defer stop() + + span := tracer.StartSpan("root") + defer span.Finish() + + const n = 100 + wg := sync.WaitGroup{} + wg.Add(n * 2) + for i := 0; i < n; i++ { + go func() { + tracer.Inject(span.Context(), TextMapCarrier(map[string]string{})) + wg.Done() + }() + go func() { + span.SetTag("key", "value") + wg.Done() + }() + } + wg.Wait() +}