diff --git a/pkg/cmd/roachprod/main.go b/pkg/cmd/roachprod/main.go index beaa1479d236..65c0d26b2dae 100644 --- a/pkg/cmd/roachprod/main.go +++ b/pkg/cmd/roachprod/main.go @@ -1186,7 +1186,7 @@ var updateCmd = &cobra.Command{ if revertUpdate { if upgrade.PromptYesNo("Revert to previous version? Note: this will replace the" + - " current roachprod binary with a previous roachprod.bak binary") { + " current roachprod binary with a previous roachprod.bak binary.") { if err := upgrade.SwapBinary(currentBinary, currentBinary+".bak"); err != nil { return err } @@ -1195,18 +1195,18 @@ var updateCmd = &cobra.Command{ return nil } - newBinary, err := upgrade.DownloadLatestRoadprod() - if err != nil { + newBinary := currentBinary + ".new" + if err := upgrade.DownloadLatestRoadprod(newBinary); err != nil { return err } - if upgrade.PromptYesNo("Download successful. Continue with update? Note: this will " + - "overwrite any existing roachprod.bak binary") { + if upgrade.PromptYesNo("Download successful. Continue with update? Note: " + + "This will overwrite any existing roachprod.bak binary.") { if err := upgrade.SwapBinary(currentBinary, newBinary); err != nil { return errors.WithDetail(err, "unable to update binary") } - fmt.Println("roachprod successfully updated, run `roachprod -v` to confirm.") + fmt.Println("Update successful: run `roachprod -v` to confirm.") } return nil }), diff --git a/pkg/cmd/roachprod/upgrade/teamcity.go b/pkg/cmd/roachprod/upgrade/teamcity.go index 5750f55c1a4c..c4e81b70beb0 100644 --- a/pkg/cmd/roachprod/upgrade/teamcity.go +++ b/pkg/cmd/roachprod/upgrade/teamcity.go @@ -17,6 +17,7 @@ import ( "net/http" "os" "runtime" + "time" "github.com/cockroachdb/cockroach/pkg/util/httputil" ) @@ -33,65 +34,62 @@ var ( apiBase = "https://teamcity.cockroachdb.com/guestAuth/app/rest" ) -func downloadURL(buildId int32) string { - url := fmt.Sprintf("%s%s", - apiBase, - fmt.Sprintf("/builds/id:%v/artifacts/content/bazel-bin/pkg/cmd/roachprod/roachprod_/roachprod", buildId), - ) - fmt.Println(url) - return url -} - // DownloadLatestRoadprod attempts to download the latest binary to the // current binary's directory. It returns the path to the downloaded binary. -func DownloadLatestRoadprod() (string, error) { +// toFile is the path to the file to download to. +func DownloadLatestRoadprod(toFile string) error { if buildType == "" { - panic("unable to find build type for this platform") + fmt.Println("Supported platforms:") + for k := range buildIDs { + fmt.Printf("\t%s\n", k) + } + return fmt.Errorf("unable to find build type for this platform") } - builds, err := GetBuilds("count:1,status:SUCCESS,branch:master,buildType:" + buildType) + // Build are sorted by build date desc, so limiting to 1 will get the latest. + builds, err := getBuilds("count:1,status:SUCCESS,branch:master,buildType:" + buildType) if err != nil { - return "", err + return err } if len(builds.Build) == 0 { - return "", fmt.Errorf("no builds found") - } - - currentRoachprod, err := os.Executable() - if err != nil { - return "", err + return fmt.Errorf("no builds found") } - newRoachprod := currentRoachprod + ".new" - err = DownloadRoachprod(builds.Build[0], newRoachprod) + err = downloadRoachprod(builds.Build[0].Id, toFile) if err != nil { - return "", err + return err } - fmt.Printf("Latest roachprod downloaded to %s\n", newRoachprod) - return newRoachprod, nil + return nil } -func GetBuilds(locator string) (TCBuildResponse, error) { - // Get the latest successful build +// getBuilds returns a list of builds matching the locator +// See https://www.jetbrains.com/help/teamcity/rest/buildlocator.html +func getBuilds(locator string) (TCBuildResponse, error) { urlWithLocator := fmt.Sprintf("%s/builds?locator=%s", apiBase, locator) - fmt.Println("URL: ", urlWithLocator) - buildResp := &TCBuildResponse{} - err := httputil.GetJSONWithOptions(*httputil.DefaultClient.Client, urlWithLocator, buildResp, httputil.IgnoreUnknownFields()) - + err := httputil.GetJSONWithOptions(*httputil.DefaultClient.Client, urlWithLocator, buildResp, + httputil.IgnoreUnknownFields()) return *buildResp, err } -// DownloadRoachprod downloads the roachprod binary from the build +// downloadRoachprod downloads the roachprod binary from the build // to the specified destination file. -func DownloadRoachprod(build *TCBuild, destFile string) error { +func downloadRoachprod(buildId int32, destFile string) error { + if buildId <= 0 { + return fmt.Errorf("invalid build id") + } out, err := os.Create(destFile) if err != nil { return err } defer out.Close() - resp, err := httputil.Get(context.Background(), downloadURL(build.Id)) + url := roachprodDownloadUrl(buildId) + fmt.Printf("Downloading latest roachprod \n\tfrom:\t%s \n\tto :\t%s\n", url, destFile) + + // Set a long timeout here because the download can take a while. + httpClient := httputil.NewClientWithTimeouts(httputil.StandardHTTPTimeout, 10*time.Minute) + resp, err := httpClient.Get(context.Background(), url) if err != nil { return err } @@ -104,6 +102,13 @@ func DownloadRoachprod(build *TCBuild, destFile string) error { if err != nil { return err } - return nil } + +func roachprodDownloadUrl(buildId int32) string { + url := fmt.Sprintf("%s%s", + apiBase, + fmt.Sprintf("/builds/id:%v/artifacts/content/bazel-bin/pkg/cmd/roachprod/roachprod_/roachprod", buildId), + ) + return url +} diff --git a/pkg/cmd/roachprod/upgrade/util.go b/pkg/cmd/roachprod/upgrade/util.go index b991fd0491f6..2ea5524a13ec 100644 --- a/pkg/cmd/roachprod/upgrade/util.go +++ b/pkg/cmd/roachprod/upgrade/util.go @@ -20,7 +20,7 @@ import ( ) func PromptYesNo(msg string) bool { - fmt.Printf("%s (y[default]/n)", msg) + fmt.Printf("%s y[default]/n: ", msg) var answer string _, _ = fmt.Scanln(&answer) answer = strings.TrimSpace(answer) diff --git a/pkg/util/httputil/client.go b/pkg/util/httputil/client.go index 60060dc06db3..bbeebc3ef200 100644 --- a/pkg/util/httputil/client.go +++ b/pkg/util/httputil/client.go @@ -27,12 +27,17 @@ const StandardHTTPTimeout time.Duration = 3 * time.Second // NewClientWithTimeout defines a http.Client with the given timeout. func NewClientWithTimeout(timeout time.Duration) *Client { + return NewClientWithTimeouts(timeout, timeout) +} + +// NewClientWithTimeouts defines a http.Client with the given dialer and client timeouts. +func NewClientWithTimeouts(dialerTimeout, clientTimeout time.Duration) *Client { return &Client{&http.Client{ - Timeout: timeout, + Timeout: clientTimeout, Transport: &http.Transport{ // Don't leak a goroutine on OSX (the TCP level timeout is probably // much higher than on linux). - DialContext: (&net.Dialer{Timeout: timeout}).DialContext, + DialContext: (&net.Dialer{Timeout: dialerTimeout}).DialContext, DisableKeepAlives: true, }, }}