Skip to content

Commit

Permalink
add 'CheckErr' and 'assertNoErr'
Browse files Browse the repository at this point in the history
  • Loading branch information
umarcor committed Dec 29, 2020
1 parent df7929d commit f3a7c61
Show file tree
Hide file tree
Showing 23 changed files with 122 additions and 123 deletions.
9 changes: 1 addition & 8 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -234,21 +234,14 @@ func init() {
rootCmd.AddCommand(initCmd)
}

func er(msg interface{}) {
fmt.Println("Error:", msg)
os.Exit(1)
}

func initConfig() {
if cfgFile != "" {
// Use config file from the flag.
viper.SetConfigFile(cfgFile)
} else {
// Find home directory.
home, err := homedir.Dir()
if err != nil {
er(err)
}
cobra.CheckErr(err)

// Search config in home directory with name ".cobra" (without extension).
viper.AddConfigPath(home)
Expand Down
34 changes: 18 additions & 16 deletions bash_completions_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,9 @@ func runShellCheck(s string) error {
return err
}
go func() {
stdin.Write([]byte(s))
_, err := stdin.Write([]byte(s))
CheckErr(err)

stdin.Close()
}()

Expand All @@ -73,26 +75,26 @@ func TestBashCompletions(t *testing.T) {
Run: emptyRun,
}
rootCmd.Flags().IntP("introot", "i", -1, "help message for flag introot")
rootCmd.MarkFlagRequired("introot")
assertNoErr(t, rootCmd.MarkFlagRequired("introot"))

// Filename.
rootCmd.Flags().String("filename", "", "Enter a filename")
rootCmd.MarkFlagFilename("filename", "json", "yaml", "yml")
assertNoErr(t, rootCmd.MarkFlagFilename("filename", "json", "yaml", "yml"))

// Persistent filename.
rootCmd.PersistentFlags().String("persistent-filename", "", "Enter a filename")
rootCmd.MarkPersistentFlagFilename("persistent-filename")
rootCmd.MarkPersistentFlagRequired("persistent-filename")
assertNoErr(t, rootCmd.MarkPersistentFlagFilename("persistent-filename"))
assertNoErr(t, rootCmd.MarkPersistentFlagRequired("persistent-filename"))

// Filename extensions.
rootCmd.Flags().String("filename-ext", "", "Enter a filename (extension limited)")
rootCmd.MarkFlagFilename("filename-ext")
assertNoErr(t, rootCmd.MarkFlagFilename("filename-ext"))
rootCmd.Flags().String("custom", "", "Enter a filename (extension limited)")
rootCmd.MarkFlagCustom("custom", "__complete_custom")
assertNoErr(t, rootCmd.MarkFlagCustom("custom", "__complete_custom"))

// Subdirectories in a given directory.
rootCmd.Flags().String("theme", "", "theme to use (located in /themes/THEMENAME/)")
rootCmd.Flags().SetAnnotation("theme", BashCompSubdirsInDir, []string{"themes"})
assertNoErr(t, rootCmd.Flags().SetAnnotation("theme", BashCompSubdirsInDir, []string{"themes"}))

// For two word flags check
rootCmd.Flags().StringP("two", "t", "", "this is two word flags")
Expand All @@ -108,9 +110,9 @@ func TestBashCompletions(t *testing.T) {
}

echoCmd.Flags().String("filename", "", "Enter a filename")
echoCmd.MarkFlagFilename("filename", "json", "yaml", "yml")
assertNoErr(t, echoCmd.MarkFlagFilename("filename", "json", "yaml", "yml"))
echoCmd.Flags().String("config", "", "config to use (located in /config/PROFILE/)")
echoCmd.Flags().SetAnnotation("config", BashCompSubdirsInDir, []string{"config"})
assertNoErr(t, echoCmd.Flags().SetAnnotation("config", BashCompSubdirsInDir, []string{"config"}))

printCmd := &Command{
Use: "print [string to print]",
Expand Down Expand Up @@ -148,7 +150,7 @@ func TestBashCompletions(t *testing.T) {
rootCmd.AddCommand(echoCmd, printCmd, deprecatedCmd, colonCmd)

buf := new(bytes.Buffer)
rootCmd.GenBashCompletion(buf)
assertNoErr(t, rootCmd.GenBashCompletion(buf))
output := buf.String()

check(t, output, "_root")
Expand Down Expand Up @@ -215,10 +217,10 @@ func TestBashCompletionHiddenFlag(t *testing.T) {

const flagName = "hiddenFlag"
c.Flags().Bool(flagName, false, "")
c.Flags().MarkHidden(flagName)
assertNoErr(t, c.Flags().MarkHidden(flagName))

buf := new(bytes.Buffer)
c.GenBashCompletion(buf)
assertNoErr(t, c.GenBashCompletion(buf))
output := buf.String()

if strings.Contains(output, flagName) {
Expand All @@ -231,10 +233,10 @@ func TestBashCompletionDeprecatedFlag(t *testing.T) {

const flagName = "deprecated-flag"
c.Flags().Bool(flagName, false, "")
c.Flags().MarkDeprecated(flagName, "use --not-deprecated instead")
assertNoErr(t, c.Flags().MarkDeprecated(flagName, "use --not-deprecated instead"))

buf := new(bytes.Buffer)
c.GenBashCompletion(buf)
assertNoErr(t, c.GenBashCompletion(buf))
output := buf.String()

if strings.Contains(output, flagName) {
Expand All @@ -249,7 +251,7 @@ func TestBashCompletionTraverseChildren(t *testing.T) {
c.Flags().BoolP("bool-flag", "b", false, "bool flag")

buf := new(bytes.Buffer)
c.GenBashCompletion(buf)
assertNoErr(t, c.GenBashCompletion(buf))
output := buf.String()

// check that local nonpersistent flag are not set since we have TraverseChildren set to true
Expand Down
9 changes: 9 additions & 0 deletions cobra.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package cobra
import (
"fmt"
"io"
"os"
"reflect"
"strconv"
"strings"
Expand Down Expand Up @@ -205,3 +206,11 @@ func stringInSlice(a string, list []string) bool {
}
return false
}

// CheckErr prints the msg with the prefix 'Error:' and exits with error code 1. If the msg is nil, it does nothing.
func CheckErr(msg interface{}) {
if msg != nil {
fmt.Fprintln(os.Stderr, "Error:", msg)
os.Exit(1)
}
}
13 changes: 4 additions & 9 deletions cobra/cmd/add.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,11 @@ Example: cobra add server -> resulting in a new cmd/server.go`,

Run: func(cmd *cobra.Command, args []string) {
if len(args) < 1 {
er("add needs a name for the command")
cobra.CheckErr(fmt.Errorf("add needs a name for the command"))
}

wd, err := os.Getwd()
if err != nil {
er(err)
}
cobra.CheckErr(err)

commandName := validateCmdName(args[0])
command := &Command{
Expand All @@ -59,10 +57,7 @@ Example: cobra add server -> resulting in a new cmd/server.go`,
},
}

err = command.Create()
if err != nil {
er(err)
}
cobra.CheckErr(command.Create())

fmt.Printf("%s created at %s\n", command.CmdName, command.AbsolutePath)
},
Expand All @@ -72,7 +67,7 @@ Example: cobra add server -> resulting in a new cmd/server.go`,
func init() {
addCmd.Flags().StringVarP(&packageName, "package", "t", "", "target package name (e.g. github.com/spf13/hugo)")
addCmd.Flags().StringVarP(&parentName, "parent", "p", "rootCmd", "variable name of parent command for this command")
addCmd.Flags().MarkDeprecated("package", "this operation has been removed.")
cobra.CheckErr(addCmd.Flags().MarkDeprecated("package", "this operation has been removed."))
}

// validateCmdName returns source without any dashes and underscore.
Expand Down
6 changes: 2 additions & 4 deletions cobra/cmd/add_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,8 @@ func TestGoldenAddCmd(t *testing.T) {
}
defer os.RemoveAll(command.AbsolutePath)

command.Project.Create()
if err := command.Create(); err != nil {
t.Fatal(err)
}
assertNoErr(t, command.Project.Create())
assertNoErr(t, command.Create())

generatedFile := fmt.Sprintf("%s/cmd/%s.go", command.AbsolutePath, command.CmdName)
goldenFile := fmt.Sprintf("testdata/%s.go.golden", command.CmdName)
Expand Down
14 changes: 4 additions & 10 deletions cobra/cmd/helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,12 @@
package cmd

import (
"fmt"
"os"
"os/exec"
"path/filepath"
"strings"

"github.com/spf13/cobra"
)

var srcPaths []string
Expand All @@ -40,23 +41,16 @@ func init() {
}

out, err := exec.Command(goExecutable, "env", "GOPATH").Output()
if err != nil {
er(err)
}
cobra.CheckErr(err)

toolchainGoPath := strings.TrimSpace(string(out))
goPaths = filepath.SplitList(toolchainGoPath)
if len(goPaths) == 0 {
er("$GOPATH is not set")
cobra.CheckErr("$GOPATH is not set")
}
}
srcPaths = make([]string, 0, len(goPaths))
for _, goPath := range goPaths {
srcPaths = append(srcPaths, filepath.Join(goPath, "src"))
}
}

func er(msg interface{}) {
fmt.Println("Error:", msg)
os.Exit(1)
}
9 changes: 9 additions & 0 deletions cobra/cmd/helpers_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
package cmd

import "testing"

func assertNoErr(t *testing.T, e error) {
if e != nil {
t.Error(e)
}
}
6 changes: 2 additions & 4 deletions cobra/cmd/init.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,17 +39,15 @@ and the appropriate structure for a Cobra-based CLI application.
Run: func(_ *cobra.Command, args []string) {

projectPath, err := initializeProject(args)
if err != nil {
er(err)
}
cobra.CheckErr(err)
fmt.Printf("Your Cobra application is ready at\n%s\n", projectPath)
},
}
)

func init() {
initCmd.Flags().StringVar(&pkgName, "pkg-name", "", "fully qualified pkg name")
initCmd.MarkFlagRequired("pkg-name")
cobra.CheckErr(initCmd.MarkFlagRequired("pkg-name"))
}

func initializeProject(args []string) (string, error) {
Expand Down
2 changes: 1 addition & 1 deletion cobra/cmd/init_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ func TestGoldenInitCmd(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {

initCmd.Flags().Set("pkg-name", tt.pkgName)
assertNoErr(t, initCmd.Flags().Set("pkg-name", tt.pkgName))
viper.Set("useViper", true)
projectPath, err := initializeProject(tt.args)
defer func() {
Expand Down
4 changes: 3 additions & 1 deletion cobra/cmd/licenses.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,11 @@
package cmd

import (
"fmt"
"strings"
"time"

"github.com/spf13/cobra"
"github.com/spf13/viper"
)

Expand Down Expand Up @@ -92,7 +94,7 @@ func copyrightLine() string {
func findLicense(name string) License {
found := matchLicense(name)
if found == "" {
er("unknown license: " + name)
cobra.CheckErr(fmt.Errorf("unknown license: " + name))
}
return Licenses[found]
}
Expand Down
3 changes: 2 additions & 1 deletion cobra/cmd/project.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"os"
"text/template"

"github.com/spf13/cobra"
"github.com/spf13/cobra/cobra/tpl"
)

Expand Down Expand Up @@ -49,7 +50,7 @@ func (p *Project) Create() error {

// create cmd/root.go
if _, err = os.Stat(fmt.Sprintf("%s/cmd", p.AbsolutePath)); os.IsNotExist(err) {
os.Mkdir(fmt.Sprintf("%s/cmd", p.AbsolutePath), 0751)
cobra.CheckErr(os.Mkdir(fmt.Sprintf("%s/cmd", p.AbsolutePath), 0751))
}
rootFile, err := os.Create(fmt.Sprintf("%s/cmd/root.go", p.AbsolutePath))
if err != nil {
Expand Down
8 changes: 3 additions & 5 deletions cobra/cmd/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ func init() {
rootCmd.PersistentFlags().StringP("author", "a", "YOUR NAME", "author name for copyright attribution")
rootCmd.PersistentFlags().StringVarP(&userLicense, "license", "l", "", "name of license for the project")
rootCmd.PersistentFlags().Bool("viper", true, "use Viper for configuration")
viper.BindPFlag("author", rootCmd.PersistentFlags().Lookup("author"))
viper.BindPFlag("useViper", rootCmd.PersistentFlags().Lookup("viper"))
cobra.CheckErr(viper.BindPFlag("author", rootCmd.PersistentFlags().Lookup("author")))
cobra.CheckErr(viper.BindPFlag("useViper", rootCmd.PersistentFlags().Lookup("viper")))
viper.SetDefault("author", "NAME HERE <EMAIL ADDRESS>")
viper.SetDefault("license", "apache")

Expand All @@ -63,9 +63,7 @@ func initConfig() {
} else {
// Find home directory.
home, err := homedir.Dir()
if err != nil {
er(err)
}
cobra.CheckErr(err)

// Search config in home directory with name ".cobra" (without extension).
viper.AddConfigPath(home)
Expand Down
10 changes: 2 additions & 8 deletions cobra/cmd/testdata/root.go.golden
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,7 @@ to quickly create a Cobra application.`,
// Execute adds all child commands to the root command and sets flags appropriately.
// This is called by main.main(). It only needs to happen once to the rootCmd.
func Execute() {
if err := rootCmd.Execute(); err != nil {
fmt.Fprintln(os.Stderr, err)
os.Exit(1)
}
cobra.CheckErr(rootCmd.Execute())
}

func init() {
Expand All @@ -72,10 +69,7 @@ func initConfig() {
} else {
// Find home directory.
home, err := homedir.Dir()
if err != nil {
fmt.Fprintln(os.Stderr, err)
os.Exit(1)
}
cobra.CheckErr(err)

// Search config in home directory with name ".testproject" (without extension).
viper.AddConfigPath(home)
Expand Down
10 changes: 2 additions & 8 deletions cobra/tpl/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,7 @@ to quickly create a Cobra application.` + "`" + `,
// Execute adds all child commands to the root command and sets flags appropriately.
// This is called by main.main(). It only needs to happen once to the rootCmd.
func Execute() {
if err := rootCmd.Execute(); err != nil {
fmt.Fprintln(os.Stderr, err)
os.Exit(1)
}
cobra.CheckErr(rootCmd.Execute())
}
func init() {
Expand Down Expand Up @@ -85,10 +82,7 @@ func initConfig() {
} else {
// Find home directory.
home, err := homedir.Dir()
if err != nil {
fmt.Fprintln(os.Stderr, err)
os.Exit(1)
}
cobra.CheckErr(err)
// Search config in home directory with name ".{{ .AppName }}" (without extension).
viper.AddConfigPath(home)
Expand Down
6 changes: 6 additions & 0 deletions cobra_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,12 @@ import (
"text/template"
)

func assertNoErr(t *testing.T, e error) {
if e != nil {
t.Error(e)
}
}

func TestAddTemplateFunctions(t *testing.T) {
AddTemplateFunc("t", func() bool { return true })
AddTemplateFuncs(template.FuncMap{
Expand Down
Loading

0 comments on commit f3a7c61

Please sign in to comment.