Skip to content

Commit

Permalink
Update Cmd IO handling (#1937)
Browse files Browse the repository at this point in the history
Update `Cmd.Wait` to return a known error value if it times out waiting
on IO copy after the command exits (and update `TestCmdStuckIo` to check
for that error).
Prior, the test checked for an `io.ErrClosedPipe`, which:
1. is not the best indicator that IO is stuck; and
2. is now ignored as an error value raised during IO relay.

Update `stuckIOProcess` logic in `cmd_test.go` to mirror logic in
`interal/exec.Exec`, using `os.Pipe` for std io that returns an `io.EOF`
(instead of `io.Pipe`, which does not).

Signed-off-by: Hamza El-Saawy <[email protected]>
  • Loading branch information
helsaawy authored Feb 20, 2024
1 parent 5f9910a commit 7458e58
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 29 deletions.
53 changes: 42 additions & 11 deletions internal/cmd/cmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ package cmd
import (
"bytes"
"context"
"errors"
"fmt"
"io"
"strings"
Expand All @@ -20,6 +21,8 @@ import (
"golang.org/x/sys/windows"
)

var errIOTimeOut = errors.New("timed out waiting for stdio relay")

// CmdProcessRequest stores information on command requests made through this package.
type CmdProcessRequest struct {
Args []string
Expand Down Expand Up @@ -62,7 +65,7 @@ type Cmd struct {
// ExitState is filled out after Wait() (or Run() or Output()) completes.
ExitState *ExitState

iogrp errgroup.Group
ioGrp errgroup.Group
stdinErr atomic.Value
allDoneCh chan struct{}
}
Expand Down Expand Up @@ -90,13 +93,13 @@ func (err *ExitError) Error() string {
return fmt.Sprintf("process exited with exit code %d", err.ExitCode())
}

// Additional fields to hcsschema.ProcessParameters used by LCOW
// Additional fields to hcsschema.ProcessParameters used by LCOW.
type lcowProcessParameters struct {
hcsschema.ProcessParameters
OCIProcess *specs.Process `json:"OciProcess,omitempty"`
}

// escapeArgs makes a Windows-style escaped command line from a set of arguments
// escapeArgs makes a Windows-style escaped command line from a set of arguments.
func escapeArgs(args []string) string {
escapedArgs := make([]string, len(args))
for i, a := range args {
Expand Down Expand Up @@ -136,9 +139,19 @@ func CommandContext(ctx context.Context, host cow.ProcessHost, name string, arg
// Start starts a command. The caller must ensure that if Start succeeds,
// Wait is eventually called to clean up resources.
func (c *Cmd) Start() error {
if c.Host == nil {
return errors.New("empty ProcessHost")
}

// closed in (*Cmd).Wait; signals command execution is done
c.allDoneCh = make(chan struct{})

var x interface{}
if !c.Host.IsOCI() {
if c.Spec == nil {
return errors.New("process spec is required for non-OCI ProcessHost")
}

wpp := &hcsschema.ProcessParameters{
CommandLine: c.Spec.CommandLine,
User: c.Spec.User.Username,
Expand Down Expand Up @@ -199,6 +212,7 @@ func (c *Cmd) Start() error {
// Start relaying process IO.
stdin, stdout, stderr := p.Stdio()
if c.Stdin != nil {
c.Log.Info("coping stdin")
// Do not make stdin part of the error group because there is no way for
// us or the caller to reliably unblock the c.Stdin read when the
// process exits.
Expand All @@ -218,20 +232,20 @@ func (c *Cmd) Start() error {
}

if c.Stdout != nil {
c.iogrp.Go(func() error {
c.ioGrp.Go(func() error {
_, err := relayIO(c.Stdout, stdout, c.Log, "stdout")
if err := p.CloseStdout(context.TODO()); err != nil {
c.Log.WithError(err).Warn("failed to close Cmd stdout")
if cErr := p.CloseStdout(context.TODO()); cErr != nil && c.Log != nil {
c.Log.WithError(cErr).Warn("failed to close Cmd stdout")
}
return err
})
}

if c.Stderr != nil {
c.iogrp.Go(func() error {
c.ioGrp.Go(func() error {
_, err := relayIO(c.Stderr, stderr, c.Log, "stderr")
if err := p.CloseStderr(context.TODO()); err != nil {
c.Log.WithError(err).Warn("failed to close Cmd stderr")
if cErr := p.CloseStderr(context.TODO()); cErr != nil && c.Log != nil {
c.Log.WithError(cErr).Warn("failed to close Cmd stderr")
}
return err
})
Expand Down Expand Up @@ -270,27 +284,44 @@ func (c *Cmd) Wait() error {
state.exited = true
state.code = code
}

// Terminate the IO if the copy does not complete in the requested time.
// Closing the process should (eventually) lead to unblocking `ioGrp`, but we still need
// `timeoutErrCh` to:
// 1. communitate that the IO copy timed out; and
// 2. prevent a race condition between setting the timeout err in the goroutine and setting it for `ioErr`.
timeoutErrCh := make(chan error)
if c.CopyAfterExitTimeout != 0 {
go func() {
defer close(timeoutErrCh)
t := time.NewTimer(c.CopyAfterExitTimeout)
defer t.Stop()
select {
case <-c.allDoneCh:
case <-t.C:
// Close the process to cancel any reads to stdout or stderr.
c.Process.Close()
err := errIOTimeOut
// log the timeout, since we may not return it to the caller
if c.Log != nil {
c.Log.Warn("timed out waiting for stdio relay")
c.Log.WithField("timeout", c.CopyAfterExitTimeout).Warn(err.Error())
}
timeoutErrCh <- err
}
}()
} else {
close(timeoutErrCh)
}
ioErr := c.iogrp.Wait()

// TODO (go1.20): use multierror for these
ioErr := c.ioGrp.Wait()
if ioErr == nil {
ioErr, _ = c.stdinErr.Load().(error)
}
close(c.allDoneCh)
if tErr := <-timeoutErrCh; ioErr == nil {
ioErr = tErr
}
c.Process.Close()
c.ExitState = state
if exitErr != nil {
Expand Down
56 changes: 40 additions & 16 deletions internal/cmd/cmd_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"bytes"
"context"
"errors"
"fmt"
"io"
"os"
"os/exec"
Expand Down Expand Up @@ -213,46 +214,69 @@ func TestCmdStdinBlocked(t *testing.T) {
}
}

type stuckIoProcessHost struct {
type stuckIOProcessHost struct {
cow.ProcessHost
}

type stuckIoProcess struct {
type stuckIOProcess struct {
cow.Process
stdin, pstdout, pstderr *io.PipeWriter
pstdin, stdout, stderr *io.PipeReader

// don't initialize p.stdin, since it complicates the logic
pstdout, pstderr *os.File
stdout, stderr *os.File
}

func (h *stuckIoProcessHost) CreateProcess(ctx context.Context, cfg interface{}) (cow.Process, error) {
func (h *stuckIOProcessHost) CreateProcess(ctx context.Context, cfg interface{}) (cow.Process, error) {
p, err := h.ProcessHost.CreateProcess(ctx, cfg)
if err != nil {
return nil, err
}
sp := &stuckIoProcess{
sp := &stuckIOProcess{
Process: p,
}
sp.pstdin, sp.stdin = io.Pipe()
sp.stdout, sp.pstdout = io.Pipe()
sp.stderr, sp.pstderr = io.Pipe()

if sp.stdout, sp.pstdout, err = os.Pipe(); err != nil {
return nil, fmt.Errorf("create stdout pipe: %w", err)
}
if sp.stderr, sp.pstderr, err = os.Pipe(); err != nil {
return nil, fmt.Errorf("create stderr pipe: %w", err)
}
return sp, nil
}

func (p *stuckIoProcess) Stdio() (io.Writer, io.Reader, io.Reader) {
return p.stdin, p.stdout, p.stderr
func (p *stuckIOProcess) Stdio() (io.Writer, io.Reader, io.Reader) {
return nil, p.stdout, p.stderr
}

func (p *stuckIoProcess) Close() error {
p.stdin.Close()
func (*stuckIOProcess) CloseStdin(context.Context) error {
return nil
}

func (p *stuckIOProcess) CloseStdout(context.Context) error {
_ = p.pstdout.Close()
return p.stdout.Close()
}

func (p *stuckIOProcess) CloseStderr(context.Context) error {
_ = p.pstderr.Close()
return p.stderr.Close()
}

func (p *stuckIOProcess) Close() error {
p.pstdout.Close()
p.pstderr.Close()

p.stdout.Close()
p.stderr.Close()

return p.Process.Close()
}

func TestCmdStuckIo(t *testing.T) {
cmd := Command(&stuckIoProcessHost{&localProcessHost{}}, "cmd", "/c", "echo", "hello")
cmd := Command(&stuckIOProcessHost{&localProcessHost{}}, "cmd", "/c", "(exit 0)")
cmd.CopyAfterExitTimeout = time.Millisecond * 200
_, err := cmd.Output()
if err != io.ErrClosedPipe { //nolint:errorlint
t.Fatal(err)
if !errors.Is(err, errIOTimeOut) {
t.Fatalf("expected: %v; got: %v", errIOTimeOut, err)
}
}
4 changes: 2 additions & 2 deletions internal/cmd/io.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@ package cmd

import (
"context"
"fmt"
"io"
"net/url"
"time"

"github.com/pkg/errors"
"github.com/sirupsen/logrus"
)

Expand Down Expand Up @@ -57,7 +57,7 @@ func NewUpstreamIO(ctx context.Context, id, stdout, stderr, stdin string, termin

// Create IO for binary logging driver.
if u.Scheme != "binary" {
return nil, errors.Errorf("scheme must be 'binary', got: '%s'", u.Scheme)
return nil, fmt.Errorf("scheme must be 'binary', got: '%s'", u.Scheme)
}

return NewBinaryIO(ctx, id, u)
Expand Down

0 comments on commit 7458e58

Please sign in to comment.