diff --git a/app.go b/app.go index 10198f4332..04aa725eb7 100644 --- a/app.go +++ b/app.go @@ -121,7 +121,8 @@ type App struct { // Treat all flags as normal arguments if true SkipFlagParsing bool - didSetup bool + didSetup bool + separator separatorSpec rootCommand *Command } @@ -216,6 +217,16 @@ func (a *App) Setup() { }) } + if len(a.SliceFlagSeparator) != 0 { + a.separator.customized = true + a.separator.sep = a.SliceFlagSeparator + } + + if a.DisableSliceFlagSeparator { + a.separator.customized = true + a.separator.disabled = true + } + var newCommands []*Command for _, c := range a.Commands { @@ -223,8 +234,8 @@ func (a *App) Setup() { if c.HelpName != "" { cname = c.HelpName } + c.separator = a.separator c.HelpName = fmt.Sprintf("%s %s", a.HelpName, cname) - c.flagCategories = newFlagCategoriesFromFlags(c.Flags) newCommands = append(newCommands, c) } @@ -262,12 +273,6 @@ func (a *App) Setup() { if a.Metadata == nil { a.Metadata = make(map[string]interface{}) } - - if len(a.SliceFlagSeparator) != 0 { - defaultSliceFlagSeparator = a.SliceFlagSeparator - } - - disableSliceFlagSeparator = a.DisableSliceFlagSeparator } func (a *App) newRootCommand() *Command { @@ -293,11 +298,12 @@ func (a *App) newRootCommand() *Command { categories: a.categories, SkipFlagParsing: a.SkipFlagParsing, isRoot: true, + separator: a.separator, } } func (a *App) newFlagSet() (*flag.FlagSet, error) { - return flagSet(a.Name, a.Flags) + return flagSet(a.Name, a.Flags, a.separator) } func (a *App) useShortOptionHandling() bool { diff --git a/app_test.go b/app_test.go index 7f2a2144e5..c409eb9c95 100644 --- a/app_test.go +++ b/app_test.go @@ -2532,7 +2532,7 @@ func TestCustomHelpVersionFlags(t *testing.T) { func TestHandleExitCoder_Default(t *testing.T) { app := newTestApp() - fs, err := flagSet(app.Name, app.Flags) + fs, err := flagSet(app.Name, app.Flags, separatorSpec{}) if err != nil { t.Errorf("error creating FlagSet: %s", err) } @@ -2548,7 +2548,7 @@ func TestHandleExitCoder_Default(t *testing.T) { func TestHandleExitCoder_Custom(t *testing.T) { app := newTestApp() - fs, err := flagSet(app.Name, app.Flags) + fs, err := flagSet(app.Name, app.Flags, separatorSpec{}) if err != nil { t.Errorf("error creating FlagSet: %s", err) } diff --git a/command.go b/command.go index da9cf5302a..f978b4a43e 100644 --- a/command.go +++ b/command.go @@ -69,6 +69,8 @@ type Command struct { // if this is a root "special" command isRoot bool + + separator separatorSpec } type Commands []*Command @@ -275,7 +277,7 @@ func (c *Command) Run(cCtx *Context, arguments ...string) (err error) { } func (c *Command) newFlagSet() (*flag.FlagSet, error) { - return flagSet(c.Name, c.Flags) + return flagSet(c.Name, c.Flags, c.separator) } func (c *Command) useShortOptionHandling() bool { diff --git a/flag-spec.yaml b/flag-spec.yaml index cfe47df06d..ed4db985d0 100644 --- a/flag-spec.yaml +++ b/flag-spec.yaml @@ -20,6 +20,8 @@ flag_types: skip_interfaces: - fmt.Stringer struct_fields: + - name: separator + type: separatorSpec - name: Action type: "func(*Context, []float64) error" int: @@ -33,6 +35,8 @@ flag_types: skip_interfaces: - fmt.Stringer struct_fields: + - name: separator + type: separatorSpec - name: Action type: "func(*Context, []int) error" int64: @@ -46,6 +50,8 @@ flag_types: skip_interfaces: - fmt.Stringer struct_fields: + - name: separator + type: separatorSpec - name: Action type: "func(*Context, []int64) error" uint: @@ -59,6 +65,8 @@ flag_types: skip_interfaces: - fmt.Stringer struct_fields: + - name: separator + type: separatorSpec - name: Action type: "func(*Context, []uint) error" uint64: @@ -72,6 +80,8 @@ flag_types: skip_interfaces: - fmt.Stringer struct_fields: + - name: separator + type: separatorSpec - name: Action type: "func(*Context, []uint64) error" string: @@ -85,6 +95,8 @@ flag_types: skip_interfaces: - fmt.Stringer struct_fields: + - name: separator + type: separatorSpec - name: TakesFile type: bool - name: Action diff --git a/flag.go b/flag.go index 5260f7f9e2..fc3744ec8e 100644 --- a/flag.go +++ b/flag.go @@ -15,7 +15,7 @@ import ( const defaultPlaceholder = "value" -var ( +const ( defaultSliceFlagSeparator = "," disableSliceFlagSeparator = false ) @@ -167,10 +167,13 @@ type Countable interface { Count() int } -func flagSet(name string, flags []Flag) (*flag.FlagSet, error) { +func flagSet(name string, flags []Flag, spec separatorSpec) (*flag.FlagSet, error) { set := flag.NewFlagSet(name, flag.ContinueOnError) for _, f := range flags { + if c, ok := f.(customizedSeparator); ok { + c.WithSeparatorSpec(spec) + } if err := f.Apply(set); err != nil { return nil, err } @@ -389,10 +392,28 @@ func flagFromEnvOrFile(envVars []string, filePath string) (value string, fromWhe return "", "", false } -func flagSplitMultiValues(val string) []string { - if disableSliceFlagSeparator { +type customizedSeparator interface { + WithSeparatorSpec(separatorSpec) +} + +type separatorSpec struct { + sep string + disabled bool + customized bool +} + +func (s separatorSpec) flagSplitMultiValues(val string) []string { + var ( + disabled bool = s.disabled + sep string = s.sep + ) + if !s.customized { + disabled = disableSliceFlagSeparator + sep = defaultSliceFlagSeparator + } + if disabled { return []string{val} } - return strings.Split(val, defaultSliceFlagSeparator) + return strings.Split(val, sep) } diff --git a/flag_float64_slice.go b/flag_float64_slice.go index 413aa50e9f..0bc4612c82 100644 --- a/flag_float64_slice.go +++ b/flag_float64_slice.go @@ -11,6 +11,7 @@ import ( // Float64Slice wraps []float64 to satisfy flag.Value type Float64Slice struct { slice []float64 + separator separatorSpec hasBeenSet bool } @@ -29,6 +30,10 @@ func (f *Float64Slice) clone() *Float64Slice { return n } +func (f *Float64Slice) WithSeparatorSpec(spec separatorSpec) { + f.separator = spec +} + // Set parses the value into a float64 and appends it to the list of values func (f *Float64Slice) Set(value string) error { if !f.hasBeenSet { @@ -43,7 +48,7 @@ func (f *Float64Slice) Set(value string) error { return nil } - for _, s := range flagSplitMultiValues(value) { + for _, s := range f.separator.flagSplitMultiValues(value) { tmp, err := strconv.ParseFloat(strings.TrimSpace(s), 64) if err != nil { return err @@ -148,11 +153,12 @@ func (f *Float64SliceFlag) Apply(set *flag.FlagSet) error { setValue = f.Value.clone() default: setValue = new(Float64Slice) + setValue.WithSeparatorSpec(f.separator) } if val, source, found := flagFromEnvOrFile(f.EnvVars, f.FilePath); found { if val != "" { - for _, s := range flagSplitMultiValues(val) { + for _, s := range f.separator.flagSplitMultiValues(val) { if err := setValue.Set(strings.TrimSpace(s)); err != nil { return fmt.Errorf("could not parse %q as float64 slice value from %s for flag %s: %s", val, source, f.Name, err) } @@ -172,6 +178,10 @@ func (f *Float64SliceFlag) Apply(set *flag.FlagSet) error { return nil } +func (f *Float64SliceFlag) WithSeparatorSpec(spec separatorSpec) { + f.separator = spec +} + // Get returns the flag’s value in the given Context. func (f *Float64SliceFlag) Get(ctx *Context) []float64 { return ctx.Float64Slice(f.Name) diff --git a/flag_int64_slice.go b/flag_int64_slice.go index c45c43d3ab..d45c2dd440 100644 --- a/flag_int64_slice.go +++ b/flag_int64_slice.go @@ -11,6 +11,7 @@ import ( // Int64Slice wraps []int64 to satisfy flag.Value type Int64Slice struct { slice []int64 + separator separatorSpec hasBeenSet bool } @@ -29,6 +30,10 @@ func (i *Int64Slice) clone() *Int64Slice { return n } +func (i *Int64Slice) WithSeparatorSpec(spec separatorSpec) { + i.separator = spec +} + // Set parses the value into an integer and appends it to the list of values func (i *Int64Slice) Set(value string) error { if !i.hasBeenSet { @@ -43,7 +48,7 @@ func (i *Int64Slice) Set(value string) error { return nil } - for _, s := range flagSplitMultiValues(value) { + for _, s := range i.separator.flagSplitMultiValues(value) { tmp, err := strconv.ParseInt(strings.TrimSpace(s), 0, 64) if err != nil { return err @@ -149,10 +154,11 @@ func (f *Int64SliceFlag) Apply(set *flag.FlagSet) error { setValue = f.Value.clone() default: setValue = new(Int64Slice) + setValue.WithSeparatorSpec(f.separator) } if val, source, ok := flagFromEnvOrFile(f.EnvVars, f.FilePath); ok && val != "" { - for _, s := range flagSplitMultiValues(val) { + for _, s := range f.separator.flagSplitMultiValues(val) { if err := setValue.Set(strings.TrimSpace(s)); err != nil { return fmt.Errorf("could not parse %q as int64 slice value from %s for flag %s: %s", val, source, f.Name, err) } @@ -171,6 +177,10 @@ func (f *Int64SliceFlag) Apply(set *flag.FlagSet) error { return nil } +func (f *Int64SliceFlag) WithSeparatorSpec(spec separatorSpec) { + f.separator = spec +} + // Get returns the flag’s value in the given Context. func (f *Int64SliceFlag) Get(ctx *Context) []int64 { return ctx.Int64Slice(f.Name) diff --git a/flag_int_slice.go b/flag_int_slice.go index d4006e594c..da9c09bc73 100644 --- a/flag_int_slice.go +++ b/flag_int_slice.go @@ -11,6 +11,7 @@ import ( // IntSlice wraps []int to satisfy flag.Value type IntSlice struct { slice []int + separator separatorSpec hasBeenSet bool } @@ -40,6 +41,10 @@ func (i *IntSlice) SetInt(value int) { i.slice = append(i.slice, value) } +func (i *IntSlice) WithSeparatorSpec(spec separatorSpec) { + i.separator = spec +} + // Set parses the value into an integer and appends it to the list of values func (i *IntSlice) Set(value string) error { if !i.hasBeenSet { @@ -54,7 +59,7 @@ func (i *IntSlice) Set(value string) error { return nil } - for _, s := range flagSplitMultiValues(value) { + for _, s := range i.separator.flagSplitMultiValues(value) { tmp, err := strconv.ParseInt(strings.TrimSpace(s), 0, 64) if err != nil { return err @@ -160,10 +165,11 @@ func (f *IntSliceFlag) Apply(set *flag.FlagSet) error { setValue = f.Value.clone() default: setValue = new(IntSlice) + setValue.WithSeparatorSpec(f.separator) } if val, source, ok := flagFromEnvOrFile(f.EnvVars, f.FilePath); ok && val != "" { - for _, s := range flagSplitMultiValues(val) { + for _, s := range f.separator.flagSplitMultiValues(val) { if err := setValue.Set(strings.TrimSpace(s)); err != nil { return fmt.Errorf("could not parse %q as int slice value from %s for flag %s: %s", val, source, f.Name, err) } @@ -182,6 +188,10 @@ func (f *IntSliceFlag) Apply(set *flag.FlagSet) error { return nil } +func (f *IntSliceFlag) WithSeparatorSpec(spec separatorSpec) { + f.separator = spec +} + // Get returns the flag’s value in the given Context. func (f *IntSliceFlag) Get(ctx *Context) []int { return ctx.IntSlice(f.Name) diff --git a/flag_string_slice.go b/flag_string_slice.go index 82eeac06d8..82410dbc8b 100644 --- a/flag_string_slice.go +++ b/flag_string_slice.go @@ -11,6 +11,7 @@ import ( // StringSlice wraps a []string to satisfy flag.Value type StringSlice struct { slice []string + separator separatorSpec hasBeenSet bool } @@ -43,13 +44,17 @@ func (s *StringSlice) Set(value string) error { return nil } - for _, t := range flagSplitMultiValues(value) { + for _, t := range s.separator.flagSplitMultiValues(value) { s.slice = append(s.slice, t) } return nil } +func (s *StringSlice) WithSeparatorSpec(spec separatorSpec) { + s.separator = spec +} + // String returns a readable representation of this value (for usage defaults) func (s *StringSlice) String() string { return fmt.Sprintf("%s", s.slice) @@ -141,10 +146,11 @@ func (f *StringSliceFlag) Apply(set *flag.FlagSet) error { setValue = f.Value.clone() default: setValue = new(StringSlice) + setValue.WithSeparatorSpec(f.separator) } if val, source, found := flagFromEnvOrFile(f.EnvVars, f.FilePath); found { - for _, s := range flagSplitMultiValues(val) { + for _, s := range f.separator.flagSplitMultiValues(val) { if err := setValue.Set(strings.TrimSpace(s)); err != nil { return fmt.Errorf("could not parse %q as string value from %s for flag %s: %s", val, source, f.Name, err) } @@ -163,6 +169,10 @@ func (f *StringSliceFlag) Apply(set *flag.FlagSet) error { return nil } +func (f *StringSliceFlag) WithSeparatorSpec(spec separatorSpec) { + f.separator = spec +} + // Get returns the flag’s value in the given Context. func (f *StringSliceFlag) Get(ctx *Context) []string { return ctx.StringSlice(f.Name) diff --git a/flag_test.go b/flag_test.go index c1523717ff..132a9ede4a 100644 --- a/flag_test.go +++ b/flag_test.go @@ -3425,13 +3425,27 @@ func TestSliceShortOptionHandle(t *testing.T) { } // Test issue #1541 +func TestDefaultSliceFlagSeparator(t *testing.T) { + separator := separatorSpec{} + opts := []string{"opt1", "opt2", "opt3", "opt4"} + ret := separator.flagSplitMultiValues(strings.Join(opts, ",")) + if len(ret) != 4 { + t.Fatalf("split slice flag failed, want: 4, but get: %d", len(ret)) + } + for idx, r := range ret { + if r != opts[idx] { + t.Fatalf("get %dth failed, wanted: %s, but get: %s", idx, opts[idx], r) + } + } +} + func TestCustomizedSliceFlagSeparator(t *testing.T) { - defaultSliceFlagSeparator = ";" - defer func() { - defaultSliceFlagSeparator = "," - }() + separator := separatorSpec{ + customized: true, + sep: ";", + } opts := []string{"opt1", "opt2", "opt3,op", "opt4"} - ret := flagSplitMultiValues(strings.Join(opts, ";")) + ret := separator.flagSplitMultiValues(strings.Join(opts, ";")) if len(ret) != 4 { t.Fatalf("split slice flag failed, want: 4, but get: %d", len(ret)) } @@ -3443,13 +3457,13 @@ func TestCustomizedSliceFlagSeparator(t *testing.T) { } func TestFlagSplitMultiValues_Disabled(t *testing.T) { - disableSliceFlagSeparator = true - defer func() { - disableSliceFlagSeparator = false - }() + separator := separatorSpec{ + customized: true, + disabled: true, + } opts := []string{"opt1", "opt2", "opt3,op", "opt4"} - ret := flagSplitMultiValues(strings.Join(opts, defaultSliceFlagSeparator)) + ret := separator.flagSplitMultiValues(strings.Join(opts, defaultSliceFlagSeparator)) if len(ret) != 1 { t.Fatalf("failed to disable split slice flag, want: 1, but got: %d", len(ret)) } diff --git a/flag_uint64_slice.go b/flag_uint64_slice.go index 61bb30b551..e845dd5257 100644 --- a/flag_uint64_slice.go +++ b/flag_uint64_slice.go @@ -11,6 +11,7 @@ import ( // Uint64Slice wraps []int64 to satisfy flag.Value type Uint64Slice struct { slice []uint64 + separator separatorSpec hasBeenSet bool } @@ -43,7 +44,7 @@ func (i *Uint64Slice) Set(value string) error { return nil } - for _, s := range flagSplitMultiValues(value) { + for _, s := range i.separator.flagSplitMultiValues(value) { tmp, err := strconv.ParseUint(strings.TrimSpace(s), 0, 64) if err != nil { return err @@ -55,6 +56,10 @@ func (i *Uint64Slice) Set(value string) error { return nil } +func (i *Uint64Slice) WithSeparatorSpec(spec separatorSpec) { + i.separator = spec +} + // String returns a readable representation of this value (for usage defaults) func (i *Uint64Slice) String() string { v := i.slice @@ -153,10 +158,11 @@ func (f *Uint64SliceFlag) Apply(set *flag.FlagSet) error { setValue = f.Value.clone() default: setValue = new(Uint64Slice) + setValue.WithSeparatorSpec(f.separator) } if val, source, ok := flagFromEnvOrFile(f.EnvVars, f.FilePath); ok && val != "" { - for _, s := range flagSplitMultiValues(val) { + for _, s := range f.separator.flagSplitMultiValues(val) { if err := setValue.Set(strings.TrimSpace(s)); err != nil { return fmt.Errorf("could not parse %q as uint64 slice value from %s for flag %s: %s", val, source, f.Name, err) } @@ -175,6 +181,10 @@ func (f *Uint64SliceFlag) Apply(set *flag.FlagSet) error { return nil } +func (f *Uint64SliceFlag) WithSeparatorSpec(spec separatorSpec) { + f.separator = spec +} + // Get returns the flag’s value in the given Context. func (f *Uint64SliceFlag) Get(ctx *Context) []uint64 { return ctx.Uint64Slice(f.Name) diff --git a/flag_uint_slice.go b/flag_uint_slice.go index 363aa657f3..d2aed480d9 100644 --- a/flag_uint_slice.go +++ b/flag_uint_slice.go @@ -11,6 +11,7 @@ import ( // UintSlice wraps []int to satisfy flag.Value type UintSlice struct { slice []uint + separator separatorSpec hasBeenSet bool } @@ -54,7 +55,7 @@ func (i *UintSlice) Set(value string) error { return nil } - for _, s := range flagSplitMultiValues(value) { + for _, s := range i.separator.flagSplitMultiValues(value) { tmp, err := strconv.ParseUint(strings.TrimSpace(s), 0, 32) if err != nil { return err @@ -66,6 +67,10 @@ func (i *UintSlice) Set(value string) error { return nil } +func (i *UintSlice) WithSeparatorSpec(spec separatorSpec) { + i.separator = spec +} + // String returns a readable representation of this value (for usage defaults) func (i *UintSlice) String() string { v := i.slice @@ -164,10 +169,11 @@ func (f *UintSliceFlag) Apply(set *flag.FlagSet) error { setValue = f.Value.clone() default: setValue = new(UintSlice) + setValue.WithSeparatorSpec(f.separator) } if val, source, ok := flagFromEnvOrFile(f.EnvVars, f.FilePath); ok && val != "" { - for _, s := range flagSplitMultiValues(val) { + for _, s := range f.separator.flagSplitMultiValues(val) { if err := setValue.Set(strings.TrimSpace(s)); err != nil { return fmt.Errorf("could not parse %q as uint slice value from %s for flag %s: %s", val, source, f.Name, err) } @@ -186,6 +192,10 @@ func (f *UintSliceFlag) Apply(set *flag.FlagSet) error { return nil } +func (f *UintSliceFlag) WithSeparatorSpec(spec separatorSpec) { + f.separator = spec +} + // Get returns the flag’s value in the given Context. func (f *UintSliceFlag) Get(ctx *Context) []uint { return ctx.UintSlice(f.Name) diff --git a/zz_generated.flags.go b/zz_generated.flags.go index 016951a3f0..8c29f6ee03 100644 --- a/zz_generated.flags.go +++ b/zz_generated.flags.go @@ -25,6 +25,8 @@ type Float64SliceFlag struct { defaultValue *Float64Slice + separator separatorSpec + Action func(*Context, []float64) error } @@ -120,6 +122,8 @@ type Int64SliceFlag struct { defaultValue *Int64Slice + separator separatorSpec + Action func(*Context, []int64) error } @@ -164,6 +168,8 @@ type IntSliceFlag struct { defaultValue *IntSlice + separator separatorSpec + Action func(*Context, []int) error } @@ -259,6 +265,8 @@ type StringSliceFlag struct { defaultValue *StringSlice + separator separatorSpec + TakesFile bool Action func(*Context, []string) error @@ -358,6 +366,8 @@ type Uint64SliceFlag struct { defaultValue *Uint64Slice + separator separatorSpec + Action func(*Context, []uint64) error } @@ -402,6 +412,8 @@ type UintSliceFlag struct { defaultValue *UintSlice + separator separatorSpec + Action func(*Context, []uint) error }