Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

osutil/sys, client: add sys.RunAsUidGid, use it for auth.json #4983

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 42 additions & 12 deletions client/login.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import (
"path/filepath"

"github.com/snapcore/snapd/osutil"
"github.com/snapcore/snapd/osutil/sys"
)

// User holds logged in user information.
Expand Down Expand Up @@ -118,30 +119,46 @@ func writeAuthData(user User) error {

targetFile := storeAuthDataFilename(real.HomeDir)

if err := osutil.MkdirAllChown(filepath.Dir(targetFile), 0700, uid, gid); err != nil {
out, err := json.Marshal(user)
if err != nil {
return err
}

outStr, err := json.Marshal(user)
if err != nil {
return nil
}
return sys.RunAsUidGid(uid, gid, func() error {
if err := os.MkdirAll(filepath.Dir(targetFile), 0700); err != nil {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, just return the error directly? :-)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, the diff formatting confused me, never mind :)

return err
}

return osutil.AtomicWriteFileChown(targetFile, []byte(outStr), 0600, 0, uid, gid)
return osutil.AtomicWriteFile(targetFile, out, 0600, 0)
})
}

// readAuthData reads previously written authentication details
func readAuthData() (*User, error) {
sourceFile := storeAuthDataFilename("")
f, err := os.Open(sourceFile)
real, err := osutil.RealUser()
if err != nil {
return nil, err
}

uid, gid, err := osutil.UidGid(real)
if err != nil {
return nil, err
}
defer f.Close()

var user User
dec := json.NewDecoder(f)
if err := dec.Decode(&user); err != nil {
sourceFile := storeAuthDataFilename("")

if err := sys.RunAsUidGid(uid, gid, func() error {
f, err := os.Open(sourceFile)
if err != nil {
return err
}
defer f.Close()

dec := json.NewDecoder(f)

return dec.Decode(&user)
}); err != nil {
return nil, err
}

Expand All @@ -150,6 +167,19 @@ func readAuthData() (*User, error) {

// removeAuthData removes any previously written authentication details.
func removeAuthData() error {
real, err := osutil.RealUser()
if err != nil {
return err
}

uid, gid, err := osutil.UidGid(real)
if err != nil {
return err
}

filename := storeAuthDataFilename("")
return os.Remove(filename)

return sys.RunAsUidGid(uid, gid, func() error {
return os.Remove(filename)
})
}
60 changes: 60 additions & 0 deletions osutil/sys/syscall.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@
package sys

import (
"fmt"
"os"
"runtime"
"syscall"
"unsafe"
)
Expand Down Expand Up @@ -107,3 +109,61 @@ func FcntlGetFl(fd int) (int, error) {
}
return int(flags), nil
}

// UnrecoverableError is an error that flags that things have Gone Wrong, the
// runtime is in a bad state, and you should really quit. The intention is that
// if you're trying to recover from a panic and find that the value of the panic
// is an UnrecoverableError, you should just exit ASAP.
type UnrecoverableError struct {
Call string
Err error
}

func (e UnrecoverableError) Error() string {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please document this error.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing docs :)

return fmt.Sprintf("%s: %v", e.Call, e.Err)
}

// RunAsUidGid starts a goroutine, pins it to the OS thread, sets euid and egid,
// and runs the function; after the function returns, it restores euid and egid.
//
// If restoring the original euid and egid fails this function will panic with
// an UnrecoverableError, and you should _not_ try to recover from it: the
// runtime itself is going to be in trouble.
func RunAsUidGid(uid UserID, gid GroupID, f func() error) error {
ch := make(chan error, 1)
go func() {
// from the docs:
// until the goroutine exits or calls UnlockOSThread, it will
// always execute in this thread, and no other goroutine can.
// that last bit means it's safe to setuid/setgid in here, as no
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nitpick, That

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm torn between fixing it now, or landing it now and fixing it later :-)

// other code will run.
runtime.LockOSThread()

ruid := Getuid()
rgid := Getgid()

if _, _, errno := syscall.RawSyscall(_SYS_SETREGID, FlagID, uintptr(gid), 0); errno == 0 {
if _, _, errno := syscall.RawSyscall(_SYS_SETREUID, FlagID, uintptr(uid), 0); errno == 0 {
ch <- f()
// try to restore euid
if _, _, errno := syscall.RawSyscall(_SYS_SETREUID, FlagID, uintptr(ruid), 0); errno != 0 {
// ¯\_(ツ)_/¯
panic(UnrecoverableError{Call: "setreuid", Err: errno})
}
} else {
ch <- fmt.Errorf("setreuid: %v", errno)
}

// try to restore egid
if _, _, errno := syscall.RawSyscall(_SYS_SETREGID, FlagID, uintptr(rgid), 0); errno != 0 {
// ¯\_(ツ)_/¯
panic(UnrecoverableError{Call: "setregid", Err: errno})
}
} else {
ch <- fmt.Errorf("setregid: %v", errno)
}

runtime.UnlockOSThread()
}()
return <-ch
}
3 changes: 3 additions & 0 deletions osutil/sys/sysnum_getpid_16.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,4 +30,7 @@ const (
_SYS_GETGID = syscall.SYS_GETGID32
_SYS_GETEUID = syscall.SYS_GETEUID32
_SYS_GETEGID = syscall.SYS_GETEGID32

_SYS_SETREUID = syscall.SYS_SETREUID32
_SYS_SETREGID = syscall.SYS_SETREGID32
)
3 changes: 3 additions & 0 deletions osutil/sys/sysnum_getpid_32.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,4 +29,7 @@ const (
_SYS_GETGID = syscall.SYS_GETGID
_SYS_GETEUID = syscall.SYS_GETEUID
_SYS_GETEGID = syscall.SYS_GETEGID

_SYS_SETREUID = syscall.SYS_SETREUID
_SYS_SETREGID = syscall.SYS_SETREGID
)
7 changes: 7 additions & 0 deletions tests/main/drop-privs/task.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
summary: test chattr
# ubuntu-core doesn't have go :-)
systems: [-ubuntu-core-16-*]
execute: |
go build ./testit.go
test "$(./testit)" = "before: 0/0, during: 12345/12345 (<nil>), after: 0/0; status: OK"
test "$(sudo -u '#12345' -g '#12345' ./testit)" = "before: 12345/12345, during: 12345/12345 (<nil>), after: 12345/12345; status: OK"
57 changes: 57 additions & 0 deletions tests/main/drop-privs/testit.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
package main

import (
"fmt"
"runtime"
"sync"

"github.com/snapcore/snapd/osutil/sys"
)

var wg sync.WaitGroup
var mu sync.Mutex

func check(uids []sys.UserID, n int) {
// spin
for i := 0; i < 1<<30; i++ {
}

mu.Lock()
uids[n] = sys.Geteuid()
mu.Unlock()

wg.Done()
}

func main() {
orig := sys.Geteuid()
before := fmt.Sprintf("%d/%d", sys.Geteuid(), sys.Getegid())
var during string
err := sys.RunAsUidGid(12345, 12345, func() error {
during = fmt.Sprintf("%d/%d", sys.Geteuid(), sys.Getegid())
return nil
})
after := fmt.Sprintf("%d/%d", sys.Geteuid(), sys.Getegid())

N := 2 * runtime.NumCPU()
uids := make([]sys.UserID, N)
// launch a lot of goroutines so we cover all threads with space to spare
for i := 0; i < N; i++ {
wg.Add(1)
go check(uids, i)
}
wg.Wait()

bad := 0
for _, uid := range uids {
if uid != orig {
bad++
}
}
status := "OK"
if bad != 0 {
status = fmt.Sprintf("%d BAD!", bad)
}

fmt.Printf("before: %s, during: %s (%v), after: %s; status: %s\n", before, during, err, after, status)
}