diff --git a/cmd/close.go b/cmd/close.go index 9782a1ed..c4107775 100644 --- a/cmd/close.go +++ b/cmd/close.go @@ -19,7 +19,7 @@ func CloseCmd() *cobra.Command { } cmd.Flags().StringP("branch", "B", "multi-gitter-branch", "The name of the branch where changes are committed.") - cmd.Flags().AddFlagSet(platformFlags()) + configurePlatform(cmd) cmd.Flags().AddFlagSet(logFlags("-")) return cmd @@ -30,7 +30,7 @@ func close(cmd *cobra.Command, args []string) error { branchName, _ := flag.GetString("branch") - vc, err := getVersionController(flag) + vc, err := getVersionController(flag, true) if err != nil { return err } diff --git a/cmd/merge.go b/cmd/merge.go index dbf3634b..bae535f5 100644 --- a/cmd/merge.go +++ b/cmd/merge.go @@ -20,7 +20,7 @@ func MergeCmd() *cobra.Command { cmd.Flags().StringP("branch", "B", "multi-gitter-branch", "The name of the branch where changes are committed.") cmd.Flags().StringSliceP("merge-type", "", []string{"merge", "squash", "rebase"}, "The type of merge that should be done (GitHub). Multiple types can be used as backup strategies if the first one is not allowed.") - cmd.Flags().AddFlagSet(platformFlags()) + configurePlatform(cmd) cmd.Flags().AddFlagSet(logFlags("-")) return cmd @@ -31,7 +31,7 @@ func merge(cmd *cobra.Command, args []string) error { branchName, _ := flag.GetString("branch") - vc, err := getVersionController(flag) + vc, err := getVersionController(flag, true) if err != nil { return err } diff --git a/cmd/print.go b/cmd/print.go index 6bf14b97..4d3715cd 100755 --- a/cmd/print.go +++ b/cmd/print.go @@ -35,7 +35,7 @@ func PrintCmd() *cobra.Command { cmd.Flags().IntP("concurrent", "C", 1, "The maximum number of concurrent runs") cmd.Flags().IntP("fetch-depth", "f", 1, "Limit fetching to the specified number of commits. Set to 0 for no limit") cmd.Flags().StringP("error-output", "E", "-", `The file that the output of the script should be outputted to. "-" means stderr`) - cmd.Flags().AddFlagSet(platformFlags()) + configurePlatform(cmd) cmd.Flags().AddFlagSet(logFlags("")) cmd.Flags().AddFlagSet(outputFlag()) @@ -76,7 +76,7 @@ func print(cmd *cobra.Command, args []string) error { return errors.New("could not get the working directory") } - vc, err := getVersionController(flag) + vc, err := getVersionController(flag, true) if err != nil { return err } diff --git a/cmd/root.go b/cmd/root.go index d3620829..ce1de2d1 100755 --- a/cmd/root.go +++ b/cmd/root.go @@ -1,6 +1,7 @@ package cmd import ( + "context" "fmt" "io" "math/rand" @@ -41,8 +42,8 @@ func init() { rand.Seed(time.Now().UTC().UnixNano()) } -func platformFlags() *flag.FlagSet { - flags := flag.NewFlagSet("platform", flag.ExitOnError) +func configurePlatform(cmd *cobra.Command) { + flags := cmd.Flags() flags.StringP("base-url", "g", "", "Base URL of the (v3) GitHub API, needs to be changed if GitHub enterprise is used. Or the url to a self-hosted GitLab instance.") flags.StringP("token", "T", "", "The GitHub/GitLab personal access token. Can also be set using the GITHUB_TOKEN/GITLAB_TOKEN environment variable.") @@ -55,7 +56,77 @@ func platformFlags() *flag.FlagSet { flags.StringP("platform", "p", "github", "The platform that is used. Available values: github, gitlab") - return flags + // Autocompletion for organizations + _ = cmd.RegisterFlagCompletionFunc("org", func(cmd *cobra.Command, _ []string, toComplete string) ([]string, cobra.ShellCompDirective) { + vc, err := getVersionController(cmd.Flags(), false) + if err != nil { + return nil, cobra.ShellCompDirectiveError + } + + type getOrger interface { + GetAutocompleteOrganizations(ctx context.Context, _ string) ([]string, error) + } + + g, ok := vc.(getOrger) + if !ok { + return nil, cobra.ShellCompDirectiveError + } + + orgs, err := g.GetAutocompleteOrganizations(cmd.Root().Context(), toComplete) + if err != nil { + return nil, cobra.ShellCompDirectiveError + } + + return orgs, cobra.ShellCompDirectiveDefault + }) + + // Autocompletion for users + _ = cmd.RegisterFlagCompletionFunc("user", func(cmd *cobra.Command, _ []string, toComplete string) ([]string, cobra.ShellCompDirective) { + vc, err := getVersionController(cmd.Flags(), false) + if err != nil { + return nil, cobra.ShellCompDirectiveError + } + + type getUserser interface { + GetAutocompleteUsers(ctx context.Context, _ string) ([]string, error) + } + + g, ok := vc.(getUserser) + if !ok { + return nil, cobra.ShellCompDirectiveError + } + + users, err := g.GetAutocompleteUsers(cmd.Root().Context(), toComplete) + if err != nil { + return nil, cobra.ShellCompDirectiveError + } + + return users, cobra.ShellCompDirectiveDefault + }) + + // Autocompletion for repositories + _ = cmd.RegisterFlagCompletionFunc("repo", func(cmd *cobra.Command, _ []string, toComplete string) ([]string, cobra.ShellCompDirective) { + vc, err := getVersionController(cmd.Flags(), false) + if err != nil { + return nil, cobra.ShellCompDirectiveError + } + + type getRepositorieser interface { + GetAutocompleteRepositories(ctx context.Context, _ string) ([]string, error) + } + + g, ok := vc.(getRepositorieser) + if !ok { + return nil, cobra.ShellCompDirectiveError + } + + users, err := g.GetAutocompleteRepositories(cmd.Root().Context(), toComplete) + if err != nil { + return nil, cobra.ShellCompDirectiveError + } + + return users, cobra.ShellCompDirectiveDefault + }) } func logFlags(logFile string) *flag.FlagSet { @@ -119,7 +190,9 @@ func outputFlag() *flag.FlagSet { // This is used to override the version controller with a mock, to be used during testing var OverrideVersionController multigitter.VersionController = nil -func getVersionController(flag *flag.FlagSet) (multigitter.VersionController, error) { +// getVersionController gets the complete version controller +// the verifyFlags parameter can be set to false if a complete vc is not required (during autocompletion) +func getVersionController(flag *flag.FlagSet, verifyFlags bool) (multigitter.VersionController, error) { if OverrideVersionController != nil { return OverrideVersionController, nil } @@ -129,21 +202,21 @@ func getVersionController(flag *flag.FlagSet) (multigitter.VersionController, er default: return nil, fmt.Errorf("unknown platform: %s", platform) case "github": - return createGithubClient(flag) + return createGithubClient(flag, verifyFlags) case "gitlab": - return createGitlabClient(flag) + return createGitlabClient(flag, verifyFlags) } } -func createGithubClient(flag *flag.FlagSet) (multigitter.VersionController, error) { +func createGithubClient(flag *flag.FlagSet, verifyFlags bool) (multigitter.VersionController, error) { gitBaseURL, _ := flag.GetString("base-url") orgs, _ := flag.GetStringSlice("org") users, _ := flag.GetStringSlice("user") repos, _ := flag.GetStringSlice("repo") mergeTypeStrs, _ := flag.GetStringSlice("merge-type") // Only used for the merge command - if len(orgs) == 0 && len(users) == 0 && len(repos) == 0 { - return nil, errors.New("no organization or user set") + if verifyFlags && len(orgs) == 0 && len(users) == 0 && len(repos) == 0 { + return nil, errors.New("no organization, user or repo set") } token, err := getToken(flag) @@ -180,12 +253,16 @@ func createGithubClient(flag *flag.FlagSet) (multigitter.VersionController, erro return vc, nil } -func createGitlabClient(flag *flag.FlagSet) (multigitter.VersionController, error) { +func createGitlabClient(flag *flag.FlagSet, verifyFlags bool) (multigitter.VersionController, error) { gitBaseURL, _ := flag.GetString("base-url") groups, _ := flag.GetStringSlice("group") users, _ := flag.GetStringSlice("user") projects, _ := flag.GetStringSlice("project") + if verifyFlags && len(groups) == 0 && len(users) == 0 && len(projects) == 0 { + return nil, errors.New("no group user or project set") + } + token, err := getToken(flag) if err != nil { return nil, err diff --git a/cmd/run.go b/cmd/run.go index 57b0a992..a8afeea4 100755 --- a/cmd/run.go +++ b/cmd/run.go @@ -47,7 +47,7 @@ func RunCmd() *cobra.Command { cmd.Flags().BoolP("dry-run", "d", false, "Run without pushing changes or creating pull requests") cmd.Flags().StringP("author-name", "", "", "Name of the committer. If not set, the global git config setting will be used.") cmd.Flags().StringP("author-email", "", "", "Email of the committer. If not set, the global git config setting will be used.") - cmd.Flags().AddFlagSet(platformFlags()) + configurePlatform(cmd) cmd.Flags().AddFlagSet(logFlags("-")) cmd.Flags().AddFlagSet(outputFlag()) @@ -121,7 +121,7 @@ func run(cmd *cobra.Command, args []string) error { return errors.New("could not get the working directory") } - vc, err := getVersionController(flag) + vc, err := getVersionController(flag, true) if err != nil { return err } diff --git a/cmd/status.go b/cmd/status.go index 93063d1f..281a3f2c 100644 --- a/cmd/status.go +++ b/cmd/status.go @@ -20,7 +20,7 @@ func StatusCmd() *cobra.Command { } cmd.Flags().StringP("branch", "B", "multi-gitter-branch", "The name of the branch where changes are committed.") - cmd.Flags().AddFlagSet(platformFlags()) + configurePlatform(cmd) cmd.Flags().AddFlagSet(logFlags("-")) cmd.Flags().AddFlagSet(outputFlag()) @@ -33,7 +33,7 @@ func status(cmd *cobra.Command, args []string) error { branchName, _ := flag.GetString("branch") strOutput, _ := flag.GetString("output") - vc, err := getVersionController(flag) + vc, err := getVersionController(flag, true) if err != nil { return err } diff --git a/internal/github/github.go b/internal/github/github.go index 9e3cca46..f089fb23 100755 --- a/internal/github/github.go +++ b/internal/github/github.go @@ -405,3 +405,61 @@ func (g Github) ClosePullRequest(ctx context.Context, pullReq domain.PullRequest _, err = g.ghClient.Git.DeleteRef(ctx, pr.ownerName, pr.repoName, fmt.Sprintf("heads/%s", pr.branchName)) return err } + +// GetAutocompleteOrganizations gets organizations for autocompletion +func (g Github) GetAutocompleteOrganizations(ctx context.Context, _ string) ([]string, error) { + orgs, _, err := g.ghClient.Organizations.List(ctx, "", nil) + if err != nil { + return nil, err + } + + ret := make([]string, len(orgs)) + for i, org := range orgs { + ret[i] = org.GetLogin() + } + + return ret, nil +} + +// GetAutocompleteUsers gets users for autocompletion +func (g Github) GetAutocompleteUsers(ctx context.Context, str string) ([]string, error) { + users, _, err := g.ghClient.Search.Users(ctx, str, nil) + if err != nil { + return nil, err + } + + ret := make([]string, len(users.Users)) + for i, user := range users.Users { + ret[i] = user.GetLogin() + } + + return ret, nil +} + +// GetAutocompleteRepositories gets repositories for autocompletion +func (g Github) GetAutocompleteRepositories(ctx context.Context, str string) ([]string, error) { + var q string + + // If the user has already provided a org/user, it's much more effective to search based on that + // comparared to a complete freetext search + splitted := strings.SplitN(str, "/", 2) + switch { + case len(splitted) == 2: + // Search set the user or org (user/org in the search can be used interchangeable) + q = fmt.Sprintf("user:%s %s in:name", splitted[0], splitted[1]) + default: + q = fmt.Sprintf("%s in:name", str) + } + + repos, _, err := g.ghClient.Search.Repositories(ctx, q, nil) + if err != nil { + return nil, err + } + + ret := make([]string, len(repos.Repositories)) + for i, repositories := range repos.Repositories { + ret[i] = repositories.GetFullName() + } + + return ret, nil +} diff --git a/tests/table_test.go b/tests/table_test.go index fc7982ad..75b914c5 100644 --- a/tests/table_test.go +++ b/tests/table_test.go @@ -1,6 +1,7 @@ package tests import ( + "bytes" "fmt" "io/ioutil" "os" @@ -17,6 +18,7 @@ import ( type runData struct { out string logOut string + cmdOut string took time.Duration } @@ -323,6 +325,42 @@ Repositories with a successful run: assert.Equal(t, "i like bananas", readTestFile(t, vcMock.Repositories[0].Path)) }, }, + + { + name: "autocomplete org", + vc: &vcmock.VersionController{}, + args: []string{ + "__complete", "run", + "--org", "dynamic-org", + }, + verify: func(t *testing.T, vcMock *vcmock.VersionController, runData runData) { + assert.Equal(t, "static-org\ndynamic-org\n:0\nCompletion ended with directive: ShellCompDirectiveDefault\n", runData.cmdOut) + }, + }, + + { + name: "autocomplete user", + vc: &vcmock.VersionController{}, + args: []string{ + "__complete", "run", + "--user", "dynamic-user", + }, + verify: func(t *testing.T, vcMock *vcmock.VersionController, runData runData) { + assert.Equal(t, "static-user\ndynamic-user\n:0\nCompletion ended with directive: ShellCompDirectiveDefault\n", runData.cmdOut) + }, + }, + + { + name: "autocomplete repo", + vc: &vcmock.VersionController{}, + args: []string{ + "__complete", "run", + "--repo", "dynamic-repo", + }, + verify: func(t *testing.T, vcMock *vcmock.VersionController, runData runData) { + assert.Equal(t, "static-repo\ndynamic-repo\n:0\nCompletion ended with directive: ShellCompDirectiveDefault\n", runData.cmdOut) + }, + }, } for _, test := range tests { @@ -343,12 +381,17 @@ Repositories with a successful run: } cmd.OverrideVersionController = vc - command := cmd.RootCmd() - command.SetArgs(append( - test.args, + cobraBuf := &bytes.Buffer{} + + staticArgs := []string{ "--log-file", logFile.Name(), "--output", outFile.Name(), - )) + } + + command := cmd.RootCmd() + command.SetOut(cobraBuf) + command.SetErr(cobraBuf) + command.SetArgs(append(staticArgs, test.args...)) before := time.Now() err = command.Execute() took := time.Since(before) @@ -367,6 +410,7 @@ Repositories with a successful run: test.verify(t, vc, runData{ logOut: string(logData), out: string(outData), + cmdOut: cobraBuf.String(), took: took, }) }) diff --git a/tests/vcmock/vcmock.go b/tests/vcmock/vcmock.go index a8542c69..09ca6e36 100644 --- a/tests/vcmock/vcmock.go +++ b/tests/vcmock/vcmock.go @@ -91,6 +91,21 @@ func (vc *VersionController) SetPRStatus(repoName string, branchName string, new } } +// GetAutocompleteOrganizations gets organizations for autocompletion +func (vc *VersionController) GetAutocompleteOrganizations(ctx context.Context, str string) ([]string, error) { + return []string{"static-org", str}, nil +} + +// GetAutocompleteUsers gets users for autocompletion +func (vc *VersionController) GetAutocompleteUsers(ctx context.Context, str string) ([]string, error) { + return []string{"static-user", str}, nil +} + +// GetAutocompleteRepositories gets repositories for autocompletion +func (vc *VersionController) GetAutocompleteRepositories(ctx context.Context, str string) ([]string, error) { + return []string{"static-repo", str}, nil +} + // PullRequest is a mock pr type PullRequest struct { PRStatus domain.PullRequestStatus