diff --git a/lib/srv/ctx.go b/lib/srv/ctx.go index 12ed439833853..556f6497a87a0 100644 --- a/lib/srv/ctx.go +++ b/lib/srv/ctx.go @@ -389,6 +389,11 @@ type ServerContext struct { contr *os.File contw *os.File + // ready{r,w} is used to send the ready signal from the child process + // to the parent process. + readyr *os.File + readyw *os.File + // killShell{r,w} are used to send kill signal to the child process // to terminate the shell. killShellr *os.File @@ -551,6 +556,15 @@ func NewServerContext(ctx context.Context, parent *sshutils.ConnectionContext, s child.AddCloser(child.contr) child.AddCloser(child.contw) + // Create pipe used to signal continue to parent process. + child.readyr, child.readyw, err = os.Pipe() + if err != nil { + childErr := child.Close() + return nil, nil, trace.NewAggregate(err, childErr) + } + child.AddCloser(child.readyr) + child.AddCloser(child.readyw) + child.killShellr, child.killShellw, err = os.Pipe() if err != nil { childErr := child.Close() diff --git a/lib/srv/exec.go b/lib/srv/exec.go index 571a0485dfa5f..784edef1b24d6 100644 --- a/lib/srv/exec.go +++ b/lib/srv/exec.go @@ -75,6 +75,10 @@ type Exec interface { // Wait will block while the command executes. Wait() *ExecResult + // WaitForChild blocks until the child process has completed any required + // setup operations before proceeding with execution. + WaitForChild() error + // Continue will resume execution of the process after it completes its // pre-processing routine (placed in a cgroup). Continue() @@ -173,6 +177,12 @@ func (e *localExec) Start(ctx context.Context, channel ssh.Channel) (*ExecResult Code: exitCode(err), }, trace.ConvertSystemError(err) } + // Close our half of the write pipe since it is only to be used by the child process. + // Not closing prevents being signaled when the child closes its half. + if err := e.Ctx.readyw.Close(); err != nil { + e.Ctx.Logger.WithError(err).Warn("Failed to close parent process ready signal write fd") + } + e.Ctx.readyw = nil go func() { if _, err := io.Copy(inputWriter, channel); err != nil { @@ -211,6 +221,14 @@ func (e *localExec) Wait() *ExecResult { return execResult } +func (e *localExec) WaitForChild() error { + err := waitForSignal(e.Ctx.readyr, 20*time.Second) + closeErr := e.Ctx.readyr.Close() + // Set to nil so the close in the context doesn't attempt to re-close. + e.Ctx.readyr = nil + return trace.NewAggregate(err, closeErr) +} + // Continue will resume execution of the process after it completes its // pre-processing routine (placed in a cgroup). func (e *localExec) Continue() { @@ -262,28 +280,26 @@ func (e *localExec) transformSecureCopy() error { return nil } -// waitForContinue will wait 10 seconds for the continue signal, if not +// waitForContinue will wait for the provided timeout for the continue signal, if not // received, it will stop waiting and exit. -func waitForContinue(contfd *os.File) error { +func waitForSignal(fd *os.File, timeout time.Duration) error { waitCh := make(chan error, 1) go func() { - // Reading from the continue file descriptor will block until it's closed. It - // won't be closed until the parent has placed it in a cgroup. - buf := make([]byte, 1) - _, err := contfd.Read(buf) + // Reading from the file descriptor will block until it's closed. + _, err := fd.Read(make([]byte, 1)) if errors.Is(err, io.EOF) { err = nil } waitCh <- err }() - // Wait for 10 seconds and then timeout if no continue signal has been sent. - timeout := time.NewTimer(10 * time.Second) - defer timeout.Stop() + // Timeout if no signal has been sent within the provided duration. + timer := time.NewTimer(timeout) + defer timer.Stop() select { - case <-timeout.C: - return trace.BadParameter("timed out waiting for continue signal") + case <-timer.C: + return trace.LimitExceeded("timed out waiting for continue signal") case err := <-waitCh: return err } @@ -357,6 +373,8 @@ func (e *remoteExec) Wait() *ExecResult { } } +func (e *remoteExec) WaitForChild() error { return nil } + // Continue does nothing for remote command execution. func (e *remoteExec) Continue() {} diff --git a/lib/srv/exec_linux_test.go b/lib/srv/exec_linux_test.go index 23f1e0c0b359b..f91abdfb48123 100644 --- a/lib/srv/exec_linux_test.go +++ b/lib/srv/exec_linux_test.go @@ -29,6 +29,7 @@ import ( "testing" "time" + "github.com/gravitational/trace" "github.com/stretchr/testify/require" ) @@ -135,7 +136,13 @@ func TestContinue(t *testing.T) { // Re-execute Teleport and run "ls". Signal over the context when execution // is complete. go func() { - cmdDone <- cmd.Run() + if err := cmd.Start(); err != nil { + cmdDone <- err + } + + // Close the read half of the pipe to unblock the ready signal. + closeErr := scx.readyw.Close() + cmdDone <- trace.NewAggregate(closeErr, cmd.Wait()) }() // Wait for the process. Since the continue pipe has not been closed, the @@ -146,10 +153,11 @@ func TestContinue(t *testing.T) { case <-time.After(5 * time.Second): } - // Close the continue and terminate pipe to signal to Teleport to now execute the - // requested program. - err = scx.contw.Close() - require.NoError(t, err) + // Wait for the child process to indicate its completed initialization. + require.NoError(t, scx.execRequest.WaitForChild()) + + // Signal to child that it may execute the requested program. + scx.execRequest.Continue() // Program should have executed now. If the complete signal has not come // over the context, something failed. diff --git a/lib/srv/mock.go b/lib/srv/mock.go index 12a8b8f0f02db..0d402b40398d3 100644 --- a/lib/srv/mock.go +++ b/lib/srv/mock.go @@ -95,6 +95,9 @@ func newTestServerContext(t *testing.T, srv Server, roleSet services.RoleSet) *S scx.contr, scx.contw, err = os.Pipe() require.NoError(t, err) + scx.readyr, scx.readyw, err = os.Pipe() + require.NoError(t, err) + scx.killShellr, scx.killShellw, err = os.Pipe() require.NoError(t, err) diff --git a/lib/srv/reexec.go b/lib/srv/reexec.go index f9050aec5a97f..9eded12530a1e 100644 --- a/lib/srv/reexec.go +++ b/lib/srv/reexec.go @@ -58,6 +58,10 @@ const ( // it can continue after the parent process assigns a cgroup to the // child process. ContinueFile + // ReadyFile is used to communicate to the parent process that + // the child has completed any setup operations that must occur before + // the child is placed into its cgroup. + ReadyFile // TerminateFile is used to communicate to the child process that // the interactive terminal should be killed as the client ended the // SSH session and without termination the terminal process will be assigned @@ -67,7 +71,7 @@ const ( // X11File is used to communicate to the parent process that the child // process has set up X11 forwarding. X11File - // ErrorFile is used to communicate any errors terminating the child process + // ErrorFile is used to communicate any errors terminating the child process // to the parent process ErrorFile // PTYFile is a PTY the parent process passes to the child process. @@ -205,6 +209,21 @@ func RunCommand() (errw io.Writer, code int, err error) { if contfd == nil { return errorWriter, teleport.RemoteCommandFailure, trace.BadParameter("continue pipe not found") } + readyfd := os.NewFile(ReadyFile, fdName(ReadyFile)) + if readyfd == nil { + return errorWriter, teleport.RemoteCommandFailure, trace.BadParameter("ready pipe not found") + } + + // Ensure that the ready signal is sent if a failure causes execution + // to terminate prior to actually becoming ready to unblock the parent process. + defer func() { + if readyfd == nil { + return + } + + _ = readyfd.Close() + }() + termiantefd := os.NewFile(TerminateFile, fdName(TerminateFile)) if termiantefd == nil { return errorWriter, teleport.RemoteCommandFailure, trace.BadParameter("terminate pipe not found") @@ -309,6 +328,14 @@ func RunCommand() (errw io.Writer, code int, err error) { pamEnvironment = pamContext.Environment() } + // Alert the parent process that the child process has completed any setup operations, + // and that we are now waiting for the continue signal before proceeding. This is needed + // to ensure that PAM changing the cgroup doesn't bypass enhanced recording. + if err := readyfd.Close(); err != nil { + return errorWriter, teleport.RemoteCommandFailure, trace.Wrap(err) + } + readyfd = nil + localUser, err := user.Lookup(c.Login) if err != nil { return errorWriter, teleport.RemoteCommandFailure, trace.Wrap(err) @@ -322,7 +349,7 @@ func RunCommand() (errw io.Writer, code int, err error) { // Wait until the continue signal is received from Teleport signaling that // the child process has been placed in a cgroup. - err = waitForContinue(contfd) + err = waitForSignal(contfd, 10*time.Second) if err != nil { return errorWriter, teleport.RemoteCommandFailure, trace.Wrap(err) } @@ -890,6 +917,7 @@ func ConfigureCommand(ctx *ServerContext, extraFiles ...*os.File) (*exec.Cmd, er ExtraFiles: []*os.File{ ctx.cmdr, ctx.contr, + ctx.readyw, ctx.killShellr, ctx.x11rdyw, ctx.errw, diff --git a/lib/srv/sess.go b/lib/srv/sess.go index 364a4358e6b63..401e49e3e300d 100644 --- a/lib/srv/sess.go +++ b/lib/srv/sess.go @@ -1007,6 +1007,11 @@ func (s *session) startInteractive(ctx context.Context, ch ssh.Channel, scx *Ser Events: scx.Identity.AccessChecker.EnhancedRecordingSet(), } + if err := s.term.WaitForChild(); err != nil { + s.log.WithError(err).Error("Child process never became ready") + return trace.Wrap(err) + } + if cgroupID, err := scx.srv.GetBPF().OpenSession(sessionContext); err != nil { s.log.WithError(err).Error("Failed to open enhanced recording (interactive) session") return trace.Wrap(err) @@ -1198,6 +1203,12 @@ func (s *session) startExec(ctx context.Context, channel ssh.Channel, scx *Serve User: scx.Identity.TeleportUser, Events: scx.Identity.AccessChecker.EnhancedRecordingSet(), } + + if err := execRequest.WaitForChild(); err != nil { + s.log.WithError(err).Error("Child process never became ready") + return trace.Wrap(err) + } + cgroupID, err := scx.srv.GetBPF().OpenSession(sessionContext) if err != nil { s.log.WithError(err).Errorf("Failed to open enhanced recording (exec) session: %v", execRequest.GetCommand()) diff --git a/lib/srv/term.go b/lib/srv/term.go index 16d179ef491a7..c7d9a40355958 100644 --- a/lib/srv/term.go +++ b/lib/srv/term.go @@ -60,6 +60,10 @@ type Terminal interface { // Wait will block until the terminal is complete. Wait() (*ExecResult, error) + // WaitForChild blocks until the child process has completed any required + // setup operations before proceeding with execution. + WaitForChild() error + // Continue will resume execution of the process after it completes its // pre-processing routine (placed in a cgroup). Continue() @@ -207,6 +211,13 @@ func (t *terminal) Run(ctx context.Context) error { return trace.Wrap(err) } + // Close our half of the write pipe since it is only to be used by the child process. + // Not closing prevents being signaled when the child closes its half. + if err := t.serverContext.readyw.Close(); err != nil { + t.serverContext.Logger.WithError(err).Warn("Failed to close parent process ready signal write fd") + } + t.serverContext.readyw = nil + // Save off the PID of the Teleport process under which the shell is executing. t.pid = t.cmd.Process.Pid @@ -235,6 +246,14 @@ func (t *terminal) Wait() (*ExecResult, error) { }, nil } +func (t *terminal) WaitForChild() error { + err := waitForSignal(t.serverContext.readyr, 20*time.Second) + closeErr := t.serverContext.readyr.Close() + // Set to nil so the close in the context doesn't attempt to re-close. + t.serverContext.readyr = nil + return trace.NewAggregate(err, closeErr) +} + // Continue will resume execution of the process after it completes its // pre-processing routine (placed in a cgroup). func (t *terminal) Continue() { @@ -591,6 +610,10 @@ func (t *remoteTerminal) Wait() (*ExecResult, error) { }, nil } +func (t *remoteTerminal) WaitForChild() error { + return nil +} + // Continue does nothing for remote command execution. func (t *remoteTerminal) Continue() {} diff --git a/lib/srv/term_test.go b/lib/srv/term_test.go index 90f420ae5174c..7489737f1d5a8 100644 --- a/lib/srv/term_test.go +++ b/lib/srv/term_test.go @@ -116,9 +116,11 @@ func TestTerminal_KillUnderlyingShell(t *testing.T) { errors <- err }() + // Wait for the child process to indicate its completed initialization. + require.NoError(t, scx.execRequest.WaitForChild()) + // Continue execution - err = scx.contw.Close() - require.NoError(t, err) + scx.execRequest.Continue() ctx, cancel := context.WithTimeout(ctx, 5*time.Second) t.Cleanup(cancel)