Skip to content

Commit

Permalink
Pass context to completion (#1265)
Browse files Browse the repository at this point in the history
  • Loading branch information
lukasmalkmus authored May 3, 2021
1 parent 7223a99 commit 6d00909
Show file tree
Hide file tree
Showing 4 changed files with 94 additions and 1 deletion.
11 changes: 10 additions & 1 deletion command.go
Original file line number Diff line number Diff line change
Expand Up @@ -887,7 +887,8 @@ func (c *Command) preRun() {
}

// ExecuteContext is the same as Execute(), but sets the ctx on the command.
// Retrieve ctx by calling cmd.Context() inside your *Run lifecycle functions.
// Retrieve ctx by calling cmd.Context() inside your *Run lifecycle or ValidArgs
// functions.
func (c *Command) ExecuteContext(ctx context.Context) error {
c.ctx = ctx
return c.Execute()
Expand All @@ -901,6 +902,14 @@ func (c *Command) Execute() error {
return err
}

// ExecuteContextC is the same as ExecuteC(), but sets the ctx on the command.
// Retrieve ctx by calling cmd.Context() inside your *Run lifecycle or ValidArgs
// functions.
func (c *Command) ExecuteContextC(ctx context.Context) (*Command, error) {
c.ctx = ctx
return c.ExecuteC()
}

// ExecuteC executes the command.
func (c *Command) ExecuteC() (cmd *Command, err error) {
if c.ctx == nil {
Expand Down
40 changes: 40 additions & 0 deletions command_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,17 @@ func executeCommandC(root *Command, args ...string) (c *Command, output string,
return c, buf.String(), err
}

func executeCommandWithContextC(ctx context.Context, root *Command, args ...string) (c *Command, output string, err error) {
buf := new(bytes.Buffer)
root.SetOut(buf)
root.SetErr(buf)
root.SetArgs(args)

c, err = root.ExecuteContextC(ctx)

return c, buf.String(), err
}

func resetCommandLineFlagSet() {
pflag.CommandLine = pflag.NewFlagSet(os.Args[0], pflag.ExitOnError)
}
Expand Down Expand Up @@ -178,6 +189,35 @@ func TestExecuteContext(t *testing.T) {
}
}

func TestExecuteContextC(t *testing.T) {
ctx := context.TODO()

ctxRun := func(cmd *Command, args []string) {
if cmd.Context() != ctx {
t.Errorf("Command %q must have context when called with ExecuteContext", cmd.Use)
}
}

rootCmd := &Command{Use: "root", Run: ctxRun, PreRun: ctxRun}
childCmd := &Command{Use: "child", Run: ctxRun, PreRun: ctxRun}
granchildCmd := &Command{Use: "grandchild", Run: ctxRun, PreRun: ctxRun}

childCmd.AddCommand(granchildCmd)
rootCmd.AddCommand(childCmd)

if _, _, err := executeCommandWithContextC(ctx, rootCmd, ""); err != nil {
t.Errorf("Root command must not fail: %+v", err)
}

if _, _, err := executeCommandWithContextC(ctx, rootCmd, "child"); err != nil {
t.Errorf("Subcommand must not fail: %+v", err)
}

if _, _, err := executeCommandWithContextC(ctx, rootCmd, "child", "grandchild"); err != nil {
t.Errorf("Command child must not fail: %+v", err)
}
}

func TestExecute_NoContext(t *testing.T) {
run := func(cmd *Command, args []string) {
if cmd.Context() != context.Background() {
Expand Down
1 change: 1 addition & 0 deletions completions.go
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,7 @@ func (c *Command) getCompletions(args []string) (*Command, []string, ShellCompDi
// Unable to find the real command. E.g., <program> someInvalidCmd <TAB>
return c, []string{}, ShellCompDirectiveDefault, fmt.Errorf("Unable to find a command for arguments: %v", trimmedArgs)
}
finalCmd.ctx = c.ctx

// Check if we are doing flag value completion before parsing the flags.
// This is important because if we are completing a flag value, we need to also
Expand Down
43 changes: 43 additions & 0 deletions completions_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package cobra

import (
"bytes"
"context"
"strings"
"testing"
)
Expand Down Expand Up @@ -1203,6 +1204,48 @@ func TestFlagDirFilterCompletionInGo(t *testing.T) {
}
}

func TestValidArgsFuncCmdContext(t *testing.T) {
validArgsFunc := func(cmd *Command, args []string, toComplete string) ([]string, ShellCompDirective) {
ctx := cmd.Context()

if ctx == nil {
t.Error("Received nil context in completion func")
} else if ctx.Value("testKey") != "123" {
t.Error("Received invalid context")
}

return nil, ShellCompDirectiveDefault
}

rootCmd := &Command{
Use: "root",
Run: emptyRun,
}
childCmd := &Command{
Use: "childCmd",
ValidArgsFunction: validArgsFunc,
Run: emptyRun,
}
rootCmd.AddCommand(childCmd)

//nolint:golint,staticcheck // We can safely use a basic type as key in tests.
ctx := context.WithValue(context.Background(), "testKey", "123")

// Test completing an empty string on the childCmd
_, output, err := executeCommandWithContextC(ctx, rootCmd, ShellCompNoDescRequestCmd, "childCmd", "")
if err != nil {
t.Errorf("Unexpected error: %v", err)
}

expected := strings.Join([]string{
":0",
"Completion ended with directive: ShellCompDirectiveDefault", ""}, "\n")

if output != expected {
t.Errorf("expected: %q, got: %q", expected, output)
}
}

func TestValidArgsFuncSingleCmd(t *testing.T) {
rootCmd := &Command{
Use: "root",
Expand Down

0 comments on commit 6d00909

Please sign in to comment.