diff --git a/capability/capability.go b/capability/capability.go index aee08ed..11e47be 100644 --- a/capability/capability.go +++ b/capability/capability.go @@ -142,3 +142,35 @@ func NewFile2(path string) (Capabilities, error) { func LastCap() (Cap, error) { return lastCap() } + +// GetAmbient determines if a specific ambient capability is raised in the +// calling thread. +func GetAmbient(c Cap) (bool, error) { + return getAmbient(c) +} + +// SetAmbient raises or lowers specified ambient capabilities for the calling +// thread. To complete successfully, the prevailing effective capability set +// must have a raised CAP_SETPCAP. Further, to raise a specific ambient +// capability the inheritable and permitted sets of the calling thread must +// already contain the specified capability. +func SetAmbient(raise bool, caps ...Cap) error { + return setAmbient(raise, caps...) +} + +// ResetAmbient resets all of the ambient capabilities for the calling thread +// to their lowered value. +func ResetAmbient() error { + return resetAmbient() +} + +// GetBound determines if a specific bounding capability is raised in the +// calling thread. +func GetBound(c Cap) (bool, error) { + return getBound(c) +} + +// DropBound lowers the specified bounding set capability. +func DropBound(caps ...Cap) error { + return dropBound(caps...) +} diff --git a/capability/capability_linux.go b/capability/capability_linux.go index d79ceee..234b1ef 100644 --- a/capability/capability_linux.go +++ b/capability/capability_linux.go @@ -117,6 +117,13 @@ func newPid(pid int) (c Capabilities, retErr error) { return } +func ignoreEINVAL(err error) error { + if errors.Is(err, syscall.EINVAL) { + err = nil + } + return err +} + type capsV3 struct { hdr capHeader data [2]capData @@ -327,7 +334,7 @@ func (c *capsV3) Load() (err error) { return } -func (c *capsV3) Apply(kind CapType) (err error) { +func (c *capsV3) Apply(kind CapType) error { if c.hdr.pid != 0 { return errors.New("unable to modify capabilities of another process") } @@ -339,21 +346,17 @@ func (c *capsV3) Apply(kind CapType) (err error) { var data [2]capData err = capget(&c.hdr, &data[0]) if err != nil { - return + return err } if (1< 0, nil +} + +func setAmbient(raise bool, caps ...Cap) error { + op := pr_CAP_AMBIENT_RAISE + if !raise { + op = pr_CAP_AMBIENT_LOWER + } + for _, val := range caps { + err := prctl(pr_CAP_AMBIENT, op, uintptr(val)) + if err != nil { + return err + } + } + return nil +} + +func resetAmbient() error { + return prctl(pr_CAP_AMBIENT, pr_CAP_AMBIENT_CLEAR_ALL, 0) +} + +func getBound(c Cap) (bool, error) { + res, err := prctlRetInt(syscall.PR_CAPBSET_READ, uintptr(c), 0) + if err != nil { + return false, err + } + return res > 0, nil +} + +func dropBound(caps ...Cap) error { + for _, val := range caps { + err := prctl(syscall.PR_CAPBSET_DROP, uintptr(val), 0) + if err != nil { + return err + } + } + return nil } func newFile(path string) (c Capabilities, err error) { diff --git a/capability/capability_noop.go b/capability/capability_noop.go index ba819ff..b766e44 100644 --- a/capability/capability_noop.go +++ b/capability/capability_noop.go @@ -24,3 +24,23 @@ func newFile(_ string) (Capabilities, error) { func lastCap() (Cap, error) { return -1, errNotSup } + +func getAmbient(_ Cap) (bool, error) { + return false, errNotSup +} + +func setAmbient(_ bool, _ ...Cap) error { + return errNotSup +} + +func resetAmbient() error { + return errNotSup +} + +func getBound(_ Cap) (bool, error) { + return false, errNotSup +} + +func dropBound(_ ...Cap) error { + return errNotSup +} diff --git a/capability/capability_test.go b/capability/capability_test.go index 4e99185..c9a24f9 100644 --- a/capability/capability_test.go +++ b/capability/capability_test.go @@ -42,7 +42,7 @@ func requirePCapSet(t testing.TB) { } // testInChild runs fn as a separate process, and returns its output. -// This is useful for tests which manipulate capabilties, allowing to +// This is useful for tests which manipulate capabilities, allowing to // preserve those of the main test process. // // The fn is a function which must end with os.Exit. In case exit code @@ -150,6 +150,7 @@ func TestAmbientCapSet(t *testing.T) { } func childAmbientCapSet() { + runtime.LockOSThread() // We can't use t.Log etc. here, yet filename and line number is nice // to have. Set up and use the standard logger for this. log.SetFlags(log.Lshortfile) @@ -227,3 +228,148 @@ func TestApplyOtherProcess(t *testing.T) { } } } + +func TestGetSetResetAmbient(t *testing.T) { + if runtime.GOOS != "linux" { + _, err := GetAmbient(Cap(0)) + if err == nil { + t.Error(runtime.GOOS, ": want error, got nil") + } + err = SetAmbient(false, Cap(0)) + if err == nil { + t.Error(runtime.GOOS, ": want error, got nil") + } + err = ResetAmbient() + if err == nil { + t.Error(runtime.GOOS, ": want error, got nil") + } + return + } + + requirePCapSet(t) + out := testInChild(t, childGetSetResetAmbient) + t.Logf("output from child:\n%s", out) +} + +func childGetSetResetAmbient() { + runtime.LockOSThread() + log.SetFlags(log.Lshortfile) + + pid, err := NewPid2(0) + if err != nil { + log.Fatal(err) + } + + list := []Cap{CAP_KILL, CAP_CHOWN, CAP_SYS_CHROOT} + pid.Set(CAPS, list...) + if err = pid.Apply(CAPS); err != nil { + log.Fatal(err) + } + + // Set ambient caps from list. + if err = SetAmbient(true, list...); err != nil { + log.Fatal(err) + } + + // Check if they were set as expected. + for _, cap := range list { + want := true + got, err := GetAmbient(cap) + if err != nil { + log.Fatalf("GetAmbient(%s): want nil, got error %v", cap, err) + } else if want != got { + log.Fatalf("Get(AMBIENT, %s): want %v, got %v", cap, want, got) + } + } + + // Lower one ambient cap. + const unsetIdx = 1 + if err = SetAmbient(false, list[unsetIdx]); err != nil { + log.Fatal(err) + } + // Check they are set as expected. + for i, cap := range list { + want := i != unsetIdx + got, err := GetAmbient(cap) + if err != nil { + log.Fatalf("GetAmbient(%s): want nil, got error %v", cap, err) + } else if want != got { + log.Fatalf("Get(AMBIENT, %s): want %v, got %v", cap, want, got) + } + } + + // Lower all ambient caps. + if err = ResetAmbient(); err != nil { + log.Fatal(err) + } + for _, cap := range list { + want := false + got, err := GetAmbient(cap) + if err != nil { + log.Fatalf("GetAmbient(%s): want nil, got error %v", cap, err) + } else if want != got { + log.Fatalf("Get(AMBIENT, %s): want %v, got %v", cap, want, got) + } + } + os.Exit(0) +} + +func TestGetBound(t *testing.T) { + if runtime.GOOS != "linux" { + _, err := GetBound(Cap(0)) + if err == nil { + t.Error(runtime.GOOS, ": want error, got nil") + } + return + } + + last, err := LastCap() + if err != nil { + t.Fatalf("LastCap: %v", err) + } + for i := Cap(0); i < Cap(63); i++ { + wantErr := i > last + set, err := GetBound(i) + t.Logf("GetBound(%q): %v, %v", i, set, err) + if wantErr && err == nil { + t.Errorf("GetBound(%q): want err, got nil", i) + } else if !wantErr && err != nil { + t.Errorf("GetBound(%q): want nil, got error %v", i, err) + } + } +} + +func TestDropBound(t *testing.T) { + if runtime.GOOS != "linux" { + err := DropBound(Cap(0)) + if err == nil { + t.Error(runtime.GOOS, ": want error, got nil") + } + return + } + + requirePCapSet(t) + out := testInChild(t, childDropBound) + t.Logf("output from child:\n%s", out) +} + +func childDropBound() { + runtime.LockOSThread() + log.SetFlags(log.Lshortfile) + + for i := Cap(2); i < Cap(4); i++ { + err := DropBound(i) + if err != nil { + log.Fatalf("DropBound(%q): want nil, got error %v", i, err) + } + set, err := GetBound(i) + if err != nil { + log.Fatalf("GetBound(%q): want nil, got error %v", i, err) + } + if set { + log.Fatalf("GetBound(%q): want false, got true", i) + } + } + + os.Exit(0) +} diff --git a/capability/syscall_linux.go b/capability/syscall_linux.go index d6b6932..2d8faa8 100644 --- a/capability/syscall_linux.go +++ b/capability/syscall_linux.go @@ -24,7 +24,7 @@ type capData struct { } func capget(hdr *capHeader, data *capData) (err error) { - _, _, e1 := syscall.Syscall(syscall.SYS_CAPGET, uintptr(unsafe.Pointer(hdr)), uintptr(unsafe.Pointer(data)), 0) + _, _, e1 := syscall.RawSyscall(syscall.SYS_CAPGET, uintptr(unsafe.Pointer(hdr)), uintptr(unsafe.Pointer(data)), 0) if e1 != 0 { err = e1 } @@ -32,7 +32,7 @@ func capget(hdr *capHeader, data *capData) (err error) { } func capset(hdr *capHeader, data *capData) (err error) { - _, _, e1 := syscall.Syscall(syscall.SYS_CAPSET, uintptr(unsafe.Pointer(hdr)), uintptr(unsafe.Pointer(data)), 0) + _, _, e1 := syscall.RawSyscall(syscall.SYS_CAPSET, uintptr(unsafe.Pointer(hdr)), uintptr(unsafe.Pointer(data)), 0) if e1 != 0 { err = e1 } @@ -48,14 +48,22 @@ const ( pr_CAP_AMBIENT_CLEAR_ALL = uintptr(4) ) -func prctl(option int, arg2, arg3, arg4, arg5 uintptr) (err error) { - _, _, e1 := syscall.Syscall6(syscall.SYS_PRCTL, uintptr(option), arg2, arg3, arg4, arg5, 0) +func prctl(option int, arg2, arg3 uintptr) (err error) { + _, _, e1 := syscall.RawSyscall(syscall.SYS_PRCTL, uintptr(option), arg2, arg3) if e1 != 0 { err = e1 } return } +func prctlRetInt(option int, arg2, arg3 uintptr) (int, error) { + ret, _, err := syscall.RawSyscall(syscall.SYS_PRCTL, uintptr(option), arg2, arg3) + if err != 0 { + return 0, err + } + return int(ret), nil +} + const ( vfsXattrName = "security.capability" @@ -92,7 +100,7 @@ func getVfsCap(path string, dest *vfscapData) (err error) { if err != nil { return } - r0, _, e1 := syscall.Syscall6(syscall.SYS_GETXATTR, uintptr(unsafe.Pointer(_p0)), uintptr(unsafe.Pointer(_vfsXattrName)), uintptr(unsafe.Pointer(dest)), vfscapDataSizeV2, 0, 0) + r0, _, e1 := syscall.RawSyscall6(syscall.SYS_GETXATTR, uintptr(unsafe.Pointer(_p0)), uintptr(unsafe.Pointer(_vfsXattrName)), uintptr(unsafe.Pointer(dest)), vfscapDataSizeV2, 0, 0) if e1 != 0 { if e1 == syscall.ENODATA { dest.version = 2 @@ -145,7 +153,7 @@ func setVfsCap(path string, data *vfscapData) (err error) { } else { return syscall.EINVAL } - _, _, e1 := syscall.Syscall6(syscall.SYS_SETXATTR, uintptr(unsafe.Pointer(_p0)), uintptr(unsafe.Pointer(_vfsXattrName)), uintptr(unsafe.Pointer(data)), size, 0, 0) + _, _, e1 := syscall.RawSyscall6(syscall.SYS_SETXATTR, uintptr(unsafe.Pointer(_p0)), uintptr(unsafe.Pointer(_vfsXattrName)), uintptr(unsafe.Pointer(data)), size, 0, 0) if e1 != 0 { err = e1 }