From cb250ee6ee7ff68bc33bdd3d8385d668e2831b70 Mon Sep 17 00:00:00 2001 From: Raphael 'kena' Poss Date: Fri, 2 Apr 2021 17:45:10 +0200 Subject: [PATCH] server,testutils: ensure server start is friendly to async cancellation This commit ensures that the server Start() method is properly sensitive to its context cancellation and possible async interruption by its stopper. Release note: None --- pkg/cli/demo_cluster.go | 6 ++--- pkg/server/connectivity_test.go | 3 ++- pkg/server/server.go | 15 ++++++++++- pkg/server/testing_knobs.go | 9 ++++--- pkg/server/testserver.go | 3 +-- .../reduce/reducesql/reducesql_test.go | 2 +- pkg/testutils/serverutils/test_server_shim.go | 6 ++--- pkg/testutils/testcluster/testcluster.go | 27 +++++++++++-------- 8 files changed, 45 insertions(+), 26 deletions(-) diff --git a/pkg/cli/demo_cluster.go b/pkg/cli/demo_cluster.go index 970d1d97d50a..93f6f67a038f 100644 --- a/pkg/cli/demo_cluster.go +++ b/pkg/cli/demo_cluster.go @@ -193,7 +193,7 @@ func (c *transientCluster) start( // the start routine needs to wait for the latency map construction after their RPC address has been computed. if demoCtx.simulateLatency { go func(i int) { - if err := serv.Start(); err != nil { + if err := serv.Start(ctx); err != nil { errCh <- err } else { // Block until the ReadyFn has been called before continuing. @@ -203,7 +203,7 @@ func (c *transientCluster) start( }(i) <-servRPCReadyCh } else { - if err := serv.Start(); err != nil { + if err := serv.Start(ctx); err != nil { return err } // Block until the ReadyFn has been called before continuing. @@ -516,7 +516,7 @@ func (c *transientCluster) RestartNode(nodeID roachpb.NodeID) error { close(readyCh) } - if err := serv.Start(); err != nil { + if err := serv.Start(context.Background()); err != nil { return err } diff --git a/pkg/server/connectivity_test.go b/pkg/server/connectivity_test.go index f474365d8097..869d5d4ca5ea 100644 --- a/pkg/server/connectivity_test.go +++ b/pkg/server/connectivity_test.go @@ -333,7 +333,8 @@ func TestJoinVersionGate(t *testing.T) { } defer serv.Stop() - if err := serv.Start(); !errors.Is(errors.Cause(err), server.ErrIncompatibleBinaryVersion) { + ctx := context.Background() + if err := serv.Start(ctx); !errors.Is(errors.Cause(err), server.ErrIncompatibleBinaryVersion) { t.Fatalf("expected error %s, got %v", server.ErrIncompatibleBinaryVersion.Error(), err.Error()) } } diff --git a/pkg/server/server.go b/pkg/server/server.go index 2a7f65b65de4..20a0bf6e75dc 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -1284,7 +1284,20 @@ func (s *Server) PreStart(ctx context.Context) error { close(knobs.SignalAfterGettingRPCAddress) } if knobs.PauseAfterGettingRPCAddress != nil { - <-knobs.PauseAfterGettingRPCAddress + select { + case <-knobs.PauseAfterGettingRPCAddress: + // Normal case. Just continue below. + + case <-ctx.Done(): + // Test timeout or some other condition in the caller, by which + // we are instructed to stop. + return ctx.Err() + + case <-s.stopper.ShouldQuiesce(): + // The server is instructed to stop before it even finished + // starting up. + return nil + } } } diff --git a/pkg/server/testing_knobs.go b/pkg/server/testing_knobs.go index e7cc9dac6f9d..b22c0886be22 100644 --- a/pkg/server/testing_knobs.go +++ b/pkg/server/testing_knobs.go @@ -29,12 +29,13 @@ type TestingKnobs struct { DefaultZoneConfigOverride *zonepb.ZoneConfig // DefaultSystemZoneConfigOverride, if set, overrides the default system zone config defined in `pkg/config/zone.go` DefaultSystemZoneConfigOverride *zonepb.ZoneConfig - // PauseAfterGettingRPCAddress, if non-nil, instructs the server to wait until - // the channel is closed after getting an RPC serving address. - PauseAfterGettingRPCAddress chan struct{} // SignalAfterGettingRPCAddress, if non-nil, is closed after the server gets - // an RPC server address. + // an RPC server address, and prior to waiting on PauseAfterGettingRPCAddress below. SignalAfterGettingRPCAddress chan struct{} + // PauseAfterGettingRPCAddress, if non-nil, instructs the server to wait until + // the channel is closed after determining its RPC serving address, and after + // closing SignalAfterGettingRPCAddress. + PauseAfterGettingRPCAddress chan struct{} // ContextTestingKnobs allows customization of the RPC context testing knobs. ContextTestingKnobs rpc.ContextTestingKnobs // DiagnosticsTestingKnobs allows customization of diagnostics testing knobs. diff --git a/pkg/server/testserver.go b/pkg/server/testserver.go index 65a8427fb7e3..a9551d8b24ed 100644 --- a/pkg/server/testserver.go +++ b/pkg/server/testserver.go @@ -460,8 +460,7 @@ func (ts *TestServer) NodeDialer() *nodedialer.Dialer { // TestServer.ServingRPCAddr() after Start() for client connections. // Use TestServer.Stopper().Stop() to shutdown the server after the test // completes. -func (ts *TestServer) Start() error { - ctx := context.Background() +func (ts *TestServer) Start(ctx context.Context) error { return ts.Server.Start(ctx) } diff --git a/pkg/testutils/reduce/reducesql/reducesql_test.go b/pkg/testutils/reduce/reducesql/reducesql_test.go index 601aaf6a3f8a..058ac0e2b071 100644 --- a/pkg/testutils/reduce/reducesql/reducesql_test.go +++ b/pkg/testutils/reduce/reducesql/reducesql_test.go @@ -47,7 +47,7 @@ func isInterestingSQL(contains string) reduce.InterestingFn { } serv := ts.(*server.TestServer) defer serv.Stopper().Stop(ctx) - if err := serv.Start(); err != nil { + if err := serv.Start(context.Background()); err != nil { panic(err) } diff --git a/pkg/testutils/serverutils/test_server_shim.go b/pkg/testutils/serverutils/test_server_shim.go index bcd6ee07722f..86f99d44798d 100644 --- a/pkg/testutils/serverutils/test_server_shim.go +++ b/pkg/testutils/serverutils/test_server_shim.go @@ -47,7 +47,7 @@ import ( type TestServerInterface interface { Stopper() *stop.Stopper - Start() error + Start(context.Context) error // Node returns the server.Node as an interface{}. Node() interface{} @@ -260,7 +260,7 @@ func StartServer( if err != nil { t.Fatalf("%+v", err) } - if err := server.Start(); err != nil { + if err := server.Start(context.Background()); err != nil { t.Fatalf("%+v", err) } goDB := OpenDBConn( @@ -328,7 +328,7 @@ func StartServerRaw(args base.TestServerArgs) (TestServerInterface, error) { if err != nil { return nil, err } - if err := server.Start(); err != nil { + if err := server.Start(context.Background()); err != nil { return nil, err } return server, nil diff --git a/pkg/testutils/testcluster/testcluster.go b/pkg/testutils/testcluster/testcluster.go index 8016081440d1..f9b7c9963b5a 100644 --- a/pkg/testutils/testcluster/testcluster.go +++ b/pkg/testutils/testcluster/testcluster.go @@ -462,7 +462,7 @@ func (tc *TestCluster) AddServer(serverArgs base.TestServerArgs) (*server.TestSe // actually starting the server. func (tc *TestCluster) startServer(idx int, serverArgs base.TestServerArgs) error { server := tc.Servers[idx] - if err := server.Start(); err != nil { + if err := server.Start(context.Background()); err != nil { return err } @@ -1306,17 +1306,22 @@ func (tc *TestCluster) RestartServerWithInspect(idx int, inspect func(s *server. } s := srv.(*server.TestServer) + ctx := context.Background() if err := func() error { - tc.mu.Lock() - defer tc.mu.Unlock() - tc.Servers[idx] = s - tc.mu.serverStoppers[idx] = s.Stopper() - - if inspect != nil { - inspect(s) - } + func() { + // Only lock the assignment of the server and the stopper and the call to the inspect function. + // This ensures that the stopper's Stop() method can abort an async Start() call. + tc.mu.Lock() + defer tc.mu.Unlock() + tc.Servers[idx] = s + tc.mu.serverStoppers[idx] = s.Stopper() + + if inspect != nil { + inspect(s) + } + }() - if err := srv.Start(); err != nil { + if err := srv.Start(ctx); err != nil { return err } @@ -1336,7 +1341,7 @@ func (tc *TestCluster) RestartServerWithInspect(idx int, inspect func(s *server. // different port, and a cycle of gossip is necessary to make all other nodes // aware. return contextutil.RunWithTimeout( - context.Background(), "check-conn", 15*time.Second, + ctx, "check-conn", 15*time.Second, func(ctx context.Context) error { r := retry.StartWithCtx(ctx, retry.Options{ InitialBackoff: 1 * time.Millisecond,