Skip to content

Commit

Permalink
Add scrubbing to octtrpc
Browse files Browse the repository at this point in the history
Allow enabling adding the request/response to the ttrpc span as
attributes, since there are other customers of our code.

Allow scrubbing the `proto.Message` request and response messages in the
ttrpc server interceptor by specifying an arbitrary function to update
a clone of the payloads.

Signed-off-by: Hamza El-Saawy <[email protected]>
  • Loading branch information
helsaawy committed Jan 26, 2024
1 parent f35aaad commit d455f6d
Show file tree
Hide file tree
Showing 11 changed files with 854 additions and 16 deletions.
2 changes: 1 addition & 1 deletion cmd/containerd-shim-runhcs-v1/exec_hcs.go
Original file line number Diff line number Diff line change
Expand Up @@ -517,7 +517,7 @@ func (he *hcsExec) waitForContainerExit() {
trace.StringAttribute("tid", he.tid),
trace.StringAttribute("eid", he.id))

// wait for container or process to exit and ckean up resrources
// wait for container or process to exit and clean up resources
select {
case <-he.c.WaitChannel():
// Container exited first. We need to force the process into the exited
Expand Down
8 changes: 5 additions & 3 deletions cmd/containerd-shim-runhcs-v1/serve.go
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ var serveCommand = cli.Command{
// TODO: JTERRY75 we need this to be the reconnect log listener or
// switch to events
// TODO: JTERRY75 switch containerd to use the protected path.
//const logAddrFmt = "\\\\.\\pipe\\ProtectedPrefix\\Administrators\\containerd-shim-%s-%s-log"
// const logAddrFmt = "\\\\.\\pipe\\ProtectedPrefix\\Administrators\\containerd-shim-%s-%s-log"
const logAddrFmt = "\\\\.\\pipe\\containerd-shim-%s-%s-log"
logl, err := winio.ListenPipe(fmt.Sprintf(logAddrFmt, namespaceFlag, idFlag), nil)
if err != nil {
Expand Down Expand Up @@ -197,10 +197,12 @@ var serveCommand = cli.Command{

s, err := ttrpc.NewServer(
ttrpc.WithUnaryServerInterceptor(octtrpc.ServerInterceptor(
octtrpc.WithAttributes( // todo (helsaawy) set these in resource when we switch to OTel
octtrpc.WithAttributes( // TODO (helsaawy) set these in resource when we switch to OTel
trace.StringAttribute(logfields.ShimID, svc.tid),
trace.BoolAttribute(logfields.IsSandbox, svc.isSandbox),
),
octtrpc.WithAddMessage(),
octtrpc.WithAddMessageHook(hcslog.ScrubShimTTRPC),
)))
if err != nil {
return err
Expand Down Expand Up @@ -318,7 +320,7 @@ func createEvent(event string) (windows.Handle, error) {
}

// setupDebuggerEvent listens for an event to allow a debugger such as delve
// to attach for advanced debugging. It's called when handling a ContainerCreate
// to attach for advanced debugging. It's called when handling a ContainerCreate.
func setupDebuggerEvent() {
if os.Getenv("CONTAINERD_SHIM_RUNHCS_V1_WAIT_DEBUGGER") == "" {
return
Expand Down
3 changes: 3 additions & 0 deletions cmd/containerd-shim-runhcs-v1/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,9 @@ func (s *service) IsShutdown() bool {
}

func (s *service) logEntry(ctx context.Context) *logrus.Entry {
if s == nil {
return log.G(ctx)
}
return log.G(ctx).WithFields(logrus.Fields{
logfields.ShimID: s.tid,
logfields.IsSandbox: s.isSandbox,
Expand Down
4 changes: 2 additions & 2 deletions internal/gcs/bridge.go
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ func (brdg *bridge) RPC(ctx context.Context, proc rpcProc, req requestMessage, r
brdg.log.WithField("reason", ctx.Err()).Warn("ignoring response to bridge message")
return ctx.Err()
case <-t.C:
//todo: dont kill bridge on message timeout
// TODO: don't kill bridge on message timeout
brdg.kill(errors.New("message timeout"))
<-call.ch
return call.Err()
Expand Down Expand Up @@ -390,7 +390,7 @@ func (brdg *bridge) writeMessage(buf *bytes.Buffer, enc *json.Encoder, typ msgTy
// Update the message header with the size.
binary.LittleEndian.PutUint32(buf.Bytes()[hdrOffSize:], uint32(buf.Len()))

if brdg.log.Logger.GetLevel() >= logrus.DebugLevel {
if brdg.log.Logger.IsLevelEnabled(logrus.DebugLevel) {
b := buf.Bytes()[hdrSize:]
switch typ {
// container environment vars are in rpCreate for linux; rpcExecuteProcess for windows
Expand Down
2 changes: 1 addition & 1 deletion internal/guest/bridge/bridge.go
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,7 @@ func (b *Bridge) ListenAndServe(bridgeIn io.ReadCloser, bridgeOut io.WriteCloser
trace.StringAttribute(logfields.ContainerID, base.ContainerID))

entry := log.G(ctx)
if entry.Logger.GetLevel() >= logrus.DebugLevel {
if entry.Logger.IsLevelEnabled(logrus.DebugLevel) {
s := string(message)
switch header.Type {
case prot.ComputeSystemCreateV1:
Expand Down
2 changes: 1 addition & 1 deletion internal/guest/network/netns.go
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ func NetNSConfig(ctx context.Context, ifStr string, nsPid int, adapter *prot.Net
}

// Add some debug logging
if entry.Logger.GetLevel() >= logrus.DebugLevel {
if entry.Logger.IsLevelEnabled(logrus.DebugLevel) {
curNS, _ := netns.Get()
// Refresh link attributes/state
link, _ = netlink.LinkByIndex(link.Attrs().Index)
Expand Down
3 changes: 1 addition & 2 deletions internal/log/format.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ func formatAddr(a net.Addr) string {
func Format(ctx context.Context, v interface{}) string {
b, err := encode(v)
if err != nil {
// logging errors aren't really warning worthy, and can potentially spam a lot of logs out
// logging-related errors aren't really warning worthy, and can potentially spam a lot of logs out
G(ctx).WithFields(logrus.Fields{
logrus.ErrorKey: err,
"type": fmt.Sprintf("%T", v),
Expand Down Expand Up @@ -93,7 +93,6 @@ func encode(v interface{}) (_ []byte, err error) {
// more robust to fall back on json marshalling for errors in general
return b, nil
}

}

buf := &bytes.Buffer{}
Expand Down
27 changes: 26 additions & 1 deletion internal/log/scrub.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,11 @@ import (
"errors"
"sync/atomic"

task "github.com/containerd/containerd/api/runtime/task/v2"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/known/anypb"
"google.golang.org/protobuf/types/known/wrapperspb"

hcsschema "github.com/Microsoft/hcsshim/internal/hcs/schema2"
)

Expand All @@ -22,6 +27,7 @@ var (
// case sensitive keywords, so "env" is not a substring on "Environment"
_scrubKeywords = [][]byte{[]byte("env"), []byte("Environment")}

// TODO (go1.19) atomic.Bool
_scrub int32
)

Expand All @@ -40,8 +46,27 @@ func IsScrubbingEnabled() bool {
return v != 0
}

func ScrubShimTTRPC(m proto.Message) {
if !IsScrubbingEnabled() {
return
}

switch t := m.(type) {
case *task.ExecProcessRequest:
// the spec will be logged elsewhere, so scrub the entire spec wholesale to avoid
// needing to unmarshall from then re-marshal back to an anypb.Any
//
// ignore errors with creating an anypb.Any and use nil response since nil is a
// "valid" proto.Message that it signifies an invlaid message
// see comment here: https://pkg.go.dev/google.golang.org/[email protected]/proto#Equal
a, _ := anypb.New(wrapperspb.String(_scrubbedReplacement))
t.Spec = a
default:
}
}

// ScrubProcessParameters scrubs HCS Create Process requests with config parameters of
// type internal/hcs/schema2.ScrubProcessParameters (aka hcsshema.ScrubProcessParameters)
// type internal/hcs/schema2.ScrubProcessParameters (aka hcsschema.ScrubProcessParameters)
func ScrubProcessParameters(s string) (string, error) {
// todo: deal with v1 ProcessConfig
b := []byte(s)
Expand Down
58 changes: 53 additions & 5 deletions pkg/octtrpc/interceptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"go.opencensus.io/trace/propagation"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/proto"

"github.com/Microsoft/hcsshim/internal/log"
"github.com/Microsoft/hcsshim/internal/oc"
Expand All @@ -18,6 +19,21 @@ import (
type options struct {
sampler trace.Sampler
attrs []trace.Attribute
// add the request/response messages as span attributes
addMsg bool
// hook to update/modify a copy of the request/response messages before adding them as span attributes
msgAttrHook func(proto.Message)
}

func (o *options) msgHook(v any) any {
m, ok := v.(proto.Message)
if !ok || o.msgAttrHook == nil {
return v
}

m = proto.Clone(m)
o.msgAttrHook(m)
return m
}

// Option represents an option function that can be used with the OC TTRPC
Expand All @@ -39,6 +55,37 @@ func WithAttributes(attr ...trace.Attribute) Option {
}
}

// these are (currently) ServerInterceptor-specific options, but we cannot create a new [ServerOption] type
// since that would break our API

// WithAddMessage adds the request and response messages as attributes to the ttrpc method span.
//
// [ServerInterceptor] only.
func WithAddMessage() Option {
return func(opts *options) {
opts.addMsg = true
}
}

// WithAddMessageHook specifies a hook to modify the ttrpc request or response messages
// before adding them to the ttrpc method span.
// This is intended to allow scrubbing sensitive fields from a the message.
//
// The function will be called with a clone of the original message (via [proto.Clone]) only if:
// - the interceptor is created with the [WithAddMessage] option
// - the function is non-nil
// - the ttrpc request or response are of type [proto.Message]
//
// Since ttrpc is a gRPC replacement, we are guaranteed that the messages will
// implement [proto.Message].
//
// [ServerInterceptor] only.
func WithAddMessageHook(f func(proto.Message)) Option {
return func(opts *options) {
opts.msgAttrHook = f
}
}

const metadataTraceContextKey = "octtrpc.tracecontext"

func convertMethodName(name string) string {
Expand Down Expand Up @@ -124,18 +171,19 @@ func ServerInterceptor(opts ...Option) ttrpc.UnaryServerInterceptor {
}
defer span.End()
defer func() {
if err == nil {
span.AddAttributes(trace.StringAttribute("response", log.Format(ctx, resp)))
if o.addMsg && err == nil {
span.AddAttributes(trace.StringAttribute("response", log.Format(ctx, o.msgHook(resp))))
}
setSpanStatus(span, err)
}()
if len(o.attrs) > 0 {
span.AddAttributes(o.attrs...)
}

return method(ctx, func(req interface{}) (err error) {
if err = unmarshal(req); err == nil {
span.AddAttributes(trace.StringAttribute("request", log.Format(ctx, req)))
return method(ctx, func(req interface{}) error {
err := unmarshal(req)
if o.addMsg {
span.AddAttributes(trace.StringAttribute("request", log.Format(ctx, o.msgHook(req))))
}
return err
})
Expand Down
Loading

0 comments on commit d455f6d

Please sign in to comment.