diff --git a/pkg/cli/demo_cluster.go b/pkg/cli/demo_cluster.go index 970d1d97d50a..789a35468d84 100644 --- a/pkg/cli/demo_cluster.go +++ b/pkg/cli/demo_cluster.go @@ -37,23 +37,24 @@ import ( "github.com/cockroachdb/cockroach/pkg/sql/catalog/catconstants" "github.com/cockroachdb/cockroach/pkg/util/log" "github.com/cockroachdb/cockroach/pkg/util/log/severity" + "github.com/cockroachdb/cockroach/pkg/util/retry" "github.com/cockroachdb/cockroach/pkg/util/stop" - "github.com/cockroachdb/cockroach/pkg/util/timeutil" "github.com/cockroachdb/cockroach/pkg/workload" "github.com/cockroachdb/cockroach/pkg/workload/histogram" "github.com/cockroachdb/cockroach/pkg/workload/workloadsql" "github.com/cockroachdb/errors" + "github.com/cockroachdb/logtags" "github.com/spf13/cobra" "golang.org/x/time/rate" ) type transientCluster struct { - connURL string - demoDir string - useSockets bool - stopper *stop.Stopper - s *server.TestServer - servers []*server.TestServer + connURL string + demoDir string + useSockets bool + stopper *stop.Stopper + firstServer *server.TestServer + servers []*server.TestServer httpFirstPort int sqlFirstPort int @@ -119,184 +120,427 @@ func (c *transientCluster) checkConfigAndSetupLogging( func (c *transientCluster) start( ctx context.Context, cmd *cobra.Command, gen workload.Generator, ) (err error) { - serverFactory := server.TestServerFactory - var servers []*server.TestServer + ctx = logtags.AddTag(ctx, "start-demo-cluster", nil) + // We now proceed to start all the nodes concurrently. This is + // somewhat a complicated dance. + // + // On the one hand, we use a concurrent start, because the latency + // map needs to be initialized after all the nodes have started + // listening on the network, but before they proceed to initialize + // their RPC context. + // + // On the other hand, we cannot use full concurrency either, because + // we need to wait until the *first* node has started + // listening on the network before we can start the next nodes. + // + // So we proceed in phases, as follows: + // + // 1. create and start the first node asynchronously. + // 2. wait for the first node to listen on RPC and determine its + // listen addr; OR wait for an error in the first node initialization. + // 3. create and start all the other nodes asynchronously. + // 4. wait for all the nodes to listen on RPC, OR wait for an error + // from any node. + // 5. if no error, proceed to initialize the latency map. + // 6. in sequence, let each node initialize then wait for RPC readiness OR error from them. + // This ensures the node IDs are assigned sequentially. + // 7. wait for the SQL readiness from all nodes. + // 8. after all nodes are initialized, initialize SQL and telemetry. + // - // latencyMapWaitCh is used to block test servers after RPC address computation until the artificial - // latency map has been constructed. - latencyMapWaitCh := make(chan struct{}) + timeoutCh := time.After(maxNodeInitTime) - // errCh is used to catch all errors when initializing servers. - // Sending a nil on this channel indicates success. + // errCh is used to catch errors when initializing servers. errCh := make(chan error, demoCtx.nodes) - for i := 0; i < demoCtx.nodes; i++ { - // All the nodes connect to the address of the first server created. - var joinAddr string - if i != 0 { - joinAddr = c.s.ServingRPCAddr() + // rpcAddrReadyChs will be used in steps 2 and 4 below + // to wait until the nodes know their RPC address. + rpcAddrReadyChs := make([]chan struct{}, demoCtx.nodes) + + // latencyMapWaitChs is used to block test servers after RPC address + // computation until the artificial latency map has been constructed. + latencyMapWaitChs := make([]chan struct{}, demoCtx.nodes) + + // Step 1: create the first node. + { + phaseCtx := logtags.AddTag(ctx, "phase", 1) + log.Infof(phaseCtx, "creating the first node") + + latencyMapWaitChs[0] = make(chan struct{}) + firstRPCAddrReadyCh, err := c.createAndAddNode(phaseCtx, 0, latencyMapWaitChs[0], timeoutCh) + if err != nil { + return err + } + rpcAddrReadyChs[0] = firstRPCAddrReadyCh + } + + // Step 2: start the first node asynchronously, then wait for RPC + // listen readiness or error. + { + phaseCtx := logtags.AddTag(ctx, "phase", 2) + + log.Infof(phaseCtx, "starting first node") + if err := c.startNodeAsync(phaseCtx, 0, errCh, timeoutCh); err != nil { + return err } - nodeID := roachpb.NodeID(i + 1) - args := testServerArgsForTransientCluster( - c.sockForServer(nodeID), nodeID, joinAddr, c.demoDir, - c.sqlFirstPort, - c.httpFirstPort, - c.stickyEngineRegistry, - ) - if i == 0 { - // The first node also auto-inits the cluster. - args.NoAutoInitializeCluster = false + log.Infof(phaseCtx, "waiting for first node RPC address") + if err := c.waitForRPCAddrReadinessOrError(phaseCtx, 0, errCh, rpcAddrReadyChs, timeoutCh); err != nil { + return err } + } - // servRPCReadyCh is used if latency simulation is requested to notify that a test server has - // successfully computed its RPC address. - servRPCReadyCh := make(chan struct{}) + // Step 3: create the other nodes and start them asynchronously. + { + phaseCtx := logtags.AddTag(ctx, "phase", 3) + log.Infof(phaseCtx, "starting other nodes") - if demoCtx.simulateLatency { - serverKnobs := args.Knobs.Server.(*server.TestingKnobs) - serverKnobs.PauseAfterGettingRPCAddress = latencyMapWaitCh - serverKnobs.SignalAfterGettingRPCAddress = servRPCReadyCh - serverKnobs.ContextTestingKnobs = rpc.ContextTestingKnobs{ - ArtificialLatencyMap: make(map[string]int), + for i := 1; i < demoCtx.nodes; i++ { + latencyMapWaitChs[i] = make(chan struct{}) + rpcAddrReady, err := c.createAndAddNode(phaseCtx, i, latencyMapWaitChs[i], timeoutCh) + if err != nil { + return err } + rpcAddrReadyChs[i] = rpcAddrReady } - s, err := serverFactory.New(args) - if err != nil { - return err - } - serv := s.(*server.TestServer) - c.stopper.AddCloser(stop.CloserFn(serv.Stop)) - if i == 0 { - c.s = serv - // The first node connects its Settings instance to the `log` - // package for crash reporting. - // - // There's a known shortcoming with this approach: restarting - // node 1 using the \demo commands will break this connection: - // if the user changes the cluster setting after restarting node - // 1, the `log` package will not see this change. - // - // TODO(knz): re-connect the `log` package every time the first - // node is restarted and gets a new `Settings` instance. - settings.SetCanonicalValuesContainer(&serv.ClusterSettings().SV) - } - servers = append(servers, serv) - - // We force a wait for all servers until they are ready. - servReadyFnCh := make(chan struct{}) - serv.Cfg.ReadyFn = func(bool) { - close(servReadyFnCh) - } - - // If latency simulation is requested, start the servers in a background thread. We do this because - // the start routine needs to wait for the latency map construction after their RPC address has been computed. - if demoCtx.simulateLatency { - go func(i int) { - if err := serv.Start(); err != nil { - errCh <- err - } else { - // Block until the ReadyFn has been called before continuing. - <-servReadyFnCh - errCh <- nil - } - }(i) - <-servRPCReadyCh - } else { - if err := serv.Start(); err != nil { + // Ensure we close all sticky stores we've created when the stopper + // instructs the entire cluster to stop. We do this only here + // because we want this closer to be registered after all the + // individual servers' Stop() methods have been registered + // via createAndAddNode() above. + c.stopper.AddCloser(stop.CloserFn(func() { + c.stickyEngineRegistry.CloseAllStickyInMemEngines() + })) + + // Start the remaining nodes asynchronously. + for i := 1; i < demoCtx.nodes; i++ { + if err := c.startNodeAsync(phaseCtx, i, errCh, timeoutCh); err != nil { return err } - // Block until the ReadyFn has been called before continuing. - <-servReadyFnCh - errCh <- nil } } - // Ensure we close all sticky stores we've created. - c.stopper.AddCloser(stop.CloserFn(func() { - c.stickyEngineRegistry.CloseAllStickyInMemEngines() - })) - c.servers = servers + // Step 4: wait for all the nodes to know their RPC address, + // or for an error or premature shutdown. + { + phaseCtx := logtags.AddTag(ctx, "phase", 4) + log.Infof(phaseCtx, "waiting for remaining nodes to get their RPC address") - if demoCtx.simulateLatency { - // Now, all servers have been started enough to know their own RPC serving - // addresses, but nothing else. Assemble the artificial latency map. - for i, src := range servers { - latencyMap := src.Cfg.TestingKnobs.Server.(*server.TestingKnobs).ContextTestingKnobs.ArtificialLatencyMap - srcLocality, ok := src.Cfg.Locality.Find("region") - if !ok { - continue - } - srcLocalityMap, ok := regionToRegionToLatency[srcLocality] - if !ok { - continue + for i := 0; i < demoCtx.nodes; i++ { + if err := c.waitForRPCAddrReadinessOrError(phaseCtx, i, errCh, rpcAddrReadyChs, timeoutCh); err != nil { + return err } - for j, dst := range servers { - if i == j { + } + } + + // Step 5: optionally initialize the latency map, then let the servers + // proceed with their initialization. + + { + phaseCtx := logtags.AddTag(ctx, "phase", 5) + + // If latency simulation is requested, initialize the latency map. + if demoCtx.simulateLatency { + // Now, all servers have been started enough to know their own RPC serving + // addresses, but nothing else. Assemble the artificial latency map. + log.Infof(phaseCtx, "initializing latency map") + for i, serv := range c.servers { + latencyMap := serv.Cfg.TestingKnobs.Server.(*server.TestingKnobs).ContextTestingKnobs.ArtificialLatencyMap + srcLocality, ok := serv.Cfg.Locality.Find("region") + if !ok { continue } - dstLocality, ok := dst.Cfg.Locality.Find("region") + srcLocalityMap, ok := regionToRegionToLatency[srcLocality] if !ok { continue } - latency := srcLocalityMap[dstLocality] - latencyMap[dst.ServingRPCAddr()] = latency + for j, dst := range c.servers { + if i == j { + continue + } + dstLocality, ok := dst.Cfg.Locality.Find("region") + if !ok { + continue + } + latency := srcLocalityMap[dstLocality] + latencyMap[dst.ServingRPCAddr()] = latency + } } } } - // We've assembled our latency maps and are ready for all servers to proceed - // through bootstrapping. - close(latencyMapWaitCh) + { + phaseCtx := logtags.AddTag(ctx, "phase", 6) + + for i := 0; i < demoCtx.nodes; i++ { + log.Infof(phaseCtx, "letting server %d initialize", i) + close(latencyMapWaitChs[i]) + if err := c.waitForNodeIDReadiness(phaseCtx, i, errCh, timeoutCh); err != nil { + return err + } + log.Infof(phaseCtx, "node n%d initialized", c.servers[i].NodeID()) + } + } - // Wait for all servers to respond. { - timeRemaining := maxNodeInitTime - lastUpdateTime := timeutil.Now() - var err error + phaseCtx := logtags.AddTag(ctx, "phase", 7) + for i := 0; i < demoCtx.nodes; i++ { - select { - case e := <-errCh: - err = errors.CombineErrors(err, e) - case <-time.After(timeRemaining): - return errors.New("failed to setup transientCluster in time") + log.Infof(phaseCtx, "waiting for server %d SQL readiness", i) + if err := c.waitForSQLReadiness(phaseCtx, i, errCh, timeoutCh); err != nil { + return err } - updateTime := timeutil.Now() - timeRemaining -= updateTime.Sub(lastUpdateTime) - lastUpdateTime = updateTime + log.Infof(phaseCtx, "node n%d ready", c.servers[i].NodeID()) + } + } + + { + phaseCtx := logtags.AddTag(ctx, "phase", 8) + + // Run the SQL initialization. This takes care of setting up the + // initial replication factor for small clusters and creating the + // admin user. + log.Infof(phaseCtx, "running initial SQL for demo cluster") + + const demoUsername = "demo" + demoPassword := genDemoPassword(demoUsername) + if err := runInitialSQL(phaseCtx, c.firstServer.Server, demoCtx.nodes < 3, demoUsername, demoPassword); err != nil { + return err } + if demoCtx.insecure { + c.adminUser = security.RootUserName() + c.adminPassword = "unused" + } else { + c.adminUser = security.MakeSQLUsernameFromPreNormalizedString(demoUsername) + c.adminPassword = demoPassword + } + + // Prepare the URL for use by the SQL shell. + c.connURL, err = c.getNetworkURLForServer(0, gen, true /* includeAppName */) if err != nil { return err } + + // Start up the update check loop. + // We don't do this in (*server.Server).Start() because we don't want this + // overhead and possible interference in tests. + if !demoCtx.disableTelemetry { + log.Infof(phaseCtx, "starting telemetry") + c.firstServer.StartDiagnostics(phaseCtx) + } } + return nil +} - // Run the SQL initialization. This takes care of setting up the - // initial replication factor for small clusters and creating the - // admin user. - const demoUsername = "demo" - demoPassword := genDemoPassword(demoUsername) - if err := runInitialSQL(ctx, c.s.Server, demoCtx.nodes < 3, demoUsername, demoPassword); err != nil { - return err +// createAndAddNode is responsible for determining node parameters, +// instantiating the server component and connecting it to the +// cluster's stopper. +// +// The caller is responsible for calling createAndAddNode() with idx 0 +// synchronously, then startNodeAsync(), then +// waitForNodeRPCListener(), before using createAndAddNode() with +// other indexes. +func (c *transientCluster) createAndAddNode( + ctx context.Context, idx int, latencyMapWaitCh chan struct{}, timeoutCh <-chan time.Time, +) (rpcAddrReadyCh chan struct{}, err error) { + var joinAddr string + if idx > 0 { + // The caller is responsible for ensuring that the method + // is not called before the first server has finished + // computing its RPC listen address. + joinAddr = c.firstServer.ServingRPCAddr() } - if demoCtx.insecure { - c.adminUser = security.RootUserName() - c.adminPassword = "unused" - } else { - c.adminUser = security.MakeSQLUsernameFromPreNormalizedString(demoUsername) - c.adminPassword = demoPassword + nodeID := roachpb.NodeID(idx + 1) + args := testServerArgsForTransientCluster( + c.sockForServer(nodeID), nodeID, joinAddr, c.demoDir, + c.sqlFirstPort, + c.httpFirstPort, + c.stickyEngineRegistry, + ) + if idx == 0 { + // The first node also auto-inits the cluster. + args.NoAutoInitializeCluster = false + } + + serverKnobs := args.Knobs.Server.(*server.TestingKnobs) + + // SignalAfterGettingRPCAddress will be closed by the server startup routine + // once it has determined its RPC address. + rpcAddrReadyCh = make(chan struct{}) + serverKnobs.SignalAfterGettingRPCAddress = rpcAddrReadyCh + + // The server will wait until PauseAfterGettingRPCAddress is closed + // after it has signaled SignalAfterGettingRPCAddress, and before + // it continues the startup routine. + serverKnobs.PauseAfterGettingRPCAddress = latencyMapWaitCh + + if demoCtx.simulateLatency { + // The latency map will be populated after all servers have + // started listening on RPC, and before they proceed with their + // startup routine. + serverKnobs.ContextTestingKnobs = rpc.ContextTestingKnobs{ + ArtificialLatencyMap: make(map[string]int), + } } - // Prepare the URL for use by the SQL shell. - c.connURL, err = c.getNetworkURLForServer(0, gen, true /* includeAppName */) + // Create the server instance. This also registers the in-memory store + // into the sticky engine registry. + s, err := server.TestServerFactory.New(args) if err != nil { + return nil, err + } + serv := s.(*server.TestServer) + + // Ensure that this server gets stopped when the top level demo + // stopper instructs the cluster to stop. + c.stopper.AddCloser(stop.CloserFn(serv.Stop)) + + if idx == 0 { + // Remember the first server for later use by other APIs on + // transientCluster. + c.firstServer = serv + // The first node connects its Settings instance to the `log` + // package for crash reporting. + // + // There's a known shortcoming with this approach: restarting + // node 1 using the \demo commands will break this connection: + // if the user changes the cluster setting after restarting node + // 1, the `log` package will not see this change. + // + // TODO(knz): re-connect the `log` package every time the first + // node is restarted and gets a new `Settings` instance. + settings.SetCanonicalValuesContainer(&serv.ClusterSettings().SV) + } + + // Remember this server for the stop/restart primitives in the SQL + // shell. + c.servers = append(c.servers, serv) + + return rpcAddrReadyCh, nil +} + +// startNodeAsync starts the node initialization asynchronously. +func (c *transientCluster) startNodeAsync( + ctx context.Context, idx int, errCh chan error, timeoutCh <-chan time.Time, +) error { + if idx > len(c.servers) { + return errors.AssertionFailedf("programming error: server %d not created yet", idx) + } + + serv := c.servers[idx] + tag := fmt.Sprintf("start-n%d", idx+1) + return c.stopper.RunAsyncTask(ctx, tag, func(ctx context.Context) { + ctx = logtags.AddTag(ctx, tag, nil) + err := serv.Start(ctx) + if err != nil { + log.Warningf(ctx, "server %d failed to start: %v", idx, err) + select { + case errCh <- err: + + // Don't block if we are shutting down. + case <-ctx.Done(): + case <-serv.Stopper().ShouldQuiesce(): + case <-c.stopper.ShouldQuiesce(): + case <-timeoutCh: + } + } + }) +} + +// waitForRPCAddrReadinessOrError waits until the given server knows its +// RPC address or fails to initialize. +func (c *transientCluster) waitForRPCAddrReadinessOrError( + ctx context.Context, + idx int, + errCh chan error, + rpcAddrReadyChs []chan struct{}, + timeoutCh <-chan time.Time, +) error { + if idx > len(rpcAddrReadyChs) || idx > len(c.servers) { + return errors.AssertionFailedf("programming error: server %d not created yet", idx) + } + + select { + case <-rpcAddrReadyChs[idx]: + // This server knows its RPC address. Proceed with the next phase. + return nil + + // If we are asked for an early shutdown by the cases below or a + // server startup failure, detect it here. + case err := <-errCh: return err + case <-timeoutCh: + return errors.Newf("demo startup timeout while waiting for server %d", idx) + case <-ctx.Done(): + return errors.CombineErrors(ctx.Err(), errors.Newf("server %d startup aborted due to context cancellation", idx)) + case <-c.servers[idx].Stopper().ShouldQuiesce(): + return errors.Newf("server %d stopped prematurely", idx) + case <-c.stopper.ShouldQuiesce(): + return errors.Newf("demo cluster stopped prematurely while starting server %d", idx) + } +} + +// waitForNodeIDReadiness waits until the given server reports it knows its node ID. +func (c *transientCluster) waitForNodeIDReadiness( + ctx context.Context, idx int, errCh chan error, timeoutCh <-chan time.Time, +) error { + retryOpts := retry.Options{InitialBackoff: 10 * time.Millisecond, MaxBackoff: time.Second} + for r := retry.StartWithCtx(ctx, retryOpts); r.Next(); { + log.Infof(ctx, "waiting for server %d to know its node ID", idx) + select { + // Errors or premature shutdown. + case <-timeoutCh: + return errors.Newf("initialization timeout while waiting for server %d node ID", idx) + case err := <-errCh: + return errors.Wrapf(err, "server %d failed to start", idx) + case <-ctx.Done(): + return errors.CombineErrors(errors.Newf("context cancellation while waiting for server %d to have a node ID", idx), ctx.Err()) + case <-c.servers[idx].Stopper().ShouldQuiesce(): + return errors.Newf("server %s shut down prematurely", idx) + case <-c.stopper.ShouldQuiesce(): + return errors.Newf("demo cluster shut down prematurely while waiting for server %d to have a node ID", idx) + + default: + if c.servers[idx].NodeID() == 0 { + log.Infof(ctx, "server %d does not know its node ID yet", idx) + continue + } else { + log.Infof(ctx, "server %d: n%d", idx, c.servers[idx].NodeID()) + } + } + break } + return nil +} - // Start up the update check loop. - // We don't do this in (*server.Server).Start() because we don't want this - // overhead and possible interference in tests. - if !demoCtx.disableTelemetry { - c.s.StartDiagnostics(ctx) +// waitForSQLReadiness waits until the given server reports it is +// healthy and ready to accept SQL clients. +func (c *transientCluster) waitForSQLReadiness( + baseCtx context.Context, idx int, errCh chan error, timeoutCh <-chan time.Time, +) error { + retryOpts := retry.Options{InitialBackoff: 10 * time.Millisecond, MaxBackoff: time.Second} + for r := retry.StartWithCtx(baseCtx, retryOpts); r.Next(); { + ctx := logtags.AddTag(baseCtx, "n", c.servers[idx].NodeID()) + log.Infof(ctx, "waiting for server %d to become ready", idx) + select { + // Errors or premature shutdown. + case <-timeoutCh: + return errors.Newf("initialization timeout while waiting for server %d readiness", idx) + case err := <-errCh: + return errors.Wrapf(err, "server %d failed to start", idx) + case <-ctx.Done(): + return errors.CombineErrors(errors.Newf("context cancellation while waiting for server %d to become ready", idx), ctx.Err()) + case <-c.servers[idx].Stopper().ShouldQuiesce(): + return errors.Newf("server %s shut down prematurely", idx) + case <-c.stopper.ShouldQuiesce(): + return errors.Newf("demo cluster shut down prematurely while waiting for server %d to become ready", idx) + default: + if err := c.servers[idx].Readiness(ctx); err != nil { + log.Infof(ctx, "server %d not yet ready: %v", idx, err) + continue + } + } + break } return nil } @@ -429,7 +673,7 @@ func (c *transientCluster) Recommission(nodeID roachpb.NodeID) error { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - adminClient, finish, err := getAdminClient(ctx, *(c.s.Cfg)) + adminClient, finish, err := getAdminClient(ctx, *(c.firstServer.Cfg)) if err != nil { return err } @@ -454,7 +698,7 @@ func (c *transientCluster) Decommission(nodeID roachpb.NodeID) error { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - adminClient, finish, err := getAdminClient(ctx, *(c.s.Cfg)) + adminClient, finish, err := getAdminClient(ctx, *(c.firstServer.Cfg)) if err != nil { return err } @@ -502,7 +746,10 @@ func (c *transientCluster) RestartNode(nodeID roachpb.NodeID) error { } // TODO(#42243): re-compute the latency mapping. - args := testServerArgsForTransientCluster(c.sockForServer(nodeID), nodeID, c.s.ServingRPCAddr(), c.demoDir, + // TODO(...): the RPC address of the first server may not be available + // if the first server was shut down. + args := testServerArgsForTransientCluster(c.sockForServer(nodeID), nodeID, + c.firstServer.ServingRPCAddr(), c.demoDir, c.sqlFirstPort, c.httpFirstPort, c.stickyEngineRegistry) s, err := server.TestServerFactory.New(args) if err != nil { @@ -516,7 +763,7 @@ func (c *transientCluster) RestartNode(nodeID roachpb.NodeID) error { close(readyCh) } - if err := serv.Start(); err != nil { + if err := serv.Start(context.Background()); err != nil { return err } @@ -535,7 +782,7 @@ func (c *transientCluster) RestartNode(nodeID roachpb.NodeID) error { // AddNode create a new node in the cluster and start it. // This function uses RestartNode to perform the actual node // starting. -func (c *transientCluster) AddNode(localityString string) error { +func (c *transientCluster) AddNode(ctx context.Context, localityString string) error { // '\demo add' accepts both strings that are quoted and not quoted. To properly make use of // quoted strings, strip off the quotes. Before we do that though, make sure that the quotes match, // or that there aren't any quotes in the string. @@ -757,17 +1004,23 @@ func (c *transientCluster) runWorkload( log.Warningf(ctx, "error running workload query: %+v", err) } select { - case <-c.s.Stopper().ShouldQuiesce(): + case <-c.firstServer.Stopper().ShouldQuiesce(): + return + case <-ctx.Done(): + log.Warningf(ctx, "workload terminating from context cancellation: %v", ctx.Err()) + return + case <-c.stopper.ShouldQuiesce(): + log.Warningf(ctx, "demo cluster shutting down") return default: } } } } - // As the SQL shell is tied to `c.s`, this means we want to tie the workload + // As the SQL shell is tied to `c.firstServer`, this means we want to tie the workload // onto this as we want the workload to stop when the server dies, // rather than the cluster. Otherwise, interrupts on cockroach demo hangs. - if err := c.s.Stopper().RunAsyncTask(ctx, "workload", workloadFun(workerFn)); err != nil { + if err := c.firstServer.Stopper().RunAsyncTask(ctx, "workload", workloadFun(workerFn)); err != nil { return err } } @@ -797,17 +1050,33 @@ func (c *transientCluster) acquireDemoLicense(ctx context.Context) (chan error, go func() { defer db.Close() - success, err := GetAndApplyLicense(db, c.s.ClusterID(), demoOrg) + success, err := GetAndApplyLicense(db, c.firstServer.ClusterID(), demoOrg) if err != nil { - licenseDone <- err + select { + case licenseDone <- err: + + // Avoid waiting on the license channel write if the + // server or cluster is shutting down. + case <-ctx.Done(): + case <-c.firstServer.Stopper().ShouldQuiesce(): + case <-c.stopper.ShouldQuiesce(): + } return } if !success { if demoCtx.geoPartitionedReplicas { - licenseDone <- errors.WithDetailf( + select { + case licenseDone <- errors.WithDetailf( errors.New("unable to acquire a license for this demo"), "Enterprise features are needed for this demo (--%s).", - cliflags.DemoGeoPartitionedReplicas.Name) + cliflags.DemoGeoPartitionedReplicas.Name): + + // Avoid waiting on the license channel write if the + // server or cluster is shutting down. + case <-ctx.Done(): + case <-c.firstServer.Stopper().ShouldQuiesce(): + case <-c.stopper.ShouldQuiesce(): + } return } } diff --git a/pkg/cli/sql.go b/pkg/cli/sql.go index 4e59d50817d1..6370b0f210e9 100644 --- a/pkg/cli/sql.go +++ b/pkg/cli/sql.go @@ -569,7 +569,7 @@ func (c *cliState) handleDemoAddNode(cmd []string, nextState, errState cliStateE return nextState } - if err := demoCtx.transientCluster.AddNode(cmd[1]); err != nil { + if err := demoCtx.transientCluster.AddNode(context.Background(), cmd[1]); err != nil { return c.internalServerError(errState, err) } addedNodeID := len(demoCtx.transientCluster.servers) diff --git a/pkg/server/connectivity_test.go b/pkg/server/connectivity_test.go index 6274e59ec31d..5483b5610294 100644 --- a/pkg/server/connectivity_test.go +++ b/pkg/server/connectivity_test.go @@ -334,7 +334,8 @@ func TestJoinVersionGate(t *testing.T) { } defer serv.Stop() - if err := serv.Start(); !errors.Is(errors.Cause(err), server.ErrIncompatibleBinaryVersion) { + ctx := context.Background() + if err := serv.Start(ctx); !errors.Is(errors.Cause(err), server.ErrIncompatibleBinaryVersion) { t.Fatalf("expected error %s, got %v", server.ErrIncompatibleBinaryVersion.Error(), err.Error()) } } diff --git a/pkg/server/server.go b/pkg/server/server.go index 2a7f65b65de4..4985b4d1eacc 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -1281,10 +1281,26 @@ func (s *Server) PreStart(ctx context.Context) error { if s.cfg.TestingKnobs.Server != nil { knobs := s.cfg.TestingKnobs.Server.(*TestingKnobs) if knobs.SignalAfterGettingRPCAddress != nil { + log.Infof(ctx, "signaling caller that RPC address is ready") close(knobs.SignalAfterGettingRPCAddress) } if knobs.PauseAfterGettingRPCAddress != nil { - <-knobs.PauseAfterGettingRPCAddress + log.Infof(ctx, "waiting for signal from caller to proceed with initialization") + select { + case <-knobs.PauseAfterGettingRPCAddress: + // Normal case. Just continue below. + + case <-ctx.Done(): + // Test timeout or some other condition in the caller, by which + // we are instructed to stop. + return errors.CombineErrors(errors.New("server stopping prematurely from context shutdown"), ctx.Err()) + + case <-s.stopper.ShouldQuiesce(): + // The server is instructed to stop before it even finished + // starting up. + return errors.New("server stopping prematurely") + } + log.Infof(ctx, "caller is letting us proceed with initialization") } } diff --git a/pkg/server/testing_knobs.go b/pkg/server/testing_knobs.go index e7cc9dac6f9d..b22c0886be22 100644 --- a/pkg/server/testing_knobs.go +++ b/pkg/server/testing_knobs.go @@ -29,12 +29,13 @@ type TestingKnobs struct { DefaultZoneConfigOverride *zonepb.ZoneConfig // DefaultSystemZoneConfigOverride, if set, overrides the default system zone config defined in `pkg/config/zone.go` DefaultSystemZoneConfigOverride *zonepb.ZoneConfig - // PauseAfterGettingRPCAddress, if non-nil, instructs the server to wait until - // the channel is closed after getting an RPC serving address. - PauseAfterGettingRPCAddress chan struct{} // SignalAfterGettingRPCAddress, if non-nil, is closed after the server gets - // an RPC server address. + // an RPC server address, and prior to waiting on PauseAfterGettingRPCAddress below. SignalAfterGettingRPCAddress chan struct{} + // PauseAfterGettingRPCAddress, if non-nil, instructs the server to wait until + // the channel is closed after determining its RPC serving address, and after + // closing SignalAfterGettingRPCAddress. + PauseAfterGettingRPCAddress chan struct{} // ContextTestingKnobs allows customization of the RPC context testing knobs. ContextTestingKnobs rpc.ContextTestingKnobs // DiagnosticsTestingKnobs allows customization of diagnostics testing knobs. diff --git a/pkg/server/testserver.go b/pkg/server/testserver.go index 749a3a70ed5e..9730c22ec305 100644 --- a/pkg/server/testserver.go +++ b/pkg/server/testserver.go @@ -460,8 +460,7 @@ func (ts *TestServer) NodeDialer() *nodedialer.Dialer { // TestServer.ServingRPCAddr() after Start() for client connections. // Use TestServer.Stopper().Stop() to shutdown the server after the test // completes. -func (ts *TestServer) Start() error { - ctx := context.Background() +func (ts *TestServer) Start(ctx context.Context) error { return ts.Server.Start(ctx) } @@ -844,6 +843,12 @@ func (ts *TestServer) DrainClients(ctx context.Context) error { return ts.drainClients(ctx, nil /* reporter */) } +// Readiness returns nil when the server's health probe reports +// readiness, a readiness error otherwise. +func (ts *TestServer) Readiness(ctx context.Context) error { + return ts.admin.checkReadinessForHealthCheck(ctx) +} + // WriteSummaries implements TestServerInterface. func (ts *TestServer) WriteSummaries() error { return ts.node.writeNodeStatus(context.TODO(), time.Hour, false) diff --git a/pkg/testutils/reduce/reducesql/reducesql_test.go b/pkg/testutils/reduce/reducesql/reducesql_test.go index 601aaf6a3f8a..058ac0e2b071 100644 --- a/pkg/testutils/reduce/reducesql/reducesql_test.go +++ b/pkg/testutils/reduce/reducesql/reducesql_test.go @@ -47,7 +47,7 @@ func isInterestingSQL(contains string) reduce.InterestingFn { } serv := ts.(*server.TestServer) defer serv.Stopper().Stop(ctx) - if err := serv.Start(); err != nil { + if err := serv.Start(context.Background()); err != nil { panic(err) } diff --git a/pkg/testutils/serverutils/test_server_shim.go b/pkg/testutils/serverutils/test_server_shim.go index bcd6ee07722f..86f99d44798d 100644 --- a/pkg/testutils/serverutils/test_server_shim.go +++ b/pkg/testutils/serverutils/test_server_shim.go @@ -47,7 +47,7 @@ import ( type TestServerInterface interface { Stopper() *stop.Stopper - Start() error + Start(context.Context) error // Node returns the server.Node as an interface{}. Node() interface{} @@ -260,7 +260,7 @@ func StartServer( if err != nil { t.Fatalf("%+v", err) } - if err := server.Start(); err != nil { + if err := server.Start(context.Background()); err != nil { t.Fatalf("%+v", err) } goDB := OpenDBConn( @@ -328,7 +328,7 @@ func StartServerRaw(args base.TestServerArgs) (TestServerInterface, error) { if err != nil { return nil, err } - if err := server.Start(); err != nil { + if err := server.Start(context.Background()); err != nil { return nil, err } return server, nil diff --git a/pkg/testutils/testcluster/testcluster.go b/pkg/testutils/testcluster/testcluster.go index f426caef44bd..db812d184c35 100644 --- a/pkg/testutils/testcluster/testcluster.go +++ b/pkg/testutils/testcluster/testcluster.go @@ -462,7 +462,7 @@ func (tc *TestCluster) AddServer(serverArgs base.TestServerArgs) (*server.TestSe // actually starting the server. func (tc *TestCluster) startServer(idx int, serverArgs base.TestServerArgs) error { server := tc.Servers[idx] - if err := server.Start(); err != nil { + if err := server.Start(context.Background()); err != nil { return err } @@ -1327,17 +1327,22 @@ func (tc *TestCluster) RestartServerWithInspect(idx int, inspect func(s *server. } s := srv.(*server.TestServer) + ctx := context.Background() if err := func() error { - tc.mu.Lock() - defer tc.mu.Unlock() - tc.Servers[idx] = s - tc.mu.serverStoppers[idx] = s.Stopper() - - if inspect != nil { - inspect(s) - } + func() { + // Only lock the assignment of the server and the stopper and the call to the inspect function. + // This ensures that the stopper's Stop() method can abort an async Start() call. + tc.mu.Lock() + defer tc.mu.Unlock() + tc.Servers[idx] = s + tc.mu.serverStoppers[idx] = s.Stopper() + + if inspect != nil { + inspect(s) + } + }() - if err := srv.Start(); err != nil { + if err := srv.Start(ctx); err != nil { return err } @@ -1357,7 +1362,7 @@ func (tc *TestCluster) RestartServerWithInspect(idx int, inspect func(s *server. // different port, and a cycle of gossip is necessary to make all other nodes // aware. return contextutil.RunWithTimeout( - context.Background(), "check-conn", 15*time.Second, + ctx, "check-conn", 15*time.Second, func(ctx context.Context) error { r := retry.StartWithCtx(ctx, retry.Options{ InitialBackoff: 1 * time.Millisecond,