From 4f4d10bc950048ab2724ac85b518eeb129e829a2 Mon Sep 17 00:00:00 2001 From: Miral Gadani Date: Sat, 1 Apr 2023 18:15:54 +0000 Subject: [PATCH 1/6] roachprod: make c.Parallel fail fast when a command on any of the nodes fails. Previously, when encountering an error on a node, we would wait until the command finished on all the specified nodes until returning to the caller. This meant long-running workloads would continue until completion despite having already failed on at least one of the nodes. The context is now threaded through the call stack, allowing the c.Parallel to issue a cancellation for early termination. Resolves: #98520 Release note: none Epic: none --- pkg/cmd/roachprod/main.go | 4 +- pkg/cmd/roachtest/cluster.go | 2 +- pkg/roachprod/install/cluster_synced.go | 94 +++++++++++--------- pkg/roachprod/install/cluster_synced_test.go | 17 ++-- pkg/roachprod/install/cockroach.go | 4 +- pkg/roachprod/prometheus/prometheus.go | 1 + pkg/roachprod/roachprod.go | 15 ++-- 7 files changed, 73 insertions(+), 64 deletions(-) diff --git a/pkg/cmd/roachprod/main.go b/pkg/cmd/roachprod/main.go index 299571a3de19..f0a5be5128d6 100644 --- a/pkg/cmd/roachprod/main.go +++ b/pkg/cmd/roachprod/main.go @@ -802,7 +802,7 @@ multiple nodes the destination file name will be prefixed with the node number. if len(args) == 3 { dest = args[2] } - return roachprod.Get(config.Logger, args[0], src, dest) + return roachprod.Get(context.Background(), config.Logger, args[0], src, dest) }), } @@ -858,7 +858,7 @@ Examples: if cmd.CalledAs() == "pprof-heap" { pprofOpts.Heap = true } - return roachprod.Pprof(config.Logger, args[0], pprofOpts) + return roachprod.Pprof(context.Background(), config.Logger, args[0], pprofOpts) }), } diff --git a/pkg/cmd/roachtest/cluster.go b/pkg/cmd/roachtest/cluster.go index 3a8e6ae740ac..1659d426bcde 100644 --- a/pkg/cmd/roachtest/cluster.go +++ b/pkg/cmd/roachtest/cluster.go @@ -1969,7 +1969,7 @@ func (c *clusterImpl) Get( } c.status(fmt.Sprintf("getting %v", src)) defer c.status("") - return errors.Wrap(roachprod.Get(l, c.MakeNodes(opts...), src, dest), "cluster.Get") + return errors.Wrap(roachprod.Get(ctx, l, c.MakeNodes(opts...), src, dest), "cluster.Get") } // Put a string into the specified file on the remote(s). diff --git a/pkg/roachprod/install/cluster_synced.go b/pkg/roachprod/install/cluster_synced.go index 303d5fa0dee8..cd8fe04cc841 100644 --- a/pkg/roachprod/install/cluster_synced.go +++ b/pkg/roachprod/install/cluster_synced.go @@ -155,10 +155,13 @@ var defaultSCPRetry = newRunRetryOpts(defaultRetryOpt, // captured in a *RunResultDetails[] in Run, but here we may enrich with attempt // number and a wrapper error. func runWithMaybeRetry( - l *logger.Logger, retryOpts *RunRetryOpts, f func() (*RunResultDetails, error), + ctx context.Context, + l *logger.Logger, + retryOpts *RunRetryOpts, + f func(ctx context.Context) (*RunResultDetails, error), ) (*RunResultDetails, error) { if retryOpts == nil { - res, err := f() + res, err := f(ctx) res.Attempt = 1 return res, err } @@ -167,8 +170,8 @@ func runWithMaybeRetry( var err error var cmdErr error - for r := retry.Start(retryOpts.Options); r.Next(); { - res, err = f() + for r := retry.StartWithCtx(ctx, retryOpts.Options); r.Next(); { + res, err = f(ctx) res.Attempt = r.CurrentAttempt() + 1 // nil err (denoting a roachprod error) indicates a potentially retryable res.Err if err == nil && res.Err != nil { @@ -193,8 +196,10 @@ func runWithMaybeRetry( return res, err } -func scpWithRetry(l *logger.Logger, src, dest string) (*RunResultDetails, error) { - return runWithMaybeRetry(l, defaultSCPRetry, func() (*RunResultDetails, error) { return scp(l, src, dest) }) +func scpWithRetry( + ctx context.Context, l *logger.Logger, src, dest string, +) (*RunResultDetails, error) { + return runWithMaybeRetry(ctx, l, defaultSCPRetry, func(ctx context.Context) (*RunResultDetails, error) { return scp(l, src, dest) }) } // Host returns the public IP of a node. @@ -431,7 +436,7 @@ func (c *SyncedCluster) kill( // `kill -9` without wait is never what a caller wants. See #77334. wait = true } - return c.Parallel(l, display, len(c.Nodes), 0, func(i int) (*RunResultDetails, error) { + return c.Parallel(ctx, l, display, len(c.Nodes), 0, func(ctx context.Context, i int) (*RunResultDetails, error) { node := c.Nodes[i] var waitCmd string @@ -494,7 +499,7 @@ func (c *SyncedCluster) Wipe(ctx context.Context, l *logger.Logger, preserveCert if err := c.Stop(ctx, l, 9, true /* wait */, 0 /* maxWait */); err != nil { return err } - return c.Parallel(l, display, len(c.Nodes), 0, func(i int) (*RunResultDetails, error) { + return c.Parallel(ctx, l, display, len(c.Nodes), 0, func(ctx context.Context, i int) (*RunResultDetails, error) { node := c.Nodes[i] var cmd string if c.IsLocal() { @@ -540,7 +545,7 @@ type NodeStatus struct { func (c *SyncedCluster) Status(ctx context.Context, l *logger.Logger) ([]NodeStatus, error) { display := fmt.Sprintf("%s: status", c.Name) results := make([]NodeStatus, len(c.Nodes)) - if err := c.Parallel(l, display, len(c.Nodes), 0, func(i int) (*RunResultDetails, error) { + if err := c.Parallel(ctx, l, display, len(c.Nodes), 0, func(ctx context.Context, i int) (*RunResultDetails, error) { node := c.Nodes[i] binary := cockroachNodeBinary(c, node) @@ -877,7 +882,7 @@ func (c *SyncedCluster) Run( results := make([]*RunResultDetails, len(nodes)) // A result is the output of running a command (could be interpreted as an error) - if _, err := c.ParallelE(l, display, len(nodes), 0, func(i int) (*RunResultDetails, error) { + if _, err := c.ParallelE(ctx, l, display, len(nodes), 0, func(ctx context.Context, i int) (*RunResultDetails, error) { // An err returned here is an unexpected state within roachprod (non-command error). // For errors that occur as part of running a command over ssh, the `result` will contain // the actual error on a specific node. @@ -929,7 +934,7 @@ func (c *SyncedCluster) RunWithDetails( // Both return values are explicitly ignored because, in this case, resultPtrs // capture both error and result state for each node - _, _ = c.ParallelE(l, display, len(nodes), 0, func(i int) (*RunResultDetails, error) { //nolint:errcheck + _, _ = c.ParallelE(ctx, l, display, len(nodes), 0, func(ctx context.Context, i int) (*RunResultDetails, error) { //nolint:errcheck result, err := c.runCmdOnSingleNode(ctx, l, nodes[i], cmd, false, l.Stdout, l.Stderr) resultPtrs[i] = result return result, err @@ -978,7 +983,7 @@ func (c *SyncedCluster) RepeatRun( func (c *SyncedCluster) Wait(ctx context.Context, l *logger.Logger) error { display := fmt.Sprintf("%s: waiting for nodes to start", c.Name) errs := make([]error, len(c.Nodes)) - if err := c.Parallel(l, display, len(c.Nodes), 0, func(i int) (*RunResultDetails, error) { + if err := c.Parallel(ctx, l, display, len(c.Nodes), 0, func(ctx context.Context, i int) (*RunResultDetails, error) { node := c.Nodes[i] res := &RunResultDetails{Node: node} cmd := "test -e /mnt/data1/.roachprod-initialized" @@ -1041,7 +1046,7 @@ func (c *SyncedCluster) SetupSSH(ctx context.Context, l *logger.Logger) error { // Generate an ssh key that we'll distribute to all of the nodes in the // cluster in order to allow inter-node ssh. var sshTar []byte - if err := c.Parallel(l, "generating ssh key", 1, 0, func(i int) (*RunResultDetails, error) { + if err := c.Parallel(ctx, l, "generating ssh key", 1, 0, func(ctx context.Context, i int) (*RunResultDetails, error) { // Create the ssh key and then tar up the public, private and // authorized_keys files and output them to stdout. We'll take this output // and pipe it back into tar on the other nodes in the cluster. @@ -1075,7 +1080,7 @@ tar cf - .ssh/id_rsa .ssh/id_rsa.pub .ssh/authorized_keys // Skip the first node which is where we generated the key. nodes := c.Nodes[1:] - if err := c.Parallel(l, "distributing ssh key", len(nodes), 0, func(i int) (*RunResultDetails, error) { + if err := c.Parallel(ctx, l, "distributing ssh key", len(nodes), 0, func(ctx context.Context, i int) (*RunResultDetails, error) { node := nodes[i] cmd := `tar xf -` @@ -1124,7 +1129,7 @@ tar cf - .ssh/id_rsa .ssh/id_rsa.pub .ssh/authorized_keys providerKnownHostData := make(map[string][]byte) providers := maps.Keys(providerPrivateIPs) // Only need to scan on the first node of each provider. - if err := c.Parallel(l, "scanning hosts", len(providers), 0, func(i int) (*RunResultDetails, error) { + if err := c.Parallel(ctx, l, "scanning hosts", len(providers), 0, func(ctx context.Context, i int) (*RunResultDetails, error) { provider := providers[i] node := providerPrivateIPs[provider][0].node // Scan a combination of all remote IPs and local IPs pertaining to this @@ -1180,7 +1185,7 @@ exit 1 return err } - if err := c.Parallel(l, "distributing known_hosts", len(c.Nodes), 0, func(i int) (*RunResultDetails, error) { + if err := c.Parallel(ctx, l, "distributing known_hosts", len(c.Nodes), 0, func(ctx context.Context, i int) (*RunResultDetails, error) { node := c.Nodes[i] provider := c.VMs[node-1].Provider const cmd = ` @@ -1233,7 +1238,7 @@ fi // additional authorized_keys to both the current user (your username on // gce and the shared user on aws) as well as to the shared user on both // platforms. - if err := c.Parallel(l, "adding additional authorized keys", len(c.Nodes), 0, func(i int) (*RunResultDetails, error) { + if err := c.Parallel(ctx, l, "adding additional authorized keys", len(c.Nodes), 0, func(ctx context.Context, i int) (*RunResultDetails, error) { node := c.Nodes[i] const cmd = ` keys_data="$(cat)" @@ -1298,7 +1303,7 @@ func (c *SyncedCluster) DistributeCerts(ctx context.Context, l *logger.Logger) e // Generate the ca, client and node certificates on the first node. var msg string display := fmt.Sprintf("%s: initializing certs", c.Name) - if err := c.Parallel(l, display, 1, 0, func(i int) (*RunResultDetails, error) { + if err := c.Parallel(ctx, l, display, 1, 0, func(ctx context.Context, i int) (*RunResultDetails, error) { var cmd string if c.IsLocal() { cmd = fmt.Sprintf(`cd %s ; `, c.localVMDir(1)) @@ -1346,7 +1351,7 @@ tar cvf %[3]s certs exit.WithCode(exit.UnspecifiedError()) } - tarfile, cleanup, err := c.getFileFromFirstNode(l, certsTarName) + tarfile, cleanup, err := c.getFileFromFirstNode(ctx, l, certsTarName) if err != nil { return err } @@ -1379,7 +1384,7 @@ func (c *SyncedCluster) DistributeTenantCerts( return err } - tarfile, cleanup, err := hostCluster.getFileFromFirstNode(l, tenantCertsTarName) + tarfile, cleanup, err := hostCluster.getFileFromFirstNode(ctx, l, tenantCertsTarName) if err != nil { return err } @@ -1397,7 +1402,7 @@ func (c *SyncedCluster) createTenantCertBundle( ctx context.Context, l *logger.Logger, bundleName string, tenantID int, nodeNames []string, ) error { display := fmt.Sprintf("%s: initializing tenant certs", c.Name) - return c.Parallel(l, display, 1, 0, func(i int) (*RunResultDetails, error) { + return c.Parallel(ctx, l, display, 1, 0, func(ctx context.Context, i int) (*RunResultDetails, error) { node := c.Nodes[i] cmd := "set -e;" @@ -1448,7 +1453,7 @@ fi // filename is assumed to be relative from the home directory of the node's // user. func (c *SyncedCluster) getFileFromFirstNode( - l *logger.Logger, name string, + ctx context.Context, l *logger.Logger, name string, ) (string, func(), error) { var tmpfileName string cleanup := func() {} @@ -1465,7 +1470,7 @@ func (c *SyncedCluster) getFileFromFirstNode( } srcFileName := fmt.Sprintf("%s@%s:%s", c.user(1), c.Host(1), name) - if res, _ := scpWithRetry(l, srcFileName, tmpfile.Name()); res.Err != nil { + if res, _ := scpWithRetry(ctx, l, srcFileName, tmpfile.Name()); res.Err != nil { cleanup() return "", nil, res.Err } @@ -1499,7 +1504,7 @@ func (c *SyncedCluster) fileExistsOnFirstNode( ) bool { var existsErr error display := fmt.Sprintf("%s: checking %s", c.Name, path) - if err := c.Parallel(l, display, 1, 0, func(i int) (*RunResultDetails, error) { + if err := c.Parallel(ctx, l, display, 1, 0, func(ctx context.Context, i int) (*RunResultDetails, error) { node := c.Nodes[i] sess := c.newSession(l, node, `test -e `+path) defer sess.Close() @@ -1556,7 +1561,7 @@ func (c *SyncedCluster) distributeLocalCertsTar( } display := c.Name + ": distributing certs" - return c.Parallel(l, display, len(nodes), 0, func(i int) (*RunResultDetails, error) { + return c.Parallel(ctx, l, display, len(nodes), 0, func(ctx context.Context, i int) (*RunResultDetails, error) { node := nodes[i] var cmd string if c.IsLocal() { @@ -1777,7 +1782,7 @@ func (c *SyncedCluster) Put( return } - res, _ := scpWithRetry(l, from, to) + res, _ := scpWithRetry(ctx, l, from, to) results <- result{i, res.Err} if res.Err != nil { @@ -2028,7 +2033,9 @@ func (c *SyncedCluster) Logs( } // Get TODO(peter): document -func (c *SyncedCluster) Get(l *logger.Logger, nodes Nodes, src, dest string) error { +func (c *SyncedCluster) Get( + ctx context.Context, l *logger.Logger, nodes Nodes, src, dest string, +) error { if err := c.validateHost(context.TODO(), l, nodes[0]); err != nil { return err } @@ -2146,7 +2153,7 @@ func (c *SyncedCluster) Get(l *logger.Logger, nodes Nodes, src, dest string) err return } - res, _ := scpWithRetry(l, fmt.Sprintf("%s@%s:%s", c.user(nodes[0]), c.Host(nodes[i]), src), dest) + res, _ := scpWithRetry(ctx, l, fmt.Sprintf("%s@%s:%s", c.user(nodes[0]), c.Host(nodes[i]), src), dest) if res.Err == nil { // Make sure all created files and directories are world readable. // The CRDB process intentionally sets a 0007 umask (resulting in @@ -2262,7 +2269,7 @@ func (c *SyncedCluster) pghosts( ctx context.Context, l *logger.Logger, nodes Nodes, ) (map[Node]string, error) { ips := make([]string, len(nodes)) - if err := c.Parallel(l, "", len(nodes), 0, func(i int) (*RunResultDetails, error) { + if err := c.Parallel(ctx, l, "", len(nodes), 0, func(ctx context.Context, i int) (*RunResultDetails, error) { node := nodes[i] res := &RunResultDetails{Node: node} res.Stdout, res.Err = c.GetInternalIP(node) @@ -2389,13 +2396,14 @@ type ParallelResult struct { // // See ParallelE for more information. func (c *SyncedCluster) Parallel( + ctx context.Context, l *logger.Logger, display string, count, concurrency int, - fn func(i int) (*RunResultDetails, error), + fn func(ctx context.Context, i int) (*RunResultDetails, error), runRetryOpts *RunRetryOpts, ) error { - failed, err := c.ParallelE(l, display, count, concurrency, fn, runRetryOpts) + failed, err := c.ParallelE(ctx, l, display, count, concurrency, fn, runRetryOpts) if err != nil { sort.Slice(failed, func(i, j int) bool { return failed[i].Index < failed[j].Index }) for _, f := range failed { @@ -2420,10 +2428,11 @@ func (c *SyncedCluster) Parallel( // If err is non-nil, the slice of ParallelResults will contain the // results from any of the failed invocations. func (c *SyncedCluster) ParallelE( + ctx context.Context, l *logger.Logger, display string, count, concurrency int, - fn func(i int) (*RunResultDetails, error), + fn func(ctx context.Context, i int) (*RunResultDetails, error), runRetryOpts *RunRetryOpts, ) ([]ParallelResult, error) { if concurrency == 0 || concurrency > count { @@ -2437,12 +2446,18 @@ func (c *SyncedCluster) ParallelE( var wg sync.WaitGroup wg.Add(count) + groupCtx, groupCancel := context.WithCancel(ctx) + defer groupCancel() var index int startNext := func() { + // If we needed to react to a context cancellation here we would need to + // nest this goroutine in another one and select on the groupCtx. However, + // since anything intensive here is a command over ssh, and we are threading + // the context through, a cancellation will be handled by the command itself. go func(i int) { defer wg.Done() - res, err := runWithMaybeRetry(l, runRetryOpts, func() (*RunResultDetails, error) { return fn(i) }) - results <- ParallelResult{i, res.CombinedOut, err} + res, err := runWithMaybeRetry(groupCtx, l, runRetryOpts, func(ctx context.Context) (*RunResultDetails, error) { return fn(ctx, i) }) + results <- ParallelResult{i, res.CombinedOut, errors.CombineErrors(err, res.Err)} }(index) index++ } @@ -2474,7 +2489,6 @@ func (c *SyncedCluster) ParallelE( } defer ticker.Stop() complete := make([]bool, count) - var failed []ParallelResult var spinner = []string{"|", "/", "-", "\\"} spinnerIdx := 0 @@ -2487,7 +2501,8 @@ func (c *SyncedCluster) ParallelE( } case r, ok := <-results: if r.Err != nil { - failed = append(failed, r) + groupCancel() + return nil, errors.Wrap(r.Err, "parallel execution failure") } done = !ok if ok { @@ -2520,13 +2535,6 @@ func (c *SyncedCluster) ParallelE( fmt.Fprintf(out, "\n") } - if len(failed) > 0 { - var err error - for _, res := range failed { - err = errors.CombineErrors(err, res.Err) - } - return failed, errors.Wrap(err, "parallel execution failure") - } return nil, nil } diff --git a/pkg/roachprod/install/cluster_synced_test.go b/pkg/roachprod/install/cluster_synced_test.go index 89f941679b64..f52d77e8a69c 100644 --- a/pkg/roachprod/install/cluster_synced_test.go +++ b/pkg/roachprod/install/cluster_synced_test.go @@ -11,6 +11,7 @@ package install import ( + "context" "fmt" "io" "testing" @@ -96,21 +97,21 @@ func TestRunWithMaybeRetry(t *testing.T) { attempt := 0 cases := []struct { - f func() (*RunResultDetails, error) + f func(ctx context.Context) (*RunResultDetails, error) shouldRetryFn func(*RunResultDetails) bool nilRetryOpts bool expectedAttempts int shouldError bool }{ { // 1. Happy path: no error, no retry required - f: func() (*RunResultDetails, error) { + f: func(ctx context.Context) (*RunResultDetails, error) { return newResult(0), nil }, expectedAttempts: 1, shouldError: false, }, { // 2. Error, but with no retries - f: func() (*RunResultDetails, error) { + f: func(ctx context.Context) (*RunResultDetails, error) { return newResult(1), nil }, shouldRetryFn: func(*RunResultDetails) bool { @@ -120,14 +121,14 @@ func TestRunWithMaybeRetry(t *testing.T) { shouldError: true, }, { // 3. Error, but no retry function specified - f: func() (*RunResultDetails, error) { + f: func(ctx context.Context) (*RunResultDetails, error) { return newResult(1), nil }, expectedAttempts: 3, shouldError: true, }, { // 4. Error, with retries exhausted - f: func() (*RunResultDetails, error) { + f: func(ctx context.Context) (*RunResultDetails, error) { return newResult(255), nil }, shouldRetryFn: func(d *RunResultDetails) bool { return d.RemoteExitStatus == 255 }, @@ -135,7 +136,7 @@ func TestRunWithMaybeRetry(t *testing.T) { shouldError: true, }, { // 5. Eventual success after retries - f: func() (*RunResultDetails, error) { + f: func(ctx context.Context) (*RunResultDetails, error) { attempt++ if attempt == 3 { return newResult(0), nil @@ -147,7 +148,7 @@ func TestRunWithMaybeRetry(t *testing.T) { shouldError: false, }, { // 6. Error, runs once because nil retryOpts - f: func() (*RunResultDetails, error) { + f: func(ctx context.Context) (*RunResultDetails, error) { return newResult(255), nil }, nilRetryOpts: true, @@ -163,7 +164,7 @@ func TestRunWithMaybeRetry(t *testing.T) { if !tc.nilRetryOpts { retryOpts = newRunRetryOpts(testRetryOpts, tc.shouldRetryFn) } - res, _ := runWithMaybeRetry(l, retryOpts, tc.f) + res, _ := runWithMaybeRetry(context.Background(), l, retryOpts, tc.f) require.Equal(t, tc.shouldError, res.Err != nil) require.Equal(t, tc.expectedAttempts, res.Attempt) diff --git a/pkg/roachprod/install/cockroach.go b/pkg/roachprod/install/cockroach.go index f35ecde466eb..81060f1df77a 100644 --- a/pkg/roachprod/install/cockroach.go +++ b/pkg/roachprod/install/cockroach.go @@ -183,7 +183,7 @@ func (c *SyncedCluster) Start(ctx context.Context, l *logger.Logger, startOpts S l.Printf("%s: starting nodes", c.Name) // SSH retries are disabled by passing nil RunRetryOpts - if err := c.Parallel(l, "", len(nodes), parallelism, func(nodeIdx int) (*RunResultDetails, error) { + if err := c.Parallel(ctx, l, "", len(nodes), parallelism, func(ctx context.Context, nodeIdx int) (*RunResultDetails, error) { node := nodes[nodeIdx] res := &RunResultDetails{Node: node} // NB: if cockroach started successfully, we ignore the output as it is @@ -329,7 +329,7 @@ func (c *SyncedCluster) ExecSQL( resultChan := make(chan result, len(c.Nodes)) display := fmt.Sprintf("%s: executing sql", c.Name) - if err := c.Parallel(l, display, len(c.Nodes), 0, func(nodeIdx int) (*RunResultDetails, error) { + if err := c.Parallel(ctx, l, display, len(c.Nodes), 0, func(ctx context.Context, nodeIdx int) (*RunResultDetails, error) { node := c.Nodes[nodeIdx] var cmd string diff --git a/pkg/roachprod/prometheus/prometheus.go b/pkg/roachprod/prometheus/prometheus.go index ee6c5cea6b8c..22a59af72a24 100644 --- a/pkg/roachprod/prometheus/prometheus.go +++ b/pkg/roachprod/prometheus/prometheus.go @@ -468,6 +468,7 @@ docker run --privileged -p 9090:9090 \ } return c.Get( + ctx, l, promNode, "/tmp/prometheus/prometheus-snapshot.tar.gz", diff --git a/pkg/roachprod/roachprod.go b/pkg/roachprod/roachprod.go index 6207af516a07..31d8489be57e 100644 --- a/pkg/roachprod/roachprod.go +++ b/pkg/roachprod/roachprod.go @@ -877,7 +877,7 @@ func Put( // Get copies a remote file from the nodes in a cluster. // If the file is retrieved from multiple nodes the destination // file name will be prefixed with the node number. -func Get(l *logger.Logger, clusterName, src, dest string) error { +func Get(ctx context.Context, l *logger.Logger, clusterName, src, dest string) error { if err := LoadClusters(); err != nil { return err } @@ -885,7 +885,7 @@ func Get(l *logger.Logger, clusterName, src, dest string) error { if err != nil { return err } - return c.Get(l, c.Nodes, src, dest) + return c.Get(ctx, l, c.Nodes, src, dest) } type PGURLOptions struct { @@ -913,8 +913,7 @@ func PgURL( ips[i] = c.VMs[nodes[i]-1].PublicIP } } else { - var err error - if err := c.Parallel(l, "", len(nodes), 0, func(i int) (*install.RunResultDetails, error) { + if err := c.Parallel(ctx, l, "", len(nodes), 0, func(ctx context.Context, i int) (*install.RunResultDetails, error) { node := nodes[i] res := &install.RunResultDetails{Node: node} res.Stdout, res.Err = c.GetInternalIP(node) @@ -1016,7 +1015,7 @@ type PprofOpts struct { } // Pprof TODO -func Pprof(l *logger.Logger, clusterName string, opts PprofOpts) error { +func Pprof(ctx context.Context, l *logger.Logger, clusterName string, opts PprofOpts) error { if err := LoadClusters(); err != nil { return err } @@ -1048,7 +1047,7 @@ func Pprof(l *logger.Logger, clusterName string, opts PprofOpts) error { httpClient := httputil.NewClientWithTimeout(timeout) startTime := timeutil.Now().Unix() nodes := c.TargetNodes() - failed, err := c.ParallelE(l, description, len(nodes), 0, func(i int) (*install.RunResultDetails, error) { + failed, err := c.ParallelE(ctx, l, description, len(nodes), 0, func(ctx context.Context, i int) (*install.RunResultDetails, error) { node := nodes[i] res := &install.RunResultDetails{Node: node} host := c.Host(node) @@ -1715,11 +1714,11 @@ func sendCaptureCommand( ) error { nodes := c.TargetNodes() httpClient := httputil.NewClientWithTimeout(0 /* timeout: None */) - _, err := c.ParallelE(l, + _, err := c.ParallelE(ctx, l, fmt.Sprintf("Performing workload capture %s", action), len(nodes), 0, - func(i int) (*install.RunResultDetails, error) { + func(ctx context.Context, i int) (*install.RunResultDetails, error) { node := nodes[i] res := &install.RunResultDetails{Node: node} host := c.Host(node) From 9d4495fc681e6dc86a6a240b0d5e7e56705e866e Mon Sep 17 00:00:00 2001 From: Miral Gadani Date: Sun, 2 Apr 2023 00:04:48 +0000 Subject: [PATCH 2/6] roachprod: modify signature of c.Parallel to return ParallelResult,err as we no longer wait for all the nodes to complete in the event of an error on any node. --- pkg/roachprod/install/cluster_synced.go | 27 ++++++++++++++----------- pkg/roachprod/roachprod.go | 6 +----- 2 files changed, 16 insertions(+), 17 deletions(-) diff --git a/pkg/roachprod/install/cluster_synced.go b/pkg/roachprod/install/cluster_synced.go index cd8fe04cc841..78ac631533e9 100644 --- a/pkg/roachprod/install/cluster_synced.go +++ b/pkg/roachprod/install/cluster_synced.go @@ -22,7 +22,6 @@ import ( "os/exec" "os/signal" "path/filepath" - "sort" "strings" "sync" "syscall" @@ -2405,17 +2404,16 @@ func (c *SyncedCluster) Parallel( ) error { failed, err := c.ParallelE(ctx, l, display, count, concurrency, fn, runRetryOpts) if err != nil { - sort.Slice(failed, func(i, j int) bool { return failed[i].Index < failed[j].Index }) - for _, f := range failed { - l.Errorf("%d: %+v: %s", f.Index, f.Err, f.Out) - } + l.Errorf("%d: %+v: %s\n", failed.Index, failed.Err, failed.Out) return err } return nil } // ParallelE runs the given function in parallel across the given -// nodes, returning an error if function returns an error. +// nodes. In the event of an error on any of the nodes, the function +// will fail fast and return the error and the accompanying ParallelResult. +// Successful invocation will return a nil result and nil error. // // ParallelE runs the user-defined functions on the first `count` // nodes in the cluster. It runs at most `concurrency` (or @@ -2423,10 +2421,12 @@ func (c *SyncedCluster) Parallel( // 0, then it defaults to `count`. // // The function returns a pointer to RunResultDetails as we may enrich -// the result with retry information (attempt number, wrapper error) +// the result with retry information (attempt number, wrapper error). // -// If err is non-nil, the slice of ParallelResults will contain the -// results from any of the failed invocations. +// RunRetryOpts controls the retry behavior in the case that +// the function fails, but returns a nil error. A non-nil error returned by the +// function denotes a roachprod error and will not be retried regardless of the +// retry options. func (c *SyncedCluster) ParallelE( ctx context.Context, l *logger.Logger, @@ -2434,7 +2434,7 @@ func (c *SyncedCluster) ParallelE( count, concurrency int, fn func(ctx context.Context, i int) (*RunResultDetails, error), runRetryOpts *RunRetryOpts, -) ([]ParallelResult, error) { +) (ParallelResult, error) { if concurrency == 0 || concurrency > count { concurrency = count } @@ -2501,8 +2501,11 @@ func (c *SyncedCluster) ParallelE( } case r, ok := <-results: if r.Err != nil { + // We no longer wait for failures from other goroutines but instead cancel the context. + // If required, we could restore or control the previous behavior by not cancelling and + // and returning here, but instead return a slice at the end. groupCancel() - return nil, errors.Wrap(r.Err, "parallel execution failure") + return r, errors.Wrap(r.Err, "parallel execution failure") } done = !ok if ok { @@ -2535,7 +2538,7 @@ func (c *SyncedCluster) ParallelE( fmt.Fprintf(out, "\n") } - return nil, nil + return ParallelResult{}, nil } // Init initializes the cluster. It does it through node 1 (as per TargetNodes) diff --git a/pkg/roachprod/roachprod.go b/pkg/roachprod/roachprod.go index 31d8489be57e..4f7c9ba149a2 100644 --- a/pkg/roachprod/roachprod.go +++ b/pkg/roachprod/roachprod.go @@ -1047,7 +1047,7 @@ func Pprof(ctx context.Context, l *logger.Logger, clusterName string, opts Pprof httpClient := httputil.NewClientWithTimeout(timeout) startTime := timeutil.Now().Unix() nodes := c.TargetNodes() - failed, err := c.ParallelE(ctx, l, description, len(nodes), 0, func(ctx context.Context, i int) (*install.RunResultDetails, error) { + err = c.Parallel(ctx, l, description, len(nodes), 0, func(ctx context.Context, i int) (*install.RunResultDetails, error) { node := nodes[i] res := &install.RunResultDetails{Node: node} host := c.Host(node) @@ -1116,10 +1116,6 @@ func Pprof(ctx context.Context, l *logger.Logger, clusterName string, opts Pprof } if err != nil { - sort.Slice(failed, func(i, j int) bool { return failed[i].Index < failed[j].Index }) - for _, f := range failed { - l.Errorf("%d: %+v: %s\n", f.Index, f.Err, f.Out) - } exit.WithCode(exit.UnspecifiedError()) } From be5ea434d7190410de05e8ca250ae1c2f75e4cae Mon Sep 17 00:00:00 2001 From: Miral Gadani Date: Tue, 11 Apr 2023 16:33:33 +0000 Subject: [PATCH 3/6] roachprod: `c.ParallelE` can be made optionally made to fail slow however sometimes it is desirable for the command to complete running on all nodes despite failures. e.g. `roachprod run 'ls /mnt/data1/cockroach'` or when collecting dmesg logs after a test. This commit has also switched to functional options for c.Parallel for more flexibility and allow calls to specify WithWaitOnFail(). Epic: none Fixes: #101150 Release note: None --- pkg/cmd/roachprod/main.go | 3 +- pkg/roachprod/install/cluster_synced.go | 221 +++++++++++++++++------- pkg/roachprod/install/cockroach.go | 10 +- pkg/roachprod/roachprod.go | 20 +-- 4 files changed, 170 insertions(+), 84 deletions(-) diff --git a/pkg/cmd/roachprod/main.go b/pkg/cmd/roachprod/main.go index f0a5be5128d6..cb0c28a47684 100644 --- a/pkg/cmd/roachprod/main.go +++ b/pkg/cmd/roachprod/main.go @@ -655,7 +655,8 @@ var runCmd = &cobra.Command{ `, Args: cobra.MinimumNArgs(1), Run: wrap(func(_ *cobra.Command, args []string) error { - return roachprod.Run(context.Background(), config.Logger, args[0], extraSSHOptions, tag, secure, os.Stdout, os.Stderr, args[1:]) + return roachprod.Run(context.Background(), config.Logger, args[0], extraSSHOptions, tag, + secure, os.Stdout, os.Stderr, args[1:], install.WithWaitOnFail()) }), } diff --git a/pkg/roachprod/install/cluster_synced.go b/pkg/roachprod/install/cluster_synced.go index 78ac631533e9..13658872f351 100644 --- a/pkg/roachprod/install/cluster_synced.go +++ b/pkg/roachprod/install/cluster_synced.go @@ -22,6 +22,7 @@ import ( "os/exec" "os/signal" "path/filepath" + "sort" "strings" "sync" "syscall" @@ -172,7 +173,7 @@ func runWithMaybeRetry( for r := retry.StartWithCtx(ctx, retryOpts.Options); r.Next(); { res, err = f(ctx) res.Attempt = r.CurrentAttempt() + 1 - // nil err (denoting a roachprod error) indicates a potentially retryable res.Err + // nil err (non-nil denotes a roachprod error) indicates a potentially retryable res.Err if err == nil && res.Err != nil { cmdErr = errors.CombineErrors(cmdErr, res.Err) if retryOpts.shouldRetryFn == nil || retryOpts.shouldRetryFn(res) { @@ -435,7 +436,7 @@ func (c *SyncedCluster) kill( // `kill -9` without wait is never what a caller wants. See #77334. wait = true } - return c.Parallel(ctx, l, display, len(c.Nodes), 0, func(ctx context.Context, i int) (*RunResultDetails, error) { + return c.Parallel(ctx, l, len(c.Nodes), func(ctx context.Context, i int) (*RunResultDetails, error) { node := c.Nodes[i] var waitCmd string @@ -489,7 +490,7 @@ fi`, res := newRunResultDetails(node, cmdErr) res.CombinedOut = out return res, res.Err - }, nil) // Disable SSH Retries + }, WithDisplay(display), WithRetryOpts(nil)) // Disable SSH Retries } // Wipe TODO(peter): document @@ -498,7 +499,7 @@ func (c *SyncedCluster) Wipe(ctx context.Context, l *logger.Logger, preserveCert if err := c.Stop(ctx, l, 9, true /* wait */, 0 /* maxWait */); err != nil { return err } - return c.Parallel(ctx, l, display, len(c.Nodes), 0, func(ctx context.Context, i int) (*RunResultDetails, error) { + return c.Parallel(ctx, l, len(c.Nodes), func(ctx context.Context, i int) (*RunResultDetails, error) { node := c.Nodes[i] var cmd string if c.IsLocal() { @@ -528,7 +529,7 @@ sudo rm -fr logs && res := newRunResultDetails(node, cmdErr) res.CombinedOut = out return res, res.Err - }, DefaultSSHRetryOpts) + }, WithDisplay(display)) } // NodeStatus contains details about the status of a node. @@ -544,7 +545,7 @@ type NodeStatus struct { func (c *SyncedCluster) Status(ctx context.Context, l *logger.Logger) ([]NodeStatus, error) { display := fmt.Sprintf("%s: status", c.Name) results := make([]NodeStatus, len(c.Nodes)) - if err := c.Parallel(ctx, l, display, len(c.Nodes), 0, func(ctx context.Context, i int) (*RunResultDetails, error) { + if err := c.Parallel(ctx, l, len(c.Nodes), func(ctx context.Context, i int) (*RunResultDetails, error) { node := c.Nodes[i] binary := cockroachNodeBinary(c, node) @@ -580,7 +581,7 @@ fi results[i] = NodeStatus{Running: true, Version: info[0], Pid: info[1]} return res, nil - }, DefaultSSHRetryOpts); err != nil { + }, WithDisplay(display)); err != nil { return nil, err } for i := 0; i < len(results); i++ { @@ -869,7 +870,12 @@ func (c *SyncedCluster) runCmdOnSingleNode( // title: A description of the command being run that is output to the logs. // cmd: The command to run. func (c *SyncedCluster) Run( - ctx context.Context, l *logger.Logger, stdout, stderr io.Writer, nodes Nodes, title, cmd string, + ctx context.Context, + l *logger.Logger, + stdout, stderr io.Writer, + nodes Nodes, + title, cmd string, + opts ...ParallelOption, ) error { // Stream output if we're running the command on only 1 node. stream := len(nodes) == 1 @@ -881,14 +887,14 @@ func (c *SyncedCluster) Run( results := make([]*RunResultDetails, len(nodes)) // A result is the output of running a command (could be interpreted as an error) - if _, err := c.ParallelE(ctx, l, display, len(nodes), 0, func(ctx context.Context, i int) (*RunResultDetails, error) { + if _, err := c.ParallelE(ctx, l, len(nodes), func(ctx context.Context, i int) (*RunResultDetails, error) { // An err returned here is an unexpected state within roachprod (non-command error). // For errors that occur as part of running a command over ssh, the `result` will contain // the actual error on a specific node. result, err := c.runCmdOnSingleNode(ctx, l, nodes[i], cmd, !stream, stdout, stderr) results[i] = result return result, err - }, DefaultSSHRetryOpts); err != nil { + }, append(opts, WithDisplay(display))...); err != nil { return err } @@ -899,6 +905,12 @@ func (c *SyncedCluster) Run( func processResults(results []*RunResultDetails, stream bool, stdout io.Writer) error { var resultWithError *RunResultDetails for i, r := range results { + // We no longer wait for all nodes to complete before returning in the case of an error (#100403) + // which means that some node results may be nil. + if r == nil { + continue + } + if !stream { fmt.Fprintf(stdout, " %2d: %s\n%v\n", i+1, strings.TrimSpace(string(r.CombinedOut)), r.Err) } @@ -922,6 +934,7 @@ func processResults(results []*RunResultDetails, stream bool, stdout io.Writer) } // RunWithDetails runs a command on the specified nodes and returns results details and an error. +// This will wait for all commands to complete before returning unless encountering a roachprod error. func (c *SyncedCluster) RunWithDetails( ctx context.Context, l *logger.Logger, nodes Nodes, title, cmd string, ) ([]RunResultDetails, error) { @@ -931,13 +944,14 @@ func (c *SyncedCluster) RunWithDetails( // be processed further by the caller. resultPtrs := make([]*RunResultDetails, len(nodes)) - // Both return values are explicitly ignored because, in this case, resultPtrs - // capture both error and result state for each node - _, _ = c.ParallelE(ctx, l, display, len(nodes), 0, func(ctx context.Context, i int) (*RunResultDetails, error) { //nolint:errcheck + // Failing slow here allows us to capture the output of all nodes even if one fails with a command error. + if _, err := c.ParallelE(ctx, l, len(nodes), func(ctx context.Context, i int) (*RunResultDetails, error) { //nolint:errcheck result, err := c.runCmdOnSingleNode(ctx, l, nodes[i], cmd, false, l.Stdout, l.Stderr) resultPtrs[i] = result return result, err - }, DefaultSSHRetryOpts) + }, WithDisplay(display), WithWaitOnFail()); err != nil { + return nil, err + } // Return values to preserve API results := make([]RunResultDetails, len(nodes)) @@ -982,7 +996,7 @@ func (c *SyncedCluster) RepeatRun( func (c *SyncedCluster) Wait(ctx context.Context, l *logger.Logger) error { display := fmt.Sprintf("%s: waiting for nodes to start", c.Name) errs := make([]error, len(c.Nodes)) - if err := c.Parallel(ctx, l, display, len(c.Nodes), 0, func(ctx context.Context, i int) (*RunResultDetails, error) { + if err := c.Parallel(ctx, l, len(c.Nodes), func(ctx context.Context, i int) (*RunResultDetails, error) { node := c.Nodes[i] res := &RunResultDetails{Node: node} cmd := "test -e /mnt/data1/.roachprod-initialized" @@ -1000,7 +1014,7 @@ func (c *SyncedCluster) Wait(ctx context.Context, l *logger.Logger) error { errs[i] = errors.New("timed out after 5m") res.Err = errs[i] return res, nil - }, nil); err != nil { + }, WithDisplay(display), WithRetryOpts(nil)); err != nil { return err } @@ -1045,7 +1059,7 @@ func (c *SyncedCluster) SetupSSH(ctx context.Context, l *logger.Logger) error { // Generate an ssh key that we'll distribute to all of the nodes in the // cluster in order to allow inter-node ssh. var sshTar []byte - if err := c.Parallel(ctx, l, "generating ssh key", 1, 0, func(ctx context.Context, i int) (*RunResultDetails, error) { + if err := c.Parallel(ctx, l, 1, func(ctx context.Context, i int) (*RunResultDetails, error) { // Create the ssh key and then tar up the public, private and // authorized_keys files and output them to stdout. We'll take this output // and pipe it back into tar on the other nodes in the cluster. @@ -1073,13 +1087,13 @@ tar cf - .ssh/id_rsa .ssh/id_rsa.pub .ssh/authorized_keys } sshTar = []byte(res.Stdout) return res, nil - }, DefaultSSHRetryOpts); err != nil { + }, WithDisplay("generating ssh key")); err != nil { return err } // Skip the first node which is where we generated the key. nodes := c.Nodes[1:] - if err := c.Parallel(ctx, l, "distributing ssh key", len(nodes), 0, func(ctx context.Context, i int) (*RunResultDetails, error) { + if err := c.Parallel(ctx, l, len(nodes), func(ctx context.Context, i int) (*RunResultDetails, error) { node := nodes[i] cmd := `tar xf -` @@ -1096,7 +1110,7 @@ tar cf - .ssh/id_rsa .ssh/id_rsa.pub .ssh/authorized_keys return res, errors.Wrapf(res.Err, "%s: output:\n%s", cmd, res.CombinedOut) } return res, nil - }, DefaultSSHRetryOpts); err != nil { + }, WithDisplay("distributing ssh key")); err != nil { return err } @@ -1128,7 +1142,7 @@ tar cf - .ssh/id_rsa .ssh/id_rsa.pub .ssh/authorized_keys providerKnownHostData := make(map[string][]byte) providers := maps.Keys(providerPrivateIPs) // Only need to scan on the first node of each provider. - if err := c.Parallel(ctx, l, "scanning hosts", len(providers), 0, func(ctx context.Context, i int) (*RunResultDetails, error) { + if err := c.Parallel(ctx, l, len(providers), func(ctx context.Context, i int) (*RunResultDetails, error) { provider := providers[i] node := providerPrivateIPs[provider][0].node // Scan a combination of all remote IPs and local IPs pertaining to this @@ -1180,11 +1194,11 @@ exit 1 providerKnownHostData[provider] = stdout.Bytes() mu.Unlock() return res, nil - }, DefaultSSHRetryOpts); err != nil { + }, WithDisplay("scanning hosts")); err != nil { return err } - if err := c.Parallel(ctx, l, "distributing known_hosts", len(c.Nodes), 0, func(ctx context.Context, i int) (*RunResultDetails, error) { + if err := c.Parallel(ctx, l, len(c.Nodes), func(ctx context.Context, i int) (*RunResultDetails, error) { node := c.Nodes[i] provider := c.VMs[node-1].Provider const cmd = ` @@ -1227,7 +1241,7 @@ fi return res, errors.Wrapf(res.Err, "%s: output:\n%s", cmd, res.CombinedOut) } return res, nil - }, DefaultSSHRetryOpts); err != nil { + }, WithDisplay("distributing known_hosts")); err != nil { return err } @@ -1237,7 +1251,7 @@ fi // additional authorized_keys to both the current user (your username on // gce and the shared user on aws) as well as to the shared user on both // platforms. - if err := c.Parallel(ctx, l, "adding additional authorized keys", len(c.Nodes), 0, func(ctx context.Context, i int) (*RunResultDetails, error) { + if err := c.Parallel(ctx, l, len(c.Nodes), func(ctx context.Context, i int) (*RunResultDetails, error) { node := c.Nodes[i] const cmd = ` keys_data="$(cat)" @@ -1274,7 +1288,7 @@ fi return res, errors.Wrapf(res.Err, "~ %s\n%s", cmd, res.CombinedOut) } return res, nil - }, DefaultSSHRetryOpts); err != nil { + }, WithDisplay("adding additional authorized keys")); err != nil { return err } } @@ -1302,7 +1316,7 @@ func (c *SyncedCluster) DistributeCerts(ctx context.Context, l *logger.Logger) e // Generate the ca, client and node certificates on the first node. var msg string display := fmt.Sprintf("%s: initializing certs", c.Name) - if err := c.Parallel(ctx, l, display, 1, 0, func(ctx context.Context, i int) (*RunResultDetails, error) { + if err := c.Parallel(ctx, l, 1, func(ctx context.Context, i int) (*RunResultDetails, error) { var cmd string if c.IsLocal() { cmd = fmt.Sprintf(`cd %s ; `, c.localVMDir(1)) @@ -1341,7 +1355,7 @@ tar cvf %[3]s certs msg = fmt.Sprintf("%s: %v", res.CombinedOut, res.Err) } return res, nil - }, DefaultSSHRetryOpts); err != nil { + }, WithDisplay(display)); err != nil { return err } @@ -1401,7 +1415,7 @@ func (c *SyncedCluster) createTenantCertBundle( ctx context.Context, l *logger.Logger, bundleName string, tenantID int, nodeNames []string, ) error { display := fmt.Sprintf("%s: initializing tenant certs", c.Name) - return c.Parallel(ctx, l, display, 1, 0, func(ctx context.Context, i int) (*RunResultDetails, error) { + return c.Parallel(ctx, l, 1, func(ctx context.Context, i int) (*RunResultDetails, error) { node := c.Nodes[i] cmd := "set -e;" @@ -1445,7 +1459,7 @@ fi return res, errors.Wrapf(res.Err, "certificate creation error: %s", res.CombinedOut) } return res, nil - }, DefaultSSHRetryOpts) + }, WithDisplay(display)) } // getFile retrieves the given file from the first node in the cluster. The @@ -1503,7 +1517,7 @@ func (c *SyncedCluster) fileExistsOnFirstNode( ) bool { var existsErr error display := fmt.Sprintf("%s: checking %s", c.Name, path) - if err := c.Parallel(ctx, l, display, 1, 0, func(ctx context.Context, i int) (*RunResultDetails, error) { + if err := c.Parallel(ctx, l, 1, func(ctx context.Context, i int) (*RunResultDetails, error) { node := c.Nodes[i] sess := c.newSession(l, node, `test -e `+path) defer sess.Close() @@ -1514,7 +1528,7 @@ func (c *SyncedCluster) fileExistsOnFirstNode( existsErr = res.Err return res, nil - }, DefaultSSHRetryOpts); err != nil { + }, WithDisplay(display)); err != nil { return false } return existsErr == nil @@ -1560,7 +1574,7 @@ func (c *SyncedCluster) distributeLocalCertsTar( } display := c.Name + ": distributing certs" - return c.Parallel(ctx, l, display, len(nodes), 0, func(ctx context.Context, i int) (*RunResultDetails, error) { + return c.Parallel(ctx, l, len(nodes), func(ctx context.Context, i int) (*RunResultDetails, error) { node := nodes[i] var cmd string if c.IsLocal() { @@ -1584,7 +1598,7 @@ func (c *SyncedCluster) distributeLocalCertsTar( return res, errors.Wrapf(res.Err, "~ %s\n%s", cmd, res.CombinedOut) } return res, nil - }, DefaultSSHRetryOpts) + }, WithDisplay(display)) } const progressDone = "=======================================>" @@ -2268,14 +2282,14 @@ func (c *SyncedCluster) pghosts( ctx context.Context, l *logger.Logger, nodes Nodes, ) (map[Node]string, error) { ips := make([]string, len(nodes)) - if err := c.Parallel(ctx, l, "", len(nodes), 0, func(ctx context.Context, i int) (*RunResultDetails, error) { + if err := c.Parallel(ctx, l, len(nodes), func(ctx context.Context, i int) (*RunResultDetails, error) { node := nodes[i] res := &RunResultDetails{Node: node} res.Stdout, res.Err = c.GetInternalIP(node) ips[i] = res.Stdout - return res, errors.Wrapf(res.Err, "pghosts") - }, DefaultSSHRetryOpts); err != nil { - return nil, err + return res, nil + }); err != nil { + return nil, errors.Wrapf(err, "pghosts") } m := make(map[Node]string, len(ips)) @@ -2389,6 +2403,43 @@ type ParallelResult struct { Err error } +type ParallelOptions struct { + concurrency int + display string + retryOpts *RunRetryOpts + // waitOnFail will cause the Parallel function to wait for all nodes to + // finish when encountering a command error on any node. The default + // behaviour is to exit immediately on the first error, in which case the + // slice of ParallelResults will only contain the one error result. + waitOnFail bool +} + +type ParallelOption func(result *ParallelOptions) + +func WithConcurrency(concurrency int) ParallelOption { + return func(result *ParallelOptions) { + result.concurrency = concurrency + } +} + +func WithRetryOpts(retryOpts *RunRetryOpts) ParallelOption { + return func(result *ParallelOptions) { + result.retryOpts = retryOpts + } +} + +func WithWaitOnFail() ParallelOption { + return func(result *ParallelOptions) { + result.waitOnFail = true + } +} + +func WithDisplay(display string) ParallelOption { + return func(result *ParallelOptions) { + result.display = display + } +} + // Parallel runs a user-defined function across the nodes in the // cluster. If any of the commands fail, Parallel will log an error // and exit the program. @@ -2397,23 +2448,43 @@ type ParallelResult struct { func (c *SyncedCluster) Parallel( ctx context.Context, l *logger.Logger, - display string, - count, concurrency int, + count int, fn func(ctx context.Context, i int) (*RunResultDetails, error), - runRetryOpts *RunRetryOpts, + opts ...ParallelOption, ) error { - failed, err := c.ParallelE(ctx, l, display, count, concurrency, fn, runRetryOpts) + // failed will contain command errors if any occur. + // err is an unexpected roachprod error, which we return immediately. + failed, err := c.ParallelE(ctx, l, count, fn, opts...) if err != nil { - l.Errorf("%d: %+v: %s\n", failed.Index, failed.Err, failed.Out) return err } + + if len(failed) > 0 { + sort.Slice(failed, func(i, j int) bool { return failed[i].Index < failed[j].Index }) + for _, f := range failed { + // Since this function is potentially returning a single error despite + // having run on multiple nodes, we combine all the errors into a single + // error. + err = errors.CombineErrors(err, f.Err) + l.Errorf("%d: %+v: %s", f.Index, f.Err, f.Out) + } + return errors.Wrap(err, "one or more parallel execution failure") + } return nil } // ParallelE runs the given function in parallel across the given -// nodes. In the event of an error on any of the nodes, the function -// will fail fast and return the error and the accompanying ParallelResult. -// Successful invocation will return a nil result and nil error. +// nodes. +// +// By default, this will fail fast if a command error occurs on any node, in which +// case the function will return a slice containing the erroneous result. +// +// If `WithWaitOnFail()` is passed in, then the function will wait for all +// invocations to complete before returning a slice with all failed results. +// +// ParallelE only returns an error for roachprod itself, not any command errors run +// on the cluster. It is up to the caller to check the slice for command errors. Any +// such roachprod error will always be returned immediately. // // ParallelE runs the user-defined functions on the first `count` // nodes in the cluster. It runs at most `concurrency` (or @@ -2430,19 +2501,24 @@ func (c *SyncedCluster) Parallel( func (c *SyncedCluster) ParallelE( ctx context.Context, l *logger.Logger, - display string, - count, concurrency int, + count int, fn func(ctx context.Context, i int) (*RunResultDetails, error), - runRetryOpts *RunRetryOpts, -) (ParallelResult, error) { - if concurrency == 0 || concurrency > count { - concurrency = count + opts ...ParallelOption, +) ([]ParallelResult, error) { + options := ParallelOptions{retryOpts: DefaultSSHRetryOpts} + for _, opt := range opts { + opt(&options) } - if config.MaxConcurrency > 0 && concurrency > config.MaxConcurrency { - concurrency = config.MaxConcurrency + + if options.concurrency == 0 || options.concurrency > count { + options.concurrency = count + } + if config.MaxConcurrency > 0 && options.concurrency > config.MaxConcurrency { + options.concurrency = config.MaxConcurrency } results := make(chan ParallelResult, count) + errorChannel := make(chan error) var wg sync.WaitGroup wg.Add(count) @@ -2456,24 +2532,31 @@ func (c *SyncedCluster) ParallelE( // the context through, a cancellation will be handled by the command itself. go func(i int) { defer wg.Done() - res, err := runWithMaybeRetry(groupCtx, l, runRetryOpts, func(ctx context.Context) (*RunResultDetails, error) { return fn(ctx, i) }) - results <- ParallelResult{i, res.CombinedOut, errors.CombineErrors(err, res.Err)} + // This is rarely expected to return an error, but we fail fast in case. + // Command errors, which are far more common, will be contained within the result. + res, err := runWithMaybeRetry(groupCtx, l, options.retryOpts, func(ctx context.Context) (*RunResultDetails, error) { return fn(ctx, i) }) + if err != nil { + errorChannel <- err + return + } + results <- ParallelResult{i, res.CombinedOut, res.Err} }(index) index++ } - for index < concurrency { + for index < options.concurrency { startNext() } go func() { wg.Wait() close(results) + close(errorChannel) }() var writer ui.Writer out := l.Stdout - if display == "" { + if options.display == "" { out = io.Discard } @@ -2482,13 +2565,14 @@ func (c *SyncedCluster) ParallelE( ticker = time.NewTicker(100 * time.Millisecond) } else { ticker = time.NewTicker(1000 * time.Millisecond) - fmt.Fprintf(out, "%s", display) + fmt.Fprintf(out, "%s", options.display) if l.File != nil { fmt.Fprintf(out, "\n") } } defer ticker.Stop() complete := make([]bool, count) + var failed []ParallelResult var spinner = []string{"|", "/", "-", "\\"} spinnerIdx := 0 @@ -2500,12 +2584,12 @@ func (c *SyncedCluster) ParallelE( fmt.Fprintf(out, ".") } case r, ok := <-results: - if r.Err != nil { - // We no longer wait for failures from other goroutines but instead cancel the context. - // If required, we could restore or control the previous behavior by not cancelling and - // and returning here, but instead return a slice at the end. - groupCancel() - return r, errors.Wrap(r.Err, "parallel execution failure") + if r.Err != nil { // Command error + failed = append(failed, r) + if !options.waitOnFail { + groupCancel() + return failed, nil + } } done = !ok if ok { @@ -2514,10 +2598,13 @@ func (c *SyncedCluster) ParallelE( if index < count { startNext() } + case err := <-errorChannel: // Roachprod error + groupCancel() + return nil, err } if !config.Quiet && l.File == nil { - fmt.Fprint(&writer, display) + fmt.Fprint(&writer, options.display) var n int for i := range complete { if complete[i] { @@ -2538,7 +2625,7 @@ func (c *SyncedCluster) ParallelE( fmt.Fprintf(out, "\n") } - return ParallelResult{}, nil + return failed, nil } // Init initializes the cluster. It does it through node 1 (as per TargetNodes) diff --git a/pkg/roachprod/install/cockroach.go b/pkg/roachprod/install/cockroach.go index 81060f1df77a..e644b4dd41d7 100644 --- a/pkg/roachprod/install/cockroach.go +++ b/pkg/roachprod/install/cockroach.go @@ -183,7 +183,7 @@ func (c *SyncedCluster) Start(ctx context.Context, l *logger.Logger, startOpts S l.Printf("%s: starting nodes", c.Name) // SSH retries are disabled by passing nil RunRetryOpts - if err := c.Parallel(ctx, l, "", len(nodes), parallelism, func(ctx context.Context, nodeIdx int) (*RunResultDetails, error) { + if err := c.Parallel(ctx, l, len(nodes), func(ctx context.Context, nodeIdx int) (*RunResultDetails, error) { node := nodes[nodeIdx] res := &RunResultDetails{Node: node} // NB: if cockroach started successfully, we ignore the output as it is @@ -228,7 +228,7 @@ func (c *SyncedCluster) Start(ctx context.Context, l *logger.Logger, startOpts S return res, errors.Wrap(err, "failed to set cluster settings") } return res, nil - }, DefaultSSHRetryOpts); err != nil { + }, WithConcurrency(parallelism)); err != nil { return err } @@ -329,7 +329,7 @@ func (c *SyncedCluster) ExecSQL( resultChan := make(chan result, len(c.Nodes)) display := fmt.Sprintf("%s: executing sql", c.Name) - if err := c.Parallel(ctx, l, display, len(c.Nodes), 0, func(ctx context.Context, nodeIdx int) (*RunResultDetails, error) { + if err := c.Parallel(ctx, l, len(c.Nodes), func(ctx context.Context, nodeIdx int) (*RunResultDetails, error) { node := c.Nodes[nodeIdx] var cmd string @@ -348,11 +348,11 @@ func (c *SyncedCluster) ExecSQL( res.CombinedOut = out if res.Err != nil { - return res, errors.Wrapf(res.Err, "~ %s\n%s", cmd, res.CombinedOut) + res.Err = errors.Wrapf(res.Err, "~ %s\n%s", cmd, res.CombinedOut) } resultChan <- result{node: node, output: string(res.CombinedOut)} return res, nil - }, DefaultSSHRetryOpts); err != nil { + }, WithDisplay(display)); err != nil { return err } diff --git a/pkg/roachprod/roachprod.go b/pkg/roachprod/roachprod.go index 4f7c9ba149a2..ff643fb0248f 100644 --- a/pkg/roachprod/roachprod.go +++ b/pkg/roachprod/roachprod.go @@ -366,6 +366,7 @@ func Run( secure bool, stdout, stderr io.Writer, cmdArray []string, + opts ...install.ParallelOption, ) error { if err := LoadClusters(); err != nil { return err @@ -386,7 +387,7 @@ func Run( if len(title) > 30 { title = title[:27] + "..." } - return c.Run(ctx, l, stdout, stderr, c.Nodes, title, cmd) + return c.Run(ctx, l, stdout, stderr, c.Nodes, title, cmd, opts...) } // RunWithDetails runs a command on the nodes in a cluster. @@ -913,13 +914,13 @@ func PgURL( ips[i] = c.VMs[nodes[i]-1].PublicIP } } else { - if err := c.Parallel(ctx, l, "", len(nodes), 0, func(ctx context.Context, i int) (*install.RunResultDetails, error) { + if err := c.Parallel(ctx, l, len(nodes), func(ctx context.Context, i int) (*install.RunResultDetails, error) { node := nodes[i] res := &install.RunResultDetails{Node: node} res.Stdout, res.Err = c.GetInternalIP(node) ips[i] = res.Stdout - return res, err - }, install.DefaultSSHRetryOpts); err != nil { + return res, nil + }); err != nil { return nil, err } } @@ -1047,7 +1048,7 @@ func Pprof(ctx context.Context, l *logger.Logger, clusterName string, opts Pprof httpClient := httputil.NewClientWithTimeout(timeout) startTime := timeutil.Now().Unix() nodes := c.TargetNodes() - err = c.Parallel(ctx, l, description, len(nodes), 0, func(ctx context.Context, i int) (*install.RunResultDetails, error) { + err = c.Parallel(ctx, l, len(nodes), func(ctx context.Context, i int) (*install.RunResultDetails, error) { node := nodes[i] res := &install.RunResultDetails{Node: node} host := c.Host(node) @@ -1109,7 +1110,7 @@ func Pprof(ctx context.Context, l *logger.Logger, clusterName string, opts Pprof outputFiles = append(outputFiles, outputFile) mu.Unlock() return res, nil - }, install.DefaultSSHRetryOpts) + }, install.WithDisplay(description)) for _, s := range outputFiles { l.Printf("Created %s", s) @@ -1710,10 +1711,7 @@ func sendCaptureCommand( ) error { nodes := c.TargetNodes() httpClient := httputil.NewClientWithTimeout(0 /* timeout: None */) - _, err := c.ParallelE(ctx, l, - fmt.Sprintf("Performing workload capture %s", action), - len(nodes), - 0, + _, err := c.ParallelE(ctx, l, len(nodes), func(ctx context.Context, i int) (*install.RunResultDetails, error) { node := nodes[i] res := &install.RunResultDetails{Node: node} @@ -1778,7 +1776,7 @@ func sendCaptureCommand( } } return res, res.Err - }, install.DefaultSSHRetryOpts) + }, install.WithDisplay(fmt.Sprintf("Performing workload capture %s", action))) return err } From d45e15115402762ab944031ab38c959850087f85 Mon Sep 17 00:00:00 2001 From: Miral Gadani Date: Mon, 5 Jun 2023 17:01:10 +0000 Subject: [PATCH 4/6] roachprod: add WithWaitOnFail() when executing sql across nodes Previously, we added fail fast behaviour when executing commands across nodes using roachtest/roachprod. There are some instances, specifically from the CLI, that should wait for all results to be returned. This PR adds `WithWaitOnFail()` to `ExecSQL()` in roachprod. Epic: none Fixes: #104342 Release note: None --- pkg/roachprod/install/cockroach.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg/roachprod/install/cockroach.go b/pkg/roachprod/install/cockroach.go index e644b4dd41d7..892f59396b6f 100644 --- a/pkg/roachprod/install/cockroach.go +++ b/pkg/roachprod/install/cockroach.go @@ -352,7 +352,7 @@ func (c *SyncedCluster) ExecSQL( } resultChan <- result{node: node, output: string(res.CombinedOut)} return res, nil - }, WithDisplay(display)); err != nil { + }, WithDisplay(display), WithWaitOnFail()); err != nil { return err } From c75821100445c0c83aa9102e753a5918384fa3fd Mon Sep 17 00:00:00 2001 From: Herko Lategan Date: Tue, 13 Jun 2023 14:49:26 +0100 Subject: [PATCH 5/6] roachprod: fix confusing start-up error Starting a cluster locally does a check for certificates. In the event the certificates are not found, which is a valid case, an error is printed. The start command works correctly, but the error can cause confusion: ```bash /bin/roachprod start local --secure 12:47:56 cluster_synced.go:2475: 0: COMMAND_PROBLEM: exit status 1 (1) COMMAND_PROBLEM Wraps: (2) exit status 1 Error types: (1) errors.Cmd (2) *exec.ExitError: local: initializing certs 1/1 / local: distributing certs 2/2 ``` This change modifies the command to do a check without causing an error, and also propagates any real errors that could occur while doing the check. --- pkg/roachprod/install/cluster_synced.go | 40 +++++++++++-------------- 1 file changed, 17 insertions(+), 23 deletions(-) diff --git a/pkg/roachprod/install/cluster_synced.go b/pkg/roachprod/install/cluster_synced.go index 13658872f351..a3cb9ec8c76c 100644 --- a/pkg/roachprod/install/cluster_synced.go +++ b/pkg/roachprod/install/cluster_synced.go @@ -1304,7 +1304,9 @@ const ( // DistributeCerts will generate and distribute certificates to all of the // nodes. func (c *SyncedCluster) DistributeCerts(ctx context.Context, l *logger.Logger) error { - if c.checkForCertificates(ctx, l) { + if found, err := c.checkForCertificates(ctx, l); err != nil { + return err + } else if found { return nil } @@ -1380,11 +1382,15 @@ tar cvf %[3]s certs func (c *SyncedCluster) DistributeTenantCerts( ctx context.Context, l *logger.Logger, hostCluster *SyncedCluster, tenantID int, ) error { - if hostCluster.checkForTenantCertificates(ctx, l) { + if found, err := hostCluster.checkForTenantCertificates(ctx, l); err != nil { + return err + } else if found { return nil } - if !hostCluster.checkForCertificates(ctx, l) { + if found, err := hostCluster.checkForCertificates(ctx, l); err != nil { + return err + } else if !found { return errors.New("host cluster missing certificate bundle") } @@ -1494,7 +1500,7 @@ func (c *SyncedCluster) getFileFromFirstNode( // checkForCertificates checks if the cluster already has a certs bundle created // on the first node. -func (c *SyncedCluster) checkForCertificates(ctx context.Context, l *logger.Logger) bool { +func (c *SyncedCluster) checkForCertificates(ctx context.Context, l *logger.Logger) (bool, error) { dir := "" if c.IsLocal() { dir = c.localVMDir(1) @@ -1504,7 +1510,9 @@ func (c *SyncedCluster) checkForCertificates(ctx context.Context, l *logger.Logg // checkForTenantCertificates checks if the cluster already has a tenant-certs bundle created // on the first node. -func (c *SyncedCluster) checkForTenantCertificates(ctx context.Context, l *logger.Logger) bool { +func (c *SyncedCluster) checkForTenantCertificates( + ctx context.Context, l *logger.Logger, +) (bool, error) { dir := "" if c.IsLocal() { dir = c.localVMDir(1) @@ -1514,24 +1522,10 @@ func (c *SyncedCluster) checkForTenantCertificates(ctx context.Context, l *logge func (c *SyncedCluster) fileExistsOnFirstNode( ctx context.Context, l *logger.Logger, path string, -) bool { - var existsErr error - display := fmt.Sprintf("%s: checking %s", c.Name, path) - if err := c.Parallel(ctx, l, 1, func(ctx context.Context, i int) (*RunResultDetails, error) { - node := c.Nodes[i] - sess := c.newSession(l, node, `test -e `+path) - defer sess.Close() - - out, cmdErr := sess.CombinedOutput(ctx) - res := newRunResultDetails(node, cmdErr) - res.CombinedOut = out - - existsErr = res.Err - return res, nil - }, WithDisplay(display)); err != nil { - return false - } - return existsErr == nil +) (bool, error) { + l.Printf("%s: checking %s", c.Name, path) + result, err := c.runCmdOnSingleNode(ctx, l, c.Nodes[0], `$(test -e `+path+`); echo $?`, false, l.Stdout, l.Stderr) + return result.Stdout == "0", err } // createNodeCertArguments returns a list of strings appropriate for use as From 3a12d96f3e82ce1f12fa0e79a3f15ffa9cd6dbee Mon Sep 17 00:00:00 2001 From: healthy-pod Date: Sat, 24 Jun 2023 13:41:50 -0700 Subject: [PATCH 6/6] roachprod: fix `fileExistsOnFirstNode` check For some reason, the current form of `fileExistsOnFirstNode` can return `found=true` when it should return `found=false`. This can be reproduced by running the `multitenant-upgrade` roachtest and seeing it hang at: ``` multitenant_upgrade.go:154: test status: checking the pre-upgrade sql server still works after the system tenant binary upgrade ``` because of a `TLS handshake error`. Note: the test is also broken because of another issue so with this fix it should now fail with: ``` (assertions.go:333).Fail: Error Trace: github.com/cockroachdb/cockroach/pkg/cmd/roachtest/tests/multitenant_upgrade.go:390 github.com/cockroachdb/cockroach/pkg/cmd/roachtest/tests/multitenant_upgrade.go:189 github.com/cockroachdb/cockroach/pkg/cmd/roachtest/tests/multitenant_upgrade.go:38 main/pkg/cmd/roachtest/test_runner.go:1060 GOROOT/src/runtime/asm_arm64.s:1172 Error: Not equal: expected: [][]string{[]string{"23.1"}} actual : [][]string{[]string{"22.2"}} Diff: --- Expected +++ Actual @@ -2,3 +2,3 @@ ([]string) (len=1) { - (string) (len=4) "23.1" + (string) (len=4) "22.2" } Test: multitenant-upgrade ``` Release note: None Epic: none --- pkg/roachprod/install/cluster_synced.go | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/pkg/roachprod/install/cluster_synced.go b/pkg/roachprod/install/cluster_synced.go index a3cb9ec8c76c..32256dd3ade9 100644 --- a/pkg/roachprod/install/cluster_synced.go +++ b/pkg/roachprod/install/cluster_synced.go @@ -1524,8 +1524,19 @@ func (c *SyncedCluster) fileExistsOnFirstNode( ctx context.Context, l *logger.Logger, path string, ) (bool, error) { l.Printf("%s: checking %s", c.Name, path) - result, err := c.runCmdOnSingleNode(ctx, l, c.Nodes[0], `$(test -e `+path+`); echo $?`, false, l.Stdout, l.Stderr) - return result.Stdout == "0", err + testCmd := `$(test -e ` + path + `);` + // Do not log output to stdout/stderr because in some cases this call will be expected to exit 1. + result, err := c.runCmdOnSingleNode(ctx, l, c.Nodes[0], testCmd, true, nil, nil) + if (result.RemoteExitStatus != 0 && result.RemoteExitStatus != 1) || err != nil { + // Unexpected exit status (neither 0 nor 1) or non-nil error. Return combined output along with err returned + // from the call if it's not nil. + if err != nil { + return false, errors.Wrapf(err, "running '%s' failed with exit code=%d: got %s", testCmd, result.RemoteExitStatus, string(result.CombinedOut)) + } else { + return false, errors.Newf("running '%s' failed with exit code=%d: got %s", testCmd, result.RemoteExitStatus, string(result.CombinedOut)) + } + } + return result.RemoteExitStatus == 0, nil } // createNodeCertArguments returns a list of strings appropriate for use as