diff --git a/powershell_completions.go b/powershell_completions.go new file mode 100644 index 000000000..756c61b9d --- /dev/null +++ b/powershell_completions.go @@ -0,0 +1,100 @@ +// PowerShell completions are based on the amazing work from clap: +// https://github.com/clap-rs/clap/blob/3294d18efe5f264d12c9035f404c7d189d4824e1/src/completions/powershell.rs +// +// The generated scripts require PowerShell v5.0+ (which comes Windows 10, but +// can be downloaded separately for windows 7 or 8.1). + +package cobra + +import ( + "bytes" + "fmt" + "io" + "os" + "strings" + + "github.com/spf13/pflag" +) + +var powerShellCompletionTemplate = `using namespace System.Management.Automation +using namespace System.Management.Automation.Language +Register-ArgumentCompleter -Native -CommandName '%s' -ScriptBlock { + param($wordToComplete, $commandAst, $cursorPosition) + $commandElements = $commandAst.CommandElements + $command = @( + '%s' + for ($i = 1; $i -lt $commandElements.Count; $i++) { + $element = $commandElements[$i] + if ($element -isnot [StringConstantExpressionAst] -or + $element.StringConstantType -ne [StringConstantType]::BareWord -or + $element.Value.StartsWith('-')) { + break + } + $element.Value + } + ) -join ';' + $completions = @(switch ($command) {%s + }) + $completions.Where{ $_.CompletionText -like "$wordToComplete*" } | + Sort-Object -Property ListItemText +}` + +func generatePowerShellSubcommandCases(out io.Writer, cmd *Command, previousCommandName string) { + var cmdName string + if previousCommandName == "" { + cmdName = cmd.Name() + } else { + cmdName = fmt.Sprintf("%s;%s", previousCommandName, cmd.Name()) + } + + fmt.Fprintf(out, "\n '%s' {", cmdName) + + cmd.Flags().VisitAll(func(flag *pflag.Flag) { + if nonCompletableFlag(flag) { + return + } + usage := escapeStringForPowerShell(flag.Usage) + if len(flag.Shorthand) > 0 { + fmt.Fprintf(out, "\n [CompletionResult]::new('-%s', '%s', [CompletionResultType]::ParameterName, '%s')", flag.Shorthand, flag.Shorthand, usage) + } + fmt.Fprintf(out, "\n [CompletionResult]::new('--%s', '%s', [CompletionResultType]::ParameterName, '%s')", flag.Name, flag.Name, usage) + }) + + for _, subCmd := range cmd.Commands() { + usage := escapeStringForPowerShell(subCmd.Short) + fmt.Fprintf(out, "\n [CompletionResult]::new('%s', '%s', [CompletionResultType]::ParameterValue, '%s')", subCmd.Name(), subCmd.Name(), usage) + } + + fmt.Fprint(out, "\n break\n }") + + for _, subCmd := range cmd.Commands() { + generatePowerShellSubcommandCases(out, subCmd, cmdName) + } +} + +func escapeStringForPowerShell(s string) string { + return strings.Replace(s, "'", "''", -1) +} + +// GenPowerShellCompletion generates PowerShell completion file and writes to the passed writer. +func (c *Command) GenPowerShellCompletion(w io.Writer) error { + buf := new(bytes.Buffer) + + var subCommandCases bytes.Buffer + generatePowerShellSubcommandCases(&subCommandCases, c, "") + fmt.Fprintf(buf, powerShellCompletionTemplate, c.Name(), c.Name(), subCommandCases.String()) + + _, err := buf.WriteTo(w) + return err +} + +// GenPowerShellCompletionFile generates PowerShell completion file. +func (c *Command) GenPowerShellCompletionFile(filename string) error { + outFile, err := os.Create(filename) + if err != nil { + return err + } + defer outFile.Close() + + return c.GenPowerShellCompletion(outFile) +} diff --git a/powershell_completions_test.go b/powershell_completions_test.go new file mode 100644 index 000000000..29b609de0 --- /dev/null +++ b/powershell_completions_test.go @@ -0,0 +1,122 @@ +package cobra + +import ( + "bytes" + "strings" + "testing" +) + +func TestPowerShellCompletion(t *testing.T) { + tcs := []struct { + name string + root *Command + expectedExpressions []string + }{ + { + name: "trivial", + root: &Command{Use: "trivialapp"}, + expectedExpressions: []string{ + "Register-ArgumentCompleter -Native -CommandName 'trivialapp' -ScriptBlock", + "$command = @(\n 'trivialapp'\n", + }, + }, + { + name: "tree", + root: func() *Command { + r := &Command{Use: "tree"} + + sub1 := &Command{Use: "sub1"} + r.AddCommand(sub1) + + sub11 := &Command{Use: "sub11"} + sub12 := &Command{Use: "sub12"} + + sub1.AddCommand(sub11) + sub1.AddCommand(sub12) + + sub2 := &Command{Use: "sub2"} + r.AddCommand(sub2) + + sub21 := &Command{Use: "sub21"} + sub22 := &Command{Use: "sub22"} + + sub2.AddCommand(sub21) + sub2.AddCommand(sub22) + + return r + }(), + expectedExpressions: []string{ + "'tree'", + "[CompletionResult]::new('sub1', 'sub1', [CompletionResultType]::ParameterValue, '')", + "[CompletionResult]::new('sub2', 'sub2', [CompletionResultType]::ParameterValue, '')", + "'tree;sub1'", + "[CompletionResult]::new('sub11', 'sub11', [CompletionResultType]::ParameterValue, '')", + "[CompletionResult]::new('sub12', 'sub12', [CompletionResultType]::ParameterValue, '')", + "'tree;sub1;sub11'", + "'tree;sub1;sub12'", + "'tree;sub2'", + "[CompletionResult]::new('sub21', 'sub21', [CompletionResultType]::ParameterValue, '')", + "[CompletionResult]::new('sub22', 'sub22', [CompletionResultType]::ParameterValue, '')", + "'tree;sub2;sub21'", + "'tree;sub2;sub22'", + }, + }, + { + name: "flags", + root: func() *Command { + r := &Command{Use: "flags"} + r.Flags().StringP("flag1", "a", "", "") + r.Flags().String("flag2", "", "") + + sub1 := &Command{Use: "sub1"} + sub1.Flags().StringP("flag3", "c", "", "") + r.AddCommand(sub1) + + return r + }(), + expectedExpressions: []string{ + "'flags'", + "[CompletionResult]::new('-a', 'a', [CompletionResultType]::ParameterName, '')", + "[CompletionResult]::new('--flag1', 'flag1', [CompletionResultType]::ParameterName, '')", + "[CompletionResult]::new('--flag2', 'flag2', [CompletionResultType]::ParameterName, '')", + "[CompletionResult]::new('sub1', 'sub1', [CompletionResultType]::ParameterValue, '')", + "'flags;sub1'", + "[CompletionResult]::new('-c', 'c', [CompletionResultType]::ParameterName, '')", + "[CompletionResult]::new('--flag3', 'flag3', [CompletionResultType]::ParameterName, '')", + }, + }, + { + name: "usage", + root: func() *Command { + r := &Command{Use: "usage"} + r.Flags().String("flag", "", "this describes the usage of the 'flag' flag") + + sub1 := &Command{ + Use: "sub1", + Short: "short describes 'sub1'", + } + r.AddCommand(sub1) + + return r + }(), + expectedExpressions: []string{ + "[CompletionResult]::new('--flag', 'flag', [CompletionResultType]::ParameterName, 'this describes the usage of the ''flag'' flag')", + "[CompletionResult]::new('sub1', 'sub1', [CompletionResultType]::ParameterValue, 'short describes ''sub1''')", + }, + }, + } + + for _, tc := range tcs { + t.Run(tc.name, func(t *testing.T) { + buf := new(bytes.Buffer) + tc.root.GenPowerShellCompletion(buf) + output := buf.String() + + for _, expectedExpression := range tc.expectedExpressions { + if !strings.Contains(output, expectedExpression) { + t.Errorf("Expected completion to contain %q somewhere; got %q", expectedExpression, output) + } + } + }) + } +}