Skip to content

Commit

Permalink
[CWS] Fix ptrace issues and add tests (DataDog#31510)
Browse files Browse the repository at this point in the history
  • Loading branch information
spikat authored Nov 27, 2024
1 parent 7332199 commit 2e988c0
Show file tree
Hide file tree
Showing 10 changed files with 224 additions and 11 deletions.
1 change: 1 addition & 0 deletions pkg/security/ebpf/c/include/events_definition.h
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,7 @@ struct ptrace_event_t {
u32 request;
u32 pid;
u64 addr;
u32 ns_pid;
};

struct syscall_monitor_event_t {
Expand Down
2 changes: 2 additions & 0 deletions pkg/security/ebpf/c/include/hooks/ptrace.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ HOOK_SYSCALL_ENTRY3(ptrace, u32, request, pid_t, pid, void *, addr) {
.ptrace = {
.request = request,
.pid = 0, // 0 in case the root ns pid resolution failed
.ns_pid = (u32)pid,
.addr = (u64)addr,
}
};
Expand Down Expand Up @@ -59,6 +60,7 @@ int __attribute__((always_inline)) sys_ptrace_ret(void *ctx, int retval) {
.request = syscall->ptrace.request,
.pid = syscall->ptrace.pid,
.addr = syscall->ptrace.addr,
.ns_pid = syscall->ptrace.ns_pid,
};

struct proc_cache_t *entry = fill_process_context(&event.process);
Expand Down
1 change: 1 addition & 0 deletions pkg/security/ebpf/c/include/structs/syscalls.h
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ struct syscall_cache_t {
u32 request;
u32 pid;
u64 addr;
u32 ns_pid;
} ptrace;

struct {
Expand Down
34 changes: 31 additions & 3 deletions pkg/security/probe/probe_ebpf.go
Original file line number Diff line number Diff line change
Expand Up @@ -1072,14 +1072,42 @@ func (p *EBPFProbe) handleEvent(CPU int, data []byte) {
seclog.Errorf("failed to decode ptrace event: %s (offset %d, len %d)", err, offset, len(data))
return
}

// resolve tracee process context
var pce *model.ProcessCacheEntry
if event.PTrace.PID == 0 { // pid can be 0 for a PTRACE_TRACEME request
if event.PTrace.Request == unix.PTRACE_TRACEME { // pid can be 0 for a PTRACE_TRACEME request
pce = newPlaceholderProcessCacheEntryPTraceMe()
} else if event.PTrace.PID == 0 && event.PTrace.NSPID == 0 {
seclog.Errorf("ptrace event without any PID to resolve")
return
} else {
pce = p.Resolvers.ProcessResolver.Resolve(event.PTrace.PID, event.PTrace.PID, 0, false, newEntryCb)
pidToResolve := event.PTrace.PID

if pidToResolve == 0 { // resolve the PID given as argument instead
if event.ContainerContext.ContainerID == "" {
pidToResolve = event.PTrace.NSPID
} else {
// 1. get the pid namespace of the tracer
ns, err := utils.GetProcessPidNamespace(event.ProcessContext.Process.Pid)
if err != nil {
seclog.Errorf("Failed to resolve PID namespace: %v", err)
return
}

// 2. find the host pid matching the arg pid with he tracer namespace
pid, err := utils.FindPidNamespace(event.PTrace.NSPID, ns)
if err != nil {
seclog.Warnf("Failed to resolve tracee PID namespace: %v", err)
return
}

pidToResolve = pid
}
}

pce = p.Resolvers.ProcessResolver.Resolve(pidToResolve, pidToResolve, 0, false, newEntryCb)
if pce == nil {
pce = model.NewPlaceholderProcessCacheEntry(event.PTrace.PID, event.PTrace.PID, false)
pce = model.NewPlaceholderProcessCacheEntry(pidToResolve, pidToResolve, false)
}
}
event.PTrace.Tracee = &pce.ProcessContext
Expand Down
1 change: 1 addition & 0 deletions pkg/security/secl/model/model_unix.go
Original file line number Diff line number Diff line change
Expand Up @@ -538,6 +538,7 @@ type PTraceEvent struct {

Request uint32 `field:"request"` // SECLDoc[request] Definition:`ptrace request` Constants:`Ptrace constants`
PID uint32 `field:"-"`
NSPID uint32 `field:"-"`
Address uint64 `field:"-"`
Tracee *ProcessContext `field:"tracee"` // process context of the tracee
}
Expand Down
5 changes: 3 additions & 2 deletions pkg/security/secl/model/unmarshallers_linux.go
Original file line number Diff line number Diff line change
Expand Up @@ -833,14 +833,15 @@ func (e *PTraceEvent) UnmarshalBinary(data []byte) (int, error) {
return 0, err
}

if len(data)-read < 16 {
if len(data)-read < 20 {
return 0, ErrNotEnoughData
}

e.Request = binary.NativeEndian.Uint32(data[read : read+4])
e.PID = binary.NativeEndian.Uint32(data[read+4 : read+8])
e.Address = binary.NativeEndian.Uint64(data[read+8 : read+16])
return read + 16, nil
e.NSPID = binary.NativeEndian.Uint32(data[read+16 : read+20])
return read + 20, nil
}

// UnmarshalBinary unmarshals a binary representation of itself
Expand Down
16 changes: 15 additions & 1 deletion pkg/security/serializers/serializers_linux.go
Original file line number Diff line number Diff line change
Expand Up @@ -895,10 +895,24 @@ func newMProtectEventSerializer(e *model.Event) *MProtectEventSerializer {
}

func newPTraceEventSerializer(e *model.Event) *PTraceEventSerializer {
if e.PTrace.Tracee == nil {
return nil
}

fakeTraceeEvent := &model.Event{
BaseEvent: model.BaseEvent{
FieldHandlers: e.FieldHandlers,
ProcessContext: e.PTrace.Tracee,
ContainerContext: &model.ContainerContext{
ContainerID: e.PTrace.Tracee.ContainerID,
},
},
}

return &PTraceEventSerializer{
Request: model.PTraceRequest(e.PTrace.Request).String(),
Address: fmt.Sprintf("0x%x", e.PTrace.Address),
Tracee: newProcessContextSerializer(e.PTrace.Tracee, e),
Tracee: newProcessContextSerializer(e.PTrace.Tracee, fakeTraceeEvent),
}
}

Expand Down
76 changes: 71 additions & 5 deletions pkg/security/tests/ptrace_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"fmt"
"os/exec"
"testing"
"time"

"github.com/stretchr/testify/assert"

Expand All @@ -24,9 +25,17 @@ func TestPTraceEvent(t *testing.T) {

ruleDefs := []*rules.RuleDefinition{
{
ID: "test_ptrace",
ID: "test_ptrace_cont",
Expression: `ptrace.request == PTRACE_CONT && ptrace.tracee.file.name == "syscall_tester"`,
},
{
ID: "test_ptrace_me",
Expression: `ptrace.request == PTRACE_TRACEME && process.file.name == "syscall_tester"`,
},
{
ID: "test_ptrace_attach",
Expression: `ptrace.request == PTRACE_ATTACH && ptrace.tracee.file.name == "syscall_tester"`,
},
}

test, err := newTestModule(t, nil, ruleDefs)
Expand All @@ -40,25 +49,82 @@ func TestPTraceEvent(t *testing.T) {
t.Fatal(err)
}

test.Run(t, "ptrace", func(t *testing.T, _ wrapperType, cmdFunc func(cmd string, args []string, envs []string) *exec.Cmd) {
test.Run(t, "ptrace-cont", func(t *testing.T, _ wrapperType, cmdFunc func(cmd string, args []string, envs []string) *exec.Cmd) {
args := []string{"ptrace-traceme"}
envs := []string{}

test.WaitSignal(t, func() error {
err := test.GetEventSent(t, func() error {
cmd := cmdFunc(syscallTester, args, envs)
if out, err := cmd.CombinedOutput(); err != nil {
return fmt.Errorf("%s: %w", out, err)
}

return nil
}, func(event *model.Event, _ *rules.Rule) {
}, func(_ *rules.Rule, event *model.Event) bool {
assert.Equal(t, "ptrace", event.GetType(), "wrong event type")
assert.Equal(t, uint64(42), event.PTrace.Address, "wrong address")

value, _ := event.GetFieldValue("event.async")
assert.Equal(t, value.(bool), false)

test.validatePTraceSchema(t, event)
})
return true
}, time.Second*3, "test_ptrace_cont")
if err != nil {
t.Error(err)
}
})

test.Run(t, "ptrace-me", func(t *testing.T, _ wrapperType, cmdFunc func(cmd string, args []string, envs []string) *exec.Cmd) {
args := []string{"ptrace-traceme"}
envs := []string{}

err := test.GetEventSent(t, func() error {
cmd := cmdFunc(syscallTester, args, envs)
if out, err := cmd.CombinedOutput(); err != nil {
return fmt.Errorf("%s: %w", out, err)
}

return nil
}, func(_ *rules.Rule, event *model.Event) bool {
assert.Equal(t, "ptrace", event.GetType(), "wrong event type")
assert.Equal(t, uint64(0), event.PTrace.Address, "wrong address")

value, _ := event.GetFieldValue("event.async")
assert.Equal(t, value.(bool), false)

test.validatePTraceSchema(t, event)
return true
}, time.Second*3, "test_ptrace_me")
if err != nil {
t.Error(err)
}
})

test.Run(t, "ptrace-attach", func(t *testing.T, _ wrapperType, cmdFunc func(cmd string, args []string, envs []string) *exec.Cmd) {
args := []string{"ptrace-attach"}
envs := []string{}

err := test.GetEventSent(t, func() error {
cmd := cmdFunc(syscallTester, args, envs)
if out, err := cmd.CombinedOutput(); err != nil {
return fmt.Errorf("%s: %w", out, err)
}

return nil
}, func(_ *rules.Rule, event *model.Event) bool {
assert.Equal(t, "ptrace", event.GetType(), "wrong event type")
assert.Equal(t, uint64(0), event.PTrace.Address, "wrong address")
assert.Equal(t, event.PTrace.Tracee.PPid, event.PTrace.Tracee.Parent.Pid, "tracee wrong ppid / parent pid")

value, _ := event.GetFieldValue("event.async")
assert.Equal(t, value.(bool), false)

test.validatePTraceSchema(t, event)
return true
}, time.Second*3, "test_ptrace_attach")
if err != nil {
t.Error(err)
}
})
}
15 changes: 15 additions & 0 deletions pkg/security/tests/syscall_tester/c/syscall_tester.c
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,19 @@ int ptrace_traceme() {
return EXIT_SUCCESS;
}

int ptrace_attach() {
int child = fork();
if (child == 0) {
for (int i = 0; i < 20; i++) {
sleep(1);
}
} else {
ptrace(PTRACE_ATTACH, child, 0, NULL);
wait(NULL);
}
return EXIT_SUCCESS;
}

int test_signal_sigusr(int child, int sig) {
int do_fork = child == 0;
if (do_fork) {
Expand Down Expand Up @@ -885,6 +898,8 @@ int main(int argc, char **argv) {
exit_code = span_exec(sub_argc, sub_argv);
} else if (strcmp(cmd, "ptrace-traceme") == 0) {
exit_code = ptrace_traceme();
} else if (strcmp(cmd, "ptrace-attach") == 0) {
exit_code = ptrace_attach();
} else if (strcmp(cmd, "span-open") == 0) {
exit_code = span_open(sub_argc, sub_argv);
} else if (strcmp(cmd, "pipe-chown") == 0) {
Expand Down
84 changes: 84 additions & 0 deletions pkg/security/utils/proc_linux.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ package utils

import (
"bufio"
"errors"
"fmt"
"io"
"os"
Expand All @@ -20,6 +21,7 @@ import (

"github.com/DataDog/datadog-agent/pkg/security/secl/model"
"github.com/DataDog/datadog-agent/pkg/util/kernel"
"github.com/shirou/gopsutil/v3/process"
)

// Getpid returns the current process ID in the host namespace
Expand Down Expand Up @@ -384,3 +386,85 @@ func FetchLoadedModules() (map[string]ProcFSModule, error) {

return output, nil
}

// GetProcessPidNamespace returns the PID namespace of the given PID
func GetProcessPidNamespace(pid uint32) (uint64, error) {
nspidPath := procPidPath(pid, "ns/pid")
link, err := os.Readlink(nspidPath)
if err != nil {
return 0, err
}
// link should be in for of: pid:[4026532294]
if !strings.HasPrefix(link, "pid:[") {
return 0, fmt.Errorf("Failed to retrieve PID NS, pid ns malformated: (%s) err: %v", link, err)
}

link = strings.TrimPrefix(link, "pid:[")
link = strings.TrimSuffix(link, "]")

ns, err := strconv.ParseUint(link, 10, 64)
if err != nil {
return 0, fmt.Errorf("Failed to retrieve PID NS, pid ns malformated: (%s) err: %v", link, err)
}
return ns, nil
}

// GetNsPids returns the namespaced pids of the the givent root pid
func GetNsPids(pid uint32) ([]uint32, error) {
statusFile := StatusPath(pid)
content, err := os.ReadFile(statusFile)
if err != nil {
return nil, fmt.Errorf("failed to read status file: %w", err)
}

lines := strings.Split(string(content), "\n")
for _, line := range lines {
if strings.HasPrefix(line, "NSpid:") {
// Remove "NSpid:" prefix and trim spaces
values := strings.TrimPrefix(line, "NSpid:")
values = strings.TrimSpace(values)

// Split the remaining string into fields
fields := strings.Fields(values)

// Convert string values to integers
nspids := make([]uint32, 0, len(fields))
for _, field := range fields {
val, err := strconv.ParseUint(field, 10, 64)
if err != nil {
return nil, fmt.Errorf("failed to parse NSpid value: %w", err)
}
nspids = append(nspids, uint32(val))
}
return nspids, nil
}
}
return nil, fmt.Errorf("NSpid field not found")
}

// FindPidNamespace search and return the host PID for the given namespaced PID + its namespace
func FindPidNamespace(nspid uint32, ns uint64) (uint32, error) {
procPids, err := process.Pids()
if err != nil {
return 0, err
}

for _, procPid := range procPids {
procNs, err := GetProcessPidNamespace(uint32(procPid))
if err != nil {
continue
}

if procNs == ns {
nspids, err := GetNsPids(uint32(procPid))
if err != nil {
return 0, err
}
// we look only at the last one, as it the most inner one and corresponding to its /proc/pid/ns/pid namespace
if nspids[len(nspids)-1] == nspid {
return uint32(procPid), nil
}
}
}
return 0, errors.New("PID not found")
}

0 comments on commit 2e988c0

Please sign in to comment.