Skip to content

Commit

Permalink
address pr comments
Browse files Browse the repository at this point in the history
  • Loading branch information
verbanicm committed Jun 26, 2023
1 parent f7fe48f commit e98561e
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 57 deletions.
42 changes: 21 additions & 21 deletions cli/command.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,8 @@ import (
"fmt"
"io"
"os"
"os/signal"
"sort"
"strings"
"syscall"

"github.com/mattn/go-isatty"
"github.com/posener/complete/v2"
Expand Down Expand Up @@ -68,10 +66,8 @@ type Command interface {
// Prompt provides a mechanism for asking for user input. It reads from
// [Stdin]. If there's an input stream (e.g. a pipe), it will read the pipe.
// If the terminal is a TTY, it will prompt. Otherwise it will fail if there's
// no pipe and the terminal is not a tty. If a default value is provided and the
// user entered an empty string, the default value will be used instead, otherwise
// it returns the empty string.
Prompt(ctx context.Context, msg string, defaultValue *string) (string, error)
// no pipe and the terminal is not a tty.
Prompt(ctx context.Context, msg string, args ...any) (string, error)

// Stdout returns the stdout stream. SetStdout sets the stdout stream.
Stdout() io.Writer
Expand Down Expand Up @@ -278,40 +274,44 @@ func (c *BaseCommand) Hidden() bool {

// Prompt prompts the user for a value. If stdin is a tty, it prompts. Otherwise
// it reads from the reader.
func (c *BaseCommand) Prompt(ctx context.Context, msg string, defaultValue *string) (string, error) {
// guarantee Prompt is cancelable with SIGTERM and SIGINT
ctx, done := signal.NotifyContext(ctx, syscall.SIGTERM, syscall.SIGINT)
defer done()

resultChan := make(chan string)
errorChan := make(chan error)
func (c *BaseCommand) Prompt(ctx context.Context, msg string, args ...any) (string, error) {
resultCh := make(chan string, 1)
errCh := make(chan error, 1)

// Prompts for user input and returns the value or an error if there is one.
go func() {
defer close(resultCh)
defer close(errCh)

scanner := bufio.NewScanner(io.LimitReader(c.Stdin(), 64*1_000))

if c.Stdin() == os.Stdin && isatty.IsTerminal(os.Stdin.Fd()) {
fmt.Fprint(c.Stdout(), msg)
fmt.Fprint(c.Stdout(), msg, args)
}

scanner.Scan()

if err := scanner.Err(); err != nil {
errorChan <- fmt.Errorf("failed to read stdin: %w", err)
select {
case errCh <- fmt.Errorf("failed to read stdin: %w", err):
default:
}
}

if defaultValue != nil && strings.TrimSpace(scanner.Text()) == "" {
resultChan <- *defaultValue
select {
case resultCh <- scanner.Text():
default:
}

resultChan <- scanner.Text()
}()

// If input is provided and the provided context are cancelled simultaneously,
// the behavior is undefined. See https://www.sethvargo.com/what-id-like-to-see-in-go-2/#deterministic-select.
select {
case <-ctx.Done():
return "", fmt.Errorf("failed to prompt: %w", ctx.Err())
case err := <-errorChan:
case err := <-errCh:
return "", fmt.Errorf("failed to prompt: %w", err)
case result := <-resultChan:
case result := <-resultCh:
return result, nil
}
}
Expand Down
51 changes: 15 additions & 36 deletions cli/command_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -278,39 +278,19 @@ func TestRootCommand_Prompt_Values(t *testing.T) {
return &RootCommand{}
}

defaultValue := "default"

cases := []struct {
name string
args []string
msg string
defaultValue *string
inputValue string
exp string
name string
args []string
msg string
inputValue string
exp string
}{
{
name: "base_success",
args: []string{"prompt"},
msg: "enter input value:",
defaultValue: nil,
inputValue: "input",
exp: "input",
},
{
name: "defaults_input",
args: []string{"prompt"},
msg: "enter default value:",
defaultValue: &defaultValue,
inputValue: "",
exp: "default",
},
{
name: "trims_defaults_spaces",
args: []string{"prompt"},
msg: "enter default value:",
defaultValue: &defaultValue,
inputValue: " ",
exp: "default",
name: "base_success",
args: []string{"prompt"},
msg: "enter input value:",
inputValue: "input",
exp: "input",
},
}

Expand All @@ -324,7 +304,7 @@ func TestRootCommand_Prompt_Values(t *testing.T) {
stdin, _, stderr := cmd.Pipe()
stdin.WriteString(tc.inputValue)

v, err := cmd.Prompt(ctx, tc.msg, tc.defaultValue)
v, err := cmd.Prompt(ctx, tc.msg)
if diff := testutil.DiffErrString(err, ""); diff != "" {
t.Errorf("Unexpected err: %s", diff)
}
Expand All @@ -343,7 +323,7 @@ func TestRootCommand_Prompt_Values(t *testing.T) {
}
}

func TestRootCommand_Prompt_Cancel(t *testing.T) {
func TestRootCommand_Prompt_Cancels(t *testing.T) {
t.Parallel()

ctx := logging.WithLogger(context.Background(), logging.TestLogger(t))
Expand All @@ -358,7 +338,7 @@ func TestRootCommand_Prompt_Cancel(t *testing.T) {
err string
}{
{
name: "contex_times_out",
name: "context_cancels",
msg: "enter value:",
err: "context canceled",
},
Expand All @@ -370,12 +350,11 @@ func TestRootCommand_Prompt_Cancel(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()

cmd := rootCmd()

cancelCtx, cancel := context.WithCancel(ctx)
cancel()

_, err := cmd.Prompt(cancelCtx, tc.msg, nil)
cmd := rootCmd()
_, err := cmd.Prompt(cancelCtx, tc.msg)
if diff := testutil.DiffErrString(err, tc.err); diff != "" {
t.Errorf("Unexpected err: %s", diff)
}
Expand Down

0 comments on commit e98561e

Please sign in to comment.