Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

daemon: make ucrednetGet() return a *ucrednet structure #10126

Merged
merged 2 commits into from
Apr 12, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions daemon/api_snapctl.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ func runSnapctl(c *Command, r *http.Request, user *auth.UserState) Response {
return BadRequest("snapctl cannot run without args")
}

_, uid, _, err := ucrednetGet(r.RemoteAddr)
ucred, err := ucrednetGet(r.RemoteAddr)
if err != nil {
return Forbidden("cannot get remote user: %s", err)
}
Expand All @@ -69,7 +69,7 @@ func runSnapctl(c *Command, r *http.Request, user *auth.UserState) Response {
context.Unlock()
}

stdout, stderr, err := ctlcmdRun(context, snapctlPostData.Args, uid)
stdout, stderr, err := ctlcmdRun(context, snapctlPostData.Args, ucred.Uid)
if err != nil {
if e, ok := err.(*ctlcmd.UnsuccessfulError); ok {
result := map[string]interface{}{
Expand Down
8 changes: 4 additions & 4 deletions daemon/api_snapctl_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,8 @@ func (s *snapctlSuite) TestSnapctlGetNoUID(c *check.C) {
func (s *snapctlSuite) TestSnapctlForbiddenError(c *check.C) {
s.daemon(c)

defer daemon.MockUcrednetGet(func(string) (int32, uint32, string, error) {
return 100, 9999, dirs.SnapSocket, nil
defer daemon.MockUcrednetGet(func(string) (*daemon.Ucrednet, error) {
return &daemon.Ucrednet{Uid: 100, Pid: 9999, Socket: dirs.SnapSocket}, nil
})()

defer daemon.MockCtlcmdRun(func(ctx *hookstate.Context, arg []string, uid uint32) ([]byte, []byte, error) {
Expand All @@ -69,8 +69,8 @@ func (s *snapctlSuite) TestSnapctlForbiddenError(c *check.C) {
func (s *snapctlSuite) TestSnapctlUnsuccesfulError(c *check.C) {
s.daemon(c)

defer daemon.MockUcrednetGet(func(string) (int32, uint32, string, error) {
return 100, 9999, dirs.SnapSocket, nil
defer daemon.MockUcrednetGet(func(string) (*daemon.Ucrednet, error) {
return &daemon.Ucrednet{Uid: 100, Pid: 9999, Socket: dirs.SnapSocket}, nil
})()

defer daemon.MockCtlcmdRun(func(ctx *hookstate.Context, arg []string, uid uint32) ([]byte, []byte, error) {
Expand Down
16 changes: 7 additions & 9 deletions daemon/daemon.go
Original file line number Diff line number Diff line change
Expand Up @@ -148,16 +148,14 @@ func (c *Command) canAccess(r *http.Request, user *auth.UserState) accessResult
return accessOK
}

// isUser means we have a UID for the request
isUser := false
pid, uid, socket, err := ucrednetGet(r.RemoteAddr)
if err == nil {
isUser = true
} else if err != errNoID {
ucred, err := ucrednetGet(r.RemoteAddr)
if err != nil && err != errNoID {
logger.Noticef("unexpected error when attempting to get UID: %s", err)
return accessForbidden
}
isSnap := (socket == dirs.SnapSocket)
// isUser means we have a UID for the request
isUser := ucred != nil
isSnap := (ucred != nil && ucred.Socket == dirs.SnapSocket)

// ensure that snaps can only access SnapOK things
if isSnap {
Expand All @@ -184,7 +182,7 @@ func (c *Command) canAccess(r *http.Request, user *auth.UserState) accessResult
return accessUnauthorized
}

if uid == 0 {
if ucred.Uid == 0 {
// Superuser does anything.
return accessOK
}
Expand All @@ -204,7 +202,7 @@ func (c *Command) canAccess(r *http.Request, user *auth.UserState) accessResult
}
}
// Pass both pid and uid from the peer ucred to avoid pid race
if authorized, err := polkitCheckAuthorization(pid, uid, c.PolkitOK, nil, flags); err == nil {
if authorized, err := polkitCheckAuthorization(ucred.Pid, ucred.Uid, c.PolkitOK, nil, flags); err == nil {
if authorized {
// polkit says user is authorised
return accessOK
Expand Down
4 changes: 3 additions & 1 deletion daemon/export_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,9 @@ func (d *Daemon) RequestedRestart() state.RestartType {
return d.requestedRestart
}

func MockUcrednetGet(mock func(remoteAddr string) (pid int32, uid uint32, socket string, err error)) (restore func()) {
type Ucrednet = ucrednet

func MockUcrednetGet(mock func(remoteAddr string) (ucred *Ucrednet, err error)) (restore func()) {
oldUcrednetGet := ucrednetGet
ucrednetGet = mock
return func() {
Expand Down
34 changes: 18 additions & 16 deletions daemon/ucrednet.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,40 +40,42 @@ var raddrRegexp = regexp.MustCompile(`^pid=(\d+);uid=(\d+);socket=([^;]*);$`)

var ucrednetGet = ucrednetGetImpl

func ucrednetGetImpl(remoteAddr string) (pid int32, uid uint32, socket string, err error) {
func ucrednetGetImpl(remoteAddr string) (*ucrednet, error) {
// NOTE treat remoteAddr at one point included a user-controlled
// string. In case that happens again by accident, treat it as tainted,
// and be very suspicious of it.
pid = ucrednetNoProcess
uid = ucrednetNobody
u := &ucrednet{
Pid: ucrednetNoProcess,
Uid: ucrednetNobody,
}
subs := raddrRegexp.FindStringSubmatch(remoteAddr)
if subs != nil {
if v, err := strconv.ParseInt(subs[1], 10, 32); err == nil {
pid = int32(v)
u.Pid = int32(v)
}
if v, err := strconv.ParseUint(subs[2], 10, 32); err == nil {
uid = uint32(v)
u.Uid = uint32(v)
}
socket = subs[3]
u.Socket = subs[3]
}
if pid == ucrednetNoProcess || uid == ucrednetNobody {
err = errNoID
if u.Pid == ucrednetNoProcess || u.Uid == ucrednetNobody {
return nil, errNoID
}

return pid, uid, socket, err
return u, nil
}

type ucrednet struct {
pid int32
uid uint32
socket string
Pid int32
Uid uint32
Socket string
}

func (un *ucrednet) String() string {
if un == nil {
return "pid=;uid=;socket=;"
}
return fmt.Sprintf("pid=%d;uid=%d;socket=%s;", un.pid, un.uid, un.socket)
return fmt.Sprintf("pid=%d;uid=%d;socket=%s;", un.Pid, un.Uid, un.Socket)
}

type ucrednetAddr struct {
Expand Down Expand Up @@ -127,9 +129,9 @@ func (wl *ucrednetListener) Accept() (net.Conn, error) {
}

unet = &ucrednet{
pid: ucred.Pid,
uid: ucred.Uid,
socket: ucon.LocalAddr().String(),
Pid: ucred.Pid,
Uid: ucred.Uid,
Socket: ucon.LocalAddr().String(),
}
}

Expand Down
51 changes: 22 additions & 29 deletions daemon/ucrednet_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,10 +74,10 @@ func (s *ucrednetSuite) TestAcceptConnRemoteAddrString(c *check.C) {

remoteAddr := conn.RemoteAddr().String()
c.Check(remoteAddr, check.Matches, "pid=100;uid=42;.*")
pid, uid, _, err := ucrednetGet(remoteAddr)
c.Check(pid, check.Equals, int32(100))
c.Check(uid, check.Equals, uint32(42))
c.Check(err, check.IsNil)
u, err := ucrednetGet(remoteAddr)
c.Assert(err, check.IsNil)
c.Check(u.Pid, check.Equals, int32(100))
c.Check(u.Uid, check.Equals, uint32(42))
}

func (s *ucrednetSuite) TestNonUnix(c *check.C) {
Expand All @@ -101,9 +101,8 @@ func (s *ucrednetSuite) TestNonUnix(c *check.C) {

remoteAddr := conn.RemoteAddr().String()
c.Check(remoteAddr, check.Matches, "pid=;uid=;.*")
pid, uid, _, err := ucrednetGet(remoteAddr)
c.Check(pid, check.Equals, ucrednetNoProcess)
c.Check(uid, check.Equals, ucrednetNobody)
u, err := ucrednetGet(remoteAddr)
c.Check(u, check.IsNil)
c.Check(err, check.Equals, errNoID)
}

Expand Down Expand Up @@ -157,45 +156,39 @@ func (s *ucrednetSuite) TestIdempotentClose(c *check.C) {
}

func (s *ucrednetSuite) TestGetNoUid(c *check.C) {
pid, uid, _, err := ucrednetGet("pid=100;uid=;socket=;")
u, err := ucrednetGet("pid=100;uid=;socket=;")
c.Check(err, check.Equals, errNoID)
c.Check(pid, check.Equals, ucrednetNoProcess)
c.Check(uid, check.Equals, ucrednetNobody)
c.Check(u, check.IsNil)
}

func (s *ucrednetSuite) TestGetBadUid(c *check.C) {
pid, uid, _, err := ucrednetGet("pid=100;uid=4294967296;socket=;")
c.Check(err, check.NotNil)
c.Check(pid, check.Equals, int32(100))
c.Check(uid, check.Equals, ucrednetNobody)
u, err := ucrednetGet("pid=100;uid=4294967296;socket=;")
c.Check(err, check.Equals, errNoID)
c.Check(u, check.IsNil)
}

func (s *ucrednetSuite) TestGetNonUcrednet(c *check.C) {
pid, uid, _, err := ucrednetGet("hello")
u, err := ucrednetGet("hello")
c.Check(err, check.Equals, errNoID)
c.Check(pid, check.Equals, ucrednetNoProcess)
c.Check(uid, check.Equals, ucrednetNobody)
c.Check(u, check.IsNil)
}

func (s *ucrednetSuite) TestGetNothing(c *check.C) {
pid, uid, _, err := ucrednetGet("")
u, err := ucrednetGet("")
c.Check(err, check.Equals, errNoID)
c.Check(pid, check.Equals, ucrednetNoProcess)
c.Check(uid, check.Equals, ucrednetNobody)
c.Check(u, check.IsNil)
}

func (s *ucrednetSuite) TestGet(c *check.C) {
pid, uid, socket, err := ucrednetGet("pid=100;uid=42;socket=/run/snap.socket;")
c.Check(err, check.IsNil)
c.Check(pid, check.Equals, int32(100))
c.Check(uid, check.Equals, uint32(42))
c.Check(socket, check.Equals, "/run/snap.socket")
u, err := ucrednetGet("pid=100;uid=42;socket=/run/snap.socket;")
c.Assert(err, check.IsNil)
c.Check(u.Pid, check.Equals, int32(100))
c.Check(u.Uid, check.Equals, uint32(42))
c.Check(u.Socket, check.Equals, "/run/snap.socket")
}

func (s *ucrednetSuite) TestGetSneak(c *check.C) {
pid, uid, socket, err := ucrednetGet("pid=100;uid=42;socket=/run/snap.socket;pid=0;uid=0;socket=/tmp/my.socket")
u, err := ucrednetGet("pid=100;uid=42;socket=/run/snap.socket;pid=0;uid=0;socket=/tmp/my.socket")
c.Check(err, check.Equals, errNoID)
c.Check(pid, check.Equals, ucrednetNoProcess)
c.Check(uid, check.Equals, ucrednetNobody)
c.Check(socket, check.Equals, "")
c.Check(u, check.IsNil)
}