Skip to content

Commit

Permalink
feat: added GitHub autocompletion (#84)
Browse files Browse the repository at this point in the history
* feat: added github autocompletion

* test: added autocompletion tests
  • Loading branch information
lindell authored Mar 30, 2021
1 parent 1884847 commit 5fee0c4
Show file tree
Hide file tree
Showing 9 changed files with 218 additions and 24 deletions.
4 changes: 2 additions & 2 deletions cmd/close.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ func CloseCmd() *cobra.Command {
}

cmd.Flags().StringP("branch", "B", "multi-gitter-branch", "The name of the branch where changes are committed.")
cmd.Flags().AddFlagSet(platformFlags())
configurePlatform(cmd)
cmd.Flags().AddFlagSet(logFlags("-"))

return cmd
Expand All @@ -30,7 +30,7 @@ func close(cmd *cobra.Command, args []string) error {

branchName, _ := flag.GetString("branch")

vc, err := getVersionController(flag)
vc, err := getVersionController(flag, true)
if err != nil {
return err
}
Expand Down
4 changes: 2 additions & 2 deletions cmd/merge.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ func MergeCmd() *cobra.Command {

cmd.Flags().StringP("branch", "B", "multi-gitter-branch", "The name of the branch where changes are committed.")
cmd.Flags().StringSliceP("merge-type", "", []string{"merge", "squash", "rebase"}, "The type of merge that should be done (GitHub). Multiple types can be used as backup strategies if the first one is not allowed.")
cmd.Flags().AddFlagSet(platformFlags())
configurePlatform(cmd)
cmd.Flags().AddFlagSet(logFlags("-"))

return cmd
Expand All @@ -31,7 +31,7 @@ func merge(cmd *cobra.Command, args []string) error {

branchName, _ := flag.GetString("branch")

vc, err := getVersionController(flag)
vc, err := getVersionController(flag, true)
if err != nil {
return err
}
Expand Down
4 changes: 2 additions & 2 deletions cmd/print.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ func PrintCmd() *cobra.Command {
cmd.Flags().IntP("concurrent", "C", 1, "The maximum number of concurrent runs")
cmd.Flags().IntP("fetch-depth", "f", 1, "Limit fetching to the specified number of commits. Set to 0 for no limit")
cmd.Flags().StringP("error-output", "E", "-", `The file that the output of the script should be outputted to. "-" means stderr`)
cmd.Flags().AddFlagSet(platformFlags())
configurePlatform(cmd)
cmd.Flags().AddFlagSet(logFlags(""))
cmd.Flags().AddFlagSet(outputFlag())

Expand Down Expand Up @@ -76,7 +76,7 @@ func print(cmd *cobra.Command, args []string) error {
return errors.New("could not get the working directory")
}

vc, err := getVersionController(flag)
vc, err := getVersionController(flag, true)
if err != nil {
return err
}
Expand Down
97 changes: 87 additions & 10 deletions cmd/root.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package cmd

import (
"context"
"fmt"
"io"
"math/rand"
Expand Down Expand Up @@ -41,8 +42,8 @@ func init() {
rand.Seed(time.Now().UTC().UnixNano())
}

func platformFlags() *flag.FlagSet {
flags := flag.NewFlagSet("platform", flag.ExitOnError)
func configurePlatform(cmd *cobra.Command) {
flags := cmd.Flags()

flags.StringP("base-url", "g", "", "Base URL of the (v3) GitHub API, needs to be changed if GitHub enterprise is used. Or the url to a self-hosted GitLab instance.")
flags.StringP("token", "T", "", "The GitHub/GitLab personal access token. Can also be set using the GITHUB_TOKEN/GITLAB_TOKEN environment variable.")
Expand All @@ -55,7 +56,77 @@ func platformFlags() *flag.FlagSet {

flags.StringP("platform", "p", "github", "The platform that is used. Available values: github, gitlab")

return flags
// Autocompletion for organizations
_ = cmd.RegisterFlagCompletionFunc("org", func(cmd *cobra.Command, _ []string, toComplete string) ([]string, cobra.ShellCompDirective) {
vc, err := getVersionController(cmd.Flags(), false)
if err != nil {
return nil, cobra.ShellCompDirectiveError
}

type getOrger interface {
GetAutocompleteOrganizations(ctx context.Context, _ string) ([]string, error)
}

g, ok := vc.(getOrger)
if !ok {
return nil, cobra.ShellCompDirectiveError
}

orgs, err := g.GetAutocompleteOrganizations(cmd.Root().Context(), toComplete)
if err != nil {
return nil, cobra.ShellCompDirectiveError
}

return orgs, cobra.ShellCompDirectiveDefault
})

// Autocompletion for users
_ = cmd.RegisterFlagCompletionFunc("user", func(cmd *cobra.Command, _ []string, toComplete string) ([]string, cobra.ShellCompDirective) {
vc, err := getVersionController(cmd.Flags(), false)
if err != nil {
return nil, cobra.ShellCompDirectiveError
}

type getUserser interface {
GetAutocompleteUsers(ctx context.Context, _ string) ([]string, error)
}

g, ok := vc.(getUserser)
if !ok {
return nil, cobra.ShellCompDirectiveError
}

users, err := g.GetAutocompleteUsers(cmd.Root().Context(), toComplete)
if err != nil {
return nil, cobra.ShellCompDirectiveError
}

return users, cobra.ShellCompDirectiveDefault
})

// Autocompletion for repositories
_ = cmd.RegisterFlagCompletionFunc("repo", func(cmd *cobra.Command, _ []string, toComplete string) ([]string, cobra.ShellCompDirective) {
vc, err := getVersionController(cmd.Flags(), false)
if err != nil {
return nil, cobra.ShellCompDirectiveError
}

type getRepositorieser interface {
GetAutocompleteRepositories(ctx context.Context, _ string) ([]string, error)
}

g, ok := vc.(getRepositorieser)
if !ok {
return nil, cobra.ShellCompDirectiveError
}

users, err := g.GetAutocompleteRepositories(cmd.Root().Context(), toComplete)
if err != nil {
return nil, cobra.ShellCompDirectiveError
}

return users, cobra.ShellCompDirectiveDefault
})
}

func logFlags(logFile string) *flag.FlagSet {
Expand Down Expand Up @@ -119,7 +190,9 @@ func outputFlag() *flag.FlagSet {
// This is used to override the version controller with a mock, to be used during testing
var OverrideVersionController multigitter.VersionController = nil

func getVersionController(flag *flag.FlagSet) (multigitter.VersionController, error) {
// getVersionController gets the complete version controller
// the verifyFlags parameter can be set to false if a complete vc is not required (during autocompletion)
func getVersionController(flag *flag.FlagSet, verifyFlags bool) (multigitter.VersionController, error) {
if OverrideVersionController != nil {
return OverrideVersionController, nil
}
Expand All @@ -129,21 +202,21 @@ func getVersionController(flag *flag.FlagSet) (multigitter.VersionController, er
default:
return nil, fmt.Errorf("unknown platform: %s", platform)
case "github":
return createGithubClient(flag)
return createGithubClient(flag, verifyFlags)
case "gitlab":
return createGitlabClient(flag)
return createGitlabClient(flag, verifyFlags)
}
}

func createGithubClient(flag *flag.FlagSet) (multigitter.VersionController, error) {
func createGithubClient(flag *flag.FlagSet, verifyFlags bool) (multigitter.VersionController, error) {
gitBaseURL, _ := flag.GetString("base-url")
orgs, _ := flag.GetStringSlice("org")
users, _ := flag.GetStringSlice("user")
repos, _ := flag.GetStringSlice("repo")
mergeTypeStrs, _ := flag.GetStringSlice("merge-type") // Only used for the merge command

if len(orgs) == 0 && len(users) == 0 && len(repos) == 0 {
return nil, errors.New("no organization or user set")
if verifyFlags && len(orgs) == 0 && len(users) == 0 && len(repos) == 0 {
return nil, errors.New("no organization, user or repo set")
}

token, err := getToken(flag)
Expand Down Expand Up @@ -180,12 +253,16 @@ func createGithubClient(flag *flag.FlagSet) (multigitter.VersionController, erro
return vc, nil
}

func createGitlabClient(flag *flag.FlagSet) (multigitter.VersionController, error) {
func createGitlabClient(flag *flag.FlagSet, verifyFlags bool) (multigitter.VersionController, error) {
gitBaseURL, _ := flag.GetString("base-url")
groups, _ := flag.GetStringSlice("group")
users, _ := flag.GetStringSlice("user")
projects, _ := flag.GetStringSlice("project")

if verifyFlags && len(groups) == 0 && len(users) == 0 && len(projects) == 0 {
return nil, errors.New("no group user or project set")
}

token, err := getToken(flag)
if err != nil {
return nil, err
Expand Down
4 changes: 2 additions & 2 deletions cmd/run.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ func RunCmd() *cobra.Command {
cmd.Flags().BoolP("dry-run", "d", false, "Run without pushing changes or creating pull requests")
cmd.Flags().StringP("author-name", "", "", "Name of the committer. If not set, the global git config setting will be used.")
cmd.Flags().StringP("author-email", "", "", "Email of the committer. If not set, the global git config setting will be used.")
cmd.Flags().AddFlagSet(platformFlags())
configurePlatform(cmd)
cmd.Flags().AddFlagSet(logFlags("-"))
cmd.Flags().AddFlagSet(outputFlag())

Expand Down Expand Up @@ -121,7 +121,7 @@ func run(cmd *cobra.Command, args []string) error {
return errors.New("could not get the working directory")
}

vc, err := getVersionController(flag)
vc, err := getVersionController(flag, true)
if err != nil {
return err
}
Expand Down
4 changes: 2 additions & 2 deletions cmd/status.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ func StatusCmd() *cobra.Command {
}

cmd.Flags().StringP("branch", "B", "multi-gitter-branch", "The name of the branch where changes are committed.")
cmd.Flags().AddFlagSet(platformFlags())
configurePlatform(cmd)
cmd.Flags().AddFlagSet(logFlags("-"))
cmd.Flags().AddFlagSet(outputFlag())

Expand All @@ -33,7 +33,7 @@ func status(cmd *cobra.Command, args []string) error {
branchName, _ := flag.GetString("branch")
strOutput, _ := flag.GetString("output")

vc, err := getVersionController(flag)
vc, err := getVersionController(flag, true)
if err != nil {
return err
}
Expand Down
58 changes: 58 additions & 0 deletions internal/github/github.go
Original file line number Diff line number Diff line change
Expand Up @@ -405,3 +405,61 @@ func (g Github) ClosePullRequest(ctx context.Context, pullReq domain.PullRequest
_, err = g.ghClient.Git.DeleteRef(ctx, pr.ownerName, pr.repoName, fmt.Sprintf("heads/%s", pr.branchName))
return err
}

// GetAutocompleteOrganizations gets organizations for autocompletion
func (g Github) GetAutocompleteOrganizations(ctx context.Context, _ string) ([]string, error) {
orgs, _, err := g.ghClient.Organizations.List(ctx, "", nil)
if err != nil {
return nil, err
}

ret := make([]string, len(orgs))
for i, org := range orgs {
ret[i] = org.GetLogin()
}

return ret, nil
}

// GetAutocompleteUsers gets users for autocompletion
func (g Github) GetAutocompleteUsers(ctx context.Context, str string) ([]string, error) {
users, _, err := g.ghClient.Search.Users(ctx, str, nil)
if err != nil {
return nil, err
}

ret := make([]string, len(users.Users))
for i, user := range users.Users {
ret[i] = user.GetLogin()
}

return ret, nil
}

// GetAutocompleteRepositories gets repositories for autocompletion
func (g Github) GetAutocompleteRepositories(ctx context.Context, str string) ([]string, error) {
var q string

// If the user has already provided a org/user, it's much more effective to search based on that
// comparared to a complete freetext search
splitted := strings.SplitN(str, "/", 2)
switch {
case len(splitted) == 2:
// Search set the user or org (user/org in the search can be used interchangeable)
q = fmt.Sprintf("user:%s %s in:name", splitted[0], splitted[1])
default:
q = fmt.Sprintf("%s in:name", str)
}

repos, _, err := g.ghClient.Search.Repositories(ctx, q, nil)
if err != nil {
return nil, err
}

ret := make([]string, len(repos.Repositories))
for i, repositories := range repos.Repositories {
ret[i] = repositories.GetFullName()
}

return ret, nil
}
52 changes: 48 additions & 4 deletions tests/table_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package tests

import (
"bytes"
"fmt"
"io/ioutil"
"os"
Expand All @@ -17,6 +18,7 @@ import (
type runData struct {
out string
logOut string
cmdOut string
took time.Duration
}

Expand Down Expand Up @@ -323,6 +325,42 @@ Repositories with a successful run:
assert.Equal(t, "i like bananas", readTestFile(t, vcMock.Repositories[0].Path))
},
},

{
name: "autocomplete org",
vc: &vcmock.VersionController{},
args: []string{
"__complete", "run",
"--org", "dynamic-org",
},
verify: func(t *testing.T, vcMock *vcmock.VersionController, runData runData) {
assert.Equal(t, "static-org\ndynamic-org\n:0\nCompletion ended with directive: ShellCompDirectiveDefault\n", runData.cmdOut)
},
},

{
name: "autocomplete user",
vc: &vcmock.VersionController{},
args: []string{
"__complete", "run",
"--user", "dynamic-user",
},
verify: func(t *testing.T, vcMock *vcmock.VersionController, runData runData) {
assert.Equal(t, "static-user\ndynamic-user\n:0\nCompletion ended with directive: ShellCompDirectiveDefault\n", runData.cmdOut)
},
},

{
name: "autocomplete repo",
vc: &vcmock.VersionController{},
args: []string{
"__complete", "run",
"--repo", "dynamic-repo",
},
verify: func(t *testing.T, vcMock *vcmock.VersionController, runData runData) {
assert.Equal(t, "static-repo\ndynamic-repo\n:0\nCompletion ended with directive: ShellCompDirectiveDefault\n", runData.cmdOut)
},
},
}

for _, test := range tests {
Expand All @@ -343,12 +381,17 @@ Repositories with a successful run:
}
cmd.OverrideVersionController = vc

command := cmd.RootCmd()
command.SetArgs(append(
test.args,
cobraBuf := &bytes.Buffer{}

staticArgs := []string{
"--log-file", logFile.Name(),
"--output", outFile.Name(),
))
}

command := cmd.RootCmd()
command.SetOut(cobraBuf)
command.SetErr(cobraBuf)
command.SetArgs(append(staticArgs, test.args...))
before := time.Now()
err = command.Execute()
took := time.Since(before)
Expand All @@ -367,6 +410,7 @@ Repositories with a successful run:
test.verify(t, vc, runData{
logOut: string(logData),
out: string(outData),
cmdOut: cobraBuf.String(),
took: took,
})
})
Expand Down
Loading

0 comments on commit 5fee0c4

Please sign in to comment.