Skip to content

Commit

Permalink
fix: hoist signal handling from prompt confirmation
Browse files Browse the repository at this point in the history
Signed-off-by: Alano Terblanche <[email protected]>
  • Loading branch information
Benehiko committed Apr 9, 2024
1 parent c23a404 commit 47b6915
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 19 deletions.
10 changes: 1 addition & 9 deletions cli/command/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,9 @@ import (
"fmt"
"io"
"os"
"os/signal"
"path/filepath"
"runtime"
"strings"
"syscall"

"github.com/docker/cli/cli/streams"
"github.com/docker/docker/api/types/filters"
Expand Down Expand Up @@ -103,11 +101,6 @@ func PromptForConfirmation(ctx context.Context, ins io.Reader, outs io.Writer, m

result := make(chan bool)

// Catch the termination signal and exit the prompt gracefully.
// The caller is responsible for properly handling the termination.
notifyCtx, notifyCancel := signal.NotifyContext(ctx, syscall.SIGINT, syscall.SIGTERM)
defer notifyCancel()

go func() {
var res bool
scanner := bufio.NewScanner(ins)
Expand All @@ -121,8 +114,7 @@ func PromptForConfirmation(ctx context.Context, ins io.Reader, outs io.Writer, m
}()

select {
case <-notifyCtx.Done():
// print a newline on termination
case <-ctx.Done():
_, _ = fmt.Fprintln(outs, "")
return false, ErrPromptTerminated
case r := <-result:
Expand Down
6 changes: 5 additions & 1 deletion cli/command/utils_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"fmt"
"io"
"os"
"os/signal"
"path/filepath"
"strings"
"syscall"
Expand Down Expand Up @@ -135,6 +136,9 @@ func TestPromptForConfirmation(t *testing.T) {
}, promptResult{false, nil}},
} {
t.Run("case="+tc.desc, func(t *testing.T) {
notifyCtx, notifyCancel := signal.NotifyContext(ctx, syscall.SIGINT, syscall.SIGTERM)
t.Cleanup(notifyCancel)

buf.Reset()
promptReader, promptWriter = io.Pipe()

Expand All @@ -145,7 +149,7 @@ func TestPromptForConfirmation(t *testing.T) {

result := make(chan promptResult, 1)
go func() {
r, err := command.PromptForConfirmation(ctx, promptReader, promptOut, "")
r, err := command.PromptForConfirmation(notifyCtx, promptReader, promptOut, "")
result <- promptResult{r, err}
}()

Expand Down
21 changes: 15 additions & 6 deletions cmd/docker/docker.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,12 @@ import (
)

func main() {
ctx := context.Background()
baseCtx, cancel := context.WithCancel(context.Background())
defer cancel()

ctx, cancelNotify := signal.NotifyContext(baseCtx, platformsignals.TerminationSignals...)
defer cancelNotify()

dockerCli, err := command.NewDockerCli(command.WithBaseContext(ctx))
if err != nil {
fmt.Fprintln(os.Stderr, err)
Expand Down Expand Up @@ -223,7 +228,7 @@ func setValidateArgs(dockerCli command.Cli, cmd *cobra.Command) {
})
}

func tryPluginRun(dockerCli command.Cli, cmd *cobra.Command, subcommand string, envs []string) error {
func tryPluginRun(ctx context.Context, dockerCli command.Cli, cmd *cobra.Command, subcommand string, envs []string) error {
plugincmd, err := pluginmanager.PluginRunCommand(dockerCli, subcommand, cmd)
if err != nil {
return err
Expand All @@ -241,11 +246,15 @@ func tryPluginRun(dockerCli command.Cli, cmd *cobra.Command, subcommand string,

// Background signal handling logic: block on the signals channel, and
// notify the plugin via the PluginServer (or signal) as appropriate.
const exitLimit = 3
signals := make(chan os.Signal, exitLimit)
signal.Notify(signals, platformsignals.TerminationSignals...)
const exitLimit = 2

go func() {
retries := 0
<-ctx.Done()

signals := make(chan os.Signal, exitLimit)
signal.Notify(signals, platformsignals.TerminationSignals...)

for range signals {
// If stdin is a TTY, the kernel will forward
// signals to the subprocess because the shared
Expand Down Expand Up @@ -333,7 +342,7 @@ func runDocker(ctx context.Context, dockerCli *command.DockerCli) error {
ccmd, _, err := cmd.Find(args)
subCommand = ccmd
if err != nil || pluginmanager.IsPluginCommand(ccmd) {
err := tryPluginRun(dockerCli, cmd, args[0], envs)
err := tryPluginRun(ctx, dockerCli, cmd, args[0], envs)
if err == nil {
if dockerCli.HooksEnabled() && dockerCli.Out().IsTerminal() && ccmd != nil {
_ = pluginmanager.RunPluginHooks(dockerCli, cmd, ccmd, args[0], args)
Expand Down
8 changes: 5 additions & 3 deletions internal/test/cmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package test
import (
"context"
"os"
"syscall"
"testing"
"time"

Expand Down Expand Up @@ -32,8 +31,11 @@ func TerminatePrompt(ctx context.Context, t *testing.T, cmd *cobra.Command, cli
assert.NilError(t, err)
cli.SetIn(streams.NewIn(r))

notifyCtx, notifyCancel := context.WithCancel(ctx)
t.Cleanup(notifyCancel)

go func() {
errChan <- cmd.ExecuteContext(ctx)
errChan <- cmd.ExecuteContext(notifyCtx)
}()

writeCtx, writeCancel := context.WithTimeout(ctx, 100*time.Millisecond)
Expand Down Expand Up @@ -66,7 +68,7 @@ func TerminatePrompt(ctx context.Context, t *testing.T, cmd *cobra.Command, cli

// sigint and sigterm are caught by the prompt
// this allows us to gracefully exit the prompt with a 0 exit code
syscall.Kill(syscall.Getpid(), syscall.SIGINT)
notifyCancel()

select {
case <-errCtx.Done():
Expand Down

0 comments on commit 47b6915

Please sign in to comment.