Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: allow cli prompt to be canceled #143

Merged
merged 5 commits into from
Jun 27, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 37 additions & 10 deletions cli/command.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ type Command interface {
// [Stdin]. If there's an input stream (e.g. a pipe), it will read the pipe.
// If the terminal is a TTY, it will prompt. Otherwise it will fail if there's
// no pipe and the terminal is not a tty.
Prompt(msg string) (string, error)
Prompt(ctx context.Context, msg string, args ...any) (string, error)

// Stdout returns the stdout stream. SetStdout sets the stdout stream.
Stdout() io.Writer
Expand Down Expand Up @@ -274,19 +274,46 @@ func (c *BaseCommand) Hidden() bool {

// Prompt prompts the user for a value. If stdin is a tty, it prompts. Otherwise
// it reads from the reader.
func (c *BaseCommand) Prompt(msg string) (string, error) {
scanner := bufio.NewScanner(io.LimitReader(c.Stdin(), 64*1_000))
func (c *BaseCommand) Prompt(ctx context.Context, msg string, args ...any) (string, error) {
verbanicm marked this conversation as resolved.
Show resolved Hide resolved
resultCh := make(chan string, 1)
errCh := make(chan error, 1)

if c.Stdin() == os.Stdin && isatty.IsTerminal(os.Stdin.Fd()) {
fmt.Fprint(c.Stdout(), msg)
}
// Prompts for user input and returns the value or an error if there is one.
go func() {
defer close(resultCh)
defer close(errCh)

scanner := bufio.NewScanner(io.LimitReader(c.Stdin(), 64*1_000))
verbanicm marked this conversation as resolved.
Show resolved Hide resolved

if c.Stdin() == os.Stdin && isatty.IsTerminal(os.Stdin.Fd()) {
fmt.Fprint(c.Stdout(), msg, args)
}

scanner.Scan()
scanner.Scan()

if err := scanner.Err(); err != nil {
return "", fmt.Errorf("failed to read stdin: %w", err)
if err := scanner.Err(); err != nil {
verbanicm marked this conversation as resolved.
Show resolved Hide resolved
select {
case errCh <- fmt.Errorf("failed to read stdin: %w", err):
default:
}
}

select {
verbanicm marked this conversation as resolved.
Show resolved Hide resolved
case resultCh <- scanner.Text():
default:
}
}()

// If input is provided and the provided context are cancelled simultaneously,
// the behavior is undefined. See https://www.sethvargo.com/what-id-like-to-see-in-go-2/#deterministic-select.
select {
case <-ctx.Done():
return "", fmt.Errorf("failed to prompt: %w", ctx.Err())
case err := <-errCh:
return "", fmt.Errorf("failed to prompt: %w", err)
case result := <-resultCh:
return result, nil
}
return scanner.Text(), nil
}

// Outf is a shortcut to write to [BaseCommand.Stdout].
Expand Down
93 changes: 93 additions & 0 deletions cli/command_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,99 @@ func TestRootCommand_Run(t *testing.T) {
}
}

func TestRootCommand_Prompt_Values(t *testing.T) {
verbanicm marked this conversation as resolved.
Show resolved Hide resolved
t.Parallel()

ctx := logging.WithLogger(context.Background(), logging.TestLogger(t))

rootCmd := func() Command {
return &RootCommand{}
}

cases := []struct {
name string
args []string
msg string
inputValue string
exp string
}{
{
name: "base_success",
args: []string{"prompt"},
msg: "enter input value:",
inputValue: "input",
exp: "input",
},
}

for _, tc := range cases {
tc := tc

t.Run(tc.name, func(t *testing.T) {
t.Parallel()

cmd := rootCmd()
stdin, _, stderr := cmd.Pipe()
stdin.WriteString(tc.inputValue)

v, err := cmd.Prompt(ctx, tc.msg)
if diff := testutil.DiffErrString(err, ""); diff != "" {
t.Errorf("Unexpected err: %s", diff)
}

if err != nil {
t.Fatalf("unexpected error: %s", err)
}

if got, want := v, tc.exp; got != want {
t.Errorf("expected\n\n%s\n\nto equal\n\n%s\n\n", got, want)
}
if got := strings.TrimSpace(stderr.String()); got != "" {
t.Errorf("unexpected stderr: %s", got)
}
})
}
}

func TestRootCommand_Prompt_Cancels(t *testing.T) {
verbanicm marked this conversation as resolved.
Show resolved Hide resolved
t.Parallel()

ctx := logging.WithLogger(context.Background(), logging.TestLogger(t))

rootCmd := func() Command {
return &RootCommand{}
}

cases := []struct {
name string
msg string
err string
}{
{
name: "context_cancels",
msg: "enter value:",
err: "context canceled",
},
}

for _, tc := range cases {
tc := tc

t.Run(tc.name, func(t *testing.T) {
t.Parallel()

cancelCtx, cancel := context.WithCancel(ctx)
cancel()

cmd := rootCmd()
_, err := cmd.Prompt(cancelCtx, tc.msg)
if diff := testutil.DiffErrString(err, tc.err); diff != "" {
t.Errorf("Unexpected err: %s", diff)
}
})
}
}

type TestCommand struct {
BaseCommand

Expand Down