diff --git a/cmd/cosign/cli/generate_key_pair.go b/cmd/cosign/cli/generate_key_pair.go index d016325d403..ad51a6494c5 100644 --- a/cmd/cosign/cli/generate_key_pair.go +++ b/cmd/cosign/cli/generate_key_pair.go @@ -121,30 +121,11 @@ func GenerateKeyPairCmd(ctx context.Context, kmsVal string, args []string) error } func GetPass(confirm bool) ([]byte, error) { - read := Read() - fmt.Fprint(os.Stderr, "Enter password for private key: ") - pw1, err := read() - fmt.Fprintln(os.Stderr) - if err != nil { - return nil, err - } - if !confirm { - return pw1, nil - } - fmt.Fprint(os.Stderr, "Enter again: ") - pw2, err := read() - fmt.Fprintln(os.Stderr) - if err != nil { - return nil, err - } - - if string(pw1) != string(pw2) { - return nil, errors.New("passwords do not match") - } - return pw1, nil + read := Read(confirm) + return read() } -func readPasswordFn() func() ([]byte, error) { +func readPasswordFn(confirm bool) func() ([]byte, error) { pw, ok := os.LookupEnv("COSIGN_PASSWORD") switch { case ok: @@ -153,7 +134,7 @@ func readPasswordFn() func() ([]byte, error) { } case term.IsTerminal(0): return func() ([]byte, error) { - return term.ReadPassword(0) + return getPassFromTerm(confirm) } // Handle piped in passwords. default: @@ -162,3 +143,25 @@ func readPasswordFn() func() ([]byte, error) { } } } + +func getPassFromTerm(confirm bool) ([]byte, error) { + fmt.Fprint(os.Stderr, "Enter password for private key: ") + pw1, err := term.ReadPassword(0) + if err != nil { + return nil, err + } + if !confirm { + return pw1, nil + } + fmt.Fprint(os.Stderr, "Enter again: ") + pw2, err := term.ReadPassword(0) + fmt.Fprintln(os.Stderr) + if err != nil { + return nil, err + } + + if string(pw1) != string(pw2) { + return nil, errors.New("passwords do not match") + } + return pw1, nil +} diff --git a/cmd/cosign/cli/generate_key_pair_test.go b/cmd/cosign/cli/generate_key_pair_test.go index 4efc663083a..7b594f4c81b 100644 --- a/cmd/cosign/cli/generate_key_pair_test.go +++ b/cmd/cosign/cli/generate_key_pair_test.go @@ -25,7 +25,7 @@ import ( func TestReadPasswordFn_env(t *testing.T) { os.Setenv("COSIGN_PASSWORD", "foo") defer os.Unsetenv("COSIGN_PASSWORD") - b, err := readPasswordFn()() + b, err := readPasswordFn(true)() if err != nil { t.Fatalf("unexpected error: %v", err) } @@ -37,7 +37,7 @@ func TestReadPasswordFn_env(t *testing.T) { func TestReadPasswordFn_envEmptyVal(t *testing.T) { os.Setenv("COSIGN_PASSWORD", "") defer os.Unsetenv("COSIGN_PASSWORD") - b, err := readPasswordFn()() + b, err := readPasswordFn(true)() if err != nil { t.Fatalf("unexpected error: %v", err) }