diff --git a/capability/capability_test.go b/capability/capability_test.go index 81b27e5..840adc3 100644 --- a/capability/capability_test.go +++ b/capability/capability_test.go @@ -5,6 +5,9 @@ package capability_test import ( + "log" + "os" + "os/exec" "runtime" "testing" @@ -37,6 +40,33 @@ func requirePCapSet(t testing.TB) { } } +// testInChild runs fn as a separate process, and returns its output. +// This is useful for tests which manipulate capabilties, allowing to +// preserve those of the main test process. +// +// The fn is a function which must end with os.Exit. In case exit code +// is non-zero, t.Fatal is called. +func testInChild(t *testing.T, fn func()) []byte { + if os.Getenv("GO_WANT_HELPER_PROCESS") == "1" { + fn() + } + + args := []string{"-test.run=^" + t.Name() + "$"} + if testing.Verbose() { + args = append(args, "-test.v") + } + cmd := exec.Command("/proc/self/exe", args...) + cmd.Env = append(cmd.Environ(), "GO_WANT_HELPER_PROCESS=1") + + out, err := cmd.CombinedOutput() + if err != nil { + t.Helper() + t.Fatalf("exec failed: %v\n\n%s\n", err, out) + } + + return out +} + func TestLastCap(t *testing.T) { last, err := LastCap() switch runtime.GOOS { @@ -112,25 +142,33 @@ func TestAmbientCapSet(t *testing.T) { } requirePCapSet(t) + out := testInChild(t, childAmbientCapSet) + + t.Logf("output from child:\n%s", out) +} + +func childAmbientCapSet() { + log.SetFlags(log.Lshortfile) + pid, err := NewPid2(0) if err != nil { - t.Fatal(err) + log.Fatal(err) } list := []Cap{CAP_KILL, CAP_CHOWN, CAP_SYS_CHROOT} pid.Set(CAPS|AMBIENT, list...) if err = pid.Apply(CAPS | AMBIENT); err != nil { - t.Fatal(err) + log.Fatal(err) } // Check if ambient caps were applied. if err = pid.Load(); err != nil { - t.Fatal(err) + log.Fatal(err) } for _, cap := range list { want := true if got := pid.Get(AMBIENT, cap); want != got { - t.Errorf("Get(AMBIENT, %s): want %v, got %v", cap, want, got) + log.Fatalf("Get(AMBIENT, %s): want %v, got %v", cap, want, got) } } @@ -138,16 +176,17 @@ func TestAmbientCapSet(t *testing.T) { const unsetIdx = 1 pid.Unset(AMBIENT, list[unsetIdx]) if err = pid.Apply(AMBIENT); err != nil { - t.Fatal(err) + log.Fatal(err) } if err = pid.Load(); err != nil { - t.Fatal(err) + log.Fatal(err) } for i, cap := range list { want := i != unsetIdx if got := pid.Get(AMBIENT, cap); want != got { - t.Errorf("Get(AMBIENT, %s): want %v, got %v", cap, want, got) + log.Fatalf("Get(AMBIENT, %s): want %v, got %v", cap, want, got) } } + os.Exit(0) }