diff --git a/contrib/google.golang.org/grpc/appsec.go b/contrib/google.golang.org/grpc/appsec.go
index 108d900838..6330cd8f4f 100644
--- a/contrib/google.golang.org/grpc/appsec.go
+++ b/contrib/google.golang.org/grpc/appsec.go
@@ -29,9 +29,8 @@ import (
// UnaryHandler wrapper to use when AppSec is enabled to monitor its execution.
func appsecUnaryHandlerMiddleware(method string, span ddtrace.Span, handler grpc.UnaryHandler) grpc.UnaryHandler {
trace.SetAppSecEnabledTags(span)
- return func(ctx context.Context, req interface{}) (interface{}, error) {
- var err error
- var blocked bool
+ return func(ctx context.Context, req any) (res any, rpcErr error) {
+ var blockedErr error
md, _ := metadata.FromIncomingContext(ctx)
clientIP := setClientIP(ctx, span, md)
args := types.HandlerOperationArgs{
@@ -41,48 +40,56 @@ func appsecUnaryHandlerMiddleware(method string, span ddtrace.Span, handler grpc
}
ctx, op := grpcsec.StartHandlerOperation(ctx, args, nil, func(op *types.HandlerOperation) {
dyngo.OnData(op, func(a *sharedsec.GRPCAction) {
- code, e := a.GRPCWrapper(md)
- blocked = a.Blocking()
- err = status.Error(codes.Code(code), e.Error())
+ code, err := a.GRPCWrapper()
+ blockedErr = status.Error(codes.Code(code), err.Error())
})
})
defer func() {
events := op.Finish(types.HandlerOperationRes{})
- if blocked {
+ if len(events) > 0 {
+ grpctrace.SetSecurityEventsTags(span, events)
+ }
+ if blockedErr != nil {
op.SetTag(trace.BlockedRequestTag, true)
+ rpcErr = blockedErr
}
grpctrace.SetRequestMetadataTags(span, md)
trace.SetTags(span, op.Tags())
- if len(events) > 0 {
- grpctrace.SetSecurityEventsTags(span, events)
- }
}()
- if err != nil {
- return nil, err
+ // Check if a blocking condition was detected so far with the start operation event (ip blocking, metadata blocking, etc.)
+ if blockedErr != nil {
+ return nil, blockedErr
}
- defer grpcsec.StartReceiveOperation(types.ReceiveOperationArgs{}, op).Finish(types.ReceiveOperationRes{Message: req})
- rv, downstreamErr := handler(ctx, req)
- if blocked {
- return nil, err
+ // As of our gRPC abstract operation definition, we must fake a receive operation for unary RPCs (the same model fits both unary and streaming RPCs)
+ grpcsec.StartReceiveOperation(types.ReceiveOperationArgs{}, op).Finish(types.ReceiveOperationRes{Message: req})
+ // Check if a blocking condition was detected so far with the receive operation events
+ if blockedErr != nil {
+ return nil, blockedErr
}
- return rv, downstreamErr
+ // Call the original handler - let the deferred function above handle the blocking condition and return error
+ return handler(ctx, req)
}
}
// StreamHandler wrapper to use when AppSec is enabled to monitor its execution.
func appsecStreamHandlerMiddleware(method string, span ddtrace.Span, handler grpc.StreamHandler) grpc.StreamHandler {
trace.SetAppSecEnabledTags(span)
- return func(srv interface{}, stream grpc.ServerStream) error {
- var err error
- var blocked bool
+ return func(srv any, stream grpc.ServerStream) (rpcErr error) {
+ // Create a ServerStream wrapper with appsec RPC handler operation and the Go context (to implement the ServerStream interface)
+ appsecStream := &appsecServerStream{
+ ServerStream: stream,
+ // note: the blockedErr field is captured by the RPC handler's OnData closure below
+ }
+
ctx := stream.Context()
md, _ := metadata.FromIncomingContext(ctx)
clientIP := setClientIP(ctx, span, md)
grpctrace.SetRequestMetadataTags(span, md)
+ // Create the handler operation and listen to blocking gRPC actions to detect a blocking condition
args := types.HandlerOperationArgs{
Method: method,
Metadata: md,
@@ -90,37 +97,38 @@ func appsecStreamHandlerMiddleware(method string, span ddtrace.Span, handler grp
}
ctx, op := grpcsec.StartHandlerOperation(ctx, args, nil, func(op *types.HandlerOperation) {
dyngo.OnData(op, func(a *sharedsec.GRPCAction) {
- code, e := a.GRPCWrapper(md)
- blocked = a.Blocking()
- err = status.Error(codes.Code(code), e.Error())
+ code, e := a.GRPCWrapper()
+ appsecStream.blockedErr = status.Error(codes.Code(code), e.Error())
})
})
- stream = appsecServerStream{
- ServerStream: stream,
- handlerOperation: op,
- ctx: ctx,
- }
+
+ // Finish constructing the appsec stream wrapper and replace the original one
+ appsecStream.handlerOperation = op
+ appsecStream.ctx = ctx
+
defer func() {
events := op.Finish(types.HandlerOperationRes{})
- if blocked {
- op.SetTag(trace.BlockedRequestTag, true)
- }
- trace.SetTags(span, op.Tags())
+
if len(events) > 0 {
grpctrace.SetSecurityEventsTags(span, events)
}
- }()
- if err != nil {
- return err
- }
+ if appsecStream.blockedErr != nil {
+ op.SetTag(trace.BlockedRequestTag, true)
+ // Change the RPC return error with appsec's
+ rpcErr = appsecStream.blockedErr
+ }
+
+ trace.SetTags(span, op.Tags())
+ }()
- downstreamErr := handler(srv, stream)
- if blocked {
- return err
+ // Check if a blocking condition was detected so far with the start operation event (ip blocking, metadata blocking, etc.)
+ if appsecStream.blockedErr != nil {
+ return appsecStream.blockedErr
}
- return downstreamErr
+ // Call the original handler - let the deferred function above handle the blocking condition and return error
+ return handler(srv, appsecStream)
}
}
@@ -128,19 +136,26 @@ type appsecServerStream struct {
grpc.ServerStream
handlerOperation *types.HandlerOperation
ctx context.Context
+
+ // blockedErr is used to store the error to return when a blocking sec event is detected.
+ blockedErr error
}
// RecvMsg implements grpc.ServerStream interface method to monitor its
// execution with AppSec.
-func (ss appsecServerStream) RecvMsg(m interface{}) error {
+func (ss *appsecServerStream) RecvMsg(m interface{}) (err error) {
op := grpcsec.StartReceiveOperation(types.ReceiveOperationArgs{}, ss.handlerOperation)
defer func() {
op.Finish(types.ReceiveOperationRes{Message: m})
+ if ss.blockedErr != nil {
+ // Change the function call return error with appsec's
+ err = ss.blockedErr
+ }
}()
return ss.ServerStream.RecvMsg(m)
}
-func (ss appsecServerStream) Context() context.Context {
+func (ss *appsecServerStream) Context() context.Context {
return ss.ctx
}
diff --git a/contrib/google.golang.org/grpc/appsec_test.go b/contrib/google.golang.org/grpc/appsec_test.go
index a0de9b998e..911efa7c5d 100644
--- a/contrib/google.golang.org/grpc/appsec_test.go
+++ b/contrib/google.golang.org/grpc/appsec_test.go
@@ -10,7 +10,6 @@ import (
"encoding/json"
"fmt"
"net"
- "strings"
"testing"
pappsec "gopkg.in/DataDog/dd-trace-go.v1/appsec"
@@ -32,7 +31,7 @@ func TestAppSec(t *testing.T) {
}
setup := func() (FixtureClient, mocktracer.Tracer, func()) {
- rig, err := newRig(false)
+ rig, err := newAppsecRig(false)
require.NoError(t, err)
mt := mocktracer.Start()
@@ -139,99 +138,6 @@ func TestBlocking(t *testing.T) {
t.Skip("appsec disabled")
}
- setup := func() (FixtureClient, mocktracer.Tracer, func()) {
- rig, err := newRig(false)
- require.NoError(t, err)
-
- mt := mocktracer.Start()
-
- return rig.client, mt, func() {
- rig.Close()
- mt.Stop()
- }
- }
-
- t.Run("unary-block", func(t *testing.T) {
- client, mt, cleanup := setup()
- defer cleanup()
-
- // Send a XSS attack in the payload along with the canary value in the RPC metadata
- ctx := metadata.NewOutgoingContext(context.Background(), metadata.Pairs("dd-canary", "dd-test-scanner-log", "x-client-ip", "1.2.3.4"))
- reply, err := client.Ping(ctx, &FixtureRequest{Name: ""})
-
- require.Nil(t, reply)
- require.Equal(t, codes.Aborted, status.Code(err))
-
- finished := mt.FinishedSpans()
- require.Len(t, finished, 1)
- // The request should have the attack attempts
- event, _ := finished[0].Tag("_dd.appsec.json").(string)
- require.NotNil(t, event)
- require.True(t, strings.Contains(event, "blk-001-001"))
- })
-
- t.Run("unary-no-block", func(t *testing.T) {
- client, _, cleanup := setup()
- defer cleanup()
-
- // Send a XSS attack in the payload along with the canary value in the RPC metadata
- ctx := metadata.NewOutgoingContext(context.Background(), metadata.Pairs("dd-canary", "dd-test-scanner-log", "x-client-ip", "1.2.3.5"))
- reply, err := client.Ping(ctx, &FixtureRequest{Name: ""})
-
- require.Equal(t, "passed", reply.Message)
- require.Equal(t, codes.OK, status.Code(err))
- })
-
- t.Run("stream-block", func(t *testing.T) {
- client, mt, cleanup := setup()
- defer cleanup()
-
- ctx := metadata.NewOutgoingContext(context.Background(), metadata.Pairs("dd-canary", "dd-test-scanner-log", "x-client-ip", "1.2.3.4"))
- stream, err := client.StreamPing(ctx)
- require.NoError(t, err)
- reply, err := stream.Recv()
-
- require.Equal(t, codes.Aborted, status.Code(err))
- require.Nil(t, reply)
-
- finished := mt.FinishedSpans()
- require.Len(t, finished, 1)
- // The request should have the attack attempts
- event, _ := finished[0].Tag("_dd.appsec.json").(string)
- require.NotNil(t, event)
- require.True(t, strings.Contains(event, "blk-001-001"))
- })
-
- t.Run("stream-no-block", func(t *testing.T) {
- client, _, cleanup := setup()
- defer cleanup()
-
- ctx := metadata.NewOutgoingContext(context.Background(), metadata.Pairs("dd-canary", "dd-test-scanner-log", "x-client-ip", "1.2.3.5"))
- stream, err := client.StreamPing(ctx)
- require.NoError(t, err)
-
- // Send a XSS attack
- err = stream.Send(&FixtureRequest{Name: ""})
- require.NoError(t, err)
- reply, err := stream.Recv()
- require.Equal(t, codes.OK, status.Code(err))
- require.Equal(t, "passed", reply.Message)
-
- err = stream.CloseSend()
- require.NoError(t, err)
- })
-
-}
-
-// Test that user blocking works by using custom rules/rules data
-func TestUserBlocking(t *testing.T) {
- t.Setenv("DD_APPSEC_RULES", "../../../internal/appsec/testdata/blocking.json")
- appsec.Start()
- defer appsec.Stop()
- if !appsec.Enabled() {
- t.Skip("appsec disabled")
- }
-
setup := func() (FixtureClient, mocktracer.Tracer, func()) {
rig, err := newAppsecRig(false)
require.NoError(t, err)
@@ -245,113 +151,90 @@ func TestUserBlocking(t *testing.T) {
}
t.Run("unary-block", func(t *testing.T) {
- client, mt, cleanup := setup()
- defer cleanup()
-
- // Send a XSS attack in the payload along with the canary value in the RPC metadata
- ctx := metadata.NewOutgoingContext(context.Background(), metadata.Pairs("user-id", "blocked-user-1"))
- reply, err := client.Ping(ctx, &FixtureRequest{Name: ""})
-
- require.Nil(t, reply)
- require.Equal(t, codes.Aborted, status.Code(err))
-
- finished := mt.FinishedSpans()
- require.Len(t, finished, 1)
- // The request should have the XSS and user ID attack attempts
- event, _ := finished[0].Tag("_dd.appsec.json").(string)
- require.NotNil(t, event)
- require.True(t, strings.Contains(event, "blk-001-002"))
- require.True(t, strings.Contains(event, "crs-941-110"))
- })
-
- t.Run("unary-no-block", func(t *testing.T) {
- client, _, cleanup := setup()
- defer cleanup()
-
- ctx := metadata.NewOutgoingContext(context.Background(), metadata.Pairs("user-id", "legit user"))
- reply, err := client.Ping(ctx, &FixtureRequest{Name: ""})
-
- require.Equal(t, "passed", reply.Message)
- require.Equal(t, codes.OK, status.Code(err))
- })
-
- // This test checks that IP blocking happens BEFORE user blocking, since user blocking needs the request handler
- // to be invoked while IP blocking doesn't
- t.Run("unary-mixed-block", func(t *testing.T) {
- client, mt, cleanup := setup()
- defer cleanup()
-
- ctx := metadata.NewOutgoingContext(context.Background(), metadata.Pairs("user-id", "blocked-user-1", "x-forwarded-for", "1.2.3.4"))
- reply, err := client.Ping(ctx, &FixtureRequest{})
-
- require.Nil(t, reply)
- require.Equal(t, codes.Aborted, status.Code(err))
-
- finished := mt.FinishedSpans()
- require.Len(t, finished, 1)
- event, _ := finished[0].Tag("_dd.appsec.json").(string)
- require.NotNil(t, event)
- require.True(t, strings.Contains(event, "blk-001-001"))
- })
-
- t.Run("stream-block", func(t *testing.T) {
- client, mt, cleanup := setup()
- defer cleanup()
-
- ctx := metadata.NewOutgoingContext(context.Background(), metadata.Pairs("user-id", "blocked-user-1"))
- stream, err := client.StreamPing(ctx)
- require.NoError(t, err)
- reply, err := stream.Recv()
-
- require.Equal(t, codes.Aborted, status.Code(err))
- require.Nil(t, reply)
-
- finished := mt.FinishedSpans()
- require.Len(t, finished, 1)
- // The request should have the attack attempts
- event, _ := finished[0].Tag("_dd.appsec.json").(string)
- require.NotNil(t, event)
- require.True(t, strings.Contains(event, "blk-001-002"))
- })
-
- t.Run("stream-no-block", func(t *testing.T) {
- client, _, cleanup := setup()
- defer cleanup()
-
- ctx := metadata.NewOutgoingContext(context.Background(), metadata.Pairs("user-id", "legit user"))
- stream, err := client.StreamPing(ctx)
- require.NoError(t, err)
-
- // Send a XSS attack
- err = stream.Send(&FixtureRequest{Name: ""})
- require.NoError(t, err)
- reply, err := stream.Recv()
- require.Equal(t, codes.OK, status.Code(err))
- require.Equal(t, "passed", reply.Message)
-
- err = stream.CloseSend()
- require.NoError(t, err)
- })
- // This test checks that IP blocking happens BEFORE user blocking, since user blocking needs the request handler
- // to be invoked while IP blocking doesn't
- t.Run("stream-mixed-block", func(t *testing.T) {
- client, mt, cleanup := setup()
- defer cleanup()
-
- ctx := metadata.NewOutgoingContext(context.Background(), metadata.Pairs("user-id", "blocked-user-1", "x-forwarded-for", "1.2.3.4"))
- stream, err := client.StreamPing(ctx)
- require.NoError(t, err)
- reply, err := stream.Recv()
-
- require.Equal(t, codes.Aborted, status.Code(err))
- require.Nil(t, reply)
-
- finished := mt.FinishedSpans()
- require.Len(t, finished, 1)
- // The request should have IP related the attack attempts
- event, _ := finished[0].Tag("_dd.appsec.json").(string)
- require.NotNil(t, event)
- require.True(t, strings.Contains(event, "blk-001-001"))
+ for _, tc := range []struct {
+ name string
+ md metadata.MD
+ message string
+ expectedBlocked bool
+ expectedMatchedRules []string
+ expectedNotMatchedRules []string
+ }{
+ {
+ name: "ip blocking",
+ md: metadata.Pairs("m1", "v1", "x-client-ip", "1.2.3.4", "user-id", "blocked-user-1"),
+ message: "$globals",
+ expectedMatchedRules: []string{"blk-001-001"}, // ip blocking alone as it comes first
+ expectedNotMatchedRules: []string{"crs-933-130-block", "blk-001-002"}, // no user blocking or message blocking
+ },
+ {
+ name: "message blocking",
+ md: metadata.Pairs("m1", "v1", "x-client-ip", "1.2.3.5", "user-id", "legit-user-1"),
+ message: "$globals",
+ expectedMatchedRules: []string{"crs-933-130-block"}, // message blocking alone as it comes before user blocking
+ expectedNotMatchedRules: []string{"blk-001-002"}, // no user blocking
+ },
+ {
+ name: "user blocking",
+ md: metadata.Pairs("m1", "v1", "x-client-ip", "1.2.3.5", "user-id", "blocked-user-1"),
+ message: "",
+ expectedMatchedRules: []string{"blk-001-002"}, // user blocking alone as it comes first in our test handler
+ expectedNotMatchedRules: []string{"crs-933-130-block"}, // message blocking alone as it comes before user blocking
+ },
+ } {
+ t.Run(tc.name, func(t *testing.T) {
+ // Helper assertion function to run for the unary and stream tests
+ assert := func(t *testing.T, do func(client FixtureClient)) {
+ client, mt, cleanup := setup()
+ defer cleanup()
+
+ do(client)
+
+ finished := mt.FinishedSpans()
+ require.True(t, len(finished) >= 1) // streaming RPCs will have two spans, unary RPCs will have one
+
+ // The request should have the security events
+ events, _ := finished[len(finished)-1 /* root span */].Tag("_dd.appsec.json").(string)
+ require.NotEmpty(t, events)
+ for _, rule := range tc.expectedMatchedRules {
+ require.Contains(t, events, rule)
+ }
+ for _, rule := range tc.expectedNotMatchedRules {
+ require.NotContains(t, events, rule)
+ }
+ }
+
+ t.Run("unary", func(t *testing.T) {
+ assert(t, func(client FixtureClient) {
+ ctx := metadata.NewOutgoingContext(context.Background(), tc.md)
+ reply, err := client.Ping(ctx, &FixtureRequest{Name: tc.message})
+ require.Nil(t, reply)
+ require.Equal(t, codes.Aborted, status.Code(err))
+ })
+ })
+
+ t.Run("stream", func(t *testing.T) {
+ assert(t, func(client FixtureClient) {
+ ctx := metadata.NewOutgoingContext(context.Background(), tc.md)
+
+ // Open the stream
+ stream, err := client.StreamPing(ctx)
+ require.NoError(t, err)
+ defer func() {
+ require.NoError(t, stream.CloseSend())
+ }()
+
+ // Send a message
+ err = stream.Send(&FixtureRequest{Name: tc.message})
+ require.NoError(t, err)
+
+ // Receive a message
+ reply, err := stream.Recv()
+ require.Equal(t, codes.Aborted, status.Code(err))
+ require.Nil(t, reply)
+ })
+ })
+ })
+ }
})
}
@@ -367,7 +250,7 @@ func TestPasslist(t *testing.T) {
}
setup := func() (FixtureClient, mocktracer.Tracer, func()) {
- rig, err := newRig(false)
+ rig, err := newAppsecRig(false)
require.NoError(t, err)
mt := mocktracer.Start()
@@ -501,17 +384,20 @@ func (s *appsecFixtureServer) StreamPing(stream Fixture_StreamPingServer) (err e
ctx := stream.Context()
md, _ := metadata.FromIncomingContext(ctx)
ids := md.Get("user-id")
- if err := pappsec.SetUser(ctx, ids[0]); err != nil {
- return err
+ if len(ids) > 0 {
+ if err := pappsec.SetUser(ctx, ids[0]); err != nil {
+ return err
+ }
}
return s.s.StreamPing(stream)
}
func (s *appsecFixtureServer) Ping(ctx context.Context, in *FixtureRequest) (*FixtureReply, error) {
md, _ := metadata.FromIncomingContext(ctx)
ids := md.Get("user-id")
- if err := pappsec.SetUser(ctx, ids[0]); err != nil {
- return nil, err
+ if len(ids) > 0 {
+ if err := pappsec.SetUser(ctx, ids[0]); err != nil {
+ return nil, err
+ }
}
-
return s.s.Ping(ctx, in)
}
diff --git a/ddtrace/tracer/rules_sampler.go b/ddtrace/tracer/rules_sampler.go
index d0b61f8454..2cd911e3f7 100644
--- a/ddtrace/tracer/rules_sampler.go
+++ b/ddtrace/tracer/rules_sampler.go
@@ -498,20 +498,23 @@ func (rs *traceRulesSampler) sampleRules(span *span) bool {
}
func (rs *traceRulesSampler) applyRate(span *span, rate float64, now time.Time, sampler samplernames.SamplerName) {
+ span.Lock()
+ defer span.Unlock()
+
span.setMetric(keyRulesSamplerAppliedRate, rate)
delete(span.Metrics, keySamplingPriorityRate)
if !sampledByRate(span.TraceID, rate) {
- span.setSamplingPriority(ext.PriorityUserReject, sampler)
+ span.setSamplingPriorityLocked(ext.PriorityUserReject, sampler)
return
}
sampled, rate := rs.limiter.allowOne(now)
if sampled {
- span.setSamplingPriority(ext.PriorityUserKeep, sampler)
+ span.setSamplingPriorityLocked(ext.PriorityUserKeep, sampler)
} else {
- span.setSamplingPriority(ext.PriorityUserReject, sampler)
+ span.setSamplingPriorityLocked(ext.PriorityUserReject, sampler)
}
- span.SetTag(keyRulesSamplerLimiterRate, rate)
+ span.setMetric(keyRulesSamplerLimiterRate, rate)
}
// limit returns the rate limit set in the rules sampler, controlled by DD_TRACE_RATE_LIMIT, and
diff --git a/ddtrace/tracer/span_test.go b/ddtrace/tracer/span_test.go
index 3aa012e1cf..c0a0b12ca6 100644
--- a/ddtrace/tracer/span_test.go
+++ b/ddtrace/tracer/span_test.go
@@ -1076,3 +1076,35 @@ type stringer struct{}
func (s *stringer) String() string {
return "string"
}
+
+// TestConcurrentSpanSetTag tests that setting tags concurrently on a span directly or
+// not (through tracer.Inject when trace sampling rules are in place) does not cause
+// concurrent map writes. It seems to only be consistently reproduced with the -count=100
+// flag when running go test, but it's a good test to have.
+func TestConcurrentSpanSetTag(t *testing.T) {
+ testConcurrentSpanSetTag(t)
+ testConcurrentSpanSetTag(t)
+}
+
+func testConcurrentSpanSetTag(t *testing.T) {
+ tracer, _, _, stop := startTestTracer(t, WithSamplingRules([]SamplingRule{NameRule("root", 1.0)}))
+ defer stop()
+
+ span := tracer.StartSpan("root")
+ defer span.Finish()
+
+ const n = 100
+ wg := sync.WaitGroup{}
+ wg.Add(n * 2)
+ for i := 0; i < n; i++ {
+ go func() {
+ tracer.Inject(span.Context(), TextMapCarrier(map[string]string{}))
+ wg.Done()
+ }()
+ go func() {
+ span.SetTag("key", "value")
+ wg.Done()
+ }()
+ }
+ wg.Wait()
+}
diff --git a/internal/appsec/emitter/sharedsec/actions.go b/internal/appsec/emitter/sharedsec/actions.go
index 0a9c0cd3b8..49cd65a5d4 100644
--- a/internal/appsec/emitter/sharedsec/actions.go
+++ b/internal/appsec/emitter/sharedsec/actions.go
@@ -68,16 +68,18 @@ type (
}
// GRPCWrapper is an opaque prototype abstraction for a gRPC handler (to avoid importing grpc)
- // that takes metadata as input and returns a status code and an error
+ // that returns a status code and an error
// TODO: rely on strongly typed actions (with the actual grpc types) by introducing WAF constructors
// living in the contrib packages, along with their dependencies - something like `appsec.RegisterWAFConstructor(newGRPCWAF)`
// Such constructors would receive the full appsec config and rules, so that they would be able to build
// specific blocking actions.
- GRPCWrapper func(map[string][]string) (uint32, error)
+ GRPCWrapper func() (uint32, error)
// blockActionParams are the dynamic parameters to be provided to a "block_request"
// action type upon invocation
blockActionParams struct {
+ // GRPCStatusCode is the gRPC status code to be returned. Since 0 is the OK status, the value is nullable to
+ // be able to distinguish between unset and defaulting to Abort (10), or set to OK (0).
GRPCStatusCode *int `mapstructure:"grpc_status_code,omitempty"`
StatusCode int `mapstructure:"status_code"`
Type string `mapstructure:"type,omitempty"`
@@ -196,27 +198,27 @@ func newBlockRequestHandler(status int, ct string, payload []byte) http.Handler
}
func newGRPCBlockHandler(status int) GRPCWrapper {
- return func(_ map[string][]string) (uint32, error) {
+ return func() (uint32, error) {
return uint32(status), &events.BlockingSecurityEvent{}
}
}
func blockParamsFromMap(params map[string]any) (blockActionParams, error) {
- var (
- err error
- )
+ grpcCode := 10
p := blockActionParams{
- StatusCode: 403,
- Type: "auto",
+ Type: "auto",
+ StatusCode: 403,
+ GRPCStatusCode: &grpcCode,
}
- mapstructure.WeakDecode(params, &p)
+ if err := mapstructure.WeakDecode(params, &p); err != nil {
+ return p, err
+ }
- grpcCode := 10
if p.GRPCStatusCode == nil {
p.GRPCStatusCode = &grpcCode
}
- return p, err
+ return p, nil
}
diff --git a/internal/appsec/listener/grpcsec/grpc.go b/internal/appsec/listener/grpcsec/grpc.go
index 3dc4e16154..5b3326bddd 100644
--- a/internal/appsec/listener/grpcsec/grpc.go
+++ b/internal/appsec/listener/grpcsec/grpc.go
@@ -85,14 +85,21 @@ func newWafEventListener(wafHandle *waf.Handle, cfg *config.Config, limiter limi
func (l *wafEventListener) onEvent(op *types.HandlerOperation, handlerArgs types.HandlerOperationArgs) {
// Limit the maximum number of security events, as a streaming RPC could
// receive unlimited number of messages where we could find security events
- const maxWAFEventsPerRequest = 10
var (
nbEvents atomic.Uint32
logOnce sync.Once // per request
-
- events []any
- mu sync.Mutex // events mutex
)
+ addEvents := func(events []any) {
+ const maxWAFEventsPerRequest = 10
+ if nbEvents.Load() >= maxWAFEventsPerRequest {
+ logOnce.Do(func() {
+ log.Debug("appsec: ignoring new WAF event due to the maximum number of security events per grpc call reached")
+ })
+ return
+ }
+ nbEvents.Add(uint32(len(events)))
+ shared.AddSecurityEvents(&op.SecurityEventsHolder, l.limiter, events)
+ }
wafCtx, err := l.wafHandle.NewContextWithBudget(l.config.WAFTimeout)
if err != nil {
@@ -114,16 +121,18 @@ func (l *wafEventListener) onEvent(op *types.HandlerOperation, handlerArgs types
// UserIDOperation happens when appsec.SetUser() is called. We run the WAF and apply actions to
// see if the associated user should be blocked. Since we don't control the execution flow in this case
// (SetUser is SDK), we delegate the responsibility of interrupting the handler to the user.
- dyngo.On(op, func(userIDOp *sharedsec.UserIDOperation, args sharedsec.UserIDOperationArgs) {
+ dyngo.On(op, func(op *sharedsec.UserIDOperation, args sharedsec.UserIDOperationArgs) {
values := map[string]any{
httpsec.UserIDAddr: args.UserID,
}
wafResult := shared.RunWAF(wafCtx, waf.RunAddressData{Persistent: values})
- if wafResult.HasActions() || wafResult.HasEvents() {
- shared.ProcessActions(userIDOp, wafResult.Actions)
- shared.AddSecurityEvents(&op.SecurityEventsHolder, l.limiter, wafResult.Events)
+ if wafResult.HasEvents() {
+ addEvents(wafResult.Events)
log.Debug("appsec: WAF detected an authenticated user attack: %s", args.UserID)
}
+ if wafResult.HasActions() {
+ shared.ProcessActions(op, wafResult.Actions)
+ }
})
}
@@ -137,10 +146,12 @@ func (l *wafEventListener) onEvent(op *types.HandlerOperation, handlerArgs types
}
wafResult := shared.RunWAF(wafCtx, waf.RunAddressData{Persistent: values})
- if wafResult.HasActions() || wafResult.HasEvents() {
- interrupt := shared.ProcessActions(op, wafResult.Actions)
- shared.AddSecurityEvents(&op.SecurityEventsHolder, l.limiter, wafResult.Events)
+ if wafResult.HasEvents() {
+ addEvents(wafResult.Events)
log.Debug("appsec: WAF detected an attack before executing the request")
+ }
+ if wafResult.HasActions() {
+ interrupt := shared.ProcessActions(op, wafResult.Actions)
if interrupt {
wafCtx.Close()
return
@@ -149,13 +160,6 @@ func (l *wafEventListener) onEvent(op *types.HandlerOperation, handlerArgs types
// When the gRPC handler receives a message
dyngo.OnFinish(op, func(_ types.ReceiveOperation, res types.ReceiveOperationRes) {
- if nbEvents.Load() == maxWAFEventsPerRequest {
- logOnce.Do(func() {
- log.Debug("appsec: ignoring the rpc message due to the maximum number of security events per grpc call reached")
- })
- return
- }
-
// Run the WAF on the rule addresses available and listened to by the sec rules
var values waf.RunAddressData
// Add the gRPC message to the values if the WAF rules are using it.
@@ -174,28 +178,25 @@ func (l *wafEventListener) onEvent(op *types.HandlerOperation, handlerArgs types
// Run the WAF, ignoring the returned actions - if any - since blocking after the request handler's
// response is not supported at the moment.
wafResult := shared.RunWAF(wafCtx, values)
-
if wafResult.HasEvents() {
log.Debug("appsec: attack detected by the grpc waf")
- nbEvents.Inc()
- mu.Lock()
- defer mu.Unlock()
- events = append(events, wafResult.Events...)
+ addEvents(wafResult.Events)
+ }
+ if wafResult.HasActions() {
+ shared.ProcessActions(op, wafResult.Actions)
}
})
// When the gRPC handler finishes
dyngo.OnFinish(op, func(op *types.HandlerOperation, _ types.HandlerOperationRes) {
defer wafCtx.Close()
- shared.AddWAFMonitoringTags(op, l.wafDiags.Version, wafCtx.Stats().Metrics())
+ shared.AddWAFMonitoringTags(op, l.wafDiags.Version, wafCtx.Stats().Metrics())
// Log the following metrics once per instantiation of a WAF handle
l.once.Do(func() {
shared.AddRulesMonitoringTags(op, &l.wafDiags)
op.SetTag(ext.ManualKeep, samplernames.AppSec)
})
-
- shared.AddSecurityEvents(&op.SecurityEventsHolder, l.limiter, events)
})
}