diff --git a/pkg/cli/demo_cluster.go b/pkg/cli/demo_cluster.go index 970d1d97d50a..93f6f67a038f 100644 --- a/pkg/cli/demo_cluster.go +++ b/pkg/cli/demo_cluster.go @@ -193,7 +193,7 @@ func (c *transientCluster) start( // the start routine needs to wait for the latency map construction after their RPC address has been computed. if demoCtx.simulateLatency { go func(i int) { - if err := serv.Start(); err != nil { + if err := serv.Start(ctx); err != nil { errCh <- err } else { // Block until the ReadyFn has been called before continuing. @@ -203,7 +203,7 @@ func (c *transientCluster) start( }(i) <-servRPCReadyCh } else { - if err := serv.Start(); err != nil { + if err := serv.Start(ctx); err != nil { return err } // Block until the ReadyFn has been called before continuing. @@ -516,7 +516,7 @@ func (c *transientCluster) RestartNode(nodeID roachpb.NodeID) error { close(readyCh) } - if err := serv.Start(); err != nil { + if err := serv.Start(context.Background()); err != nil { return err } diff --git a/pkg/server/connectivity_test.go b/pkg/server/connectivity_test.go index f474365d8097..869d5d4ca5ea 100644 --- a/pkg/server/connectivity_test.go +++ b/pkg/server/connectivity_test.go @@ -333,7 +333,8 @@ func TestJoinVersionGate(t *testing.T) { } defer serv.Stop() - if err := serv.Start(); !errors.Is(errors.Cause(err), server.ErrIncompatibleBinaryVersion) { + ctx := context.Background() + if err := serv.Start(ctx); !errors.Is(errors.Cause(err), server.ErrIncompatibleBinaryVersion) { t.Fatalf("expected error %s, got %v", server.ErrIncompatibleBinaryVersion.Error(), err.Error()) } } diff --git a/pkg/server/server.go b/pkg/server/server.go index 2a7f65b65de4..20a0bf6e75dc 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -1284,7 +1284,20 @@ func (s *Server) PreStart(ctx context.Context) error { close(knobs.SignalAfterGettingRPCAddress) } if knobs.PauseAfterGettingRPCAddress != nil { - <-knobs.PauseAfterGettingRPCAddress + select { + case <-knobs.PauseAfterGettingRPCAddress: + // Normal case. Just continue below. + + case <-ctx.Done(): + // Test timeout or some other condition in the caller, by which + // we are instructed to stop. + return ctx.Err() + + case <-s.stopper.ShouldQuiesce(): + // The server is instructed to stop before it even finished + // starting up. + return nil + } } } diff --git a/pkg/server/testing_knobs.go b/pkg/server/testing_knobs.go index e7cc9dac6f9d..b22c0886be22 100644 --- a/pkg/server/testing_knobs.go +++ b/pkg/server/testing_knobs.go @@ -29,12 +29,13 @@ type TestingKnobs struct { DefaultZoneConfigOverride *zonepb.ZoneConfig // DefaultSystemZoneConfigOverride, if set, overrides the default system zone config defined in `pkg/config/zone.go` DefaultSystemZoneConfigOverride *zonepb.ZoneConfig - // PauseAfterGettingRPCAddress, if non-nil, instructs the server to wait until - // the channel is closed after getting an RPC serving address. - PauseAfterGettingRPCAddress chan struct{} // SignalAfterGettingRPCAddress, if non-nil, is closed after the server gets - // an RPC server address. + // an RPC server address, and prior to waiting on PauseAfterGettingRPCAddress below. SignalAfterGettingRPCAddress chan struct{} + // PauseAfterGettingRPCAddress, if non-nil, instructs the server to wait until + // the channel is closed after determining its RPC serving address, and after + // closing SignalAfterGettingRPCAddress. + PauseAfterGettingRPCAddress chan struct{} // ContextTestingKnobs allows customization of the RPC context testing knobs. ContextTestingKnobs rpc.ContextTestingKnobs // DiagnosticsTestingKnobs allows customization of diagnostics testing knobs. diff --git a/pkg/server/testserver.go b/pkg/server/testserver.go index 65a8427fb7e3..a9551d8b24ed 100644 --- a/pkg/server/testserver.go +++ b/pkg/server/testserver.go @@ -460,8 +460,7 @@ func (ts *TestServer) NodeDialer() *nodedialer.Dialer { // TestServer.ServingRPCAddr() after Start() for client connections. // Use TestServer.Stopper().Stop() to shutdown the server after the test // completes. -func (ts *TestServer) Start() error { - ctx := context.Background() +func (ts *TestServer) Start(ctx context.Context) error { return ts.Server.Start(ctx) } diff --git a/pkg/testutils/reduce/reducesql/reducesql_test.go b/pkg/testutils/reduce/reducesql/reducesql_test.go index 601aaf6a3f8a..058ac0e2b071 100644 --- a/pkg/testutils/reduce/reducesql/reducesql_test.go +++ b/pkg/testutils/reduce/reducesql/reducesql_test.go @@ -47,7 +47,7 @@ func isInterestingSQL(contains string) reduce.InterestingFn { } serv := ts.(*server.TestServer) defer serv.Stopper().Stop(ctx) - if err := serv.Start(); err != nil { + if err := serv.Start(context.Background()); err != nil { panic(err) } diff --git a/pkg/testutils/serverutils/test_server_shim.go b/pkg/testutils/serverutils/test_server_shim.go index bcd6ee07722f..86f99d44798d 100644 --- a/pkg/testutils/serverutils/test_server_shim.go +++ b/pkg/testutils/serverutils/test_server_shim.go @@ -47,7 +47,7 @@ import ( type TestServerInterface interface { Stopper() *stop.Stopper - Start() error + Start(context.Context) error // Node returns the server.Node as an interface{}. Node() interface{} @@ -260,7 +260,7 @@ func StartServer( if err != nil { t.Fatalf("%+v", err) } - if err := server.Start(); err != nil { + if err := server.Start(context.Background()); err != nil { t.Fatalf("%+v", err) } goDB := OpenDBConn( @@ -328,7 +328,7 @@ func StartServerRaw(args base.TestServerArgs) (TestServerInterface, error) { if err != nil { return nil, err } - if err := server.Start(); err != nil { + if err := server.Start(context.Background()); err != nil { return nil, err } return server, nil diff --git a/pkg/testutils/testcluster/testcluster.go b/pkg/testutils/testcluster/testcluster.go index 8016081440d1..f9b7c9963b5a 100644 --- a/pkg/testutils/testcluster/testcluster.go +++ b/pkg/testutils/testcluster/testcluster.go @@ -462,7 +462,7 @@ func (tc *TestCluster) AddServer(serverArgs base.TestServerArgs) (*server.TestSe // actually starting the server. func (tc *TestCluster) startServer(idx int, serverArgs base.TestServerArgs) error { server := tc.Servers[idx] - if err := server.Start(); err != nil { + if err := server.Start(context.Background()); err != nil { return err } @@ -1306,17 +1306,22 @@ func (tc *TestCluster) RestartServerWithInspect(idx int, inspect func(s *server. } s := srv.(*server.TestServer) + ctx := context.Background() if err := func() error { - tc.mu.Lock() - defer tc.mu.Unlock() - tc.Servers[idx] = s - tc.mu.serverStoppers[idx] = s.Stopper() - - if inspect != nil { - inspect(s) - } + func() { + // Only lock the assignment of the server and the stopper and the call to the inspect function. + // This ensures that the stopper's Stop() method can abort an async Start() call. + tc.mu.Lock() + defer tc.mu.Unlock() + tc.Servers[idx] = s + tc.mu.serverStoppers[idx] = s.Stopper() + + if inspect != nil { + inspect(s) + } + }() - if err := srv.Start(); err != nil { + if err := srv.Start(ctx); err != nil { return err } @@ -1336,7 +1341,7 @@ func (tc *TestCluster) RestartServerWithInspect(idx int, inspect func(s *server. // different port, and a cycle of gossip is necessary to make all other nodes // aware. return contextutil.RunWithTimeout( - context.Background(), "check-conn", 15*time.Second, + ctx, "check-conn", 15*time.Second, func(ctx context.Context) error { r := retry.StartWithCtx(ctx, retry.Options{ InitialBackoff: 1 * time.Millisecond,