diff --git a/lib/srv/ctx.go b/lib/srv/ctx.go index 3fa5868fc9ab5..d008dea81f152 100644 --- a/lib/srv/ctx.go +++ b/lib/srv/ctx.go @@ -389,6 +389,11 @@ type ServerContext struct { contr *os.File contw *os.File + // ready{r,w} is used to send the ready signal from the child process + // to the parent process. + readyr *os.File + readyw *os.File + // killShell{r,w} are used to send kill signal to the child process // to terminate the shell. killShellr *os.File @@ -561,6 +566,15 @@ func NewServerContext(ctx context.Context, parent *sshutils.ConnectionContext, s child.AddCloser(child.contr) child.AddCloser(child.contw) + // Create pipe used to signal continue to parent process. + child.readyr, child.readyw, err = os.Pipe() + if err != nil { + childErr := child.Close() + return nil, nil, trace.NewAggregate(err, childErr) + } + child.AddCloser(child.readyr) + child.AddCloser(child.readyw) + child.killShellr, child.killShellw, err = os.Pipe() if err != nil { childErr := child.Close() diff --git a/lib/srv/exec.go b/lib/srv/exec.go index db8d81e932e0c..ba1139fdf3c6a 100644 --- a/lib/srv/exec.go +++ b/lib/srv/exec.go @@ -76,6 +76,10 @@ type Exec interface { // Wait will block while the command executes. Wait() *ExecResult + // WaitForChild blocks until the child process has completed any required + // setup operations before proceeding with execution. + WaitForChild() error + // Continue will resume execution of the process after it completes its // pre-processing routine (placed in a cgroup). Continue() @@ -219,6 +223,14 @@ func (e *localExec) Wait() *ExecResult { return execResult } +func (e *localExec) WaitForChild() error { + _, err := io.ReadFull(e.Ctx.readyr, make([]byte, 1)) + closeErr := e.Ctx.readyr.Close() + // Set to nil so the close in the context doesn't attempt to re-close. + e.Ctx.readyr = nil + return trace.Wrap(err, closeErr) +} + // Continue will resume execution of the process after it completes its // pre-processing routine (placed in a cgroup). func (e *localExec) Continue() { @@ -389,6 +401,8 @@ func (e *remoteExec) Wait() *ExecResult { } } +func (e *remoteExec) WaitForChild() error { return nil } + // Continue does nothing for remote command execution. func (e *remoteExec) Continue() {} diff --git a/lib/srv/exec_linux_test.go b/lib/srv/exec_linux_test.go index 687b8a678f25f..b146025a1630a 100644 --- a/lib/srv/exec_linux_test.go +++ b/lib/srv/exec_linux_test.go @@ -145,10 +145,11 @@ func TestContinue(t *testing.T) { case <-time.After(5 * time.Second): } - // Close the continue and terminate pipe to signal to Teleport to now execute the - // requested program. - err = scx.contw.Close() - require.NoError(t, err) + // Wait for the child process to indicate its completed initialization. + require.NoError(t, scx.execRequest.WaitForChild()) + + // Signal to child that it may execute the requested program. + scx.execRequest.Continue() // Program should have executed now. If the complete signal has not come // over the context, something failed. diff --git a/lib/srv/mock.go b/lib/srv/mock.go index 56735f7878d11..a51a3a96fd726 100644 --- a/lib/srv/mock.go +++ b/lib/srv/mock.go @@ -96,6 +96,9 @@ func newTestServerContext(t *testing.T, srv Server, roleSet services.RoleSet) *S scx.contr, scx.contw, err = os.Pipe() require.NoError(t, err) + scx.readyr, scx.readyw, err = os.Pipe() + require.NoError(t, err) + scx.killShellr, scx.killShellw, err = os.Pipe() require.NoError(t, err) diff --git a/lib/srv/reexec.go b/lib/srv/reexec.go index d0e87c9576d46..9d35d7efca20f 100644 --- a/lib/srv/reexec.go +++ b/lib/srv/reexec.go @@ -60,8 +60,10 @@ const ( // it can continue after the parent process assigns a cgroup to the // child process. ContinueFile - - ContinueWriteFile + // ReadyFile is used to communicate to the parent process that + // the child has completed any setup operations that must occur before + // the child is placed into its cgroup. + ReadyFile // TerminateFile is used to communicate to the child process that // the interactive terminal should be killed as the client ended the // SSH session and without termination the terminal process will be assigned @@ -71,7 +73,7 @@ const ( // X11File is used to communicate to the parent process that the child // process has set up X11 forwarding. X11File - // ErrorFile is used to communicate any errors terminating the child process + // ErrorFile is used to communicate any errors terminating the child process // to the parent process ErrorFile // PTYFile is a PTY the parent process passes to the child process. @@ -215,10 +217,23 @@ func RunCommand() (errw io.Writer, code int, err error) { if contfd == nil { return errorWriter, teleport.RemoteCommandFailure, trace.BadParameter("continue pipe not found") } - contWfd := os.NewFile(ContinueWriteFile, fdName(ContinueWriteFile)) - if contfd == nil { + readyfd := os.NewFile(ReadyFile, fdName(ReadyFile)) + if readyfd == nil { return errorWriter, teleport.RemoteCommandFailure, trace.BadParameter("continue write pipe not found") } + + var ready bool + // Ensure that the ready signal is sent if a failure causes execution + // to terminate prior to actually becoming ready to unblock the parent process. + defer func() { + if ready { + return + } + + _, _ = readyfd.Write([]byte{0}) + _ = readyfd.Close() + }() + termiantefd := os.NewFile(TerminateFile, fdName(TerminateFile)) if termiantefd == nil { return errorWriter, teleport.RemoteCommandFailure, trace.BadParameter("terminate pipe not found") @@ -321,11 +336,18 @@ func RunCommand() (errw io.Writer, code int, err error) { // Save off any environment variables that come from PAM. pamEnvironment = pamContext.Environment() + } - if _, err = contWfd.Write([]byte{0}); err != nil { - return errorWriter, teleport.RemoteCommandFailure, trace.Wrap(err) - } - contWfd.Close() // TODO: Handle error maybe? + // Alert the parent process that the child process has completed any setup operations, + // and that we are now waiting for the continue signal before proceeding. This is needed + // to ensure that PAM changing the cgroup doesn't bypass enhanced recording. + ready = true + if _, err := readyfd.Write([]byte{0}); err != nil { + return errorWriter, teleport.RemoteCommandFailure, trace.Wrap(err) + } + + if err := readyfd.Close(); err != nil { + return errorWriter, teleport.RemoteCommandFailure, trace.Wrap(err) } localUser, err := user.Lookup(c.Login) @@ -911,7 +933,7 @@ func ConfigureCommand(ctx *ServerContext, extraFiles ...*os.File) (*exec.Cmd, er ExtraFiles: []*os.File{ ctx.cmdr, ctx.contr, - ctx.contw, + ctx.readyw, ctx.killShellr, ctx.x11rdyw, ctx.errw, diff --git a/lib/srv/sess.go b/lib/srv/sess.go index 35e3255bca488..fa2ddc0319b09 100644 --- a/lib/srv/sess.go +++ b/lib/srv/sess.go @@ -1190,7 +1190,7 @@ func (s *session) startInteractive(ctx context.Context, scx *ServerContext, p *p Events: scx.Identity.AccessChecker.EnhancedRecordingSet(), } - if err := s.term.WaitForPam(); err != nil { + if err := s.term.WaitForChild(); err != nil { return trace.Wrap(err) } @@ -1358,7 +1358,9 @@ func (s *session) startExec(ctx context.Context, channel ssh.Channel, scx *Serve Events: scx.Identity.AccessChecker.EnhancedRecordingSet(), } - // TODO Wait for PAM? + if err := execRequest.WaitForChild(); err != nil { + return trace.Wrap(err) + } cgroupID, err := scx.srv.GetBPF().OpenSession(sessionContext) if err != nil { diff --git a/lib/srv/term.go b/lib/srv/term.go index aab0d3a9efd08..862f114fc2bc6 100644 --- a/lib/srv/term.go +++ b/lib/srv/term.go @@ -61,7 +61,9 @@ type Terminal interface { // Wait will block until the terminal is complete. Wait() (*ExecResult, error) - WaitForPam() error + // WaitForChild blocks until the child process has completed any required + // setup operations before proceeding with execution. + WaitForChild() error // Continue will resume execution of the process after it completes its // pre-processing routine (placed in a cgroup). @@ -238,8 +240,8 @@ func (t *terminal) Wait() (*ExecResult, error) { }, nil } -func (t *terminal) WaitForPam() error { - _, err := io.ReadFull(t.serverContext.contr, make([]byte, 1)) +func (t *terminal) WaitForChild() error { + _, err := io.ReadFull(t.serverContext.readyr, make([]byte, 1)) return trace.Wrap(err) } @@ -599,7 +601,7 @@ func (t *remoteTerminal) Wait() (*ExecResult, error) { }, nil } -func (t *remoteTerminal) WaitForPam() error { +func (t *remoteTerminal) WaitForChild() error { return nil } diff --git a/lib/srv/term_test.go b/lib/srv/term_test.go index 90f420ae5174c..7489737f1d5a8 100644 --- a/lib/srv/term_test.go +++ b/lib/srv/term_test.go @@ -116,9 +116,11 @@ func TestTerminal_KillUnderlyingShell(t *testing.T) { errors <- err }() + // Wait for the child process to indicate its completed initialization. + require.NoError(t, scx.execRequest.WaitForChild()) + // Continue execution - err = scx.contw.Close() - require.NoError(t, err) + scx.execRequest.Continue() ctx, cancel := context.WithTimeout(ctx, 5*time.Second) t.Cleanup(cancel)