From 4f31c708d8ed592f5bd1cb04497a504861cb3afb Mon Sep 17 00:00:00 2001 From: Mike Verbanic Date: Sun, 25 Jun 2023 10:03:09 -0400 Subject: [PATCH 1/5] fix: allow cli prompt to be canceled and add default value --- cli/command.go | 48 ++++++++++++++----- cli/command_test.go | 114 ++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 151 insertions(+), 11 deletions(-) diff --git a/cli/command.go b/cli/command.go index 15dec233..19633a6f 100644 --- a/cli/command.go +++ b/cli/command.go @@ -23,8 +23,10 @@ import ( "fmt" "io" "os" + "os/signal" "sort" "strings" + "syscall" "github.com/mattn/go-isatty" "github.com/posener/complete/v2" @@ -66,8 +68,10 @@ type Command interface { // Prompt provides a mechanism for asking for user input. It reads from // [Stdin]. If there's an input stream (e.g. a pipe), it will read the pipe. // If the terminal is a TTY, it will prompt. Otherwise it will fail if there's - // no pipe and the terminal is not a tty. - Prompt(msg string) (string, error) + // no pipe and the terminal is not a tty. If a default value is provided and the + // user entered an empty string, the default value will be used instead, otherwise + // it returns the empty string. + Prompt(ctx context.Context, msg string, defaultValue *string) (string, error) // Stdout returns the stdout stream. SetStdout sets the stdout stream. Stdout() io.Writer @@ -274,19 +278,41 @@ func (c *BaseCommand) Hidden() bool { // Prompt prompts the user for a value. If stdin is a tty, it prompts. Otherwise // it reads from the reader. -func (c *BaseCommand) Prompt(msg string) (string, error) { - scanner := bufio.NewScanner(io.LimitReader(c.Stdin(), 64*1_000)) +func (c *BaseCommand) Prompt(ctx context.Context, msg string, defaultValue *string) (string, error) { + // guarantee Prompt is cancelable with SIGTERM and SIGINT + ctx, _ = signal.NotifyContext(ctx, syscall.SIGTERM, syscall.SIGINT) - if c.Stdin() == os.Stdin && isatty.IsTerminal(os.Stdin.Fd()) { - fmt.Fprint(c.Stdout(), msg) - } + resultChan := make(chan string) + errorChan := make(chan error) + + go func() { + scanner := bufio.NewScanner(io.LimitReader(c.Stdin(), 64*1_000)) + + if c.Stdin() == os.Stdin && isatty.IsTerminal(os.Stdin.Fd()) { + fmt.Fprint(c.Stdout(), msg) + } + + scanner.Scan() + + if err := scanner.Err(); err != nil { + errorChan <- fmt.Errorf("failed to read stdin: %w", err) + } + + if defaultValue != nil && strings.TrimSpace(scanner.Text()) == "" { + resultChan <- *defaultValue + } - scanner.Scan() + resultChan <- scanner.Text() + }() - if err := scanner.Err(); err != nil { - return "", fmt.Errorf("failed to read stdin: %w", err) + select { + case <-ctx.Done(): + return "", ctx.Err() + case err := <-errorChan: + return "", err + case result := <-resultChan: + return result, nil } - return scanner.Text(), nil } // Outf is a shortcut to write to [BaseCommand.Stdout]. diff --git a/cli/command_test.go b/cli/command_test.go index cf88c995..587e640e 100644 --- a/cli/command_test.go +++ b/cli/command_test.go @@ -269,6 +269,120 @@ func TestRootCommand_Run(t *testing.T) { } } +func TestRootCommand_Prompt_Values(t *testing.T) { + t.Parallel() + + ctx := logging.WithLogger(context.Background(), logging.TestLogger(t)) + + rootCmd := func() Command { + return &RootCommand{} + } + + defaultValue := "default" + + cases := []struct { + name string + args []string + msg string + defaultValue *string + inputValue string + exp string + }{ + { + name: "base_success", + args: []string{"prompt"}, + msg: "enter input value:", + defaultValue: nil, + inputValue: "input", + exp: "input", + }, + { + name: "defaults_input", + args: []string{"prompt"}, + msg: "enter default value:", + defaultValue: &defaultValue, + inputValue: "", + exp: "default", + }, + { + name: "trims_defaults_spaces", + args: []string{"prompt"}, + msg: "enter default value:", + defaultValue: &defaultValue, + inputValue: " ", + exp: "default", + }, + } + + for _, tc := range cases { + tc := tc + + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + cmd := rootCmd() + stdin, _, stderr := cmd.Pipe() + stdin.WriteString(tc.inputValue) + + v, err := cmd.Prompt(ctx, tc.msg, tc.defaultValue) + if diff := testutil.DiffErrString(err, ""); diff != "" { + t.Errorf("Unexpected err: %s", diff) + } + + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + + if got, want := v, tc.exp; got != want { + t.Errorf("expected\n\n%s\n\nto equal\n\n%s\n\n", got, want) + } + if got := strings.TrimSpace(stderr.String()); got != "" { + t.Errorf("unexpected stderr: %s", got) + } + }) + } +} + +func TestRootCommand_Prompt_Cancel(t *testing.T) { + t.Parallel() + + ctx := logging.WithLogger(context.Background(), logging.TestLogger(t)) + + rootCmd := func() Command { + return &RootCommand{} + } + + cases := []struct { + name string + msg string + err string + }{ + { + name: "contex_times_out", + msg: "enter value:", + err: "context canceled", + }, + } + + for _, tc := range cases { + tc := tc + + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + cmd := rootCmd() + + cancelCtx, cancel := context.WithCancel(ctx) + cancel() + + _, err := cmd.Prompt(cancelCtx, tc.msg, nil) + if diff := testutil.DiffErrString(err, tc.err); diff != "" { + t.Errorf("Unexpected err: %s", diff) + } + }) + } +} + type TestCommand struct { BaseCommand From f7fe48f051713d4b276b713fd8117de54737c1ba Mon Sep 17 00:00:00 2001 From: Mike Verbanic Date: Mon, 26 Jun 2023 09:11:25 -0400 Subject: [PATCH 2/5] address pr comments --- cli/command.go | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/cli/command.go b/cli/command.go index 19633a6f..fff810ca 100644 --- a/cli/command.go +++ b/cli/command.go @@ -280,7 +280,8 @@ func (c *BaseCommand) Hidden() bool { // it reads from the reader. func (c *BaseCommand) Prompt(ctx context.Context, msg string, defaultValue *string) (string, error) { // guarantee Prompt is cancelable with SIGTERM and SIGINT - ctx, _ = signal.NotifyContext(ctx, syscall.SIGTERM, syscall.SIGINT) + ctx, done := signal.NotifyContext(ctx, syscall.SIGTERM, syscall.SIGINT) + defer done() resultChan := make(chan string) errorChan := make(chan error) @@ -307,9 +308,9 @@ func (c *BaseCommand) Prompt(ctx context.Context, msg string, defaultValue *stri select { case <-ctx.Done(): - return "", ctx.Err() + return "", fmt.Errorf("failed to prompt: %w", ctx.Err()) case err := <-errorChan: - return "", err + return "", fmt.Errorf("failed to prompt: %w", err) case result := <-resultChan: return result, nil } From e98561e99f4d7942507515bcf51eab07907b1cdf Mon Sep 17 00:00:00 2001 From: Mike Verbanic Date: Mon, 26 Jun 2023 10:35:36 -0400 Subject: [PATCH 3/5] address pr comments --- cli/command.go | 42 ++++++++++++++++++------------------- cli/command_test.go | 51 +++++++++++++-------------------------------- 2 files changed, 36 insertions(+), 57 deletions(-) diff --git a/cli/command.go b/cli/command.go index fff810ca..57c0057d 100644 --- a/cli/command.go +++ b/cli/command.go @@ -23,10 +23,8 @@ import ( "fmt" "io" "os" - "os/signal" "sort" "strings" - "syscall" "github.com/mattn/go-isatty" "github.com/posener/complete/v2" @@ -68,10 +66,8 @@ type Command interface { // Prompt provides a mechanism for asking for user input. It reads from // [Stdin]. If there's an input stream (e.g. a pipe), it will read the pipe. // If the terminal is a TTY, it will prompt. Otherwise it will fail if there's - // no pipe and the terminal is not a tty. If a default value is provided and the - // user entered an empty string, the default value will be used instead, otherwise - // it returns the empty string. - Prompt(ctx context.Context, msg string, defaultValue *string) (string, error) + // no pipe and the terminal is not a tty. + Prompt(ctx context.Context, msg string, args ...any) (string, error) // Stdout returns the stdout stream. SetStdout sets the stdout stream. Stdout() io.Writer @@ -278,40 +274,44 @@ func (c *BaseCommand) Hidden() bool { // Prompt prompts the user for a value. If stdin is a tty, it prompts. Otherwise // it reads from the reader. -func (c *BaseCommand) Prompt(ctx context.Context, msg string, defaultValue *string) (string, error) { - // guarantee Prompt is cancelable with SIGTERM and SIGINT - ctx, done := signal.NotifyContext(ctx, syscall.SIGTERM, syscall.SIGINT) - defer done() - - resultChan := make(chan string) - errorChan := make(chan error) +func (c *BaseCommand) Prompt(ctx context.Context, msg string, args ...any) (string, error) { + resultCh := make(chan string, 1) + errCh := make(chan error, 1) + // Prompts for user input and returns the value or an error if there is one. go func() { + defer close(resultCh) + defer close(errCh) + scanner := bufio.NewScanner(io.LimitReader(c.Stdin(), 64*1_000)) if c.Stdin() == os.Stdin && isatty.IsTerminal(os.Stdin.Fd()) { - fmt.Fprint(c.Stdout(), msg) + fmt.Fprint(c.Stdout(), msg, args) } scanner.Scan() if err := scanner.Err(); err != nil { - errorChan <- fmt.Errorf("failed to read stdin: %w", err) + select { + case errCh <- fmt.Errorf("failed to read stdin: %w", err): + default: + } } - if defaultValue != nil && strings.TrimSpace(scanner.Text()) == "" { - resultChan <- *defaultValue + select { + case resultCh <- scanner.Text(): + default: } - - resultChan <- scanner.Text() }() + // If input is provided and the provided context are cancelled simultaneously, + // the behavior is undefined. See https://www.sethvargo.com/what-id-like-to-see-in-go-2/#deterministic-select. select { case <-ctx.Done(): return "", fmt.Errorf("failed to prompt: %w", ctx.Err()) - case err := <-errorChan: + case err := <-errCh: return "", fmt.Errorf("failed to prompt: %w", err) - case result := <-resultChan: + case result := <-resultCh: return result, nil } } diff --git a/cli/command_test.go b/cli/command_test.go index 587e640e..af80f3cf 100644 --- a/cli/command_test.go +++ b/cli/command_test.go @@ -278,39 +278,19 @@ func TestRootCommand_Prompt_Values(t *testing.T) { return &RootCommand{} } - defaultValue := "default" - cases := []struct { - name string - args []string - msg string - defaultValue *string - inputValue string - exp string + name string + args []string + msg string + inputValue string + exp string }{ { - name: "base_success", - args: []string{"prompt"}, - msg: "enter input value:", - defaultValue: nil, - inputValue: "input", - exp: "input", - }, - { - name: "defaults_input", - args: []string{"prompt"}, - msg: "enter default value:", - defaultValue: &defaultValue, - inputValue: "", - exp: "default", - }, - { - name: "trims_defaults_spaces", - args: []string{"prompt"}, - msg: "enter default value:", - defaultValue: &defaultValue, - inputValue: " ", - exp: "default", + name: "base_success", + args: []string{"prompt"}, + msg: "enter input value:", + inputValue: "input", + exp: "input", }, } @@ -324,7 +304,7 @@ func TestRootCommand_Prompt_Values(t *testing.T) { stdin, _, stderr := cmd.Pipe() stdin.WriteString(tc.inputValue) - v, err := cmd.Prompt(ctx, tc.msg, tc.defaultValue) + v, err := cmd.Prompt(ctx, tc.msg) if diff := testutil.DiffErrString(err, ""); diff != "" { t.Errorf("Unexpected err: %s", diff) } @@ -343,7 +323,7 @@ func TestRootCommand_Prompt_Values(t *testing.T) { } } -func TestRootCommand_Prompt_Cancel(t *testing.T) { +func TestRootCommand_Prompt_Cancels(t *testing.T) { t.Parallel() ctx := logging.WithLogger(context.Background(), logging.TestLogger(t)) @@ -358,7 +338,7 @@ func TestRootCommand_Prompt_Cancel(t *testing.T) { err string }{ { - name: "contex_times_out", + name: "context_cancels", msg: "enter value:", err: "context canceled", }, @@ -370,12 +350,11 @@ func TestRootCommand_Prompt_Cancel(t *testing.T) { t.Run(tc.name, func(t *testing.T) { t.Parallel() - cmd := rootCmd() - cancelCtx, cancel := context.WithCancel(ctx) cancel() - _, err := cmd.Prompt(cancelCtx, tc.msg, nil) + cmd := rootCmd() + _, err := cmd.Prompt(cancelCtx, tc.msg) if diff := testutil.DiffErrString(err, tc.err); diff != "" { t.Errorf("Unexpected err: %s", diff) } From b11b43275112c3d57ee3be01bac754cea7320c81 Mon Sep 17 00:00:00 2001 From: Mike Verbanic Date: Mon, 26 Jun 2023 14:38:04 -0400 Subject: [PATCH 4/5] address pr comments --- cli/command.go | 3 ++ cli/command_test.go | 92 +++++++++++++++++++++++++++++---------------- 2 files changed, 62 insertions(+), 33 deletions(-) diff --git a/cli/command.go b/cli/command.go index 57c0057d..1195cffd 100644 --- a/cli/command.go +++ b/cli/command.go @@ -279,6 +279,9 @@ func (c *BaseCommand) Prompt(ctx context.Context, msg string, args ...any) (stri errCh := make(chan error, 1) // Prompts for user input and returns the value or an error if there is one. + // If the context is canceled, this function leaves the c.Stdin in a bad state. + // This is currently acceptable as the cli process should ideallly be terminated + // after the context is canceled. go func() { defer close(resultCh) defer close(errCh) diff --git a/cli/command_test.go b/cli/command_test.go index af80f3cf..1b31c1ee 100644 --- a/cli/command_test.go +++ b/cli/command_test.go @@ -279,18 +279,23 @@ func TestRootCommand_Prompt_Values(t *testing.T) { } cases := []struct { - name string - args []string - msg string - inputValue string - exp string + name string + args []string + msg string + inputs []string + err string + expStderr string }{ { - name: "base_success", - args: []string{"prompt"}, - msg: "enter input value:", - inputValue: "input", - exp: "input", + name: "sets_input_value", + args: []string{"prompt"}, + msg: "enter input value:", + inputs: []string{"input"}, + }, { + name: "handles_multiple_prompts", + args: []string{"prompt"}, + msg: "enter input value:", + inputs: []string{"input1", "input2"}, }, } @@ -302,22 +307,20 @@ func TestRootCommand_Prompt_Values(t *testing.T) { cmd := rootCmd() stdin, _, stderr := cmd.Pipe() - stdin.WriteString(tc.inputValue) - - v, err := cmd.Prompt(ctx, tc.msg) - if diff := testutil.DiffErrString(err, ""); diff != "" { - t.Errorf("Unexpected err: %s", diff) - } - if err != nil { - t.Fatalf("unexpected error: %s", err) - } - - if got, want := v, tc.exp; got != want { - t.Errorf("expected\n\n%s\n\nto equal\n\n%s\n\n", got, want) - } - if got := strings.TrimSpace(stderr.String()); got != "" { - t.Errorf("unexpected stderr: %s", got) + for _, input := range tc.inputs { + stdin.WriteString(input) + + v, err := cmd.Prompt(ctx, tc.msg) + if diff := testutil.DiffErrString(err, tc.err); diff != "" { + t.Errorf("unexpected err: %s", diff) + } + if got, want := v, input; got != want { + t.Errorf("expected\n\n%s\n\nto equal\n\n%s\n\n", got, want) + } + if got, want := strings.TrimSpace(stderr.String()), strings.TrimSpace(tc.expStderr); !strings.Contains(got, want) { + t.Errorf("expected\n\n%s\n\nto contain\n\n%s\n\n", got, want) + } } }) } @@ -333,14 +336,20 @@ func TestRootCommand_Prompt_Cancels(t *testing.T) { } cases := []struct { - name string - msg string - err string + name string + args []string + msg string + inputs []string + exp string + err string + expStderr string }{ { - name: "context_cancels", - msg: "enter value:", - err: "context canceled", + name: "context_cancels", + args: []string{"prompt"}, + msg: "enter value:", + inputs: []string{"input1", "input2"}, + err: "context canceled", }, } @@ -350,13 +359,30 @@ func TestRootCommand_Prompt_Cancels(t *testing.T) { t.Run(tc.name, func(t *testing.T) { t.Parallel() + cmd := rootCmd() + stdin, _, stderr := cmd.Pipe() + + for _, input := range tc.inputs { + stdin.WriteString(input) + + _, err := cmd.Prompt(ctx, tc.msg) + if diff := testutil.DiffErrString(err, ""); diff != "" { + t.Errorf("unexpected err: %s", diff) + } + if got, want := strings.TrimSpace(stderr.String()), strings.TrimSpace(tc.expStderr); !strings.Contains(got, want) { + t.Errorf("expected\n\n%s\n\nto contain\n\n%s\n\n", got, want) + } + } + cancelCtx, cancel := context.WithCancel(ctx) cancel() - cmd := rootCmd() _, err := cmd.Prompt(cancelCtx, tc.msg) if diff := testutil.DiffErrString(err, tc.err); diff != "" { - t.Errorf("Unexpected err: %s", diff) + t.Errorf("unexpected err: %s", diff) + } + if got, want := strings.TrimSpace(stderr.String()), strings.TrimSpace(tc.expStderr); !strings.Contains(got, want) { + t.Errorf("expected\n\n%s\n\nto contain\n\n%s\n\n", got, want) } }) } From da1611411aa02afaba50bfa5c6055cc273608e1d Mon Sep 17 00:00:00 2001 From: Mike Verbanic Date: Mon, 26 Jun 2023 22:11:08 -0400 Subject: [PATCH 5/5] pr comments --- cli/command.go | 57 ++++++++++++++------------------------------- cli/command_test.go | 4 ++-- 2 files changed, 20 insertions(+), 41 deletions(-) diff --git a/cli/command.go b/cli/command.go index 1195cffd..006bf0e6 100644 --- a/cli/command.go +++ b/cli/command.go @@ -63,10 +63,7 @@ type Command interface { // Run executes the command. Run(ctx context.Context, args []string) error - // Prompt provides a mechanism for asking for user input. It reads from - // [Stdin]. If there's an input stream (e.g. a pipe), it will read the pipe. - // If the terminal is a TTY, it will prompt. Otherwise it will fail if there's - // no pipe and the terminal is not a tty. + // Prompt provides a mechanism for asking for user input. Prompt(ctx context.Context, msg string, args ...any) (string, error) // Stdout returns the stdout stream. SetStdout sets the stdout stream. @@ -272,51 +269,33 @@ func (c *BaseCommand) Hidden() bool { return false } -// Prompt prompts the user for a value. If stdin is a tty, it prompts. Otherwise -// it reads from the reader. +// Prompt prompts the user for a value. It reads from [Stdin], up to 64k bytes. +// If there's an input stream (e.g. a pipe), it will read the pipe. +// If the terminal is a TTY, it will prompt. Otherwise it will fail if there's +// no pipe and the terminal is not a tty. If the context is canceled, this function +// leaves the c.Stdin in a bad state. func (c *BaseCommand) Prompt(ctx context.Context, msg string, args ...any) (string, error) { - resultCh := make(chan string, 1) - errCh := make(chan error, 1) + if c.Stdin() == os.Stdin && isatty.IsTerminal(os.Stdin.Fd()) { + fmt.Fprint(c.Stdout(), msg, args) + } - // Prompts for user input and returns the value or an error if there is one. - // If the context is canceled, this function leaves the c.Stdin in a bad state. - // This is currently acceptable as the cli process should ideallly be terminated - // after the context is canceled. + scanner := bufio.NewScanner(io.LimitReader(c.Stdin(), 64*1_000)) + finished := make(chan struct{}) go func() { - defer close(resultCh) - defer close(errCh) - - scanner := bufio.NewScanner(io.LimitReader(c.Stdin(), 64*1_000)) - - if c.Stdin() == os.Stdin && isatty.IsTerminal(os.Stdin.Fd()) { - fmt.Fprint(c.Stdout(), msg, args) - } - + defer close(finished) scanner.Scan() - - if err := scanner.Err(); err != nil { - select { - case errCh <- fmt.Errorf("failed to read stdin: %w", err): - default: - } - } - - select { - case resultCh <- scanner.Text(): - default: - } }() - // If input is provided and the provided context are cancelled simultaneously, - // the behavior is undefined. See https://www.sethvargo.com/what-id-like-to-see-in-go-2/#deterministic-select. select { case <-ctx.Done(): return "", fmt.Errorf("failed to prompt: %w", ctx.Err()) - case err := <-errCh: - return "", fmt.Errorf("failed to prompt: %w", err) - case result := <-resultCh: - return result, nil + case <-finished: + } + + if err := scanner.Err(); err != nil { + return "", fmt.Errorf("failed to read stdin: %w", err) } + return scanner.Text(), nil } // Outf is a shortcut to write to [BaseCommand.Stdout]. diff --git a/cli/command_test.go b/cli/command_test.go index 1b31c1ee..97d2d5d7 100644 --- a/cli/command_test.go +++ b/cli/command_test.go @@ -269,7 +269,7 @@ func TestRootCommand_Run(t *testing.T) { } } -func TestRootCommand_Prompt_Values(t *testing.T) { +func TestBaseCommand_Prompt_Values(t *testing.T) { t.Parallel() ctx := logging.WithLogger(context.Background(), logging.TestLogger(t)) @@ -326,7 +326,7 @@ func TestRootCommand_Prompt_Values(t *testing.T) { } } -func TestRootCommand_Prompt_Cancels(t *testing.T) { +func TestBaseCommand_Prompt_Cancels(t *testing.T) { t.Parallel() ctx := logging.WithLogger(context.Background(), logging.TestLogger(t))