Skip to content

Commit

Permalink
server,testutils: ensure server start is friendly to async cancellation
Browse files Browse the repository at this point in the history
This commit ensures that the server Start() method is properly
sensitive to its context cancellation and possible async interruption
by its stopper.

Release note: None
  • Loading branch information
knz committed Apr 2, 2021
1 parent 527f243 commit cb250ee
Show file tree
Hide file tree
Showing 8 changed files with 45 additions and 26 deletions.
6 changes: 3 additions & 3 deletions pkg/cli/demo_cluster.go
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ func (c *transientCluster) start(
// the start routine needs to wait for the latency map construction after their RPC address has been computed.
if demoCtx.simulateLatency {
go func(i int) {
if err := serv.Start(); err != nil {
if err := serv.Start(ctx); err != nil {
errCh <- err
} else {
// Block until the ReadyFn has been called before continuing.
Expand All @@ -203,7 +203,7 @@ func (c *transientCluster) start(
}(i)
<-servRPCReadyCh
} else {
if err := serv.Start(); err != nil {
if err := serv.Start(ctx); err != nil {
return err
}
// Block until the ReadyFn has been called before continuing.
Expand Down Expand Up @@ -516,7 +516,7 @@ func (c *transientCluster) RestartNode(nodeID roachpb.NodeID) error {
close(readyCh)
}

if err := serv.Start(); err != nil {
if err := serv.Start(context.Background()); err != nil {
return err
}

Expand Down
3 changes: 2 additions & 1 deletion pkg/server/connectivity_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,8 @@ func TestJoinVersionGate(t *testing.T) {
}
defer serv.Stop()

if err := serv.Start(); !errors.Is(errors.Cause(err), server.ErrIncompatibleBinaryVersion) {
ctx := context.Background()
if err := serv.Start(ctx); !errors.Is(errors.Cause(err), server.ErrIncompatibleBinaryVersion) {
t.Fatalf("expected error %s, got %v", server.ErrIncompatibleBinaryVersion.Error(), err.Error())
}
}
Expand Down
15 changes: 14 additions & 1 deletion pkg/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -1284,7 +1284,20 @@ func (s *Server) PreStart(ctx context.Context) error {
close(knobs.SignalAfterGettingRPCAddress)
}
if knobs.PauseAfterGettingRPCAddress != nil {
<-knobs.PauseAfterGettingRPCAddress
select {
case <-knobs.PauseAfterGettingRPCAddress:
// Normal case. Just continue below.

case <-ctx.Done():
// Test timeout or some other condition in the caller, by which
// we are instructed to stop.
return ctx.Err()

case <-s.stopper.ShouldQuiesce():
// The server is instructed to stop before it even finished
// starting up.
return nil
}
}
}

Expand Down
9 changes: 5 additions & 4 deletions pkg/server/testing_knobs.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,13 @@ type TestingKnobs struct {
DefaultZoneConfigOverride *zonepb.ZoneConfig
// DefaultSystemZoneConfigOverride, if set, overrides the default system zone config defined in `pkg/config/zone.go`
DefaultSystemZoneConfigOverride *zonepb.ZoneConfig
// PauseAfterGettingRPCAddress, if non-nil, instructs the server to wait until
// the channel is closed after getting an RPC serving address.
PauseAfterGettingRPCAddress chan struct{}
// SignalAfterGettingRPCAddress, if non-nil, is closed after the server gets
// an RPC server address.
// an RPC server address, and prior to waiting on PauseAfterGettingRPCAddress below.
SignalAfterGettingRPCAddress chan struct{}
// PauseAfterGettingRPCAddress, if non-nil, instructs the server to wait until
// the channel is closed after determining its RPC serving address, and after
// closing SignalAfterGettingRPCAddress.
PauseAfterGettingRPCAddress chan struct{}
// ContextTestingKnobs allows customization of the RPC context testing knobs.
ContextTestingKnobs rpc.ContextTestingKnobs
// DiagnosticsTestingKnobs allows customization of diagnostics testing knobs.
Expand Down
3 changes: 1 addition & 2 deletions pkg/server/testserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -460,8 +460,7 @@ func (ts *TestServer) NodeDialer() *nodedialer.Dialer {
// TestServer.ServingRPCAddr() after Start() for client connections.
// Use TestServer.Stopper().Stop() to shutdown the server after the test
// completes.
func (ts *TestServer) Start() error {
ctx := context.Background()
func (ts *TestServer) Start(ctx context.Context) error {
return ts.Server.Start(ctx)
}

Expand Down
2 changes: 1 addition & 1 deletion pkg/testutils/reduce/reducesql/reducesql_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ func isInterestingSQL(contains string) reduce.InterestingFn {
}
serv := ts.(*server.TestServer)
defer serv.Stopper().Stop(ctx)
if err := serv.Start(); err != nil {
if err := serv.Start(context.Background()); err != nil {
panic(err)
}

Expand Down
6 changes: 3 additions & 3 deletions pkg/testutils/serverutils/test_server_shim.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ import (
type TestServerInterface interface {
Stopper() *stop.Stopper

Start() error
Start(context.Context) error

// Node returns the server.Node as an interface{}.
Node() interface{}
Expand Down Expand Up @@ -260,7 +260,7 @@ func StartServer(
if err != nil {
t.Fatalf("%+v", err)
}
if err := server.Start(); err != nil {
if err := server.Start(context.Background()); err != nil {
t.Fatalf("%+v", err)
}
goDB := OpenDBConn(
Expand Down Expand Up @@ -328,7 +328,7 @@ func StartServerRaw(args base.TestServerArgs) (TestServerInterface, error) {
if err != nil {
return nil, err
}
if err := server.Start(); err != nil {
if err := server.Start(context.Background()); err != nil {
return nil, err
}
return server, nil
Expand Down
27 changes: 16 additions & 11 deletions pkg/testutils/testcluster/testcluster.go
Original file line number Diff line number Diff line change
Expand Up @@ -462,7 +462,7 @@ func (tc *TestCluster) AddServer(serverArgs base.TestServerArgs) (*server.TestSe
// actually starting the server.
func (tc *TestCluster) startServer(idx int, serverArgs base.TestServerArgs) error {
server := tc.Servers[idx]
if err := server.Start(); err != nil {
if err := server.Start(context.Background()); err != nil {
return err
}

Expand Down Expand Up @@ -1306,17 +1306,22 @@ func (tc *TestCluster) RestartServerWithInspect(idx int, inspect func(s *server.
}
s := srv.(*server.TestServer)

ctx := context.Background()
if err := func() error {
tc.mu.Lock()
defer tc.mu.Unlock()
tc.Servers[idx] = s
tc.mu.serverStoppers[idx] = s.Stopper()

if inspect != nil {
inspect(s)
}
func() {
// Only lock the assignment of the server and the stopper and the call to the inspect function.
// This ensures that the stopper's Stop() method can abort an async Start() call.
tc.mu.Lock()
defer tc.mu.Unlock()
tc.Servers[idx] = s
tc.mu.serverStoppers[idx] = s.Stopper()

if inspect != nil {
inspect(s)
}
}()

if err := srv.Start(); err != nil {
if err := srv.Start(ctx); err != nil {
return err
}

Expand All @@ -1336,7 +1341,7 @@ func (tc *TestCluster) RestartServerWithInspect(idx int, inspect func(s *server.
// different port, and a cycle of gossip is necessary to make all other nodes
// aware.
return contextutil.RunWithTimeout(
context.Background(), "check-conn", 15*time.Second,
ctx, "check-conn", 15*time.Second,
func(ctx context.Context) error {
r := retry.StartWithCtx(ctx, retry.Options{
InitialBackoff: 1 * time.Millisecond,
Expand Down

0 comments on commit cb250ee

Please sign in to comment.