diff --git a/pkg/cmd/roachprod/main.go b/pkg/cmd/roachprod/main.go index 299571a3de19..cb0c28a47684 100644 --- a/pkg/cmd/roachprod/main.go +++ b/pkg/cmd/roachprod/main.go @@ -655,7 +655,8 @@ var runCmd = &cobra.Command{ `, Args: cobra.MinimumNArgs(1), Run: wrap(func(_ *cobra.Command, args []string) error { - return roachprod.Run(context.Background(), config.Logger, args[0], extraSSHOptions, tag, secure, os.Stdout, os.Stderr, args[1:]) + return roachprod.Run(context.Background(), config.Logger, args[0], extraSSHOptions, tag, + secure, os.Stdout, os.Stderr, args[1:], install.WithWaitOnFail()) }), } @@ -802,7 +803,7 @@ multiple nodes the destination file name will be prefixed with the node number. if len(args) == 3 { dest = args[2] } - return roachprod.Get(config.Logger, args[0], src, dest) + return roachprod.Get(context.Background(), config.Logger, args[0], src, dest) }), } @@ -858,7 +859,7 @@ Examples: if cmd.CalledAs() == "pprof-heap" { pprofOpts.Heap = true } - return roachprod.Pprof(config.Logger, args[0], pprofOpts) + return roachprod.Pprof(context.Background(), config.Logger, args[0], pprofOpts) }), } diff --git a/pkg/cmd/roachtest/cluster.go b/pkg/cmd/roachtest/cluster.go index 3a8e6ae740ac..1659d426bcde 100644 --- a/pkg/cmd/roachtest/cluster.go +++ b/pkg/cmd/roachtest/cluster.go @@ -1969,7 +1969,7 @@ func (c *clusterImpl) Get( } c.status(fmt.Sprintf("getting %v", src)) defer c.status("") - return errors.Wrap(roachprod.Get(l, c.MakeNodes(opts...), src, dest), "cluster.Get") + return errors.Wrap(roachprod.Get(ctx, l, c.MakeNodes(opts...), src, dest), "cluster.Get") } // Put a string into the specified file on the remote(s). diff --git a/pkg/roachprod/install/cluster_synced.go b/pkg/roachprod/install/cluster_synced.go index 303d5fa0dee8..32256dd3ade9 100644 --- a/pkg/roachprod/install/cluster_synced.go +++ b/pkg/roachprod/install/cluster_synced.go @@ -155,10 +155,13 @@ var defaultSCPRetry = newRunRetryOpts(defaultRetryOpt, // captured in a *RunResultDetails[] in Run, but here we may enrich with attempt // number and a wrapper error. func runWithMaybeRetry( - l *logger.Logger, retryOpts *RunRetryOpts, f func() (*RunResultDetails, error), + ctx context.Context, + l *logger.Logger, + retryOpts *RunRetryOpts, + f func(ctx context.Context) (*RunResultDetails, error), ) (*RunResultDetails, error) { if retryOpts == nil { - res, err := f() + res, err := f(ctx) res.Attempt = 1 return res, err } @@ -167,10 +170,10 @@ func runWithMaybeRetry( var err error var cmdErr error - for r := retry.Start(retryOpts.Options); r.Next(); { - res, err = f() + for r := retry.StartWithCtx(ctx, retryOpts.Options); r.Next(); { + res, err = f(ctx) res.Attempt = r.CurrentAttempt() + 1 - // nil err (denoting a roachprod error) indicates a potentially retryable res.Err + // nil err (non-nil denotes a roachprod error) indicates a potentially retryable res.Err if err == nil && res.Err != nil { cmdErr = errors.CombineErrors(cmdErr, res.Err) if retryOpts.shouldRetryFn == nil || retryOpts.shouldRetryFn(res) { @@ -193,8 +196,10 @@ func runWithMaybeRetry( return res, err } -func scpWithRetry(l *logger.Logger, src, dest string) (*RunResultDetails, error) { - return runWithMaybeRetry(l, defaultSCPRetry, func() (*RunResultDetails, error) { return scp(l, src, dest) }) +func scpWithRetry( + ctx context.Context, l *logger.Logger, src, dest string, +) (*RunResultDetails, error) { + return runWithMaybeRetry(ctx, l, defaultSCPRetry, func(ctx context.Context) (*RunResultDetails, error) { return scp(l, src, dest) }) } // Host returns the public IP of a node. @@ -431,7 +436,7 @@ func (c *SyncedCluster) kill( // `kill -9` without wait is never what a caller wants. See #77334. wait = true } - return c.Parallel(l, display, len(c.Nodes), 0, func(i int) (*RunResultDetails, error) { + return c.Parallel(ctx, l, len(c.Nodes), func(ctx context.Context, i int) (*RunResultDetails, error) { node := c.Nodes[i] var waitCmd string @@ -485,7 +490,7 @@ fi`, res := newRunResultDetails(node, cmdErr) res.CombinedOut = out return res, res.Err - }, nil) // Disable SSH Retries + }, WithDisplay(display), WithRetryOpts(nil)) // Disable SSH Retries } // Wipe TODO(peter): document @@ -494,7 +499,7 @@ func (c *SyncedCluster) Wipe(ctx context.Context, l *logger.Logger, preserveCert if err := c.Stop(ctx, l, 9, true /* wait */, 0 /* maxWait */); err != nil { return err } - return c.Parallel(l, display, len(c.Nodes), 0, func(i int) (*RunResultDetails, error) { + return c.Parallel(ctx, l, len(c.Nodes), func(ctx context.Context, i int) (*RunResultDetails, error) { node := c.Nodes[i] var cmd string if c.IsLocal() { @@ -524,7 +529,7 @@ sudo rm -fr logs && res := newRunResultDetails(node, cmdErr) res.CombinedOut = out return res, res.Err - }, DefaultSSHRetryOpts) + }, WithDisplay(display)) } // NodeStatus contains details about the status of a node. @@ -540,7 +545,7 @@ type NodeStatus struct { func (c *SyncedCluster) Status(ctx context.Context, l *logger.Logger) ([]NodeStatus, error) { display := fmt.Sprintf("%s: status", c.Name) results := make([]NodeStatus, len(c.Nodes)) - if err := c.Parallel(l, display, len(c.Nodes), 0, func(i int) (*RunResultDetails, error) { + if err := c.Parallel(ctx, l, len(c.Nodes), func(ctx context.Context, i int) (*RunResultDetails, error) { node := c.Nodes[i] binary := cockroachNodeBinary(c, node) @@ -576,7 +581,7 @@ fi results[i] = NodeStatus{Running: true, Version: info[0], Pid: info[1]} return res, nil - }, DefaultSSHRetryOpts); err != nil { + }, WithDisplay(display)); err != nil { return nil, err } for i := 0; i < len(results); i++ { @@ -865,7 +870,12 @@ func (c *SyncedCluster) runCmdOnSingleNode( // title: A description of the command being run that is output to the logs. // cmd: The command to run. func (c *SyncedCluster) Run( - ctx context.Context, l *logger.Logger, stdout, stderr io.Writer, nodes Nodes, title, cmd string, + ctx context.Context, + l *logger.Logger, + stdout, stderr io.Writer, + nodes Nodes, + title, cmd string, + opts ...ParallelOption, ) error { // Stream output if we're running the command on only 1 node. stream := len(nodes) == 1 @@ -877,14 +887,14 @@ func (c *SyncedCluster) Run( results := make([]*RunResultDetails, len(nodes)) // A result is the output of running a command (could be interpreted as an error) - if _, err := c.ParallelE(l, display, len(nodes), 0, func(i int) (*RunResultDetails, error) { + if _, err := c.ParallelE(ctx, l, len(nodes), func(ctx context.Context, i int) (*RunResultDetails, error) { // An err returned here is an unexpected state within roachprod (non-command error). // For errors that occur as part of running a command over ssh, the `result` will contain // the actual error on a specific node. result, err := c.runCmdOnSingleNode(ctx, l, nodes[i], cmd, !stream, stdout, stderr) results[i] = result return result, err - }, DefaultSSHRetryOpts); err != nil { + }, append(opts, WithDisplay(display))...); err != nil { return err } @@ -895,6 +905,12 @@ func (c *SyncedCluster) Run( func processResults(results []*RunResultDetails, stream bool, stdout io.Writer) error { var resultWithError *RunResultDetails for i, r := range results { + // We no longer wait for all nodes to complete before returning in the case of an error (#100403) + // which means that some node results may be nil. + if r == nil { + continue + } + if !stream { fmt.Fprintf(stdout, " %2d: %s\n%v\n", i+1, strings.TrimSpace(string(r.CombinedOut)), r.Err) } @@ -918,6 +934,7 @@ func processResults(results []*RunResultDetails, stream bool, stdout io.Writer) } // RunWithDetails runs a command on the specified nodes and returns results details and an error. +// This will wait for all commands to complete before returning unless encountering a roachprod error. func (c *SyncedCluster) RunWithDetails( ctx context.Context, l *logger.Logger, nodes Nodes, title, cmd string, ) ([]RunResultDetails, error) { @@ -927,13 +944,14 @@ func (c *SyncedCluster) RunWithDetails( // be processed further by the caller. resultPtrs := make([]*RunResultDetails, len(nodes)) - // Both return values are explicitly ignored because, in this case, resultPtrs - // capture both error and result state for each node - _, _ = c.ParallelE(l, display, len(nodes), 0, func(i int) (*RunResultDetails, error) { //nolint:errcheck + // Failing slow here allows us to capture the output of all nodes even if one fails with a command error. + if _, err := c.ParallelE(ctx, l, len(nodes), func(ctx context.Context, i int) (*RunResultDetails, error) { //nolint:errcheck result, err := c.runCmdOnSingleNode(ctx, l, nodes[i], cmd, false, l.Stdout, l.Stderr) resultPtrs[i] = result return result, err - }, DefaultSSHRetryOpts) + }, WithDisplay(display), WithWaitOnFail()); err != nil { + return nil, err + } // Return values to preserve API results := make([]RunResultDetails, len(nodes)) @@ -978,7 +996,7 @@ func (c *SyncedCluster) RepeatRun( func (c *SyncedCluster) Wait(ctx context.Context, l *logger.Logger) error { display := fmt.Sprintf("%s: waiting for nodes to start", c.Name) errs := make([]error, len(c.Nodes)) - if err := c.Parallel(l, display, len(c.Nodes), 0, func(i int) (*RunResultDetails, error) { + if err := c.Parallel(ctx, l, len(c.Nodes), func(ctx context.Context, i int) (*RunResultDetails, error) { node := c.Nodes[i] res := &RunResultDetails{Node: node} cmd := "test -e /mnt/data1/.roachprod-initialized" @@ -996,7 +1014,7 @@ func (c *SyncedCluster) Wait(ctx context.Context, l *logger.Logger) error { errs[i] = errors.New("timed out after 5m") res.Err = errs[i] return res, nil - }, nil); err != nil { + }, WithDisplay(display), WithRetryOpts(nil)); err != nil { return err } @@ -1041,7 +1059,7 @@ func (c *SyncedCluster) SetupSSH(ctx context.Context, l *logger.Logger) error { // Generate an ssh key that we'll distribute to all of the nodes in the // cluster in order to allow inter-node ssh. var sshTar []byte - if err := c.Parallel(l, "generating ssh key", 1, 0, func(i int) (*RunResultDetails, error) { + if err := c.Parallel(ctx, l, 1, func(ctx context.Context, i int) (*RunResultDetails, error) { // Create the ssh key and then tar up the public, private and // authorized_keys files and output them to stdout. We'll take this output // and pipe it back into tar on the other nodes in the cluster. @@ -1069,13 +1087,13 @@ tar cf - .ssh/id_rsa .ssh/id_rsa.pub .ssh/authorized_keys } sshTar = []byte(res.Stdout) return res, nil - }, DefaultSSHRetryOpts); err != nil { + }, WithDisplay("generating ssh key")); err != nil { return err } // Skip the first node which is where we generated the key. nodes := c.Nodes[1:] - if err := c.Parallel(l, "distributing ssh key", len(nodes), 0, func(i int) (*RunResultDetails, error) { + if err := c.Parallel(ctx, l, len(nodes), func(ctx context.Context, i int) (*RunResultDetails, error) { node := nodes[i] cmd := `tar xf -` @@ -1092,7 +1110,7 @@ tar cf - .ssh/id_rsa .ssh/id_rsa.pub .ssh/authorized_keys return res, errors.Wrapf(res.Err, "%s: output:\n%s", cmd, res.CombinedOut) } return res, nil - }, DefaultSSHRetryOpts); err != nil { + }, WithDisplay("distributing ssh key")); err != nil { return err } @@ -1124,7 +1142,7 @@ tar cf - .ssh/id_rsa .ssh/id_rsa.pub .ssh/authorized_keys providerKnownHostData := make(map[string][]byte) providers := maps.Keys(providerPrivateIPs) // Only need to scan on the first node of each provider. - if err := c.Parallel(l, "scanning hosts", len(providers), 0, func(i int) (*RunResultDetails, error) { + if err := c.Parallel(ctx, l, len(providers), func(ctx context.Context, i int) (*RunResultDetails, error) { provider := providers[i] node := providerPrivateIPs[provider][0].node // Scan a combination of all remote IPs and local IPs pertaining to this @@ -1176,11 +1194,11 @@ exit 1 providerKnownHostData[provider] = stdout.Bytes() mu.Unlock() return res, nil - }, DefaultSSHRetryOpts); err != nil { + }, WithDisplay("scanning hosts")); err != nil { return err } - if err := c.Parallel(l, "distributing known_hosts", len(c.Nodes), 0, func(i int) (*RunResultDetails, error) { + if err := c.Parallel(ctx, l, len(c.Nodes), func(ctx context.Context, i int) (*RunResultDetails, error) { node := c.Nodes[i] provider := c.VMs[node-1].Provider const cmd = ` @@ -1223,7 +1241,7 @@ fi return res, errors.Wrapf(res.Err, "%s: output:\n%s", cmd, res.CombinedOut) } return res, nil - }, DefaultSSHRetryOpts); err != nil { + }, WithDisplay("distributing known_hosts")); err != nil { return err } @@ -1233,7 +1251,7 @@ fi // additional authorized_keys to both the current user (your username on // gce and the shared user on aws) as well as to the shared user on both // platforms. - if err := c.Parallel(l, "adding additional authorized keys", len(c.Nodes), 0, func(i int) (*RunResultDetails, error) { + if err := c.Parallel(ctx, l, len(c.Nodes), func(ctx context.Context, i int) (*RunResultDetails, error) { node := c.Nodes[i] const cmd = ` keys_data="$(cat)" @@ -1270,7 +1288,7 @@ fi return res, errors.Wrapf(res.Err, "~ %s\n%s", cmd, res.CombinedOut) } return res, nil - }, DefaultSSHRetryOpts); err != nil { + }, WithDisplay("adding additional authorized keys")); err != nil { return err } } @@ -1286,7 +1304,9 @@ const ( // DistributeCerts will generate and distribute certificates to all of the // nodes. func (c *SyncedCluster) DistributeCerts(ctx context.Context, l *logger.Logger) error { - if c.checkForCertificates(ctx, l) { + if found, err := c.checkForCertificates(ctx, l); err != nil { + return err + } else if found { return nil } @@ -1298,7 +1318,7 @@ func (c *SyncedCluster) DistributeCerts(ctx context.Context, l *logger.Logger) e // Generate the ca, client and node certificates on the first node. var msg string display := fmt.Sprintf("%s: initializing certs", c.Name) - if err := c.Parallel(l, display, 1, 0, func(i int) (*RunResultDetails, error) { + if err := c.Parallel(ctx, l, 1, func(ctx context.Context, i int) (*RunResultDetails, error) { var cmd string if c.IsLocal() { cmd = fmt.Sprintf(`cd %s ; `, c.localVMDir(1)) @@ -1337,7 +1357,7 @@ tar cvf %[3]s certs msg = fmt.Sprintf("%s: %v", res.CombinedOut, res.Err) } return res, nil - }, DefaultSSHRetryOpts); err != nil { + }, WithDisplay(display)); err != nil { return err } @@ -1346,7 +1366,7 @@ tar cvf %[3]s certs exit.WithCode(exit.UnspecifiedError()) } - tarfile, cleanup, err := c.getFileFromFirstNode(l, certsTarName) + tarfile, cleanup, err := c.getFileFromFirstNode(ctx, l, certsTarName) if err != nil { return err } @@ -1362,11 +1382,15 @@ tar cvf %[3]s certs func (c *SyncedCluster) DistributeTenantCerts( ctx context.Context, l *logger.Logger, hostCluster *SyncedCluster, tenantID int, ) error { - if hostCluster.checkForTenantCertificates(ctx, l) { + if found, err := hostCluster.checkForTenantCertificates(ctx, l); err != nil { + return err + } else if found { return nil } - if !hostCluster.checkForCertificates(ctx, l) { + if found, err := hostCluster.checkForCertificates(ctx, l); err != nil { + return err + } else if !found { return errors.New("host cluster missing certificate bundle") } @@ -1379,7 +1403,7 @@ func (c *SyncedCluster) DistributeTenantCerts( return err } - tarfile, cleanup, err := hostCluster.getFileFromFirstNode(l, tenantCertsTarName) + tarfile, cleanup, err := hostCluster.getFileFromFirstNode(ctx, l, tenantCertsTarName) if err != nil { return err } @@ -1397,7 +1421,7 @@ func (c *SyncedCluster) createTenantCertBundle( ctx context.Context, l *logger.Logger, bundleName string, tenantID int, nodeNames []string, ) error { display := fmt.Sprintf("%s: initializing tenant certs", c.Name) - return c.Parallel(l, display, 1, 0, func(i int) (*RunResultDetails, error) { + return c.Parallel(ctx, l, 1, func(ctx context.Context, i int) (*RunResultDetails, error) { node := c.Nodes[i] cmd := "set -e;" @@ -1441,14 +1465,14 @@ fi return res, errors.Wrapf(res.Err, "certificate creation error: %s", res.CombinedOut) } return res, nil - }, DefaultSSHRetryOpts) + }, WithDisplay(display)) } // getFile retrieves the given file from the first node in the cluster. The // filename is assumed to be relative from the home directory of the node's // user. func (c *SyncedCluster) getFileFromFirstNode( - l *logger.Logger, name string, + ctx context.Context, l *logger.Logger, name string, ) (string, func(), error) { var tmpfileName string cleanup := func() {} @@ -1465,7 +1489,7 @@ func (c *SyncedCluster) getFileFromFirstNode( } srcFileName := fmt.Sprintf("%s@%s:%s", c.user(1), c.Host(1), name) - if res, _ := scpWithRetry(l, srcFileName, tmpfile.Name()); res.Err != nil { + if res, _ := scpWithRetry(ctx, l, srcFileName, tmpfile.Name()); res.Err != nil { cleanup() return "", nil, res.Err } @@ -1476,7 +1500,7 @@ func (c *SyncedCluster) getFileFromFirstNode( // checkForCertificates checks if the cluster already has a certs bundle created // on the first node. -func (c *SyncedCluster) checkForCertificates(ctx context.Context, l *logger.Logger) bool { +func (c *SyncedCluster) checkForCertificates(ctx context.Context, l *logger.Logger) (bool, error) { dir := "" if c.IsLocal() { dir = c.localVMDir(1) @@ -1486,7 +1510,9 @@ func (c *SyncedCluster) checkForCertificates(ctx context.Context, l *logger.Logg // checkForTenantCertificates checks if the cluster already has a tenant-certs bundle created // on the first node. -func (c *SyncedCluster) checkForTenantCertificates(ctx context.Context, l *logger.Logger) bool { +func (c *SyncedCluster) checkForTenantCertificates( + ctx context.Context, l *logger.Logger, +) (bool, error) { dir := "" if c.IsLocal() { dir = c.localVMDir(1) @@ -1496,24 +1522,21 @@ func (c *SyncedCluster) checkForTenantCertificates(ctx context.Context, l *logge func (c *SyncedCluster) fileExistsOnFirstNode( ctx context.Context, l *logger.Logger, path string, -) bool { - var existsErr error - display := fmt.Sprintf("%s: checking %s", c.Name, path) - if err := c.Parallel(l, display, 1, 0, func(i int) (*RunResultDetails, error) { - node := c.Nodes[i] - sess := c.newSession(l, node, `test -e `+path) - defer sess.Close() - - out, cmdErr := sess.CombinedOutput(ctx) - res := newRunResultDetails(node, cmdErr) - res.CombinedOut = out - - existsErr = res.Err - return res, nil - }, DefaultSSHRetryOpts); err != nil { - return false +) (bool, error) { + l.Printf("%s: checking %s", c.Name, path) + testCmd := `$(test -e ` + path + `);` + // Do not log output to stdout/stderr because in some cases this call will be expected to exit 1. + result, err := c.runCmdOnSingleNode(ctx, l, c.Nodes[0], testCmd, true, nil, nil) + if (result.RemoteExitStatus != 0 && result.RemoteExitStatus != 1) || err != nil { + // Unexpected exit status (neither 0 nor 1) or non-nil error. Return combined output along with err returned + // from the call if it's not nil. + if err != nil { + return false, errors.Wrapf(err, "running '%s' failed with exit code=%d: got %s", testCmd, result.RemoteExitStatus, string(result.CombinedOut)) + } else { + return false, errors.Newf("running '%s' failed with exit code=%d: got %s", testCmd, result.RemoteExitStatus, string(result.CombinedOut)) + } } - return existsErr == nil + return result.RemoteExitStatus == 0, nil } // createNodeCertArguments returns a list of strings appropriate for use as @@ -1556,7 +1579,7 @@ func (c *SyncedCluster) distributeLocalCertsTar( } display := c.Name + ": distributing certs" - return c.Parallel(l, display, len(nodes), 0, func(i int) (*RunResultDetails, error) { + return c.Parallel(ctx, l, len(nodes), func(ctx context.Context, i int) (*RunResultDetails, error) { node := nodes[i] var cmd string if c.IsLocal() { @@ -1580,7 +1603,7 @@ func (c *SyncedCluster) distributeLocalCertsTar( return res, errors.Wrapf(res.Err, "~ %s\n%s", cmd, res.CombinedOut) } return res, nil - }, DefaultSSHRetryOpts) + }, WithDisplay(display)) } const progressDone = "=======================================>" @@ -1777,7 +1800,7 @@ func (c *SyncedCluster) Put( return } - res, _ := scpWithRetry(l, from, to) + res, _ := scpWithRetry(ctx, l, from, to) results <- result{i, res.Err} if res.Err != nil { @@ -2028,7 +2051,9 @@ func (c *SyncedCluster) Logs( } // Get TODO(peter): document -func (c *SyncedCluster) Get(l *logger.Logger, nodes Nodes, src, dest string) error { +func (c *SyncedCluster) Get( + ctx context.Context, l *logger.Logger, nodes Nodes, src, dest string, +) error { if err := c.validateHost(context.TODO(), l, nodes[0]); err != nil { return err } @@ -2146,7 +2171,7 @@ func (c *SyncedCluster) Get(l *logger.Logger, nodes Nodes, src, dest string) err return } - res, _ := scpWithRetry(l, fmt.Sprintf("%s@%s:%s", c.user(nodes[0]), c.Host(nodes[i]), src), dest) + res, _ := scpWithRetry(ctx, l, fmt.Sprintf("%s@%s:%s", c.user(nodes[0]), c.Host(nodes[i]), src), dest) if res.Err == nil { // Make sure all created files and directories are world readable. // The CRDB process intentionally sets a 0007 umask (resulting in @@ -2262,14 +2287,14 @@ func (c *SyncedCluster) pghosts( ctx context.Context, l *logger.Logger, nodes Nodes, ) (map[Node]string, error) { ips := make([]string, len(nodes)) - if err := c.Parallel(l, "", len(nodes), 0, func(i int) (*RunResultDetails, error) { + if err := c.Parallel(ctx, l, len(nodes), func(ctx context.Context, i int) (*RunResultDetails, error) { node := nodes[i] res := &RunResultDetails{Node: node} res.Stdout, res.Err = c.GetInternalIP(node) ips[i] = res.Stdout - return res, errors.Wrapf(res.Err, "pghosts") - }, DefaultSSHRetryOpts); err != nil { - return nil, err + return res, nil + }); err != nil { + return nil, errors.Wrapf(err, "pghosts") } m := make(map[Node]string, len(ips)) @@ -2383,31 +2408,88 @@ type ParallelResult struct { Err error } +type ParallelOptions struct { + concurrency int + display string + retryOpts *RunRetryOpts + // waitOnFail will cause the Parallel function to wait for all nodes to + // finish when encountering a command error on any node. The default + // behaviour is to exit immediately on the first error, in which case the + // slice of ParallelResults will only contain the one error result. + waitOnFail bool +} + +type ParallelOption func(result *ParallelOptions) + +func WithConcurrency(concurrency int) ParallelOption { + return func(result *ParallelOptions) { + result.concurrency = concurrency + } +} + +func WithRetryOpts(retryOpts *RunRetryOpts) ParallelOption { + return func(result *ParallelOptions) { + result.retryOpts = retryOpts + } +} + +func WithWaitOnFail() ParallelOption { + return func(result *ParallelOptions) { + result.waitOnFail = true + } +} + +func WithDisplay(display string) ParallelOption { + return func(result *ParallelOptions) { + result.display = display + } +} + // Parallel runs a user-defined function across the nodes in the // cluster. If any of the commands fail, Parallel will log an error // and exit the program. // // See ParallelE for more information. func (c *SyncedCluster) Parallel( + ctx context.Context, l *logger.Logger, - display string, - count, concurrency int, - fn func(i int) (*RunResultDetails, error), - runRetryOpts *RunRetryOpts, + count int, + fn func(ctx context.Context, i int) (*RunResultDetails, error), + opts ...ParallelOption, ) error { - failed, err := c.ParallelE(l, display, count, concurrency, fn, runRetryOpts) + // failed will contain command errors if any occur. + // err is an unexpected roachprod error, which we return immediately. + failed, err := c.ParallelE(ctx, l, count, fn, opts...) if err != nil { + return err + } + + if len(failed) > 0 { sort.Slice(failed, func(i, j int) bool { return failed[i].Index < failed[j].Index }) for _, f := range failed { + // Since this function is potentially returning a single error despite + // having run on multiple nodes, we combine all the errors into a single + // error. + err = errors.CombineErrors(err, f.Err) l.Errorf("%d: %+v: %s", f.Index, f.Err, f.Out) } - return err + return errors.Wrap(err, "one or more parallel execution failure") } return nil } // ParallelE runs the given function in parallel across the given -// nodes, returning an error if function returns an error. +// nodes. +// +// By default, this will fail fast if a command error occurs on any node, in which +// case the function will return a slice containing the erroneous result. +// +// If `WithWaitOnFail()` is passed in, then the function will wait for all +// invocations to complete before returning a slice with all failed results. +// +// ParallelE only returns an error for roachprod itself, not any command errors run +// on the cluster. It is up to the caller to check the slice for command errors. Any +// such roachprod error will always be returned immediately. // // ParallelE runs the user-defined functions on the first `count` // nodes in the cluster. It runs at most `concurrency` (or @@ -2415,50 +2497,71 @@ func (c *SyncedCluster) Parallel( // 0, then it defaults to `count`. // // The function returns a pointer to RunResultDetails as we may enrich -// the result with retry information (attempt number, wrapper error) +// the result with retry information (attempt number, wrapper error). // -// If err is non-nil, the slice of ParallelResults will contain the -// results from any of the failed invocations. +// RunRetryOpts controls the retry behavior in the case that +// the function fails, but returns a nil error. A non-nil error returned by the +// function denotes a roachprod error and will not be retried regardless of the +// retry options. func (c *SyncedCluster) ParallelE( + ctx context.Context, l *logger.Logger, - display string, - count, concurrency int, - fn func(i int) (*RunResultDetails, error), - runRetryOpts *RunRetryOpts, + count int, + fn func(ctx context.Context, i int) (*RunResultDetails, error), + opts ...ParallelOption, ) ([]ParallelResult, error) { - if concurrency == 0 || concurrency > count { - concurrency = count + options := ParallelOptions{retryOpts: DefaultSSHRetryOpts} + for _, opt := range opts { + opt(&options) + } + + if options.concurrency == 0 || options.concurrency > count { + options.concurrency = count } - if config.MaxConcurrency > 0 && concurrency > config.MaxConcurrency { - concurrency = config.MaxConcurrency + if config.MaxConcurrency > 0 && options.concurrency > config.MaxConcurrency { + options.concurrency = config.MaxConcurrency } results := make(chan ParallelResult, count) + errorChannel := make(chan error) var wg sync.WaitGroup wg.Add(count) + groupCtx, groupCancel := context.WithCancel(ctx) + defer groupCancel() var index int startNext := func() { + // If we needed to react to a context cancellation here we would need to + // nest this goroutine in another one and select on the groupCtx. However, + // since anything intensive here is a command over ssh, and we are threading + // the context through, a cancellation will be handled by the command itself. go func(i int) { defer wg.Done() - res, err := runWithMaybeRetry(l, runRetryOpts, func() (*RunResultDetails, error) { return fn(i) }) - results <- ParallelResult{i, res.CombinedOut, err} + // This is rarely expected to return an error, but we fail fast in case. + // Command errors, which are far more common, will be contained within the result. + res, err := runWithMaybeRetry(groupCtx, l, options.retryOpts, func(ctx context.Context) (*RunResultDetails, error) { return fn(ctx, i) }) + if err != nil { + errorChannel <- err + return + } + results <- ParallelResult{i, res.CombinedOut, res.Err} }(index) index++ } - for index < concurrency { + for index < options.concurrency { startNext() } go func() { wg.Wait() close(results) + close(errorChannel) }() var writer ui.Writer out := l.Stdout - if display == "" { + if options.display == "" { out = io.Discard } @@ -2467,7 +2570,7 @@ func (c *SyncedCluster) ParallelE( ticker = time.NewTicker(100 * time.Millisecond) } else { ticker = time.NewTicker(1000 * time.Millisecond) - fmt.Fprintf(out, "%s", display) + fmt.Fprintf(out, "%s", options.display) if l.File != nil { fmt.Fprintf(out, "\n") } @@ -2486,8 +2589,12 @@ func (c *SyncedCluster) ParallelE( fmt.Fprintf(out, ".") } case r, ok := <-results: - if r.Err != nil { + if r.Err != nil { // Command error failed = append(failed, r) + if !options.waitOnFail { + groupCancel() + return failed, nil + } } done = !ok if ok { @@ -2496,10 +2603,13 @@ func (c *SyncedCluster) ParallelE( if index < count { startNext() } + case err := <-errorChannel: // Roachprod error + groupCancel() + return nil, err } if !config.Quiet && l.File == nil { - fmt.Fprint(&writer, display) + fmt.Fprint(&writer, options.display) var n int for i := range complete { if complete[i] { @@ -2520,14 +2630,7 @@ func (c *SyncedCluster) ParallelE( fmt.Fprintf(out, "\n") } - if len(failed) > 0 { - var err error - for _, res := range failed { - err = errors.CombineErrors(err, res.Err) - } - return failed, errors.Wrap(err, "parallel execution failure") - } - return nil, nil + return failed, nil } // Init initializes the cluster. It does it through node 1 (as per TargetNodes) diff --git a/pkg/roachprod/install/cluster_synced_test.go b/pkg/roachprod/install/cluster_synced_test.go index 89f941679b64..f52d77e8a69c 100644 --- a/pkg/roachprod/install/cluster_synced_test.go +++ b/pkg/roachprod/install/cluster_synced_test.go @@ -11,6 +11,7 @@ package install import ( + "context" "fmt" "io" "testing" @@ -96,21 +97,21 @@ func TestRunWithMaybeRetry(t *testing.T) { attempt := 0 cases := []struct { - f func() (*RunResultDetails, error) + f func(ctx context.Context) (*RunResultDetails, error) shouldRetryFn func(*RunResultDetails) bool nilRetryOpts bool expectedAttempts int shouldError bool }{ { // 1. Happy path: no error, no retry required - f: func() (*RunResultDetails, error) { + f: func(ctx context.Context) (*RunResultDetails, error) { return newResult(0), nil }, expectedAttempts: 1, shouldError: false, }, { // 2. Error, but with no retries - f: func() (*RunResultDetails, error) { + f: func(ctx context.Context) (*RunResultDetails, error) { return newResult(1), nil }, shouldRetryFn: func(*RunResultDetails) bool { @@ -120,14 +121,14 @@ func TestRunWithMaybeRetry(t *testing.T) { shouldError: true, }, { // 3. Error, but no retry function specified - f: func() (*RunResultDetails, error) { + f: func(ctx context.Context) (*RunResultDetails, error) { return newResult(1), nil }, expectedAttempts: 3, shouldError: true, }, { // 4. Error, with retries exhausted - f: func() (*RunResultDetails, error) { + f: func(ctx context.Context) (*RunResultDetails, error) { return newResult(255), nil }, shouldRetryFn: func(d *RunResultDetails) bool { return d.RemoteExitStatus == 255 }, @@ -135,7 +136,7 @@ func TestRunWithMaybeRetry(t *testing.T) { shouldError: true, }, { // 5. Eventual success after retries - f: func() (*RunResultDetails, error) { + f: func(ctx context.Context) (*RunResultDetails, error) { attempt++ if attempt == 3 { return newResult(0), nil @@ -147,7 +148,7 @@ func TestRunWithMaybeRetry(t *testing.T) { shouldError: false, }, { // 6. Error, runs once because nil retryOpts - f: func() (*RunResultDetails, error) { + f: func(ctx context.Context) (*RunResultDetails, error) { return newResult(255), nil }, nilRetryOpts: true, @@ -163,7 +164,7 @@ func TestRunWithMaybeRetry(t *testing.T) { if !tc.nilRetryOpts { retryOpts = newRunRetryOpts(testRetryOpts, tc.shouldRetryFn) } - res, _ := runWithMaybeRetry(l, retryOpts, tc.f) + res, _ := runWithMaybeRetry(context.Background(), l, retryOpts, tc.f) require.Equal(t, tc.shouldError, res.Err != nil) require.Equal(t, tc.expectedAttempts, res.Attempt) diff --git a/pkg/roachprod/install/cockroach.go b/pkg/roachprod/install/cockroach.go index f35ecde466eb..892f59396b6f 100644 --- a/pkg/roachprod/install/cockroach.go +++ b/pkg/roachprod/install/cockroach.go @@ -183,7 +183,7 @@ func (c *SyncedCluster) Start(ctx context.Context, l *logger.Logger, startOpts S l.Printf("%s: starting nodes", c.Name) // SSH retries are disabled by passing nil RunRetryOpts - if err := c.Parallel(l, "", len(nodes), parallelism, func(nodeIdx int) (*RunResultDetails, error) { + if err := c.Parallel(ctx, l, len(nodes), func(ctx context.Context, nodeIdx int) (*RunResultDetails, error) { node := nodes[nodeIdx] res := &RunResultDetails{Node: node} // NB: if cockroach started successfully, we ignore the output as it is @@ -228,7 +228,7 @@ func (c *SyncedCluster) Start(ctx context.Context, l *logger.Logger, startOpts S return res, errors.Wrap(err, "failed to set cluster settings") } return res, nil - }, DefaultSSHRetryOpts); err != nil { + }, WithConcurrency(parallelism)); err != nil { return err } @@ -329,7 +329,7 @@ func (c *SyncedCluster) ExecSQL( resultChan := make(chan result, len(c.Nodes)) display := fmt.Sprintf("%s: executing sql", c.Name) - if err := c.Parallel(l, display, len(c.Nodes), 0, func(nodeIdx int) (*RunResultDetails, error) { + if err := c.Parallel(ctx, l, len(c.Nodes), func(ctx context.Context, nodeIdx int) (*RunResultDetails, error) { node := c.Nodes[nodeIdx] var cmd string @@ -348,11 +348,11 @@ func (c *SyncedCluster) ExecSQL( res.CombinedOut = out if res.Err != nil { - return res, errors.Wrapf(res.Err, "~ %s\n%s", cmd, res.CombinedOut) + res.Err = errors.Wrapf(res.Err, "~ %s\n%s", cmd, res.CombinedOut) } resultChan <- result{node: node, output: string(res.CombinedOut)} return res, nil - }, DefaultSSHRetryOpts); err != nil { + }, WithDisplay(display), WithWaitOnFail()); err != nil { return err } diff --git a/pkg/roachprod/prometheus/prometheus.go b/pkg/roachprod/prometheus/prometheus.go index ee6c5cea6b8c..22a59af72a24 100644 --- a/pkg/roachprod/prometheus/prometheus.go +++ b/pkg/roachprod/prometheus/prometheus.go @@ -468,6 +468,7 @@ docker run --privileged -p 9090:9090 \ } return c.Get( + ctx, l, promNode, "/tmp/prometheus/prometheus-snapshot.tar.gz", diff --git a/pkg/roachprod/roachprod.go b/pkg/roachprod/roachprod.go index 6207af516a07..ff643fb0248f 100644 --- a/pkg/roachprod/roachprod.go +++ b/pkg/roachprod/roachprod.go @@ -366,6 +366,7 @@ func Run( secure bool, stdout, stderr io.Writer, cmdArray []string, + opts ...install.ParallelOption, ) error { if err := LoadClusters(); err != nil { return err @@ -386,7 +387,7 @@ func Run( if len(title) > 30 { title = title[:27] + "..." } - return c.Run(ctx, l, stdout, stderr, c.Nodes, title, cmd) + return c.Run(ctx, l, stdout, stderr, c.Nodes, title, cmd, opts...) } // RunWithDetails runs a command on the nodes in a cluster. @@ -877,7 +878,7 @@ func Put( // Get copies a remote file from the nodes in a cluster. // If the file is retrieved from multiple nodes the destination // file name will be prefixed with the node number. -func Get(l *logger.Logger, clusterName, src, dest string) error { +func Get(ctx context.Context, l *logger.Logger, clusterName, src, dest string) error { if err := LoadClusters(); err != nil { return err } @@ -885,7 +886,7 @@ func Get(l *logger.Logger, clusterName, src, dest string) error { if err != nil { return err } - return c.Get(l, c.Nodes, src, dest) + return c.Get(ctx, l, c.Nodes, src, dest) } type PGURLOptions struct { @@ -913,14 +914,13 @@ func PgURL( ips[i] = c.VMs[nodes[i]-1].PublicIP } } else { - var err error - if err := c.Parallel(l, "", len(nodes), 0, func(i int) (*install.RunResultDetails, error) { + if err := c.Parallel(ctx, l, len(nodes), func(ctx context.Context, i int) (*install.RunResultDetails, error) { node := nodes[i] res := &install.RunResultDetails{Node: node} res.Stdout, res.Err = c.GetInternalIP(node) ips[i] = res.Stdout - return res, err - }, install.DefaultSSHRetryOpts); err != nil { + return res, nil + }); err != nil { return nil, err } } @@ -1016,7 +1016,7 @@ type PprofOpts struct { } // Pprof TODO -func Pprof(l *logger.Logger, clusterName string, opts PprofOpts) error { +func Pprof(ctx context.Context, l *logger.Logger, clusterName string, opts PprofOpts) error { if err := LoadClusters(); err != nil { return err } @@ -1048,7 +1048,7 @@ func Pprof(l *logger.Logger, clusterName string, opts PprofOpts) error { httpClient := httputil.NewClientWithTimeout(timeout) startTime := timeutil.Now().Unix() nodes := c.TargetNodes() - failed, err := c.ParallelE(l, description, len(nodes), 0, func(i int) (*install.RunResultDetails, error) { + err = c.Parallel(ctx, l, len(nodes), func(ctx context.Context, i int) (*install.RunResultDetails, error) { node := nodes[i] res := &install.RunResultDetails{Node: node} host := c.Host(node) @@ -1110,17 +1110,13 @@ func Pprof(l *logger.Logger, clusterName string, opts PprofOpts) error { outputFiles = append(outputFiles, outputFile) mu.Unlock() return res, nil - }, install.DefaultSSHRetryOpts) + }, install.WithDisplay(description)) for _, s := range outputFiles { l.Printf("Created %s", s) } if err != nil { - sort.Slice(failed, func(i, j int) bool { return failed[i].Index < failed[j].Index }) - for _, f := range failed { - l.Errorf("%d: %+v: %s\n", f.Index, f.Err, f.Out) - } exit.WithCode(exit.UnspecifiedError()) } @@ -1715,11 +1711,8 @@ func sendCaptureCommand( ) error { nodes := c.TargetNodes() httpClient := httputil.NewClientWithTimeout(0 /* timeout: None */) - _, err := c.ParallelE(l, - fmt.Sprintf("Performing workload capture %s", action), - len(nodes), - 0, - func(i int) (*install.RunResultDetails, error) { + _, err := c.ParallelE(ctx, l, len(nodes), + func(ctx context.Context, i int) (*install.RunResultDetails, error) { node := nodes[i] res := &install.RunResultDetails{Node: node} host := c.Host(node) @@ -1783,7 +1776,7 @@ func sendCaptureCommand( } } return res, res.Err - }, install.DefaultSSHRetryOpts) + }, install.WithDisplay(fmt.Sprintf("Performing workload capture %s", action))) return err }