From 891821816ccedf47cde90648f474f4ead8190898 Mon Sep 17 00:00:00 2001 From: James Henstridge Date: Fri, 9 Apr 2021 18:49:41 +0800 Subject: [PATCH 1/2] daemon: make ucrednetGet() return a *ucrednet structure --- daemon/api_snapctl.go | 4 +-- daemon/api_snapctl_test.go | 8 +++--- daemon/daemon.go | 8 +++--- daemon/export_test.go | 4 ++- daemon/ucrednet.go | 34 +++++++++++++------------ daemon/ucrednet_test.go | 51 ++++++++++++++++---------------------- 6 files changed, 53 insertions(+), 56 deletions(-) diff --git a/daemon/api_snapctl.go b/daemon/api_snapctl.go index 1be70f8c829..2a01f7bcbf9 100644 --- a/daemon/api_snapctl.go +++ b/daemon/api_snapctl.go @@ -52,7 +52,7 @@ func runSnapctl(c *Command, r *http.Request, user *auth.UserState) Response { return BadRequest("snapctl cannot run without args") } - _, uid, _, err := ucrednetGet(r.RemoteAddr) + ucred, err := ucrednetGet(r.RemoteAddr) if err != nil { return Forbidden("cannot get remote user: %s", err) } @@ -69,7 +69,7 @@ func runSnapctl(c *Command, r *http.Request, user *auth.UserState) Response { context.Unlock() } - stdout, stderr, err := ctlcmdRun(context, snapctlPostData.Args, uid) + stdout, stderr, err := ctlcmdRun(context, snapctlPostData.Args, ucred.Uid) if err != nil { if e, ok := err.(*ctlcmd.UnsuccessfulError); ok { result := map[string]interface{}{ diff --git a/daemon/api_snapctl_test.go b/daemon/api_snapctl_test.go index 97b29096cf5..9af3b32e368 100644 --- a/daemon/api_snapctl_test.go +++ b/daemon/api_snapctl_test.go @@ -51,8 +51,8 @@ func (s *snapctlSuite) TestSnapctlGetNoUID(c *check.C) { func (s *snapctlSuite) TestSnapctlForbiddenError(c *check.C) { s.daemon(c) - defer daemon.MockUcrednetGet(func(string) (int32, uint32, string, error) { - return 100, 9999, dirs.SnapSocket, nil + defer daemon.MockUcrednetGet(func(string) (*daemon.Ucrednet, error) { + return &daemon.Ucrednet{100, 9999, dirs.SnapSocket}, nil })() defer daemon.MockCtlcmdRun(func(ctx *hookstate.Context, arg []string, uid uint32) ([]byte, []byte, error) { @@ -69,8 +69,8 @@ func (s *snapctlSuite) TestSnapctlForbiddenError(c *check.C) { func (s *snapctlSuite) TestSnapctlUnsuccesfulError(c *check.C) { s.daemon(c) - defer daemon.MockUcrednetGet(func(string) (int32, uint32, string, error) { - return 100, 9999, dirs.SnapSocket, nil + defer daemon.MockUcrednetGet(func(string) (*daemon.Ucrednet, error) { + return &daemon.Ucrednet{100, 9999, dirs.SnapSocket}, nil })() defer daemon.MockCtlcmdRun(func(ctx *hookstate.Context, arg []string, uid uint32) ([]byte, []byte, error) { diff --git a/daemon/daemon.go b/daemon/daemon.go index 79c830c2dcf..7137672adfd 100644 --- a/daemon/daemon.go +++ b/daemon/daemon.go @@ -150,14 +150,14 @@ func (c *Command) canAccess(r *http.Request, user *auth.UserState) accessResult // isUser means we have a UID for the request isUser := false - pid, uid, socket, err := ucrednetGet(r.RemoteAddr) + ucred, err := ucrednetGet(r.RemoteAddr) if err == nil { isUser = true } else if err != errNoID { logger.Noticef("unexpected error when attempting to get UID: %s", err) return accessForbidden } - isSnap := (socket == dirs.SnapSocket) + isSnap := (ucred != nil && ucred.Socket == dirs.SnapSocket) // ensure that snaps can only access SnapOK things if isSnap { @@ -184,7 +184,7 @@ func (c *Command) canAccess(r *http.Request, user *auth.UserState) accessResult return accessUnauthorized } - if uid == 0 { + if ucred.Uid == 0 { // Superuser does anything. return accessOK } @@ -204,7 +204,7 @@ func (c *Command) canAccess(r *http.Request, user *auth.UserState) accessResult } } // Pass both pid and uid from the peer ucred to avoid pid race - if authorized, err := polkitCheckAuthorization(pid, uid, c.PolkitOK, nil, flags); err == nil { + if authorized, err := polkitCheckAuthorization(ucred.Pid, ucred.Uid, c.PolkitOK, nil, flags); err == nil { if authorized { // polkit says user is authorised return accessOK diff --git a/daemon/export_test.go b/daemon/export_test.go index ef42c6b4ea2..1d8f68d19db 100644 --- a/daemon/export_test.go +++ b/daemon/export_test.go @@ -63,7 +63,9 @@ func (d *Daemon) RequestedRestart() state.RestartType { return d.requestedRestart } -func MockUcrednetGet(mock func(remoteAddr string) (pid int32, uid uint32, socket string, err error)) (restore func()) { +type Ucrednet = ucrednet + +func MockUcrednetGet(mock func(remoteAddr string) (ucred *Ucrednet, err error)) (restore func()) { oldUcrednetGet := ucrednetGet ucrednetGet = mock return func() { diff --git a/daemon/ucrednet.go b/daemon/ucrednet.go index 4f7e794e25e..a2eb01a65b5 100644 --- a/daemon/ucrednet.go +++ b/daemon/ucrednet.go @@ -40,40 +40,42 @@ var raddrRegexp = regexp.MustCompile(`^pid=(\d+);uid=(\d+);socket=([^;]*);$`) var ucrednetGet = ucrednetGetImpl -func ucrednetGetImpl(remoteAddr string) (pid int32, uid uint32, socket string, err error) { +func ucrednetGetImpl(remoteAddr string) (*ucrednet, error) { // NOTE treat remoteAddr at one point included a user-controlled // string. In case that happens again by accident, treat it as tainted, // and be very suspicious of it. - pid = ucrednetNoProcess - uid = ucrednetNobody + u := &ucrednet{ + Pid: ucrednetNoProcess, + Uid: ucrednetNobody, + } subs := raddrRegexp.FindStringSubmatch(remoteAddr) if subs != nil { if v, err := strconv.ParseInt(subs[1], 10, 32); err == nil { - pid = int32(v) + u.Pid = int32(v) } if v, err := strconv.ParseUint(subs[2], 10, 32); err == nil { - uid = uint32(v) + u.Uid = uint32(v) } - socket = subs[3] + u.Socket = subs[3] } - if pid == ucrednetNoProcess || uid == ucrednetNobody { - err = errNoID + if u.Pid == ucrednetNoProcess || u.Uid == ucrednetNobody { + return nil, errNoID } - return pid, uid, socket, err + return u, nil } type ucrednet struct { - pid int32 - uid uint32 - socket string + Pid int32 + Uid uint32 + Socket string } func (un *ucrednet) String() string { if un == nil { return "pid=;uid=;socket=;" } - return fmt.Sprintf("pid=%d;uid=%d;socket=%s;", un.pid, un.uid, un.socket) + return fmt.Sprintf("pid=%d;uid=%d;socket=%s;", un.Pid, un.Uid, un.Socket) } type ucrednetAddr struct { @@ -127,9 +129,9 @@ func (wl *ucrednetListener) Accept() (net.Conn, error) { } unet = &ucrednet{ - pid: ucred.Pid, - uid: ucred.Uid, - socket: ucon.LocalAddr().String(), + Pid: ucred.Pid, + Uid: ucred.Uid, + Socket: ucon.LocalAddr().String(), } } diff --git a/daemon/ucrednet_test.go b/daemon/ucrednet_test.go index 1e286980525..548680052f4 100644 --- a/daemon/ucrednet_test.go +++ b/daemon/ucrednet_test.go @@ -74,10 +74,10 @@ func (s *ucrednetSuite) TestAcceptConnRemoteAddrString(c *check.C) { remoteAddr := conn.RemoteAddr().String() c.Check(remoteAddr, check.Matches, "pid=100;uid=42;.*") - pid, uid, _, err := ucrednetGet(remoteAddr) - c.Check(pid, check.Equals, int32(100)) - c.Check(uid, check.Equals, uint32(42)) - c.Check(err, check.IsNil) + u, err := ucrednetGet(remoteAddr) + c.Assert(err, check.IsNil) + c.Check(u.Pid, check.Equals, int32(100)) + c.Check(u.Uid, check.Equals, uint32(42)) } func (s *ucrednetSuite) TestNonUnix(c *check.C) { @@ -101,9 +101,8 @@ func (s *ucrednetSuite) TestNonUnix(c *check.C) { remoteAddr := conn.RemoteAddr().String() c.Check(remoteAddr, check.Matches, "pid=;uid=;.*") - pid, uid, _, err := ucrednetGet(remoteAddr) - c.Check(pid, check.Equals, ucrednetNoProcess) - c.Check(uid, check.Equals, ucrednetNobody) + u, err := ucrednetGet(remoteAddr) + c.Check(u, check.IsNil) c.Check(err, check.Equals, errNoID) } @@ -157,45 +156,39 @@ func (s *ucrednetSuite) TestIdempotentClose(c *check.C) { } func (s *ucrednetSuite) TestGetNoUid(c *check.C) { - pid, uid, _, err := ucrednetGet("pid=100;uid=;socket=;") + u, err := ucrednetGet("pid=100;uid=;socket=;") c.Check(err, check.Equals, errNoID) - c.Check(pid, check.Equals, ucrednetNoProcess) - c.Check(uid, check.Equals, ucrednetNobody) + c.Check(u, check.IsNil) } func (s *ucrednetSuite) TestGetBadUid(c *check.C) { - pid, uid, _, err := ucrednetGet("pid=100;uid=4294967296;socket=;") - c.Check(err, check.NotNil) - c.Check(pid, check.Equals, int32(100)) - c.Check(uid, check.Equals, ucrednetNobody) + u, err := ucrednetGet("pid=100;uid=4294967296;socket=;") + c.Check(err, check.Equals, errNoID) + c.Check(u, check.IsNil) } func (s *ucrednetSuite) TestGetNonUcrednet(c *check.C) { - pid, uid, _, err := ucrednetGet("hello") + u, err := ucrednetGet("hello") c.Check(err, check.Equals, errNoID) - c.Check(pid, check.Equals, ucrednetNoProcess) - c.Check(uid, check.Equals, ucrednetNobody) + c.Check(u, check.IsNil) } func (s *ucrednetSuite) TestGetNothing(c *check.C) { - pid, uid, _, err := ucrednetGet("") + u, err := ucrednetGet("") c.Check(err, check.Equals, errNoID) - c.Check(pid, check.Equals, ucrednetNoProcess) - c.Check(uid, check.Equals, ucrednetNobody) + c.Check(u, check.IsNil) } func (s *ucrednetSuite) TestGet(c *check.C) { - pid, uid, socket, err := ucrednetGet("pid=100;uid=42;socket=/run/snap.socket;") - c.Check(err, check.IsNil) - c.Check(pid, check.Equals, int32(100)) - c.Check(uid, check.Equals, uint32(42)) - c.Check(socket, check.Equals, "/run/snap.socket") + u, err := ucrednetGet("pid=100;uid=42;socket=/run/snap.socket;") + c.Assert(err, check.IsNil) + c.Check(u.Pid, check.Equals, int32(100)) + c.Check(u.Uid, check.Equals, uint32(42)) + c.Check(u.Socket, check.Equals, "/run/snap.socket") } func (s *ucrednetSuite) TestGetSneak(c *check.C) { - pid, uid, socket, err := ucrednetGet("pid=100;uid=42;socket=/run/snap.socket;pid=0;uid=0;socket=/tmp/my.socket") + u, err := ucrednetGet("pid=100;uid=42;socket=/run/snap.socket;pid=0;uid=0;socket=/tmp/my.socket") c.Check(err, check.Equals, errNoID) - c.Check(pid, check.Equals, ucrednetNoProcess) - c.Check(uid, check.Equals, ucrednetNobody) - c.Check(socket, check.Equals, "") + c.Check(u, check.IsNil) } From fae0d5619c6d1ba5faa0680eba84e21901222aab Mon Sep 17 00:00:00 2001 From: James Henstridge Date: Fri, 9 Apr 2021 22:20:04 +0800 Subject: [PATCH 2/2] daemon: cleanups suggested in review --- daemon/api_snapctl_test.go | 4 ++-- daemon/daemon.go | 8 +++----- 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/daemon/api_snapctl_test.go b/daemon/api_snapctl_test.go index 9af3b32e368..20442978304 100644 --- a/daemon/api_snapctl_test.go +++ b/daemon/api_snapctl_test.go @@ -52,7 +52,7 @@ func (s *snapctlSuite) TestSnapctlForbiddenError(c *check.C) { s.daemon(c) defer daemon.MockUcrednetGet(func(string) (*daemon.Ucrednet, error) { - return &daemon.Ucrednet{100, 9999, dirs.SnapSocket}, nil + return &daemon.Ucrednet{Uid: 100, Pid: 9999, Socket: dirs.SnapSocket}, nil })() defer daemon.MockCtlcmdRun(func(ctx *hookstate.Context, arg []string, uid uint32) ([]byte, []byte, error) { @@ -70,7 +70,7 @@ func (s *snapctlSuite) TestSnapctlUnsuccesfulError(c *check.C) { s.daemon(c) defer daemon.MockUcrednetGet(func(string) (*daemon.Ucrednet, error) { - return &daemon.Ucrednet{100, 9999, dirs.SnapSocket}, nil + return &daemon.Ucrednet{Uid: 100, Pid: 9999, Socket: dirs.SnapSocket}, nil })() defer daemon.MockCtlcmdRun(func(ctx *hookstate.Context, arg []string, uid uint32) ([]byte, []byte, error) { diff --git a/daemon/daemon.go b/daemon/daemon.go index 7137672adfd..4fb763437cc 100644 --- a/daemon/daemon.go +++ b/daemon/daemon.go @@ -148,15 +148,13 @@ func (c *Command) canAccess(r *http.Request, user *auth.UserState) accessResult return accessOK } - // isUser means we have a UID for the request - isUser := false ucred, err := ucrednetGet(r.RemoteAddr) - if err == nil { - isUser = true - } else if err != errNoID { + if err != nil && err != errNoID { logger.Noticef("unexpected error when attempting to get UID: %s", err) return accessForbidden } + // isUser means we have a UID for the request + isUser := ucred != nil isSnap := (ucred != nil && ucred.Socket == dirs.SnapSocket) // ensure that snaps can only access SnapOK things