Skip to content

Commit

Permalink
Merge branch 'main' into juliann.zhou/fix-sns-topic-format
Browse files Browse the repository at this point in the history
  • Loading branch information
darccio authored Jun 4, 2024
2 parents eb1e1f5 + eeaff7c commit 1a078aa
Show file tree
Hide file tree
Showing 6 changed files with 230 additions and 291 deletions.
99 changes: 57 additions & 42 deletions contrib/google.golang.org/grpc/appsec.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,8 @@ import (
// UnaryHandler wrapper to use when AppSec is enabled to monitor its execution.
func appsecUnaryHandlerMiddleware(method string, span ddtrace.Span, handler grpc.UnaryHandler) grpc.UnaryHandler {
trace.SetAppSecEnabledTags(span)
return func(ctx context.Context, req interface{}) (interface{}, error) {
var err error
var blocked bool
return func(ctx context.Context, req any) (res any, rpcErr error) {
var blockedErr error
md, _ := metadata.FromIncomingContext(ctx)
clientIP := setClientIP(ctx, span, md)
args := types.HandlerOperationArgs{
Expand All @@ -41,106 +40,122 @@ func appsecUnaryHandlerMiddleware(method string, span ddtrace.Span, handler grpc
}
ctx, op := grpcsec.StartHandlerOperation(ctx, args, nil, func(op *types.HandlerOperation) {
dyngo.OnData(op, func(a *sharedsec.GRPCAction) {
code, e := a.GRPCWrapper(md)
blocked = a.Blocking()
err = status.Error(codes.Code(code), e.Error())
code, err := a.GRPCWrapper()
blockedErr = status.Error(codes.Code(code), err.Error())
})
})
defer func() {
events := op.Finish(types.HandlerOperationRes{})
if blocked {
if len(events) > 0 {
grpctrace.SetSecurityEventsTags(span, events)
}
if blockedErr != nil {
op.SetTag(trace.BlockedRequestTag, true)
rpcErr = blockedErr
}
grpctrace.SetRequestMetadataTags(span, md)
trace.SetTags(span, op.Tags())
if len(events) > 0 {
grpctrace.SetSecurityEventsTags(span, events)
}
}()

if err != nil {
return nil, err
// Check if a blocking condition was detected so far with the start operation event (ip blocking, metadata blocking, etc.)
if blockedErr != nil {
return nil, blockedErr
}
defer grpcsec.StartReceiveOperation(types.ReceiveOperationArgs{}, op).Finish(types.ReceiveOperationRes{Message: req})

rv, downstreamErr := handler(ctx, req)
if blocked {
return nil, err
// As of our gRPC abstract operation definition, we must fake a receive operation for unary RPCs (the same model fits both unary and streaming RPCs)
grpcsec.StartReceiveOperation(types.ReceiveOperationArgs{}, op).Finish(types.ReceiveOperationRes{Message: req})
// Check if a blocking condition was detected so far with the receive operation events
if blockedErr != nil {
return nil, blockedErr
}

return rv, downstreamErr
// Call the original handler - let the deferred function above handle the blocking condition and return error
return handler(ctx, req)
}
}

// StreamHandler wrapper to use when AppSec is enabled to monitor its execution.
func appsecStreamHandlerMiddleware(method string, span ddtrace.Span, handler grpc.StreamHandler) grpc.StreamHandler {
trace.SetAppSecEnabledTags(span)
return func(srv interface{}, stream grpc.ServerStream) error {
var err error
var blocked bool
return func(srv any, stream grpc.ServerStream) (rpcErr error) {
// Create a ServerStream wrapper with appsec RPC handler operation and the Go context (to implement the ServerStream interface)
appsecStream := &appsecServerStream{
ServerStream: stream,
// note: the blockedErr field is captured by the RPC handler's OnData closure below
}

ctx := stream.Context()
md, _ := metadata.FromIncomingContext(ctx)
clientIP := setClientIP(ctx, span, md)
grpctrace.SetRequestMetadataTags(span, md)

// Create the handler operation and listen to blocking gRPC actions to detect a blocking condition
args := types.HandlerOperationArgs{
Method: method,
Metadata: md,
ClientIP: clientIP,
}
ctx, op := grpcsec.StartHandlerOperation(ctx, args, nil, func(op *types.HandlerOperation) {
dyngo.OnData(op, func(a *sharedsec.GRPCAction) {
code, e := a.GRPCWrapper(md)
blocked = a.Blocking()
err = status.Error(codes.Code(code), e.Error())
code, e := a.GRPCWrapper()
appsecStream.blockedErr = status.Error(codes.Code(code), e.Error())
})
})
stream = appsecServerStream{
ServerStream: stream,
handlerOperation: op,
ctx: ctx,
}

// Finish constructing the appsec stream wrapper and replace the original one
appsecStream.handlerOperation = op
appsecStream.ctx = ctx

defer func() {
events := op.Finish(types.HandlerOperationRes{})
if blocked {
op.SetTag(trace.BlockedRequestTag, true)
}
trace.SetTags(span, op.Tags())

if len(events) > 0 {
grpctrace.SetSecurityEventsTags(span, events)
}
}()

if err != nil {
return err
}
if appsecStream.blockedErr != nil {
op.SetTag(trace.BlockedRequestTag, true)
// Change the RPC return error with appsec's
rpcErr = appsecStream.blockedErr
}

trace.SetTags(span, op.Tags())
}()

downstreamErr := handler(srv, stream)
if blocked {
return err
// Check if a blocking condition was detected so far with the start operation event (ip blocking, metadata blocking, etc.)
if appsecStream.blockedErr != nil {
return appsecStream.blockedErr
}

return downstreamErr
// Call the original handler - let the deferred function above handle the blocking condition and return error
return handler(srv, appsecStream)
}
}

type appsecServerStream struct {
grpc.ServerStream
handlerOperation *types.HandlerOperation
ctx context.Context

// blockedErr is used to store the error to return when a blocking sec event is detected.
blockedErr error
}

// RecvMsg implements grpc.ServerStream interface method to monitor its
// execution with AppSec.
func (ss appsecServerStream) RecvMsg(m interface{}) error {
func (ss *appsecServerStream) RecvMsg(m interface{}) (err error) {
op := grpcsec.StartReceiveOperation(types.ReceiveOperationArgs{}, ss.handlerOperation)
defer func() {
op.Finish(types.ReceiveOperationRes{Message: m})
if ss.blockedErr != nil {
// Change the function call return error with appsec's
err = ss.blockedErr
}
}()
return ss.ServerStream.RecvMsg(m)
}

func (ss appsecServerStream) Context() context.Context {
func (ss *appsecServerStream) Context() context.Context {
return ss.ctx
}

Expand Down
Loading

0 comments on commit 1a078aa

Please sign in to comment.