diff --git a/cmd.go b/cmd.go index ea8c4940c2..4e855286a5 100644 --- a/cmd.go +++ b/cmd.go @@ -17,7 +17,6 @@ import ( type monitoredCmd struct { cmd *exec.Cmd timeout time.Duration - ctx context.Context stdout *activityBuffer stderr *activityBuffer } @@ -68,7 +67,7 @@ func (c *monitoredCmd) run(ctx context.Context) error { return &killCmdError{err} } } - return c.ctx.Err() + return ctx.Err() case err := <-done: return err } diff --git a/cmd_test.go b/cmd_test.go index 70ffa0ef58..213ae6aa06 100644 --- a/cmd_test.go +++ b/cmd_test.go @@ -17,7 +17,7 @@ func mkTestCmd(iterations int) *monitoredCmd { } func TestMonitoredCmd(t *testing.T) { - // Sleeps make this a bit slow + // Sleeps and compile make this a bit slow if testing.Short() { t.Skip("skipping test with sleeps on short") } @@ -54,4 +54,23 @@ func TestMonitoredCmd(t *testing.T) { if cmd2.stdout.buf.String() != expectedOutput { t.Errorf("Unexpected output:\n\t(GOT): %s\n\t(WNT): %s", cmd2.stdout.buf.String(), expectedOutput) } + + ctx, cancel := context.WithCancel(context.Background()) + sync1, errchan := make(chan struct{}), make(chan error) + cmd3 := mkTestCmd(2) + go func() { + close(sync1) + errchan <- cmd3.run(ctx) + }() + + // Make sure goroutine is at least started before we cancel the context. + <-sync1 + // Give it a bit to get the process started. + <-time.After(5 * time.Millisecond) + cancel() + + err = <-errchan + if err != context.Canceled { + t.Errorf("should have gotten canceled error, got %s", err) + } }