diff --git a/cli/command.go b/cli/command.go index fff810ca..57c0057d 100644 --- a/cli/command.go +++ b/cli/command.go @@ -23,10 +23,8 @@ import ( "fmt" "io" "os" - "os/signal" "sort" "strings" - "syscall" "github.com/mattn/go-isatty" "github.com/posener/complete/v2" @@ -68,10 +66,8 @@ type Command interface { // Prompt provides a mechanism for asking for user input. It reads from // [Stdin]. If there's an input stream (e.g. a pipe), it will read the pipe. // If the terminal is a TTY, it will prompt. Otherwise it will fail if there's - // no pipe and the terminal is not a tty. If a default value is provided and the - // user entered an empty string, the default value will be used instead, otherwise - // it returns the empty string. - Prompt(ctx context.Context, msg string, defaultValue *string) (string, error) + // no pipe and the terminal is not a tty. + Prompt(ctx context.Context, msg string, args ...any) (string, error) // Stdout returns the stdout stream. SetStdout sets the stdout stream. Stdout() io.Writer @@ -278,40 +274,44 @@ func (c *BaseCommand) Hidden() bool { // Prompt prompts the user for a value. If stdin is a tty, it prompts. Otherwise // it reads from the reader. -func (c *BaseCommand) Prompt(ctx context.Context, msg string, defaultValue *string) (string, error) { - // guarantee Prompt is cancelable with SIGTERM and SIGINT - ctx, done := signal.NotifyContext(ctx, syscall.SIGTERM, syscall.SIGINT) - defer done() - - resultChan := make(chan string) - errorChan := make(chan error) +func (c *BaseCommand) Prompt(ctx context.Context, msg string, args ...any) (string, error) { + resultCh := make(chan string, 1) + errCh := make(chan error, 1) + // Prompts for user input and returns the value or an error if there is one. go func() { + defer close(resultCh) + defer close(errCh) + scanner := bufio.NewScanner(io.LimitReader(c.Stdin(), 64*1_000)) if c.Stdin() == os.Stdin && isatty.IsTerminal(os.Stdin.Fd()) { - fmt.Fprint(c.Stdout(), msg) + fmt.Fprint(c.Stdout(), msg, args) } scanner.Scan() if err := scanner.Err(); err != nil { - errorChan <- fmt.Errorf("failed to read stdin: %w", err) + select { + case errCh <- fmt.Errorf("failed to read stdin: %w", err): + default: + } } - if defaultValue != nil && strings.TrimSpace(scanner.Text()) == "" { - resultChan <- *defaultValue + select { + case resultCh <- scanner.Text(): + default: } - - resultChan <- scanner.Text() }() + // If input is provided and the provided context are cancelled simultaneously, + // the behavior is undefined. See https://www.sethvargo.com/what-id-like-to-see-in-go-2/#deterministic-select. select { case <-ctx.Done(): return "", fmt.Errorf("failed to prompt: %w", ctx.Err()) - case err := <-errorChan: + case err := <-errCh: return "", fmt.Errorf("failed to prompt: %w", err) - case result := <-resultChan: + case result := <-resultCh: return result, nil } } diff --git a/cli/command_test.go b/cli/command_test.go index 587e640e..af80f3cf 100644 --- a/cli/command_test.go +++ b/cli/command_test.go @@ -278,39 +278,19 @@ func TestRootCommand_Prompt_Values(t *testing.T) { return &RootCommand{} } - defaultValue := "default" - cases := []struct { - name string - args []string - msg string - defaultValue *string - inputValue string - exp string + name string + args []string + msg string + inputValue string + exp string }{ { - name: "base_success", - args: []string{"prompt"}, - msg: "enter input value:", - defaultValue: nil, - inputValue: "input", - exp: "input", - }, - { - name: "defaults_input", - args: []string{"prompt"}, - msg: "enter default value:", - defaultValue: &defaultValue, - inputValue: "", - exp: "default", - }, - { - name: "trims_defaults_spaces", - args: []string{"prompt"}, - msg: "enter default value:", - defaultValue: &defaultValue, - inputValue: " ", - exp: "default", + name: "base_success", + args: []string{"prompt"}, + msg: "enter input value:", + inputValue: "input", + exp: "input", }, } @@ -324,7 +304,7 @@ func TestRootCommand_Prompt_Values(t *testing.T) { stdin, _, stderr := cmd.Pipe() stdin.WriteString(tc.inputValue) - v, err := cmd.Prompt(ctx, tc.msg, tc.defaultValue) + v, err := cmd.Prompt(ctx, tc.msg) if diff := testutil.DiffErrString(err, ""); diff != "" { t.Errorf("Unexpected err: %s", diff) } @@ -343,7 +323,7 @@ func TestRootCommand_Prompt_Values(t *testing.T) { } } -func TestRootCommand_Prompt_Cancel(t *testing.T) { +func TestRootCommand_Prompt_Cancels(t *testing.T) { t.Parallel() ctx := logging.WithLogger(context.Background(), logging.TestLogger(t)) @@ -358,7 +338,7 @@ func TestRootCommand_Prompt_Cancel(t *testing.T) { err string }{ { - name: "contex_times_out", + name: "context_cancels", msg: "enter value:", err: "context canceled", }, @@ -370,12 +350,11 @@ func TestRootCommand_Prompt_Cancel(t *testing.T) { t.Run(tc.name, func(t *testing.T) { t.Parallel() - cmd := rootCmd() - cancelCtx, cancel := context.WithCancel(ctx) cancel() - _, err := cmd.Prompt(cancelCtx, tc.msg, nil) + cmd := rootCmd() + _, err := cmd.Prompt(cancelCtx, tc.msg) if diff := testutil.DiffErrString(err, tc.err); diff != "" { t.Errorf("Unexpected err: %s", diff) }