Skip to content

Commit

Permalink
flag: remove dependencies on shared variables
Browse files Browse the repository at this point in the history
See #1670 for context
This commit creates a new type separatorSpec that will get passed into
flag parser, instead of reading from shared variables.
  • Loading branch information
zllovesuki committed Jan 29, 2023
1 parent a6b3713 commit dc6dfb7
Show file tree
Hide file tree
Showing 13 changed files with 166 additions and 39 deletions.
24 changes: 15 additions & 9 deletions app.go
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,8 @@ type App struct {
// Treat all flags as normal arguments if true
SkipFlagParsing bool

didSetup bool
didSetup bool
separator separatorSpec

rootCommand *Command
}
Expand Down Expand Up @@ -216,15 +217,25 @@ func (a *App) Setup() {
})
}

if len(a.SliceFlagSeparator) != 0 {
a.separator.customized = true
a.separator.sep = a.SliceFlagSeparator
}

if a.DisableSliceFlagSeparator {
a.separator.customized = true
a.separator.disabled = true
}

var newCommands []*Command

for _, c := range a.Commands {
cname := c.Name
if c.HelpName != "" {
cname = c.HelpName
}
c.separator = a.separator
c.HelpName = fmt.Sprintf("%s %s", a.HelpName, cname)

c.flagCategories = newFlagCategoriesFromFlags(c.Flags)
newCommands = append(newCommands, c)
}
Expand Down Expand Up @@ -262,12 +273,6 @@ func (a *App) Setup() {
if a.Metadata == nil {
a.Metadata = make(map[string]interface{})
}

if len(a.SliceFlagSeparator) != 0 {
defaultSliceFlagSeparator = a.SliceFlagSeparator
}

disableSliceFlagSeparator = a.DisableSliceFlagSeparator
}

func (a *App) newRootCommand() *Command {
Expand All @@ -293,11 +298,12 @@ func (a *App) newRootCommand() *Command {
categories: a.categories,
SkipFlagParsing: a.SkipFlagParsing,
isRoot: true,
separator: a.separator,
}
}

func (a *App) newFlagSet() (*flag.FlagSet, error) {
return flagSet(a.Name, a.Flags)
return flagSet(a.Name, a.Flags, a.separator)
}

func (a *App) useShortOptionHandling() bool {
Expand Down
4 changes: 2 additions & 2 deletions app_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2532,7 +2532,7 @@ func TestCustomHelpVersionFlags(t *testing.T) {

func TestHandleExitCoder_Default(t *testing.T) {
app := newTestApp()
fs, err := flagSet(app.Name, app.Flags)
fs, err := flagSet(app.Name, app.Flags, separatorSpec{})
if err != nil {
t.Errorf("error creating FlagSet: %s", err)
}
Expand All @@ -2548,7 +2548,7 @@ func TestHandleExitCoder_Default(t *testing.T) {

func TestHandleExitCoder_Custom(t *testing.T) {
app := newTestApp()
fs, err := flagSet(app.Name, app.Flags)
fs, err := flagSet(app.Name, app.Flags, separatorSpec{})
if err != nil {
t.Errorf("error creating FlagSet: %s", err)
}
Expand Down
4 changes: 3 additions & 1 deletion command.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@ type Command struct {

// if this is a root "special" command
isRoot bool

separator separatorSpec
}

type Commands []*Command
Expand Down Expand Up @@ -275,7 +277,7 @@ func (c *Command) Run(cCtx *Context, arguments ...string) (err error) {
}

func (c *Command) newFlagSet() (*flag.FlagSet, error) {
return flagSet(c.Name, c.Flags)
return flagSet(c.Name, c.Flags, c.separator)
}

func (c *Command) useShortOptionHandling() bool {
Expand Down
12 changes: 12 additions & 0 deletions flag-spec.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ flag_types:
skip_interfaces:
- fmt.Stringer
struct_fields:
- name: separator
type: separatorSpec
- name: Action
type: "func(*Context, []float64) error"
int:
Expand All @@ -33,6 +35,8 @@ flag_types:
skip_interfaces:
- fmt.Stringer
struct_fields:
- name: separator
type: separatorSpec
- name: Action
type: "func(*Context, []int) error"
int64:
Expand All @@ -46,6 +50,8 @@ flag_types:
skip_interfaces:
- fmt.Stringer
struct_fields:
- name: separator
type: separatorSpec
- name: Action
type: "func(*Context, []int64) error"
uint:
Expand All @@ -59,6 +65,8 @@ flag_types:
skip_interfaces:
- fmt.Stringer
struct_fields:
- name: separator
type: separatorSpec
- name: Action
type: "func(*Context, []uint) error"
uint64:
Expand All @@ -72,6 +80,8 @@ flag_types:
skip_interfaces:
- fmt.Stringer
struct_fields:
- name: separator
type: separatorSpec
- name: Action
type: "func(*Context, []uint64) error"
string:
Expand All @@ -85,6 +95,8 @@ flag_types:
skip_interfaces:
- fmt.Stringer
struct_fields:
- name: separator
type: separatorSpec
- name: TakesFile
type: bool
- name: Action
Expand Down
31 changes: 26 additions & 5 deletions flag.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ import (

const defaultPlaceholder = "value"

var (
const (
defaultSliceFlagSeparator = ","
disableSliceFlagSeparator = false
)
Expand Down Expand Up @@ -167,10 +167,13 @@ type Countable interface {
Count() int
}

func flagSet(name string, flags []Flag) (*flag.FlagSet, error) {
func flagSet(name string, flags []Flag, spec separatorSpec) (*flag.FlagSet, error) {
set := flag.NewFlagSet(name, flag.ContinueOnError)

for _, f := range flags {
if c, ok := f.(customizedSeparator); ok {
c.WithSeparatorSpec(spec)
}
if err := f.Apply(set); err != nil {
return nil, err
}
Expand Down Expand Up @@ -389,10 +392,28 @@ func flagFromEnvOrFile(envVars []string, filePath string) (value string, fromWhe
return "", "", false
}

func flagSplitMultiValues(val string) []string {
if disableSliceFlagSeparator {
type customizedSeparator interface {
WithSeparatorSpec(separatorSpec)
}

type separatorSpec struct {
sep string
disabled bool
customized bool
}

func (s separatorSpec) flagSplitMultiValues(val string) []string {
var (
disabled bool = s.disabled
sep string = s.sep
)
if !s.customized {
disabled = disableSliceFlagSeparator
sep = defaultSliceFlagSeparator
}
if disabled {
return []string{val}
}

return strings.Split(val, defaultSliceFlagSeparator)
return strings.Split(val, sep)
}
14 changes: 12 additions & 2 deletions flag_float64_slice.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
// Float64Slice wraps []float64 to satisfy flag.Value
type Float64Slice struct {
slice []float64
separator separatorSpec
hasBeenSet bool
}

Expand All @@ -29,6 +30,10 @@ func (f *Float64Slice) clone() *Float64Slice {
return n
}

func (f *Float64Slice) WithSeparatorSpec(spec separatorSpec) {
f.separator = spec
}

// Set parses the value into a float64 and appends it to the list of values
func (f *Float64Slice) Set(value string) error {
if !f.hasBeenSet {
Expand All @@ -43,7 +48,7 @@ func (f *Float64Slice) Set(value string) error {
return nil
}

for _, s := range flagSplitMultiValues(value) {
for _, s := range f.separator.flagSplitMultiValues(value) {
tmp, err := strconv.ParseFloat(strings.TrimSpace(s), 64)
if err != nil {
return err
Expand Down Expand Up @@ -148,11 +153,12 @@ func (f *Float64SliceFlag) Apply(set *flag.FlagSet) error {
setValue = f.Value.clone()
default:
setValue = new(Float64Slice)
setValue.WithSeparatorSpec(f.separator)
}

if val, source, found := flagFromEnvOrFile(f.EnvVars, f.FilePath); found {
if val != "" {
for _, s := range flagSplitMultiValues(val) {
for _, s := range f.separator.flagSplitMultiValues(val) {
if err := setValue.Set(strings.TrimSpace(s)); err != nil {
return fmt.Errorf("could not parse %q as float64 slice value from %s for flag %s: %s", val, source, f.Name, err)
}
Expand All @@ -172,6 +178,10 @@ func (f *Float64SliceFlag) Apply(set *flag.FlagSet) error {
return nil
}

func (f *Float64SliceFlag) WithSeparatorSpec(spec separatorSpec) {
f.separator = spec
}

// Get returns the flag’s value in the given Context.
func (f *Float64SliceFlag) Get(ctx *Context) []float64 {
return ctx.Float64Slice(f.Name)
Expand Down
14 changes: 12 additions & 2 deletions flag_int64_slice.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
// Int64Slice wraps []int64 to satisfy flag.Value
type Int64Slice struct {
slice []int64
separator separatorSpec
hasBeenSet bool
}

Expand All @@ -29,6 +30,10 @@ func (i *Int64Slice) clone() *Int64Slice {
return n
}

func (i *Int64Slice) WithSeparatorSpec(spec separatorSpec) {
i.separator = spec
}

// Set parses the value into an integer and appends it to the list of values
func (i *Int64Slice) Set(value string) error {
if !i.hasBeenSet {
Expand All @@ -43,7 +48,7 @@ func (i *Int64Slice) Set(value string) error {
return nil
}

for _, s := range flagSplitMultiValues(value) {
for _, s := range i.separator.flagSplitMultiValues(value) {
tmp, err := strconv.ParseInt(strings.TrimSpace(s), 0, 64)
if err != nil {
return err
Expand Down Expand Up @@ -149,10 +154,11 @@ func (f *Int64SliceFlag) Apply(set *flag.FlagSet) error {
setValue = f.Value.clone()
default:
setValue = new(Int64Slice)
setValue.WithSeparatorSpec(f.separator)
}

if val, source, ok := flagFromEnvOrFile(f.EnvVars, f.FilePath); ok && val != "" {
for _, s := range flagSplitMultiValues(val) {
for _, s := range f.separator.flagSplitMultiValues(val) {
if err := setValue.Set(strings.TrimSpace(s)); err != nil {
return fmt.Errorf("could not parse %q as int64 slice value from %s for flag %s: %s", val, source, f.Name, err)
}
Expand All @@ -171,6 +177,10 @@ func (f *Int64SliceFlag) Apply(set *flag.FlagSet) error {
return nil
}

func (f *Int64SliceFlag) WithSeparatorSpec(spec separatorSpec) {
f.separator = spec
}

// Get returns the flag’s value in the given Context.
func (f *Int64SliceFlag) Get(ctx *Context) []int64 {
return ctx.Int64Slice(f.Name)
Expand Down
14 changes: 12 additions & 2 deletions flag_int_slice.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
// IntSlice wraps []int to satisfy flag.Value
type IntSlice struct {
slice []int
separator separatorSpec
hasBeenSet bool
}

Expand Down Expand Up @@ -40,6 +41,10 @@ func (i *IntSlice) SetInt(value int) {
i.slice = append(i.slice, value)
}

func (i *IntSlice) WithSeparatorSpec(spec separatorSpec) {
i.separator = spec
}

// Set parses the value into an integer and appends it to the list of values
func (i *IntSlice) Set(value string) error {
if !i.hasBeenSet {
Expand All @@ -54,7 +59,7 @@ func (i *IntSlice) Set(value string) error {
return nil
}

for _, s := range flagSplitMultiValues(value) {
for _, s := range i.separator.flagSplitMultiValues(value) {
tmp, err := strconv.ParseInt(strings.TrimSpace(s), 0, 64)
if err != nil {
return err
Expand Down Expand Up @@ -160,10 +165,11 @@ func (f *IntSliceFlag) Apply(set *flag.FlagSet) error {
setValue = f.Value.clone()
default:
setValue = new(IntSlice)
setValue.WithSeparatorSpec(f.separator)
}

if val, source, ok := flagFromEnvOrFile(f.EnvVars, f.FilePath); ok && val != "" {
for _, s := range flagSplitMultiValues(val) {
for _, s := range f.separator.flagSplitMultiValues(val) {
if err := setValue.Set(strings.TrimSpace(s)); err != nil {
return fmt.Errorf("could not parse %q as int slice value from %s for flag %s: %s", val, source, f.Name, err)
}
Expand All @@ -182,6 +188,10 @@ func (f *IntSliceFlag) Apply(set *flag.FlagSet) error {
return nil
}

func (f *IntSliceFlag) WithSeparatorSpec(spec separatorSpec) {
f.separator = spec
}

// Get returns the flag’s value in the given Context.
func (f *IntSliceFlag) Get(ctx *Context) []int {
return ctx.IntSlice(f.Name)
Expand Down
Loading

0 comments on commit dc6dfb7

Please sign in to comment.