diff --git a/pkg/cmd/github-pull-request-make/main.go b/pkg/cmd/github-pull-request-make/main.go index dffabb1e5b46..5552dc0440c5 100644 --- a/pkg/cmd/github-pull-request-make/main.go +++ b/pkg/cmd/github-pull-request-make/main.go @@ -31,6 +31,7 @@ import ( "os/exec" "path/filepath" "regexp" + "strconv" "strings" "time" @@ -40,9 +41,16 @@ import ( "golang.org/x/oauth2" ) -const githubAPITokenEnv = "GITHUB_API_TOKEN" -const teamcityVCSNumberEnv = "BUILD_VCS_NUMBER" -const targetEnv = "TARGET" +const ( + githubAPITokenEnv = "GITHUB_API_TOKEN" + teamcityVCSNumberEnv = "BUILD_VCS_NUMBER" + targetEnv = "TARGET" + // The following environment variables are for testing and are + // prefixed with GHM_ to help prevent accidentally triggering + // test code inside the CI pipeline. + packageEnv = "GHM_PACKAGES" + forceBazelEnv = "GHM_FORCE_BAZEL" +) // https://github.com/golang/go/blob/go1.7.3/src/cmd/go/test.go#L1260:L1262 // @@ -59,6 +67,24 @@ type pkg struct { tests []string } +func pkgsFromGithubPRForSHA( + ctx context.Context, org string, repo string, sha string, +) (map[string]pkg, error) { + client := ghClient(ctx) + currentPull := findPullRequest(ctx, client, org, repo, sha) + if currentPull == nil { + log.Printf("SHA %s not found in open pull requests, skipping stress", sha) + return nil, nil + } + + diff, err := getDiff(ctx, client, org, repo, *currentPull.Number) + if err != nil { + return nil, err + } + + return pkgsFromDiff(strings.NewReader(diff)) +} + // pkgsFromDiff parses a git-style diff and returns a mapping from directories // to tests added in those directories in the given diff. func pkgsFromDiff(r io.Reader) (map[string]pkg, error) { @@ -160,6 +186,25 @@ func getDiff( return diff, err } +func parsePackagesFromEnvironment(input string) (map[string]pkg, error) { + const expectedFormat = "PACKAGE_NAME=TEST_NAME[,TEST_NAME...][;PACKAGE_NAME=...]" + pkgTestStrs := strings.Split(input, ";") + pkgs := make(map[string]pkg, len(pkgTestStrs)) + for _, pts := range pkgTestStrs { + ptsParts := strings.Split(pts, "=") + if len(ptsParts) < 2 { + return nil, fmt.Errorf("invalid format for package environment variable: %q (expected format: %s)", + input, expectedFormat) + } + pkgName := ptsParts[0] + tests := ptsParts[1] + pkgs[pkgName] = pkg{ + tests: strings.Split(tests, ","), + } + } + return pkgs, nil +} + func main() { sha, ok := os.LookupEnv(teamcityVCSNumberEnv) if !ok { @@ -174,32 +219,34 @@ func main() { log.Fatalf("environment variable %s is %s; expected 'stress' or 'stressrace'", targetEnv, target) } - const org = "cockroachdb" - const repo = "cockroach" + forceBazel := false + if forceBazelStr, ok := os.LookupEnv(forceBazelEnv); ok { + forceBazel, _ = strconv.ParseBool(forceBazelStr) + } crdb, err := os.Getwd() if err != nil { log.Fatal(err) } - ctx := context.Background() - client := ghClient(ctx) - - currentPull := findPullRequest(ctx, client, org, repo, sha) - if currentPull == nil { - log.Printf("SHA %s not found in open pull requests, skipping stress", sha) - return - } + var pkgs map[string]pkg + if pkgStr, ok := os.LookupEnv(packageEnv); ok { + log.Printf("Using packages from environment variable %s", packageEnv) + pkgs, err = parsePackagesFromEnvironment(pkgStr) + if err != nil { + log.Fatal(err) + } - diff, err := getDiff(ctx, client, org, repo, *currentPull.Number) - if err != nil { - log.Fatal(err) + } else { + ctx := context.Background() + const org = "cockroachdb" + const repo = "cockroach" + pkgs, err = pkgsFromGithubPRForSHA(ctx, org, repo, sha) + if err != nil { + log.Fatal(err) + } } - pkgs, err := pkgsFromDiff(strings.NewReader(diff)) - if err != nil { - log.Fatal(err) - } if len(pkgs) > 0 { for name, pkg := range pkgs { // 20 minutes total seems OK, but at least 2 minutes per test. @@ -224,7 +271,7 @@ func main() { } var args []string - if bazel.BuiltWithBazel() { + if bazel.BuiltWithBazel() || forceBazel { args = append(args, "test") // NB: We use a pretty dumb technique to list the bazel test // targets: we ask bazel query to enumerate all the tests in this