Skip to content

Commit

Permalink
Add App.InvalidFlagAccessHandler (#1446)
Browse files Browse the repository at this point in the history
* Add App.UnknownFlagHandler

* Rename App.UnknownFlagHandler to App.InvalidFlagAccessHandler

* Traverse parent contexts
  • Loading branch information
icholy authored Aug 30, 2022
1 parent 254c38e commit ca9df40
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 1 deletion.
2 changes: 2 additions & 0 deletions app.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,8 @@ type App struct {
CommandNotFound CommandNotFoundFunc
// Execute this function if a usage error occurs
OnUsageError OnUsageErrorFunc
// Execute this function when an invalid flag is accessed from the context
InvalidFlagAccessHandler InvalidFlagAccessFunc
// Compilation date
Compiled time.Time
// List of all authors who contributed
Expand Down
15 changes: 14 additions & 1 deletion context.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@ func (cCtx *Context) NumFlags() int {

// Set sets a context flag to a value.
func (cCtx *Context) Set(name, value string) error {
if cCtx.flagSet.Lookup(name) == nil {
cCtx.onInvalidFlag(name)
}
return cCtx.flagSet.Set(name, value)
}

Expand Down Expand Up @@ -158,7 +161,7 @@ func (cCtx *Context) lookupFlagSet(name string) *flag.FlagSet {
return c.flagSet
}
}

cCtx.onInvalidFlag(name)
return nil
}

Expand Down Expand Up @@ -190,6 +193,16 @@ func (cCtx *Context) checkRequiredFlags(flags []Flag) requiredFlagsErr {
return nil
}

func (cCtx *Context) onInvalidFlag(name string) {
for cCtx != nil {
if cCtx.App != nil && cCtx.App.InvalidFlagAccessHandler != nil {
cCtx.App.InvalidFlagAccessHandler(cCtx, name)
break
}
cCtx = cCtx.parentContext
}
}

func makeFlagNameVisitor(names *[]string) func(*flag.Flag) {
return func(f *flag.Flag) {
nameParts := strings.Split(f.Name, ",")
Expand Down
38 changes: 38 additions & 0 deletions context_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,31 @@ func TestContext_Value(t *testing.T) {
expect(t, c.Value("unknown-flag"), nil)
}

func TestContext_Value_InvalidFlagAccessHandler(t *testing.T) {
var flagName string
app := &App{
InvalidFlagAccessHandler: func(_ *Context, name string) {
flagName = name
},
Commands: []*Command{
{
Name: "command",
Subcommands: []*Command{
{
Name: "subcommand",
Action: func(ctx *Context) error {
ctx.Value("missing")
return nil
},
},
},
},
},
}
expect(t, app.Run([]string{"run", "command", "subcommand"}), nil)
expect(t, flagName, "missing")
}

func TestContext_Args(t *testing.T) {
set := flag.NewFlagSet("test", 0)
set.Bool("myflag", false, "doc")
Expand Down Expand Up @@ -258,6 +283,19 @@ func TestContext_Set(t *testing.T) {
expect(t, c.IsSet("int"), true)
}

func TestContext_Set_InvalidFlagAccessHandler(t *testing.T) {
set := flag.NewFlagSet("test", 0)
var flagName string
app := &App{
InvalidFlagAccessHandler: func(_ *Context, name string) {
flagName = name
},
}
c := NewContext(app, set, nil)
c.Set("missing", "")
expect(t, flagName, "missing")
}

func TestContext_LocalFlagNames(t *testing.T) {
set := flag.NewFlagSet("test", 0)
set.Bool("one-flag", false, "doc")
Expand Down
3 changes: 3 additions & 0 deletions funcs.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ type CommandNotFoundFunc func(*Context, string)
// is displayed and the execution is interrupted.
type OnUsageErrorFunc func(cCtx *Context, err error, isSubcommand bool) error

// InvalidFlagAccessFunc is executed when an invalid flag is accessed from the context.
type InvalidFlagAccessFunc func(*Context, string)

// ExitErrHandlerFunc is executed if provided in order to handle exitError values
// returned by Actions and Before/After functions.
type ExitErrHandlerFunc func(cCtx *Context, err error)
Expand Down

0 comments on commit ca9df40

Please sign in to comment.