From 19e670f0cd603a60ad089631d0ab049822172ae9 Mon Sep 17 00:00:00 2001 From: Raphael 'kena' Poss Date: Sat, 3 Apr 2021 13:15:46 +0200 Subject: [PATCH] cli/demo: refactor the server initialization The latency simulation code needs to be injected as a latency map while the servers are initialized, i.e. concurrently with server startup. The previous initialization code for `cockroach demo` to achieve this was exceedly difficult to understand and was, in fact, incorrect. This commit reworks this code by exposing the overall workings as a comment and then ensuring the structure of the comment follows the explanation. It also add logging. Additionally, this change ensures that the same initialization code is used regardless of whether latency simulation is requested or not. Release note: None Co-authored-by: Raphael 'kena' Poss Co-authored-by: Oliver Tan --- pkg/cli/demo_cluster.go | 573 ++++++++++++++++++++++++++++----------- pkg/cli/sql.go | 2 +- pkg/server/server.go | 7 +- pkg/server/testserver.go | 6 + 4 files changed, 433 insertions(+), 155 deletions(-) diff --git a/pkg/cli/demo_cluster.go b/pkg/cli/demo_cluster.go index 93f6f67a038f..789a35468d84 100644 --- a/pkg/cli/demo_cluster.go +++ b/pkg/cli/demo_cluster.go @@ -37,23 +37,24 @@ import ( "github.com/cockroachdb/cockroach/pkg/sql/catalog/catconstants" "github.com/cockroachdb/cockroach/pkg/util/log" "github.com/cockroachdb/cockroach/pkg/util/log/severity" + "github.com/cockroachdb/cockroach/pkg/util/retry" "github.com/cockroachdb/cockroach/pkg/util/stop" - "github.com/cockroachdb/cockroach/pkg/util/timeutil" "github.com/cockroachdb/cockroach/pkg/workload" "github.com/cockroachdb/cockroach/pkg/workload/histogram" "github.com/cockroachdb/cockroach/pkg/workload/workloadsql" "github.com/cockroachdb/errors" + "github.com/cockroachdb/logtags" "github.com/spf13/cobra" "golang.org/x/time/rate" ) type transientCluster struct { - connURL string - demoDir string - useSockets bool - stopper *stop.Stopper - s *server.TestServer - servers []*server.TestServer + connURL string + demoDir string + useSockets bool + stopper *stop.Stopper + firstServer *server.TestServer + servers []*server.TestServer httpFirstPort int sqlFirstPort int @@ -119,184 +120,427 @@ func (c *transientCluster) checkConfigAndSetupLogging( func (c *transientCluster) start( ctx context.Context, cmd *cobra.Command, gen workload.Generator, ) (err error) { - serverFactory := server.TestServerFactory - var servers []*server.TestServer + ctx = logtags.AddTag(ctx, "start-demo-cluster", nil) + // We now proceed to start all the nodes concurrently. This is + // somewhat a complicated dance. + // + // On the one hand, we use a concurrent start, because the latency + // map needs to be initialized after all the nodes have started + // listening on the network, but before they proceed to initialize + // their RPC context. + // + // On the other hand, we cannot use full concurrency either, because + // we need to wait until the *first* node has started + // listening on the network before we can start the next nodes. + // + // So we proceed in phases, as follows: + // + // 1. create and start the first node asynchronously. + // 2. wait for the first node to listen on RPC and determine its + // listen addr; OR wait for an error in the first node initialization. + // 3. create and start all the other nodes asynchronously. + // 4. wait for all the nodes to listen on RPC, OR wait for an error + // from any node. + // 5. if no error, proceed to initialize the latency map. + // 6. in sequence, let each node initialize then wait for RPC readiness OR error from them. + // This ensures the node IDs are assigned sequentially. + // 7. wait for the SQL readiness from all nodes. + // 8. after all nodes are initialized, initialize SQL and telemetry. + // - // latencyMapWaitCh is used to block test servers after RPC address computation until the artificial - // latency map has been constructed. - latencyMapWaitCh := make(chan struct{}) + timeoutCh := time.After(maxNodeInitTime) - // errCh is used to catch all errors when initializing servers. - // Sending a nil on this channel indicates success. + // errCh is used to catch errors when initializing servers. errCh := make(chan error, demoCtx.nodes) - for i := 0; i < demoCtx.nodes; i++ { - // All the nodes connect to the address of the first server created. - var joinAddr string - if i != 0 { - joinAddr = c.s.ServingRPCAddr() + // rpcAddrReadyChs will be used in steps 2 and 4 below + // to wait until the nodes know their RPC address. + rpcAddrReadyChs := make([]chan struct{}, demoCtx.nodes) + + // latencyMapWaitChs is used to block test servers after RPC address + // computation until the artificial latency map has been constructed. + latencyMapWaitChs := make([]chan struct{}, demoCtx.nodes) + + // Step 1: create the first node. + { + phaseCtx := logtags.AddTag(ctx, "phase", 1) + log.Infof(phaseCtx, "creating the first node") + + latencyMapWaitChs[0] = make(chan struct{}) + firstRPCAddrReadyCh, err := c.createAndAddNode(phaseCtx, 0, latencyMapWaitChs[0], timeoutCh) + if err != nil { + return err + } + rpcAddrReadyChs[0] = firstRPCAddrReadyCh + } + + // Step 2: start the first node asynchronously, then wait for RPC + // listen readiness or error. + { + phaseCtx := logtags.AddTag(ctx, "phase", 2) + + log.Infof(phaseCtx, "starting first node") + if err := c.startNodeAsync(phaseCtx, 0, errCh, timeoutCh); err != nil { + return err } - nodeID := roachpb.NodeID(i + 1) - args := testServerArgsForTransientCluster( - c.sockForServer(nodeID), nodeID, joinAddr, c.demoDir, - c.sqlFirstPort, - c.httpFirstPort, - c.stickyEngineRegistry, - ) - if i == 0 { - // The first node also auto-inits the cluster. - args.NoAutoInitializeCluster = false + log.Infof(phaseCtx, "waiting for first node RPC address") + if err := c.waitForRPCAddrReadinessOrError(phaseCtx, 0, errCh, rpcAddrReadyChs, timeoutCh); err != nil { + return err } + } - // servRPCReadyCh is used if latency simulation is requested to notify that a test server has - // successfully computed its RPC address. - servRPCReadyCh := make(chan struct{}) + // Step 3: create the other nodes and start them asynchronously. + { + phaseCtx := logtags.AddTag(ctx, "phase", 3) + log.Infof(phaseCtx, "starting other nodes") - if demoCtx.simulateLatency { - serverKnobs := args.Knobs.Server.(*server.TestingKnobs) - serverKnobs.PauseAfterGettingRPCAddress = latencyMapWaitCh - serverKnobs.SignalAfterGettingRPCAddress = servRPCReadyCh - serverKnobs.ContextTestingKnobs = rpc.ContextTestingKnobs{ - ArtificialLatencyMap: make(map[string]int), + for i := 1; i < demoCtx.nodes; i++ { + latencyMapWaitChs[i] = make(chan struct{}) + rpcAddrReady, err := c.createAndAddNode(phaseCtx, i, latencyMapWaitChs[i], timeoutCh) + if err != nil { + return err } + rpcAddrReadyChs[i] = rpcAddrReady } - s, err := serverFactory.New(args) - if err != nil { - return err - } - serv := s.(*server.TestServer) - c.stopper.AddCloser(stop.CloserFn(serv.Stop)) - if i == 0 { - c.s = serv - // The first node connects its Settings instance to the `log` - // package for crash reporting. - // - // There's a known shortcoming with this approach: restarting - // node 1 using the \demo commands will break this connection: - // if the user changes the cluster setting after restarting node - // 1, the `log` package will not see this change. - // - // TODO(knz): re-connect the `log` package every time the first - // node is restarted and gets a new `Settings` instance. - settings.SetCanonicalValuesContainer(&serv.ClusterSettings().SV) - } - servers = append(servers, serv) - - // We force a wait for all servers until they are ready. - servReadyFnCh := make(chan struct{}) - serv.Cfg.ReadyFn = func(bool) { - close(servReadyFnCh) - } - - // If latency simulation is requested, start the servers in a background thread. We do this because - // the start routine needs to wait for the latency map construction after their RPC address has been computed. - if demoCtx.simulateLatency { - go func(i int) { - if err := serv.Start(ctx); err != nil { - errCh <- err - } else { - // Block until the ReadyFn has been called before continuing. - <-servReadyFnCh - errCh <- nil - } - }(i) - <-servRPCReadyCh - } else { - if err := serv.Start(ctx); err != nil { + // Ensure we close all sticky stores we've created when the stopper + // instructs the entire cluster to stop. We do this only here + // because we want this closer to be registered after all the + // individual servers' Stop() methods have been registered + // via createAndAddNode() above. + c.stopper.AddCloser(stop.CloserFn(func() { + c.stickyEngineRegistry.CloseAllStickyInMemEngines() + })) + + // Start the remaining nodes asynchronously. + for i := 1; i < demoCtx.nodes; i++ { + if err := c.startNodeAsync(phaseCtx, i, errCh, timeoutCh); err != nil { return err } - // Block until the ReadyFn has been called before continuing. - <-servReadyFnCh - errCh <- nil } } - // Ensure we close all sticky stores we've created. - c.stopper.AddCloser(stop.CloserFn(func() { - c.stickyEngineRegistry.CloseAllStickyInMemEngines() - })) - c.servers = servers + // Step 4: wait for all the nodes to know their RPC address, + // or for an error or premature shutdown. + { + phaseCtx := logtags.AddTag(ctx, "phase", 4) + log.Infof(phaseCtx, "waiting for remaining nodes to get their RPC address") - if demoCtx.simulateLatency { - // Now, all servers have been started enough to know their own RPC serving - // addresses, but nothing else. Assemble the artificial latency map. - for i, src := range servers { - latencyMap := src.Cfg.TestingKnobs.Server.(*server.TestingKnobs).ContextTestingKnobs.ArtificialLatencyMap - srcLocality, ok := src.Cfg.Locality.Find("region") - if !ok { - continue - } - srcLocalityMap, ok := regionToRegionToLatency[srcLocality] - if !ok { - continue + for i := 0; i < demoCtx.nodes; i++ { + if err := c.waitForRPCAddrReadinessOrError(phaseCtx, i, errCh, rpcAddrReadyChs, timeoutCh); err != nil { + return err } - for j, dst := range servers { - if i == j { + } + } + + // Step 5: optionally initialize the latency map, then let the servers + // proceed with their initialization. + + { + phaseCtx := logtags.AddTag(ctx, "phase", 5) + + // If latency simulation is requested, initialize the latency map. + if demoCtx.simulateLatency { + // Now, all servers have been started enough to know their own RPC serving + // addresses, but nothing else. Assemble the artificial latency map. + log.Infof(phaseCtx, "initializing latency map") + for i, serv := range c.servers { + latencyMap := serv.Cfg.TestingKnobs.Server.(*server.TestingKnobs).ContextTestingKnobs.ArtificialLatencyMap + srcLocality, ok := serv.Cfg.Locality.Find("region") + if !ok { continue } - dstLocality, ok := dst.Cfg.Locality.Find("region") + srcLocalityMap, ok := regionToRegionToLatency[srcLocality] if !ok { continue } - latency := srcLocalityMap[dstLocality] - latencyMap[dst.ServingRPCAddr()] = latency + for j, dst := range c.servers { + if i == j { + continue + } + dstLocality, ok := dst.Cfg.Locality.Find("region") + if !ok { + continue + } + latency := srcLocalityMap[dstLocality] + latencyMap[dst.ServingRPCAddr()] = latency + } } } } - // We've assembled our latency maps and are ready for all servers to proceed - // through bootstrapping. - close(latencyMapWaitCh) + { + phaseCtx := logtags.AddTag(ctx, "phase", 6) + + for i := 0; i < demoCtx.nodes; i++ { + log.Infof(phaseCtx, "letting server %d initialize", i) + close(latencyMapWaitChs[i]) + if err := c.waitForNodeIDReadiness(phaseCtx, i, errCh, timeoutCh); err != nil { + return err + } + log.Infof(phaseCtx, "node n%d initialized", c.servers[i].NodeID()) + } + } - // Wait for all servers to respond. { - timeRemaining := maxNodeInitTime - lastUpdateTime := timeutil.Now() - var err error + phaseCtx := logtags.AddTag(ctx, "phase", 7) + for i := 0; i < demoCtx.nodes; i++ { - select { - case e := <-errCh: - err = errors.CombineErrors(err, e) - case <-time.After(timeRemaining): - return errors.New("failed to setup transientCluster in time") + log.Infof(phaseCtx, "waiting for server %d SQL readiness", i) + if err := c.waitForSQLReadiness(phaseCtx, i, errCh, timeoutCh); err != nil { + return err } - updateTime := timeutil.Now() - timeRemaining -= updateTime.Sub(lastUpdateTime) - lastUpdateTime = updateTime + log.Infof(phaseCtx, "node n%d ready", c.servers[i].NodeID()) + } + } + + { + phaseCtx := logtags.AddTag(ctx, "phase", 8) + + // Run the SQL initialization. This takes care of setting up the + // initial replication factor for small clusters and creating the + // admin user. + log.Infof(phaseCtx, "running initial SQL for demo cluster") + + const demoUsername = "demo" + demoPassword := genDemoPassword(demoUsername) + if err := runInitialSQL(phaseCtx, c.firstServer.Server, demoCtx.nodes < 3, demoUsername, demoPassword); err != nil { + return err } + if demoCtx.insecure { + c.adminUser = security.RootUserName() + c.adminPassword = "unused" + } else { + c.adminUser = security.MakeSQLUsernameFromPreNormalizedString(demoUsername) + c.adminPassword = demoPassword + } + + // Prepare the URL for use by the SQL shell. + c.connURL, err = c.getNetworkURLForServer(0, gen, true /* includeAppName */) if err != nil { return err } + + // Start up the update check loop. + // We don't do this in (*server.Server).Start() because we don't want this + // overhead and possible interference in tests. + if !demoCtx.disableTelemetry { + log.Infof(phaseCtx, "starting telemetry") + c.firstServer.StartDiagnostics(phaseCtx) + } } + return nil +} - // Run the SQL initialization. This takes care of setting up the - // initial replication factor for small clusters and creating the - // admin user. - const demoUsername = "demo" - demoPassword := genDemoPassword(demoUsername) - if err := runInitialSQL(ctx, c.s.Server, demoCtx.nodes < 3, demoUsername, demoPassword); err != nil { - return err +// createAndAddNode is responsible for determining node parameters, +// instantiating the server component and connecting it to the +// cluster's stopper. +// +// The caller is responsible for calling createAndAddNode() with idx 0 +// synchronously, then startNodeAsync(), then +// waitForNodeRPCListener(), before using createAndAddNode() with +// other indexes. +func (c *transientCluster) createAndAddNode( + ctx context.Context, idx int, latencyMapWaitCh chan struct{}, timeoutCh <-chan time.Time, +) (rpcAddrReadyCh chan struct{}, err error) { + var joinAddr string + if idx > 0 { + // The caller is responsible for ensuring that the method + // is not called before the first server has finished + // computing its RPC listen address. + joinAddr = c.firstServer.ServingRPCAddr() } - if demoCtx.insecure { - c.adminUser = security.RootUserName() - c.adminPassword = "unused" - } else { - c.adminUser = security.MakeSQLUsernameFromPreNormalizedString(demoUsername) - c.adminPassword = demoPassword + nodeID := roachpb.NodeID(idx + 1) + args := testServerArgsForTransientCluster( + c.sockForServer(nodeID), nodeID, joinAddr, c.demoDir, + c.sqlFirstPort, + c.httpFirstPort, + c.stickyEngineRegistry, + ) + if idx == 0 { + // The first node also auto-inits the cluster. + args.NoAutoInitializeCluster = false + } + + serverKnobs := args.Knobs.Server.(*server.TestingKnobs) + + // SignalAfterGettingRPCAddress will be closed by the server startup routine + // once it has determined its RPC address. + rpcAddrReadyCh = make(chan struct{}) + serverKnobs.SignalAfterGettingRPCAddress = rpcAddrReadyCh + + // The server will wait until PauseAfterGettingRPCAddress is closed + // after it has signaled SignalAfterGettingRPCAddress, and before + // it continues the startup routine. + serverKnobs.PauseAfterGettingRPCAddress = latencyMapWaitCh + + if demoCtx.simulateLatency { + // The latency map will be populated after all servers have + // started listening on RPC, and before they proceed with their + // startup routine. + serverKnobs.ContextTestingKnobs = rpc.ContextTestingKnobs{ + ArtificialLatencyMap: make(map[string]int), + } } - // Prepare the URL for use by the SQL shell. - c.connURL, err = c.getNetworkURLForServer(0, gen, true /* includeAppName */) + // Create the server instance. This also registers the in-memory store + // into the sticky engine registry. + s, err := server.TestServerFactory.New(args) if err != nil { + return nil, err + } + serv := s.(*server.TestServer) + + // Ensure that this server gets stopped when the top level demo + // stopper instructs the cluster to stop. + c.stopper.AddCloser(stop.CloserFn(serv.Stop)) + + if idx == 0 { + // Remember the first server for later use by other APIs on + // transientCluster. + c.firstServer = serv + // The first node connects its Settings instance to the `log` + // package for crash reporting. + // + // There's a known shortcoming with this approach: restarting + // node 1 using the \demo commands will break this connection: + // if the user changes the cluster setting after restarting node + // 1, the `log` package will not see this change. + // + // TODO(knz): re-connect the `log` package every time the first + // node is restarted and gets a new `Settings` instance. + settings.SetCanonicalValuesContainer(&serv.ClusterSettings().SV) + } + + // Remember this server for the stop/restart primitives in the SQL + // shell. + c.servers = append(c.servers, serv) + + return rpcAddrReadyCh, nil +} + +// startNodeAsync starts the node initialization asynchronously. +func (c *transientCluster) startNodeAsync( + ctx context.Context, idx int, errCh chan error, timeoutCh <-chan time.Time, +) error { + if idx > len(c.servers) { + return errors.AssertionFailedf("programming error: server %d not created yet", idx) + } + + serv := c.servers[idx] + tag := fmt.Sprintf("start-n%d", idx+1) + return c.stopper.RunAsyncTask(ctx, tag, func(ctx context.Context) { + ctx = logtags.AddTag(ctx, tag, nil) + err := serv.Start(ctx) + if err != nil { + log.Warningf(ctx, "server %d failed to start: %v", idx, err) + select { + case errCh <- err: + + // Don't block if we are shutting down. + case <-ctx.Done(): + case <-serv.Stopper().ShouldQuiesce(): + case <-c.stopper.ShouldQuiesce(): + case <-timeoutCh: + } + } + }) +} + +// waitForRPCAddrReadinessOrError waits until the given server knows its +// RPC address or fails to initialize. +func (c *transientCluster) waitForRPCAddrReadinessOrError( + ctx context.Context, + idx int, + errCh chan error, + rpcAddrReadyChs []chan struct{}, + timeoutCh <-chan time.Time, +) error { + if idx > len(rpcAddrReadyChs) || idx > len(c.servers) { + return errors.AssertionFailedf("programming error: server %d not created yet", idx) + } + + select { + case <-rpcAddrReadyChs[idx]: + // This server knows its RPC address. Proceed with the next phase. + return nil + + // If we are asked for an early shutdown by the cases below or a + // server startup failure, detect it here. + case err := <-errCh: return err + case <-timeoutCh: + return errors.Newf("demo startup timeout while waiting for server %d", idx) + case <-ctx.Done(): + return errors.CombineErrors(ctx.Err(), errors.Newf("server %d startup aborted due to context cancellation", idx)) + case <-c.servers[idx].Stopper().ShouldQuiesce(): + return errors.Newf("server %d stopped prematurely", idx) + case <-c.stopper.ShouldQuiesce(): + return errors.Newf("demo cluster stopped prematurely while starting server %d", idx) + } +} + +// waitForNodeIDReadiness waits until the given server reports it knows its node ID. +func (c *transientCluster) waitForNodeIDReadiness( + ctx context.Context, idx int, errCh chan error, timeoutCh <-chan time.Time, +) error { + retryOpts := retry.Options{InitialBackoff: 10 * time.Millisecond, MaxBackoff: time.Second} + for r := retry.StartWithCtx(ctx, retryOpts); r.Next(); { + log.Infof(ctx, "waiting for server %d to know its node ID", idx) + select { + // Errors or premature shutdown. + case <-timeoutCh: + return errors.Newf("initialization timeout while waiting for server %d node ID", idx) + case err := <-errCh: + return errors.Wrapf(err, "server %d failed to start", idx) + case <-ctx.Done(): + return errors.CombineErrors(errors.Newf("context cancellation while waiting for server %d to have a node ID", idx), ctx.Err()) + case <-c.servers[idx].Stopper().ShouldQuiesce(): + return errors.Newf("server %s shut down prematurely", idx) + case <-c.stopper.ShouldQuiesce(): + return errors.Newf("demo cluster shut down prematurely while waiting for server %d to have a node ID", idx) + + default: + if c.servers[idx].NodeID() == 0 { + log.Infof(ctx, "server %d does not know its node ID yet", idx) + continue + } else { + log.Infof(ctx, "server %d: n%d", idx, c.servers[idx].NodeID()) + } + } + break } + return nil +} - // Start up the update check loop. - // We don't do this in (*server.Server).Start() because we don't want this - // overhead and possible interference in tests. - if !demoCtx.disableTelemetry { - c.s.StartDiagnostics(ctx) +// waitForSQLReadiness waits until the given server reports it is +// healthy and ready to accept SQL clients. +func (c *transientCluster) waitForSQLReadiness( + baseCtx context.Context, idx int, errCh chan error, timeoutCh <-chan time.Time, +) error { + retryOpts := retry.Options{InitialBackoff: 10 * time.Millisecond, MaxBackoff: time.Second} + for r := retry.StartWithCtx(baseCtx, retryOpts); r.Next(); { + ctx := logtags.AddTag(baseCtx, "n", c.servers[idx].NodeID()) + log.Infof(ctx, "waiting for server %d to become ready", idx) + select { + // Errors or premature shutdown. + case <-timeoutCh: + return errors.Newf("initialization timeout while waiting for server %d readiness", idx) + case err := <-errCh: + return errors.Wrapf(err, "server %d failed to start", idx) + case <-ctx.Done(): + return errors.CombineErrors(errors.Newf("context cancellation while waiting for server %d to become ready", idx), ctx.Err()) + case <-c.servers[idx].Stopper().ShouldQuiesce(): + return errors.Newf("server %s shut down prematurely", idx) + case <-c.stopper.ShouldQuiesce(): + return errors.Newf("demo cluster shut down prematurely while waiting for server %d to become ready", idx) + default: + if err := c.servers[idx].Readiness(ctx); err != nil { + log.Infof(ctx, "server %d not yet ready: %v", idx, err) + continue + } + } + break } return nil } @@ -429,7 +673,7 @@ func (c *transientCluster) Recommission(nodeID roachpb.NodeID) error { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - adminClient, finish, err := getAdminClient(ctx, *(c.s.Cfg)) + adminClient, finish, err := getAdminClient(ctx, *(c.firstServer.Cfg)) if err != nil { return err } @@ -454,7 +698,7 @@ func (c *transientCluster) Decommission(nodeID roachpb.NodeID) error { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - adminClient, finish, err := getAdminClient(ctx, *(c.s.Cfg)) + adminClient, finish, err := getAdminClient(ctx, *(c.firstServer.Cfg)) if err != nil { return err } @@ -502,7 +746,10 @@ func (c *transientCluster) RestartNode(nodeID roachpb.NodeID) error { } // TODO(#42243): re-compute the latency mapping. - args := testServerArgsForTransientCluster(c.sockForServer(nodeID), nodeID, c.s.ServingRPCAddr(), c.demoDir, + // TODO(...): the RPC address of the first server may not be available + // if the first server was shut down. + args := testServerArgsForTransientCluster(c.sockForServer(nodeID), nodeID, + c.firstServer.ServingRPCAddr(), c.demoDir, c.sqlFirstPort, c.httpFirstPort, c.stickyEngineRegistry) s, err := server.TestServerFactory.New(args) if err != nil { @@ -535,7 +782,7 @@ func (c *transientCluster) RestartNode(nodeID roachpb.NodeID) error { // AddNode create a new node in the cluster and start it. // This function uses RestartNode to perform the actual node // starting. -func (c *transientCluster) AddNode(localityString string) error { +func (c *transientCluster) AddNode(ctx context.Context, localityString string) error { // '\demo add' accepts both strings that are quoted and not quoted. To properly make use of // quoted strings, strip off the quotes. Before we do that though, make sure that the quotes match, // or that there aren't any quotes in the string. @@ -757,17 +1004,23 @@ func (c *transientCluster) runWorkload( log.Warningf(ctx, "error running workload query: %+v", err) } select { - case <-c.s.Stopper().ShouldQuiesce(): + case <-c.firstServer.Stopper().ShouldQuiesce(): + return + case <-ctx.Done(): + log.Warningf(ctx, "workload terminating from context cancellation: %v", ctx.Err()) + return + case <-c.stopper.ShouldQuiesce(): + log.Warningf(ctx, "demo cluster shutting down") return default: } } } } - // As the SQL shell is tied to `c.s`, this means we want to tie the workload + // As the SQL shell is tied to `c.firstServer`, this means we want to tie the workload // onto this as we want the workload to stop when the server dies, // rather than the cluster. Otherwise, interrupts on cockroach demo hangs. - if err := c.s.Stopper().RunAsyncTask(ctx, "workload", workloadFun(workerFn)); err != nil { + if err := c.firstServer.Stopper().RunAsyncTask(ctx, "workload", workloadFun(workerFn)); err != nil { return err } } @@ -797,17 +1050,33 @@ func (c *transientCluster) acquireDemoLicense(ctx context.Context) (chan error, go func() { defer db.Close() - success, err := GetAndApplyLicense(db, c.s.ClusterID(), demoOrg) + success, err := GetAndApplyLicense(db, c.firstServer.ClusterID(), demoOrg) if err != nil { - licenseDone <- err + select { + case licenseDone <- err: + + // Avoid waiting on the license channel write if the + // server or cluster is shutting down. + case <-ctx.Done(): + case <-c.firstServer.Stopper().ShouldQuiesce(): + case <-c.stopper.ShouldQuiesce(): + } return } if !success { if demoCtx.geoPartitionedReplicas { - licenseDone <- errors.WithDetailf( + select { + case licenseDone <- errors.WithDetailf( errors.New("unable to acquire a license for this demo"), "Enterprise features are needed for this demo (--%s).", - cliflags.DemoGeoPartitionedReplicas.Name) + cliflags.DemoGeoPartitionedReplicas.Name): + + // Avoid waiting on the license channel write if the + // server or cluster is shutting down. + case <-ctx.Done(): + case <-c.firstServer.Stopper().ShouldQuiesce(): + case <-c.stopper.ShouldQuiesce(): + } return } } diff --git a/pkg/cli/sql.go b/pkg/cli/sql.go index 4e59d50817d1..6370b0f210e9 100644 --- a/pkg/cli/sql.go +++ b/pkg/cli/sql.go @@ -569,7 +569,7 @@ func (c *cliState) handleDemoAddNode(cmd []string, nextState, errState cliStateE return nextState } - if err := demoCtx.transientCluster.AddNode(cmd[1]); err != nil { + if err := demoCtx.transientCluster.AddNode(context.Background(), cmd[1]); err != nil { return c.internalServerError(errState, err) } addedNodeID := len(demoCtx.transientCluster.servers) diff --git a/pkg/server/server.go b/pkg/server/server.go index 20a0bf6e75dc..4985b4d1eacc 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -1281,9 +1281,11 @@ func (s *Server) PreStart(ctx context.Context) error { if s.cfg.TestingKnobs.Server != nil { knobs := s.cfg.TestingKnobs.Server.(*TestingKnobs) if knobs.SignalAfterGettingRPCAddress != nil { + log.Infof(ctx, "signaling caller that RPC address is ready") close(knobs.SignalAfterGettingRPCAddress) } if knobs.PauseAfterGettingRPCAddress != nil { + log.Infof(ctx, "waiting for signal from caller to proceed with initialization") select { case <-knobs.PauseAfterGettingRPCAddress: // Normal case. Just continue below. @@ -1291,13 +1293,14 @@ func (s *Server) PreStart(ctx context.Context) error { case <-ctx.Done(): // Test timeout or some other condition in the caller, by which // we are instructed to stop. - return ctx.Err() + return errors.CombineErrors(errors.New("server stopping prematurely from context shutdown"), ctx.Err()) case <-s.stopper.ShouldQuiesce(): // The server is instructed to stop before it even finished // starting up. - return nil + return errors.New("server stopping prematurely") } + log.Infof(ctx, "caller is letting us proceed with initialization") } } diff --git a/pkg/server/testserver.go b/pkg/server/testserver.go index a9551d8b24ed..72e8108fe8ff 100644 --- a/pkg/server/testserver.go +++ b/pkg/server/testserver.go @@ -838,6 +838,12 @@ func (ts *TestServer) DrainClients(ctx context.Context) error { return ts.drainClients(ctx, nil /* reporter */) } +// Readiness returns nil when the server's health probe reports +// readiness, a readiness error otherwise. +func (ts *TestServer) Readiness(ctx context.Context) error { + return ts.admin.checkReadinessForHealthCheck(ctx) +} + // WriteSummaries implements TestServerInterface. func (ts *TestServer) WriteSummaries() error { return ts.node.writeNodeStatus(context.TODO(), time.Hour, false)