Skip to content

Commit

Permalink
[CWS] fix thread tracking (#21774)
Browse files Browse the repository at this point in the history
[CWS] fix thread tracking with the ptracer
  • Loading branch information
safchain authored Dec 28, 2023
1 parent f75f174 commit e938af4
Show file tree
Hide file tree
Showing 2 changed files with 162 additions and 55 deletions.
109 changes: 54 additions & 55 deletions pkg/security/ptracer/cws.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,41 +21,32 @@ import (
"time"

"github.com/avast/retry-go/v4"
lru "github.com/hashicorp/golang-lru/v2"
"github.com/vmihailenco/msgpack/v5"
"golang.org/x/sys/unix"

"github.com/DataDog/datadog-agent/pkg/security/proto/ebpfless"
"github.com/DataDog/datadog-agent/pkg/util/native"
)

// Process represents a process context
type Process struct {
Pid int
Nr map[int]*ebpfless.SyscallMsg
Fd map[int32]string
Cwd string
}

func fillProcessCwd(process *Process) error {
cwd, err := os.Readlink(fmt.Sprintf("/proc/%d/cwd", process.Pid))
if err != nil {
return err
}
process.Cwd = cwd
process.Res.Cwd = cwd
return nil
}

func getFullPathFromFd(process *Process, filename string, fd int32) (string, error) {
if filename[0] != '/' {
if fd == unix.AT_FDCWD { // if use current dir, try to prefix it
if process.Cwd != "" || fillProcessCwd(process) == nil {
filename = filepath.Join(process.Cwd, filename)
if process.Res.Cwd != "" || fillProcessCwd(process) == nil {
filename = filepath.Join(process.Res.Cwd, filename)
} else {
return "", errors.New("fillProcessCwd failed")
}
} else { // if using another dir, prefix it, we should have it in cache
if path, exists := process.Fd[fd]; exists {
if path, exists := process.Res.Fd[fd]; exists {
filename = filepath.Join(path, filename)
} else {
return "", errors.New("process FD cache incomplete during path resolution")
Expand All @@ -67,8 +58,8 @@ func getFullPathFromFd(process *Process, filename string, fd int32) (string, err

func getFullPathFromFilename(process *Process, filename string) (string, error) {
if filename[0] != '/' {
if process.Cwd != "" || fillProcessCwd(process) == nil {
filename = filepath.Join(process.Cwd, filename)
if process.Res.Cwd != "" || fillProcessCwd(process) == nil {
filename = filepath.Join(process.Res.Cwd, filename)
} else {
return "", errors.New("fillProcessCwd failed")
}
Expand Down Expand Up @@ -130,7 +121,7 @@ func handleExecveAt(tracer *Tracer, process *Process, msg *ebpfless.SyscallMsg,

if filename == "" { // in this case, dirfd defines directly the file's FD
var exists bool
if filename, exists = process.Fd[fd]; !exists || filename == "" {
if filename, exists = process.Res.Fd[fd]; !exists || filename == "" {
return errors.New("can't find related file path")
}
} else {
Expand Down Expand Up @@ -217,7 +208,7 @@ func handleChdir(tracer *Tracer, process *Process, msg *ebpfless.SyscallMsg, reg

dirname, err = getFullPathFromFilename(process, dirname)
if err != nil {
process.Cwd = ""
process.Res.Cwd = ""
return err
}

Expand All @@ -229,9 +220,9 @@ func handleChdir(tracer *Tracer, process *Process, msg *ebpfless.SyscallMsg, reg

func handleFchdir(tracer *Tracer, process *Process, msg *ebpfless.SyscallMsg, regs syscall.PtraceRegs) error {
fd := tracer.ReadArgInt32(regs, 0)
dirname, ok := process.Fd[fd]
dirname, ok := process.Res.Fd[fd]
if !ok {
process.Cwd = ""
process.Res.Cwd = ""
return nil
}

Expand Down Expand Up @@ -419,18 +410,11 @@ func StartCWSPtracer(args []string, probeAddr string, creds Creds, verbose bool,
stopChan = make(chan bool, 1)
)

cache, err := lru.New[int, *Process](1024)
if err != nil {
return err
}
pc := NewProcessCache()

// first process
process := &Process{
Pid: tracer.PID,
Nr: make(map[int]*ebpfless.SyscallMsg),
Fd: make(map[int32]string),
}
cache.Add(tracer.PID, process)
process := NewProcess(tracer.PID)
pc.Add(tracer.PID, process)

wg.Add(1)
go func() {
Expand Down Expand Up @@ -488,28 +472,23 @@ func StartCWSPtracer(args []string, probeAddr string, creds Creds, verbose bool,
})

cb := func(cbType CallbackType, nr int, pid int, ppid int, regs syscall.PtraceRegs) {
process := pc.Get(pid)
if process == nil {
process = NewProcess(pid)
pc.Add(pid, process)
}

sendSyscallMsg := func(msg *ebpfless.SyscallMsg) {
if msg == nil {
return
}
msg.PID = uint32(pid)
msg.PID = uint32(process.Tgid)
send(&ebpfless.Message{
Type: ebpfless.MessageTypeSyscall,
Syscall: msg,
})
}

process, exists := cache.Get(pid)
if !exists {
process = &Process{
Pid: pid,
Nr: make(map[int]*ebpfless.SyscallMsg),
Fd: make(map[int32]string),
}

cache.Add(pid, process)
}

switch cbType {
case CallbackPreType:
syscallMsg := &ebpfless.SyscallMsg{}
Expand Down Expand Up @@ -555,13 +534,21 @@ func StartCWSPtracer(args []string, probeAddr string, creds Creds, verbose bool,
EGID: gid,
}
}
sendSyscallMsg(syscallMsg)

// special case for exec since the pre reports the pid while the post reports the tgid
if process.Pid != process.Tgid {
pc.Add(process.Tgid, process)
}
case ExecveatNr:
if err = handleExecveAt(tracer, process, syscallMsg, regs); err != nil {
logErrorf("unable to handle execveat: %v", err)
return
}
sendSyscallMsg(syscallMsg)

// special case for exec since the pre reports the pid while the post reports the tgid
if process.Pid != process.Tgid {
pc.Add(process.Tgid, process)
}
case FcntlNr:
_ = handleFcntl(tracer, process, syscallMsg, regs)
case DupNr, Dup2Nr, Dup3Nr:
Expand Down Expand Up @@ -603,7 +590,10 @@ func StartCWSPtracer(args []string, probeAddr string, creds Creds, verbose bool,
case CallbackPostType:
switch nr {
case ExecveNr, ExecveatNr:
// nothing to do. send was already done at syscall entrance
sendSyscallMsg(process.Nr[nr])

// now the pid is the tgid
process.Pid = process.Tgid
case OpenNr, OpenatNr:
if ret := tracer.ReadRet(regs); !isAcceptedRetval(ret) {
syscallMsg, exists := process.Nr[nr]
Expand All @@ -615,7 +605,7 @@ func StartCWSPtracer(args []string, probeAddr string, creds Creds, verbose bool,
sendSyscallMsg(syscallMsg)

// maintain fd/path mapping
process.Fd[int32(ret)] = syscallMsg.Open.Filename
process.Res.Fd[int32(ret)] = syscallMsg.Open.Filename
}
case SetuidNr, SetgidNr, SetreuidNr, SetregidNr:
if ret := tracer.ReadRet(regs); ret >= 0 {
Expand All @@ -626,7 +616,13 @@ func StartCWSPtracer(args []string, probeAddr string, creds Creds, verbose bool,

sendSyscallMsg(syscallMsg)
}
case ForkNr, VforkNr, CloneNr:
case CloneNr:
if flags := tracer.ReadArgUint64(regs, 0); flags&uint64(unix.SIGCHLD) == 0 {
pc.SetAsThreadOf(process, ppid)
return
}
fallthrough
case ForkNr, VforkNr:
sendSyscallMsg(&ebpfless.SyscallMsg{
Type: ebpfless.SyscallTypeFork,
Fork: &ebpfless.ForkSyscallMsg{
Expand All @@ -642,8 +638,8 @@ func StartCWSPtracer(args []string, probeAddr string, creds Creds, verbose bool,

// maintain fd/path mapping
if syscallMsg.Fcntl.Cmd == unix.F_DUPFD || syscallMsg.Fcntl.Cmd == unix.F_DUPFD_CLOEXEC {
if path, exists := process.Fd[int32(syscallMsg.Fcntl.Fd)]; exists {
process.Fd[int32(ret)] = path
if path, exists := process.Res.Fd[int32(syscallMsg.Fcntl.Fd)]; exists {
process.Res.Fd[int32(ret)] = path
}
}
}
Expand All @@ -653,10 +649,10 @@ func StartCWSPtracer(args []string, probeAddr string, creds Creds, verbose bool,
if !exists {
return
}
path, ok := process.Fd[syscallMsg.Dup.OldFd]
path, ok := process.Res.Fd[syscallMsg.Dup.OldFd]
if ok {
// maintain fd/path in case of dups
process.Fd[int32(ret)] = path
process.Res.Fd[int32(ret)] = path
}
}
case ChdirNr, FchdirNr:
Expand All @@ -665,15 +661,18 @@ func StartCWSPtracer(args []string, probeAddr string, creds Creds, verbose bool,
if !exists || syscallMsg.Chdir == nil {
return
}
process.Cwd = syscallMsg.Chdir.Path
process.Res.Cwd = syscallMsg.Chdir.Path
}
}
case CallbackExitType:
sendSyscallMsg(&ebpfless.SyscallMsg{
Type: ebpfless.SyscallTypeExit,
})
// send exit only for process not threads
if process.Pid == process.Tgid {
sendSyscallMsg(&ebpfless.SyscallMsg{
Type: ebpfless.SyscallTypeExit,
})
}

cache.Remove(pid)
pc.Remove(process)
}
}

Expand Down
108 changes: 108 additions & 0 deletions pkg/security/ptracer/process.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
// Unless explicitly stated otherwise all files in this repository are licensed
// under the Apache License Version 2.0.
// This product includes software developed at Datadog (https://www.datadoghq.com/).
// Copyright 2016-present Datadog, Inc.

//go:build linux

// Package ptracer holds the start command of CWS injector
package ptracer

import (
"github.com/DataDog/datadog-agent/pkg/security/proto/ebpfless"
"golang.org/x/exp/slices"
)

// Resources defines shared process resources
type Resources struct {
Fd map[int32]string
Cwd string
}

// Process represents a process context
type Process struct {
Pid int
Tgid int
Nr map[int]*ebpfless.SyscallMsg
Res *Resources
}

// NewProcess returns a new process
func NewProcess(pid int) *Process {
return &Process{
Pid: pid,
Tgid: pid,
Nr: make(map[int]*ebpfless.SyscallMsg),
Res: &Resources{
Fd: make(map[int32]string),
},
}
}

// ProcessCache defines a thread cache
type ProcessCache struct {
pid2Process map[int]*Process
tgid2Pid map[int][]int
}

// NewProcessCache returns a new thread cache
func NewProcessCache() *ProcessCache {
return &ProcessCache{
pid2Process: make(map[int]*Process),
tgid2Pid: make(map[int][]int),
}
}

// Add a process
func (tc *ProcessCache) Add(pid int, process *Process) {
tc.pid2Process[pid] = process

if process.Pid != process.Tgid {
tc.tgid2Pid[process.Tgid] = append(tc.tgid2Pid[process.Tgid], process.Pid)
}
}

// SetAsThreadOf set the process as thread of the given tgid
func (tc *ProcessCache) SetAsThreadOf(process *Process, ppid int) {
parent := tc.pid2Process[ppid]
if parent == nil {
// this shouldn't happen
return
}

// share resources, parent should never be nil
process.Tgid = parent.Tgid
process.Res = parent.Res

// re-add to update the caches
tc.Add(process.Pid, process)
}

// Remove a pid
func (tc *ProcessCache) Remove(process *Process) {
delete(tc.pid2Process, process.Pid)

if process.Pid == process.Tgid {
pids, ok := tc.tgid2Pid[process.Pid]
if !ok {
return
}
delete(tc.tgid2Pid, process.Pid)

for pid := range pids {
delete(tc.pid2Process, pid)
}
} else {
tc.tgid2Pid[process.Tgid] = slices.DeleteFunc(tc.tgid2Pid[process.Tgid], func(pid int) bool {
return pid == process.Pid
})
if len(tc.tgid2Pid[process.Tgid]) == 0 {
delete(tc.tgid2Pid, process.Tgid)
}
}
}

// Get return the process entry for the given pid
func (tc *ProcessCache) Get(pid int) *Process {
return tc.pid2Process[pid]
}

0 comments on commit e938af4

Please sign in to comment.