diff --git a/cmd/step/main.go b/cmd/step/main.go index fbada2e5b..82777390e 100644 --- a/cmd/step/main.go +++ b/cmd/step/main.go @@ -3,6 +3,7 @@ package main import ( "errors" "fmt" + "io" "os" "reflect" "regexp" @@ -15,10 +16,10 @@ import ( "github.com/smallstep/cli-utils/command" "github.com/smallstep/cli-utils/step" "github.com/smallstep/cli-utils/ui" + "github.com/smallstep/cli-utils/usage" "go.step.sm/crypto/jose" "go.step.sm/crypto/pemutil" - "github.com/smallstep/cli-utils/usage" "github.com/smallstep/cli/command/version" "github.com/smallstep/cli/internal/plugin" "github.com/smallstep/cli/utils" @@ -66,6 +67,42 @@ func main() { defer panicHandler() + // create new instance of app + app := newApp(os.Stdout, os.Stderr) + + if err := app.Run(os.Args); err != nil { + var messenger interface { + Message() string + } + if errors.As(err, &messenger) { + if os.Getenv("STEPDEBUG") == "1" { + fmt.Fprintf(os.Stderr, "%+v\n\n%s", err, messenger.Message()) + } else { + fmt.Fprintln(os.Stderr, messenger.Message()) + fmt.Fprintln(os.Stderr, "Re-run with STEPDEBUG=1 for more info.") + } + } else { + if os.Getenv("STEPDEBUG") == "1" { + fmt.Fprintf(os.Stderr, "%+v\n", err) + } else { + fmt.Fprintln(os.Stderr, err) + } + } + //nolint:gocritic // ignore exitAfterDefer error because the defer is required for recovery. + os.Exit(1) + } +} + +func newApp(stdout, stderr io.Writer) *cli.App { + // Define default file writers and prompters for go.step.sm/crypto + pemutil.WriteFile = utils.WriteFile + pemutil.PromptPassword = func(msg string) ([]byte, error) { + return ui.PromptPassword(msg) + } + jose.PromptPassword = func(msg string) ([]byte, error) { + return ui.PromptPassword(msg) + } + // Override global framework components cli.VersionPrinter = func(c *cli.Context) { version.Command(c) @@ -111,39 +148,10 @@ func main() { } // All non-successful output should be written to stderr - app.Writer = os.Stdout - app.ErrWriter = os.Stderr - - // Define default file writers and prompters for go.step.sm/crypto - pemutil.WriteFile = utils.WriteFile - pemutil.PromptPassword = func(msg string) ([]byte, error) { - return ui.PromptPassword(msg) - } - jose.PromptPassword = func(msg string) ([]byte, error) { - return ui.PromptPassword(msg) - } + app.Writer = stdout + app.ErrWriter = stderr - if err := app.Run(os.Args); err != nil { - var messenger interface { - Message() string - } - if errors.As(err, &messenger) { - if os.Getenv("STEPDEBUG") == "1" { - fmt.Fprintf(os.Stderr, "%+v\n\n%s", err, messenger.Message()) - } else { - fmt.Fprintln(os.Stderr, messenger.Message()) - fmt.Fprintln(os.Stderr, "Re-run with STEPDEBUG=1 for more info.") - } - } else { - if os.Getenv("STEPDEBUG") == "1" { - fmt.Fprintf(os.Stderr, "%+v\n", err) - } else { - fmt.Fprintln(os.Stderr, err) - } - } - //nolint:gocritic // ignore exitAfterDefer error because the defer is required for recovery. - os.Exit(1) - } + return app } func panicHandler() { diff --git a/cmd/step/main_test.go b/cmd/step/main_test.go new file mode 100644 index 000000000..82979e39d --- /dev/null +++ b/cmd/step/main_test.go @@ -0,0 +1,46 @@ +package main + +import ( + "bytes" + "regexp" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestAppHasAllCommands(t *testing.T) { + app := newApp(&bytes.Buffer{}, &bytes.Buffer{}) + require.NotNil(t, app) + + require.Equal(t, "step", app.Name) + require.Equal(t, "step", app.HelpName) + + var names = make([]string, 0, len(app.Commands)) + for _, c := range app.Commands { + names = append(names, c.Name) + } + require.Equal(t, []string{ + "help", "api", "path", "base64", "fileserver", + "certificate", "completion", "context", "crl", + "crypto", "oauth", "version", "ca", "beta", "ssh", + }, names) +} + +const ansi = "[\u001B\u009B][[\\]()#;?]*(?:(?:(?:[a-zA-Z\\d]*(?:;[a-zA-Z\\d]*)*)?\u0007)|(?:(?:\\d{1,4}(?:;\\d{0,4})*)?[\\dA-PRZcf-ntqry=><~]))" + +var ansiRegex = regexp.MustCompile(ansi) + +func TestAppRuns(t *testing.T) { + stdout := &bytes.Buffer{} + stderr := &bytes.Buffer{} + + app := newApp(stdout, stderr) + require.NotNil(t, app) + + err := app.Run([]string{"step"}) + require.NoError(t, err) + require.Empty(t, stderr.Bytes()) + + output := ansiRegex.ReplaceAllString(stdout.String(), "") + require.Contains(t, output, "step -- plumbing for distributed systems") +}