diff --git a/pkg/acceptance/cluster/dockercluster.go b/pkg/acceptance/cluster/dockercluster.go index e87a44f5d53a..960eb68db609 100644 --- a/pkg/acceptance/cluster/dockercluster.go +++ b/pkg/acceptance/cluster/dockercluster.go @@ -150,7 +150,7 @@ func CreateDocker( ctx context.Context, cfg TestConfig, volumesDir string, stopper *stop.Stopper, ) *DockerCluster { select { - case <-stopper.ShouldStop(): + case <-stopper.ShouldQuiesce(): // The stopper was already closed, exit early. os.Exit(1) default: @@ -563,7 +563,7 @@ func (l *DockerCluster) processEvent(ctx context.Context, event events.Message) // An event on any other container is unexpected. Die. select { - case <-l.stopper.ShouldStop(): + case <-l.stopper.ShouldQuiesce(): case <-l.monitorCtx.Done(): default: // There is a very tiny race here: the signal handler might be closing the @@ -698,7 +698,7 @@ func (l *DockerCluster) AssertAndStop(ctx context.Context, t testing.TB) { func (l *DockerCluster) stop(ctx context.Context) { if *waitOnStop { log.Infof(ctx, "waiting for interrupt") - <-l.stopper.ShouldStop() + <-l.stopper.ShouldQuiesce() } log.Infof(ctx, "stopping") diff --git a/pkg/acceptance/test_acceptance.go b/pkg/acceptance/test_acceptance.go index 1293c0a16ff7..96de4fb88a8d 100644 --- a/pkg/acceptance/test_acceptance.go +++ b/pkg/acceptance/test_acceptance.go @@ -51,13 +51,7 @@ func RunTests(m *testing.M) int { sig := make(chan os.Signal, 1) signal.Notify(sig, os.Interrupt) <-sig - select { - case <-stopper.ShouldStop(): - default: - // There is a very tiny race here: the cluster might be closing - // the stopper simultaneously. - stopper.Stop(ctx) - } + stopper.Stop(ctx) }() return m.Run() } diff --git a/pkg/acceptance/util_cluster.go b/pkg/acceptance/util_cluster.go index 14c4416b39a2..f57f6938977f 100644 --- a/pkg/acceptance/util_cluster.go +++ b/pkg/acceptance/util_cluster.go @@ -98,7 +98,7 @@ func StartCluster(ctx context.Context, t *testing.T, cfg cluster.TestConfig) (c testutils.SucceedsSoon(t, func() error { select { - case <-stopper.ShouldStop(): + case <-stopper.ShouldQuiesce(): t.Fatal("interrupted") case <-time.After(time.Second): } diff --git a/pkg/ccl/backupccl/backup_test.go b/pkg/ccl/backupccl/backup_test.go index ad5f87b3a67a..64e4498a4e1d 100644 --- a/pkg/ccl/backupccl/backup_test.go +++ b/pkg/ccl/backupccl/backup_test.go @@ -3414,7 +3414,7 @@ func TestBackupRestoreWithConcurrentWrites(t *testing.T) { var allowErrors int32 for task := 0; task < numBackgroundTasks; task++ { taskNum := task - tc.Stopper().RunWorker(context.Background(), func(context.Context) { + _ = tc.Stopper().RunAsyncTask(context.Background(), "bg-task", func(context.Context) { conn := tc.Conns[taskNum%len(tc.Conns)] // Use different sql gateways to make sure leasing is right. if err := startBackgroundWrites(tc.Stopper(), conn, rows, bgActivity, &allowErrors); err != nil { diff --git a/pkg/ccl/kvccl/kvtenantccl/connector.go b/pkg/ccl/kvccl/kvtenantccl/connector.go index 184ee9d1bec5..6721c0ca47cd 100644 --- a/pkg/ccl/kvccl/kvtenantccl/connector.go +++ b/pkg/ccl/kvccl/kvtenantccl/connector.go @@ -118,12 +118,14 @@ func (connectorFactory) NewConnector( // cluster's ID and set Connector.rpcContext.ClusterID. func (c *Connector) Start(ctx context.Context) error { startupC := c.startupC - c.rpcContext.Stopper.RunWorker(context.Background(), func(ctx context.Context) { + if err := c.rpcContext.Stopper.RunAsyncTask(context.Background(), "connector", func(ctx context.Context) { ctx = c.AnnotateCtx(ctx) ctx, cancel := c.rpcContext.Stopper.WithCancelOnQuiesce(ctx) defer cancel() c.runGossipSubscription(ctx) - }) + }); err != nil { + return err + } // Synchronously block until the first GossipSubscription event. select { case <-startupC: diff --git a/pkg/ccl/kvccl/kvtenantccl/connector_test.go b/pkg/ccl/kvccl/kvtenantccl/connector_test.go index 15fd9e9b10da..8cda06c9306c 100644 --- a/pkg/ccl/kvccl/kvtenantccl/connector_test.go +++ b/pkg/ccl/kvccl/kvtenantccl/connector_test.go @@ -371,11 +371,10 @@ func TestConnectorRetriesUnreachable(t *testing.T) { // Decompose netutil.ListenAndServeGRPC so we can listen before serving. ln, err := net.Listen(util.TestAddr.Network(), util.TestAddr.String()) require.NoError(t, err) - stopper.RunWorker(ctx, func(context.Context) { + stopper.AddCloser(stop.CloserFn(s.Stop)) + _ = stopper.RunAsyncTask(ctx, "wait-quiesce", func(context.Context) { <-stopper.ShouldQuiesce() netutil.FatalIfUnexpected(ln.Close()) - <-stopper.ShouldStop() - s.Stop() }) // Add listen address into list of other bogus addresses. @@ -401,7 +400,7 @@ func TestConnectorRetriesUnreachable(t *testing.T) { // Begin serving on gRPC server. Connector should quickly connect // and complete startup. - stopper.RunWorker(ctx, func(context.Context) { + _ = stopper.RunAsyncTask(ctx, "serve", func(context.Context) { netutil.FatalIfUnexpected(s.Serve(ln)) }) require.NoError(t, <-startedC) diff --git a/pkg/cli/cli_test.go b/pkg/cli/cli_test.go index dab9f14ae398..99c5ad0866de 100644 --- a/pkg/cli/cli_test.go +++ b/pkg/cli/cli_test.go @@ -187,14 +187,7 @@ func (c *cliTest) stopServer() { if c.TestServer != nil { log.Infof(context.Background(), "stopping server at %s / %s", c.ServingRPCAddr(), c.ServingSQLAddr()) - select { - case <-c.Stopper().ShouldStop(): - // If ShouldStop() doesn't block, that means someone has already - // called Stop(). We just need to wait. - <-c.Stopper().IsStopped() - default: - c.Stopper().Stop(context.Background()) - } + c.Stopper().Stop(context.Background()) } } diff --git a/pkg/cli/debug_synctest.go b/pkg/cli/debug_synctest.go index 15fe828fdeb4..6ea5eb4cd718 100644 --- a/pkg/cli/debug_synctest.go +++ b/pkg/cli/debug_synctest.go @@ -156,7 +156,7 @@ func runSyncer( waitFailure := time.After(time.Duration(rand.Int63n(5 * time.Second.Nanoseconds()))) - stopper.RunWorker(ctx, func(ctx context.Context) { + if err := stopper.RunAsyncTask(ctx, "syncer", func(ctx context.Context) { <-waitFailure if err := nemesis.On(); err != nil { panic(err) @@ -167,7 +167,9 @@ func runSyncer( } }() <-stopper.ShouldQuiesce() - }) + }); err != nil { + return 0, err + } ch := make(chan os.Signal, 1) signal.Notify(ch, drainSignals...) diff --git a/pkg/cli/demo_cluster.go b/pkg/cli/demo_cluster.go index f5415cc5bbc4..eb120c00fd46 100644 --- a/pkg/cli/demo_cluster.go +++ b/pkg/cli/demo_cluster.go @@ -751,7 +751,7 @@ func (c *transientCluster) runWorkload( log.Warningf(ctx, "error running workload query: %+v", err) } select { - case <-c.s.Stopper().ShouldStop(): + case <-c.s.Stopper().ShouldQuiesce(): return default: } @@ -761,7 +761,9 @@ func (c *transientCluster) runWorkload( // As the SQL shell is tied to `c.s`, this means we want to tie the workload // onto this as we want the workload to stop when the server dies, // rather than the cluster. Otherwise, interrupts on cockroach demo hangs. - c.s.Stopper().RunWorker(ctx, workloadFun(workerFn)) + if err := c.s.Stopper().RunAsyncTask(ctx, "workload", workloadFun(workerFn)); err != nil { + return err + } } return nil diff --git a/pkg/cli/start.go b/pkg/cli/start.go index 3b6f6fc34a21..1cbc525935b9 100644 --- a/pkg/cli/start.go +++ b/pkg/cli/start.go @@ -715,7 +715,7 @@ If problems persist, please see %s.` log.StartAlwaysFlush() return err - case <-stopper.ShouldStop(): + case <-stopper.ShouldQuiesce(): // Server is being stopped externally and our job is finished // here since we don't know if it's a graceful shutdown or not. <-stopper.IsStopped() @@ -826,7 +826,7 @@ If problems persist, please see %s.` select { case <-ticker.C: log.Ops.Infof(context.Background(), "%d running tasks", stopper.NumTasks()) - case <-stopper.ShouldStop(): + case <-stopper.IsStopped(): return case <-stopWithoutDrain: return diff --git a/pkg/cmd/roachtest/test_runner.go b/pkg/cmd/roachtest/test_runner.go index a1146371817d..dab1200abb21 100644 --- a/pkg/cmd/roachtest/test_runner.go +++ b/pkg/cmd/roachtest/test_runner.go @@ -269,7 +269,7 @@ func (r *testRunner) Run( for i := 0; i < parallelism; i++ { i := i // Copy for closure. wg.Add(1) - stopper.RunWorker(ctx, func(ctx context.Context) { + if err := stopper.RunAsyncTask(ctx, "worker", func(ctx context.Context) { defer wg.Done() if err := r.runWorker( @@ -284,15 +284,22 @@ func (r *testRunner) Run( msg := fmt.Sprintf("Worker %d returned with error. Quiescing. Error: %+v", i, err) shout(ctx, l, lopt.stdout, msg) errs.AddErr(err) - // Quiesce the stopper. This will cause all workers to not pick up more - // tests after finishing the currently running one. - stopper.Quiesce(ctx) + // Stop the stopper. This will cause all workers to not pick up more + // tests after finishing the currently running one. We add one to the + // WaitGroup so that wg.Wait() will also wait for the stopper. + wg.Add(1) + go func() { + defer wg.Done() + stopper.Stop(ctx) + }() // Interrupt everybody waiting for resources. if qp != nil { qp.Close(msg) } } - }) + }); err != nil { + wg.Done() + } } // Wait for all the workers to finish. diff --git a/pkg/gossip/client.go b/pkg/gossip/client.go index 9b8c536c160f..31310ada1ce5 100644 --- a/pkg/gossip/client.go +++ b/pkg/gossip/client.go @@ -81,7 +81,7 @@ func (c *client) startLocked( g.outgoing.addPlaceholder() ctx, cancel := context.WithCancel(c.AnnotateCtx(context.Background())) - stopper.RunWorker(ctx, func(ctx context.Context) { + if err := stopper.RunAsyncTask(ctx, "gossip-client", func(ctx context.Context) { var wg sync.WaitGroup defer func() { // This closes the outgoing stream, causing any attempt to send or @@ -133,7 +133,9 @@ func (c *client) startLocked( g.mu.RUnlock() } } - }) + }); err != nil { + disconnected <- c + } } // close stops the client gossip loop and returns immediately. @@ -311,7 +313,7 @@ func (c *client) gossip( // This wait group is used to allow the caller to wait until gossip // processing is terminated. wg.Add(1) - stopper.RunWorker(ctx, func(ctx context.Context) { + if err := stopper.RunAsyncTask(ctx, "client-gossip", func(ctx context.Context) { defer wg.Done() errCh <- func() error { @@ -335,7 +337,10 @@ func (c *client) gossip( } } }() - }) + }); err != nil { + wg.Done() + return err + } // We attempt to defer registration of the callback until we've heard a // response from the remote node which will contain the remote's high water @@ -366,7 +371,7 @@ func (c *client) gossip( select { case <-c.closer: return nil - case <-stopper.ShouldStop(): + case <-stopper.ShouldQuiesce(): return nil case err := <-errCh: return err diff --git a/pkg/gossip/gossip.go b/pkg/gossip/gossip.go index 02b714c30148..c65ead58855e 100644 --- a/pkg/gossip/gossip.go +++ b/pkg/gossip/gossip.go @@ -1296,12 +1296,12 @@ func (g *Gossip) getNextBootstrapAddressLocked() net.Addr { // lost and requires re-bootstrapping. func (g *Gossip) bootstrap() { ctx := g.AnnotateCtx(context.Background()) - g.server.stopper.RunWorker(ctx, func(ctx context.Context) { + _ = g.server.stopper.RunAsyncTask(ctx, "gossip-bootstrap", func(ctx context.Context) { ctx = logtags.AddTag(ctx, "bootstrap", nil) var bootstrapTimer timeutil.Timer defer bootstrapTimer.Stop() for { - if g.server.stopper.RunTask(ctx, "gossip.Gossip: bootstrap ", func(ctx context.Context) { + func(ctx context.Context) { g.mu.Lock() defer g.mu.Unlock() haveClients := g.outgoing.len() > 0 @@ -1322,9 +1322,7 @@ func (g *Gossip) bootstrap() { g.maybeSignalStatusChangeLocked() } } - }) != nil { - return - } + }(ctx) // Pause an interval before next possible bootstrap. bootstrapTimer.Reset(g.bootstrapInterval) @@ -1333,7 +1331,7 @@ func (g *Gossip) bootstrap() { case <-bootstrapTimer.C: bootstrapTimer.Read = true // continue - case <-g.server.stopper.ShouldStop(): + case <-g.server.stopper.ShouldQuiesce(): return } log.Eventf(ctx, "idling until bootstrap required") @@ -1342,7 +1340,7 @@ func (g *Gossip) bootstrap() { case <-g.stalledCh: log.Eventf(ctx, "detected stall; commencing bootstrap") // continue - case <-g.server.stopper.ShouldStop(): + case <-g.server.stopper.ShouldQuiesce(): return } } @@ -1361,7 +1359,7 @@ func (g *Gossip) bootstrap() { // is notified via the stalled conditional variable. func (g *Gossip) manage() { ctx := g.AnnotateCtx(context.Background()) - g.server.stopper.RunWorker(ctx, func(ctx context.Context) { + _ = g.server.stopper.RunAsyncTask(ctx, "gossip-manage", func(ctx context.Context) { clientsTimer := timeutil.NewTimer() cullTimer := timeutil.NewTimer() stallTimer := timeutil.NewTimer() @@ -1374,7 +1372,7 @@ func (g *Gossip) manage() { stallTimer.Reset(jitteredInterval(g.stallInterval)) for { select { - case <-g.server.stopper.ShouldStop(): + case <-g.server.stopper.ShouldQuiesce(): return case c := <-g.disconnected: g.doDisconnected(c) diff --git a/pkg/gossip/infostore.go b/pkg/gossip/infostore.go index 85ca4f6895ec..ef24561d2f72 100644 --- a/pkg/gossip/infostore.go +++ b/pkg/gossip/infostore.go @@ -171,7 +171,7 @@ func newInfoStore( callbackCh: make(chan struct{}, 1), } - is.stopper.RunWorker(context.Background(), func(ctx context.Context) { + _ = is.stopper.RunAsyncTask(context.Background(), "infostore", func(ctx context.Context) { for { for { is.callbackWorkMu.Lock() diff --git a/pkg/gossip/server.go b/pkg/gossip/server.go index 358f40d281ec..06a0501bd3a1 100644 --- a/pkg/gossip/server.go +++ b/pkg/gossip/server.go @@ -135,11 +135,8 @@ func (s *server) Gossip(stream Gossip_GossipServer) error { errCh := make(chan error, 1) - // Starting workers in a task prevents data races during shutdown. - if err := s.stopper.RunTask(ctx, "gossip.server: receiver", func(ctx context.Context) { - s.stopper.RunWorker(ctx, func(ctx context.Context) { - errCh <- s.gossipReceiver(ctx, &args, send, stream.Recv) - }) + if err := s.stopper.RunAsyncTask(ctx, "gossip receiver", func(ctx context.Context) { + errCh <- s.gossipReceiver(ctx, &args, send, stream.Recv) }); err != nil { return err } @@ -379,7 +376,7 @@ func (s *server) start(addr net.Addr) { broadcast() }, Redundant) - s.stopper.RunWorker(context.TODO(), func(context.Context) { + waitQuiesce := func(context.Context) { <-s.stopper.ShouldQuiesce() s.mu.Lock() @@ -387,7 +384,10 @@ func (s *server) start(addr net.Addr) { s.mu.Unlock() broadcast() - }) + } + if err := s.stopper.RunAsyncTask(context.Background(), "gossip-wait-quiesce", waitQuiesce); err != nil { + waitQuiesce(context.Background()) + } } func (s *server) status() ServerStatus { diff --git a/pkg/gossip/simulation/network.go b/pkg/gossip/simulation/network.go index 6ac8e891ea1e..b45f35b21999 100644 --- a/pkg/gossip/simulation/network.go +++ b/pkg/gossip/simulation/network.go @@ -116,11 +116,10 @@ func (n *Network) CreateNode(defaultZoneConfig *zonepb.ZoneConfig) (*Node, error } node := &Node{Server: server, Listener: ln, Registry: metric.NewRegistry()} node.Gossip = gossip.NewTest(0, n.RPCContext, server, n.Stopper, node.Registry, defaultZoneConfig) - n.Stopper.RunWorker(context.TODO(), func(context.Context) { + n.Stopper.AddCloser(stop.CloserFn(server.Stop)) + _ = n.Stopper.RunAsyncTask(context.TODO(), "node-wait-quiesce", func(context.Context) { <-n.Stopper.ShouldQuiesce() netutil.FatalIfUnexpected(ln.Close()) - <-n.Stopper.ShouldStop() - server.Stop() node.Gossip.EnableSimulationCycler(false) }) n.Nodes = append(n.Nodes, node) @@ -144,10 +143,9 @@ func (n *Network) StartNode(node *Node) error { encoding.EncodeUint64Ascending(nil, 0), time.Hour); err != nil { return err } - n.Stopper.RunWorker(context.TODO(), func(context.Context) { + return n.Stopper.RunAsyncTask(context.TODO(), "start-node", func(context.Context) { netutil.FatalIfUnexpected(node.Server.Serve(node.Listener)) }) - return nil } // GetNodeFromID returns the simulation node associated with diff --git a/pkg/internal/client/requestbatcher/batcher.go b/pkg/internal/client/requestbatcher/batcher.go index 20c282e45f12..9bc7580e0b03 100644 --- a/pkg/internal/client/requestbatcher/batcher.go +++ b/pkg/internal/client/requestbatcher/batcher.go @@ -267,7 +267,7 @@ func (b *RequestBatcher) sendDone(ba *batch) { } func (b *RequestBatcher) sendBatch(ctx context.Context, ba *batch) { - b.cfg.Stopper.RunWorker(ctx, func(ctx context.Context) { + if err := b.cfg.Stopper.RunAsyncTask(ctx, "send-batch", func(ctx context.Context) { defer b.sendDone(ba) var br *roachpb.BatchResponse send := func(ctx context.Context) error { @@ -340,7 +340,9 @@ func (b *RequestBatcher) sendBatch(ctx context.Context, ba *batch) { } ba.reqs, prevResps = nextReqs, nextPrevResps } - }) + }); err != nil { + b.sendDone(ba) + } } func (b *RequestBatcher) sendResponse(req *request, resp Response) { diff --git a/pkg/jobs/job_scheduler.go b/pkg/jobs/job_scheduler.go index 098526e2835b..59753c21d458 100644 --- a/pkg/jobs/job_scheduler.go +++ b/pkg/jobs/job_scheduler.go @@ -344,7 +344,7 @@ func (s *jobScheduler) executeSchedules( } func (s *jobScheduler) runDaemon(ctx context.Context, stopper *stop.Stopper) { - stopper.RunWorker(ctx, func(ctx context.Context) { + _ = stopper.RunAsyncTask(ctx, "job-scheduler", func(ctx context.Context) { initialDelay := getInitialScanDelay(s.TestingKnobs) log.Infof(ctx, "waiting %v before scheduled jobs daemon start", initialDelay) diff --git a/pkg/jobs/job_scheduler_test.go b/pkg/jobs/job_scheduler_test.go index 15e741d565c7..ede78afbb80e 100644 --- a/pkg/jobs/job_scheduler_test.go +++ b/pkg/jobs/job_scheduler_test.go @@ -303,7 +303,7 @@ func TestJobSchedulerCanBeDisabledWhileSleeping(t *testing.T) { // Notify main thread and return some small delay for daemon to sleep. select { case getWaitPeriodCalled <- struct{}{}: - case <-stopper.ShouldStop(): + case <-stopper.ShouldQuiesce(): } return 10 * time.Millisecond diff --git a/pkg/jobs/jobs.go b/pkg/jobs/jobs.go index 483b6903256b..85137f9dbd4d 100644 --- a/pkg/jobs/jobs.go +++ b/pkg/jobs/jobs.go @@ -902,7 +902,8 @@ func (sj *StartableJob) Run(ctx context.Context) error { case <-ctx.Done(): // Launch a goroutine to continue consuming results from the job. if resultsFromJob != nil { - sj.registry.stopper.RunWorker(ctx, func(ctx context.Context) { + // TODO(ajwerner): ctx is done; we shouldn't pass it to RunAsyncTask. + _ = sj.registry.stopper.RunAsyncTask(ctx, "job-results", func(ctx context.Context) { for { select { case <-errCh: @@ -911,6 +912,8 @@ func (sj *StartableJob) Run(ctx context.Context) error { if !ok { return } + case <-sj.registry.stopper.ShouldQuiesce(): + return } } }) diff --git a/pkg/kv/kvclient/kvcoord/transport_race.go b/pkg/kv/kvclient/kvcoord/transport_race.go index e9eb41c15479..3b1d6ad3f596 100644 --- a/pkg/kv/kvclient/kvcoord/transport_race.go +++ b/pkg/kv/kvclient/kvcoord/transport_race.go @@ -97,8 +97,6 @@ func GRPCTransportFactory( opts SendOptions, nodeDialer *nodedialer.Dialer, replicas ReplicaSlice, ) (Transport, error) { if atomic.AddInt32(&running, 1) <= 1 { - // NB: We can't use Stopper.RunWorker because doing so would race with - // calling Stopper.Stop. if err := nodeDialer.Stopper().RunAsyncTask( context.TODO(), "transport racer", func(ctx context.Context) { var iters int diff --git a/pkg/kv/kvserver/closedts/provider/provider.go b/pkg/kv/kvserver/closedts/provider/provider.go index 4d6125c2d0e3..eeee070eec5b 100644 --- a/pkg/kv/kvserver/closedts/provider/provider.go +++ b/pkg/kv/kvserver/closedts/provider/provider.go @@ -76,12 +76,12 @@ func NewProvider(cfg *Config) *Provider { } // Start implements closedts.Provider. -// -// TODO(tschottdorf): the closer functionality could be extracted into its own -// component, which would make the interfaces a little cleaner. Decide whether -// it's worth it during testing. func (p *Provider) Start() { - p.cfg.Stopper.RunWorker(logtags.AddTag(context.Background(), "ct-closer", nil), p.runCloser) + if err := p.cfg.Stopper.RunAsyncTask( + logtags.AddTag(context.Background(), "ct-closer", nil), "ct-closer", p.runCloser, + ); err != nil { + p.drain() + } } func (p *Provider) drain() { @@ -186,7 +186,11 @@ func (p *Provider) runCloser(ctx context.Context) { // TODO(tschottdorf): the transport should ignore connection requests from // the node to itself. Those connections would pointlessly loop this around // once more. - ch <- entry + select { + case ch <- entry: + case <-p.cfg.Stopper.ShouldQuiesce(): + return + } } } } @@ -196,7 +200,7 @@ func (p *Provider) runCloser(ctx context.Context) { func (p *Provider) Notify(nodeID roachpb.NodeID) chan<- ctpb.Entry { ch := make(chan ctpb.Entry) - p.cfg.Stopper.RunWorker(context.Background(), func(ctx context.Context) { + _ = p.cfg.Stopper.RunAsyncTask(context.Background(), "provider-notify", func(ctx context.Context) { handle := func(entry ctpb.Entry) { p.cfg.Storage.Add(nodeID, entry) } @@ -220,8 +224,16 @@ func (p *Provider) Notify(nodeID roachpb.NodeID) chan<- ctpb.Entry { p.mu.Broadcast() } } - for entry := range ch { - handle(entry) + for { + select { + case entry, ok := <-ch: + if !ok { + return + } + handle(entry) + case <-p.cfg.Stopper.ShouldQuiesce(): + return + } } }) diff --git a/pkg/kv/kvserver/closedts/provider/provider_test.go b/pkg/kv/kvserver/closedts/provider/provider_test.go index 9901c669053d..879cfe95f9a9 100644 --- a/pkg/kv/kvserver/closedts/provider/provider_test.go +++ b/pkg/kv/kvserver/closedts/provider/provider_test.go @@ -59,6 +59,7 @@ func TestProviderSubscribeNotify(t *testing.T) { Clock: func(roachpb.NodeID) (hlc.Timestamp, ctpb.Epoch, error) { select { case <-stopper.ShouldQuiesce(): + return hlc.Timestamp{}, 0, errors.New("stopping") case <-unblockClockCh: } return hlc.Timestamp{}, ctpb.Epoch(1), errors.New("injected clock error") @@ -105,7 +106,7 @@ func TestProviderSubscribeNotify(t *testing.T) { defer log.Infof(ctx, "done") ch := make(chan ctpb.Entry) - stopper.RunWorker(ctx, func(ctx context.Context) { + _ = stopper.RunAsyncTask(ctx, "subscribe", func(ctx context.Context) { p.Subscribe(ctx, ch) }) diff --git a/pkg/kv/kvserver/closedts/transport/clients.go b/pkg/kv/kvserver/closedts/transport/clients.go index d4679e7b56a9..2179ecbf20f9 100644 --- a/pkg/kv/kvserver/closedts/transport/clients.go +++ b/pkg/kv/kvserver/closedts/transport/clients.go @@ -115,7 +115,7 @@ func (pr *Clients) getOrCreateClient(nodeID roachpb.NodeID) *client { // If our client made it into the map, start it. The point in inserting // before starting is to be able to collect RangeIDs immediately while never // blocking callers. - pr.cfg.Stopper.RunWorker(ctx, func(ctx context.Context) { + if err := pr.cfg.Stopper.RunAsyncTask(ctx, "ct-client", func(ctx context.Context) { defer pr.clients.Delete(int64(nodeID)) c, err := pr.cfg.Dialer.Dial(ctx, nodeID) @@ -165,7 +165,9 @@ func (pr *Clients) getOrCreateClient(nodeID roachpb.NodeID) *client { Requested: slice, } } - }) + }); err != nil { + pr.clients.Delete(int64(nodeID)) + } return cl } diff --git a/pkg/kv/kvserver/closedts/transport/server.go b/pkg/kv/kvserver/closedts/transport/server.go index 0c29cac2d609..dc492e6ff65d 100644 --- a/pkg/kv/kvserver/closedts/transport/server.go +++ b/pkg/kv/kvserver/closedts/transport/server.go @@ -59,8 +59,6 @@ func (s *Server) Get(client ctpb.InboundClient) error { const closedTimestampNoUpdateWarnThreshold = 10 * time.Second t := timeutil.NewTimer() - // NB: We can't use Stopper.RunWorker because doing so would race with - // calling Stopper.Stop. if err := s.stopper.RunAsyncTask(ctx, "closedts-subscription", func(ctx context.Context) { s.p.Subscribe(ctx, ch) }); err != nil { diff --git a/pkg/kv/kvserver/closedts/transport/testutils/chan_dialer.go b/pkg/kv/kvserver/closedts/transport/testutils/chan_dialer.go index 5fdf13300025..303d9023e8fb 100644 --- a/pkg/kv/kvserver/closedts/transport/testutils/chan_dialer.go +++ b/pkg/kv/kvserver/closedts/transport/testutils/chan_dialer.go @@ -66,9 +66,11 @@ func (d *ChanDialer) Dial(ctx context.Context, nodeID roachpb.NodeID) (ctpb.Clie }, } - d.stopper.RunWorker(ctx, func(ctx context.Context) { + if err := d.stopper.RunAsyncTask(ctx, "closedts-dial", func(ctx context.Context) { _ = d.server.Get((*incomingClient)(c)) - }) + }); err != nil { + return nil, err + } return c, nil } diff --git a/pkg/kv/kvserver/closedts/transport/transport_util_test.go b/pkg/kv/kvserver/closedts/transport/transport_util_test.go index 52d79b95f2a6..b964a168337f 100644 --- a/pkg/kv/kvserver/closedts/transport/transport_util_test.go +++ b/pkg/kv/kvserver/closedts/transport/transport_util_test.go @@ -76,7 +76,7 @@ func newTestNotifyee(stopper *stop.Stopper) *TestNotifyee { func (tn *TestNotifyee) Notify(nodeID roachpb.NodeID) chan<- ctpb.Entry { ch := make(chan ctpb.Entry) - tn.stopper.RunWorker(context.Background(), func(ctx context.Context) { + _ = tn.stopper.RunAsyncTask(context.Background(), "test-notify", func(ctx context.Context) { for entry := range ch { tn.mu.Lock() tn.mu.entries[nodeID] = append(tn.mu.entries[nodeID], entry) diff --git a/pkg/kv/kvserver/consistency_queue_test.go b/pkg/kv/kvserver/consistency_queue_test.go index 95c456f0a485..124d5e5afb33 100644 --- a/pkg/kv/kvserver/consistency_queue_test.go +++ b/pkg/kv/kvserver/consistency_queue_test.go @@ -622,7 +622,7 @@ func testConsistencyQueueRecomputeStatsImpl(t *testing.T, hadEstimates bool) { // RecomputeStats does not see any skew in its MVCC stats when they are // modified concurrently. Note that these writes don't interfere with the // field we modified (SysCount). - tc.Stopper().RunWorker(ctx, func(ctx context.Context) { + _ = tc.Stopper().RunAsyncTask(ctx, "recompute-loop", func(ctx context.Context) { // This channel terminates the loop early if the test takes more than five // seconds. This is useful for stress race runs in CI where the tight loop // can starve the actual work to be done. diff --git a/pkg/kv/kvserver/idalloc/id_alloc.go b/pkg/kv/kvserver/idalloc/id_alloc.go index cd20b6cfaae5..5388fa974893 100644 --- a/pkg/kv/kvserver/idalloc/id_alloc.go +++ b/pkg/kv/kvserver/idalloc/id_alloc.go @@ -99,7 +99,7 @@ func (ia *Allocator) Allocate(ctx context.Context) (int64, error) { func (ia *Allocator) start() { ctx := ia.AnnotateCtx(context.Background()) - ia.opts.Stopper.RunWorker(ctx, func(ctx context.Context) { + if err := ia.opts.Stopper.RunAsyncTask(ctx, "id-alloc", func(ctx context.Context) { defer close(ia.ids) var prevValue int64 // for assertions @@ -150,10 +150,12 @@ func (ia *Allocator) start() { for i := start; i < end; i++ { select { case ia.ids <- i: - case <-ia.opts.Stopper.ShouldStop(): + case <-ia.opts.Stopper.ShouldQuiesce(): return } } } - }) + }); err != nil { + close(ia.ids) + } } diff --git a/pkg/kv/kvserver/intentresolver/intent_resolver.go b/pkg/kv/kvserver/intentresolver/intent_resolver.go index 6dc5808ee9f7..4eece296aae0 100644 --- a/pkg/kv/kvserver/intentresolver/intent_resolver.go +++ b/pkg/kv/kvserver/intentresolver/intent_resolver.go @@ -854,6 +854,8 @@ func (ir *IntentResolver) ResolveIntents( _ = resp.Resp // ignore the response case <-ctx.Done(): return roachpb.NewError(ctx.Err()) + case <-ir.stopper.ShouldQuiesce(): + return roachpb.NewErrorf("stopping") } } return nil diff --git a/pkg/kv/kvserver/liveness/liveness.go b/pkg/kv/kvserver/liveness/liveness.go index 35f14c5b3796..6f7e57970e61 100644 --- a/pkg/kv/kvserver/liveness/liveness.go +++ b/pkg/kv/kvserver/liveness/liveness.go @@ -667,10 +667,10 @@ func (nl *NodeLiveness) Start(ctx context.Context, opts NodeLivenessStartOptions nl.mu.engines = opts.Engines nl.mu.Unlock() - opts.Stopper.RunWorker(ctx, func(context.Context) { + _ = opts.Stopper.RunAsyncTask(ctx, "liveness-hb", func(context.Context) { ambient := nl.ambientCtx ambient.AddLogTag("liveness-hb", nil) - ctx, cancel := opts.Stopper.WithCancelOnStop(context.Background()) + ctx, cancel := opts.Stopper.WithCancelOnQuiesce(context.Background()) defer cancel() ctx, sp := ambient.AnnotateCtxWithSpan(ctx, "liveness heartbeat loop") defer sp.Finish() @@ -682,7 +682,7 @@ func (nl *NodeLiveness) Start(ctx context.Context, opts NodeLivenessStartOptions for { select { case <-nl.heartbeatToken: - case <-opts.Stopper.ShouldStop(): + case <-opts.Stopper.ShouldQuiesce(): return } // Give the context a timeout approximately as long as the time we @@ -719,7 +719,7 @@ func (nl *NodeLiveness) Start(ctx context.Context, opts NodeLivenessStartOptions nl.heartbeatToken <- struct{}{} select { case <-ticker.C: - case <-opts.Stopper.ShouldStop(): + case <-opts.Stopper.ShouldQuiesce(): return } } diff --git a/pkg/kv/kvserver/queue.go b/pkg/kv/kvserver/queue.go index 9635ea9c87d6..61014b585851 100644 --- a/pkg/kv/kvserver/queue.go +++ b/pkg/kv/kvserver/queue.go @@ -796,12 +796,13 @@ func (bq *baseQueue) MaybeRemove(rangeID roachpb.RangeID) { // stopper signals exit. func (bq *baseQueue) processLoop(stopper *stop.Stopper) { ctx := bq.AnnotateCtx(context.Background()) - stopper.RunWorker(ctx, func(ctx context.Context) { - defer func() { - bq.mu.Lock() - bq.mu.stopped = true - bq.mu.Unlock() - }() + stop := func() { + bq.mu.Lock() + bq.mu.stopped = true + bq.mu.Unlock() + } + if err := stopper.RunAsyncTask(ctx, "queue-loop", func(ctx context.Context) { + defer stop() // nextTime is initially nil; we don't start any timers until the queue // becomes non-empty. @@ -813,7 +814,7 @@ func (bq *baseQueue) processLoop(stopper *stop.Stopper) { for { select { // Exit on stopper. - case <-stopper.ShouldStop(): + case <-stopper.ShouldQuiesce(): return // Incoming signal sets the next time to process if there were previously @@ -872,7 +873,9 @@ func (bq *baseQueue) processLoop(stopper *stop.Stopper) { } } } - }) + }); err != nil { + stop() + } } // lastProcessDuration returns the duration of the last processing attempt. @@ -1140,7 +1143,7 @@ func (bq *baseQueue) addToPurgatoryLocked( } workerCtx := bq.AnnotateCtx(context.Background()) - stopper.RunWorker(workerCtx, func(ctx context.Context) { + _ = stopper.RunAsyncTask(workerCtx, "purgatory", func(ctx context.Context) { ticker := time.NewTicker(purgatoryReportInterval) for { select { @@ -1201,7 +1204,7 @@ func (bq *baseQueue) addToPurgatoryLocked( for errStr, count := range errMap { log.Errorf(ctx, "%d replicas failing with %q", count, errStr) } - case <-stopper.ShouldStop(): + case <-stopper.ShouldQuiesce(): return } } diff --git a/pkg/kv/kvserver/raft_transport.go b/pkg/kv/kvserver/raft_transport.go index 15c03ec1249a..5cce2d052a5d 100644 --- a/pkg/kv/kvserver/raft_transport.go +++ b/pkg/kv/kvserver/raft_transport.go @@ -190,7 +190,7 @@ func NewRaftTransport( } if t.stopper != nil && log.V(1) { ctx := t.AnnotateCtx(context.Background()) - t.stopper.RunWorker(ctx, func(ctx context.Context) { + _ = t.stopper.RunAsyncTask(ctx, "raft-transport", func(ctx context.Context) { ticker := time.NewTicker(10 * time.Second) defer ticker.Stop() lastStats := make(map[roachpb.NodeID]raftTransportStats) @@ -254,7 +254,7 @@ func NewRaftTransport( } lastTime = now log.Infof(ctx, "stats:\n%s", buf.String()) - case <-t.stopper.ShouldStop(): + case <-t.stopper.ShouldQuiesce(): return } } @@ -330,52 +330,50 @@ func (t *RaftTransport) RaftMessageBatch(stream MultiRaft_RaftMessageBatchServer errCh := make(chan error, 1) // Node stopping error is caught below in the select. - if err := t.stopper.RunTask( + if err := t.stopper.RunAsyncTask( stream.Context(), "storage.RaftTransport: processing batch", func(ctx context.Context) { - t.stopper.RunWorker(ctx, func(ctx context.Context) { - errCh <- func() error { - var stats *raftTransportStats - stream := &lockedRaftMessageResponseStream{wrapped: stream} - for { - batch, err := stream.Recv() - if err != nil { - return err - } - if len(batch.Requests) == 0 { - continue - } + errCh <- func() error { + var stats *raftTransportStats + stream := &lockedRaftMessageResponseStream{wrapped: stream} + for { + batch, err := stream.Recv() + if err != nil { + return err + } + if len(batch.Requests) == 0 { + continue + } - // This code always uses the DefaultClass. Class is primarily a - // client construct and the server has no way to determine which - // class an inbound connection holds on the client side. Because of - // this we associate all server receives and sends with the - // DefaultClass. This data is exclusively used to print a debug - // log message periodically. Using this policy may lead to a - // DefaultClass log line showing a high rate of server recv but - // a low rate of client sends if most of the traffic is due to - // system ranges. - // - // TODO(ajwerner): consider providing transport metadata to inform - // the server of the connection class or keep shared stats for all - // connection with a host. - if stats == nil { - stats = t.getStats(batch.Requests[0].FromReplica.NodeID, rpc.DefaultClass) - } + // This code always uses the DefaultClass. Class is primarily a + // client construct and the server has no way to determine which + // class an inbound connection holds on the client side. Because of + // this we associate all server receives and sends with the + // DefaultClass. This data is exclusively used to print a debug + // log message periodically. Using this policy may lead to a + // DefaultClass log line showing a high rate of server recv but + // a low rate of client sends if most of the traffic is due to + // system ranges. + // + // TODO(ajwerner): consider providing transport metadata to inform + // the server of the connection class or keep shared stats for all + // connection with a host. + if stats == nil { + stats = t.getStats(batch.Requests[0].FromReplica.NodeID, rpc.DefaultClass) + } - for i := range batch.Requests { - req := &batch.Requests[i] - atomic.AddInt64(&stats.serverRecv, 1) - if pErr := t.handleRaftRequest(ctx, req, stream); pErr != nil { - atomic.AddInt64(&stats.serverSent, 1) - if err := stream.Send(newRaftMessageResponse(req, pErr)); err != nil { - return err - } + for i := range batch.Requests { + req := &batch.Requests[i] + atomic.AddInt64(&stats.serverRecv, 1) + if pErr := t.handleRaftRequest(ctx, req, stream); pErr != nil { + atomic.AddInt64(&stats.serverSent, 1) + if err := stream.Send(newRaftMessageResponse(req, pErr)); err != nil { + return err } } } - }() - }) + } + }() }); err != nil { return err } @@ -417,7 +415,7 @@ func (t *RaftTransport) RaftSnapshot(stream MultiRaft_RaftSnapshotServer) error return err } select { - case <-t.stopper.ShouldStop(): + case <-t.stopper.ShouldQuiesce(): return nil case err := <-errCh: return err @@ -448,30 +446,29 @@ func (t *RaftTransport) processQueue( ) error { errCh := make(chan error, 1) - // Starting workers in a task prevents data races during shutdown. - if err := t.stopper.RunTask( - stream.Context(), "storage.RaftTransport: processing queue", + ctx := stream.Context() + + if err := t.stopper.RunAsyncTask( + ctx, "storage.RaftTransport: processing queue", func(ctx context.Context) { - t.stopper.RunWorker(ctx, func(ctx context.Context) { - errCh <- func() error { - for { - resp, err := stream.Recv() - if err != nil { - return err - } - atomic.AddInt64(&stats.clientRecv, 1) - handler, ok := t.getHandler(resp.ToReplica.StoreID) - if !ok { - log.Warningf(ctx, "no handler found for store %s in response %s", - resp.ToReplica.StoreID, resp) - continue - } - if err := handler.HandleRaftResponse(ctx, resp); err != nil { - return err - } + errCh <- func() error { + for { + resp, err := stream.Recv() + if err != nil { + return err } - }() - }) + atomic.AddInt64(&stats.clientRecv, 1) + handler, ok := t.getHandler(resp.ToReplica.StoreID) + if !ok { + log.Warningf(ctx, "no handler found for store %s in response %s", + resp.ToReplica.StoreID, resp) + continue + } + if err := handler.HandleRaftResponse(ctx, resp); err != nil { + return err + } + } + }() }); err != nil { return err } @@ -482,7 +479,7 @@ func (t *RaftTransport) processQueue( for { raftIdleTimer.Reset(raftIdleTimeout) select { - case <-t.stopper.ShouldStop(): + case <-t.stopper.ShouldQuiesce(): return nil case <-raftIdleTimer.C: raftIdleTimer.Read = true @@ -636,11 +633,7 @@ func (t *RaftTransport) startProcessNewQueue( log.Warningf(ctx, "while processing outgoing Raft queue to node %d: %s:", toNodeID, err) } } - // Starting workers in a task prevents data races during shutdown. - workerTask := func(ctx context.Context) { - t.stopper.RunWorker(ctx, worker) - } - err := t.stopper.RunTask(ctx, "storage.RaftTransport: sending messages", workerTask) + err := t.stopper.RunAsyncTask(ctx, "storage.RaftTransport: sending messages", worker) if err != nil { t.queues[class].Delete(int64(toNodeID)) return false diff --git a/pkg/kv/kvserver/replica_range_lease.go b/pkg/kv/kvserver/replica_range_lease.go index e488e65fb389..23b18d51e718 100644 --- a/pkg/kv/kvserver/replica_range_lease.go +++ b/pkg/kv/kvserver/replica_range_lease.go @@ -1117,7 +1117,7 @@ func (r *Replica) redirectOnOrAcquireLease( log.VErrEventf(ctx, 2, "lease acquisition failed: %s", ctx.Err()) return roachpb.NewError(newNotLeaseHolderError(nil, r.store.StoreID(), r.Desc(), "lease acquisition canceled because context canceled")) - case <-r.store.Stopper().ShouldStop(): + case <-r.store.Stopper().ShouldQuiesce(): llHandle.Cancel() return roachpb.NewError(newNotLeaseHolderError(nil, r.store.StoreID(), r.Desc(), "lease acquisition canceled because node is stopping")) diff --git a/pkg/kv/kvserver/replicate_queue_test.go b/pkg/kv/kvserver/replicate_queue_test.go index 372a383705a1..57307210ad2e 100644 --- a/pkg/kv/kvserver/replicate_queue_test.go +++ b/pkg/kv/kvserver/replicate_queue_test.go @@ -569,7 +569,7 @@ func TestTransferLeaseToLaggingNode(t *testing.T) { workerReady := make(chan bool) // Create persistent range load. - tc.Stopper().RunWorker(ctx, func(ctx context.Context) { + require.NoError(t, tc.Stopper().RunAsyncTask(ctx, "load", func(ctx context.Context) { s = sqlutils.MakeSQLRunner(tc.Conns[remoteNodeID-1]) workerReady <- true for { @@ -584,7 +584,7 @@ func TestTransferLeaseToLaggingNode(t *testing.T) { case <-time.After(queryInterval): } } - }) + })) <-workerReady // Wait until we see remote making progress leaseHolderRepl, err := leaseHolderStore.GetReplica(rangeID) diff --git a/pkg/kv/kvserver/reports/reporter.go b/pkg/kv/kvserver/reports/reporter.go index 78a1bd0eb0b0..375b2942a5fc 100644 --- a/pkg/kv/kvserver/reports/reporter.go +++ b/pkg/kv/kvserver/reports/reporter.go @@ -115,7 +115,7 @@ func (stats *Reporter) Start(ctx context.Context, stopper *stop.Stopper) { stats.frequencyMu.changeCh = make(chan struct{}) stats.frequencyMu.interval = ReporterInterval.Get(&stats.settings.SV) }) - stopper.RunWorker(ctx, func(ctx context.Context) { + _ = stopper.RunAsyncTask(ctx, "stats-reporter", func(ctx context.Context) { var timer timeutil.Timer defer timer.Stop() ctx = logtags.AddTag(ctx, "replication-reporter", nil /* value */) diff --git a/pkg/kv/kvserver/scanner.go b/pkg/kv/kvserver/scanner.go index a9301dbd7a00..096f1e3d184c 100644 --- a/pkg/kv/kvserver/scanner.go +++ b/pkg/kv/kvserver/scanner.go @@ -235,7 +235,7 @@ func (rs *replicaScanner) waitAndProcess( case repl := <-rs.removed: rs.removeReplica(repl) - case <-stopper.ShouldStop(): + case <-stopper.ShouldQuiesce(): return true } } @@ -259,7 +259,7 @@ func (rs *replicaScanner) removeReplica(repl *Replica) { // is paced to complete a full scan in approximately the scan interval. func (rs *replicaScanner) scanLoop(stopper *stop.Stopper) { ctx := rs.AnnotateCtx(context.Background()) - stopper.RunWorker(ctx, func(ctx context.Context) { + _ = stopper.RunAsyncTask(ctx, "scan-loop", func(ctx context.Context) { start := timeutil.Now() // waitTimer is reset in each call to waitAndProcess. @@ -324,7 +324,7 @@ func (rs *replicaScanner) waitEnabled(stopper *stop.Stopper) bool { case repl := <-rs.removed: rs.removeReplica(repl) - case <-stopper.ShouldStop(): + case <-stopper.ShouldQuiesce(): return true } } diff --git a/pkg/kv/kvserver/scanner_test.go b/pkg/kv/kvserver/scanner_test.go index 9d0da518908b..b5baef0effe2 100644 --- a/pkg/kv/kvserver/scanner_test.go +++ b/pkg/kv/kvserver/scanner_test.go @@ -120,7 +120,13 @@ func (tq *testQueue) setDisabled(d bool) { } func (tq *testQueue) Start(stopper *stop.Stopper) { - stopper.RunWorker(context.Background(), func(context.Context) { + done := func() { + tq.Lock() + tq.done = true + tq.Unlock() + } + + if err := stopper.RunAsyncTask(context.Background(), "testqueue", func(context.Context) { for { select { case <-time.After(1 * time.Millisecond): @@ -130,14 +136,14 @@ func (tq *testQueue) Start(stopper *stop.Stopper) { tq.processed++ } tq.Unlock() - case <-stopper.ShouldStop(): - tq.Lock() - tq.done = true - tq.Unlock() + case <-stopper.ShouldQuiesce(): + done() return } } - }) + }); err != nil { + done() + } } // NB: MaybeAddAsync on a testQueue is actually synchronous. diff --git a/pkg/kv/kvserver/scheduler.go b/pkg/kv/kvserver/scheduler.go index 037996a7efac..b5748ce42dfb 100644 --- a/pkg/kv/kvserver/scheduler.go +++ b/pkg/kv/kvserver/scheduler.go @@ -189,19 +189,22 @@ func newRaftScheduler( } func (s *raftScheduler) Start(ctx context.Context, stopper *stop.Stopper) { - stopper.RunWorker(ctx, func(ctx context.Context) { - <-stopper.ShouldStop() + waitQuiesce := func(context.Context) { + <-stopper.ShouldQuiesce() s.mu.Lock() s.mu.stopped = true s.mu.Unlock() s.mu.cond.Broadcast() - }) + } + if err := stopper.RunAsyncTask(ctx, "raftsched-wait-quiesce", waitQuiesce); err != nil { + waitQuiesce(ctx) + } s.done.Add(s.numWorkers) for i := 0; i < s.numWorkers; i++ { - stopper.RunWorker(ctx, func(ctx context.Context) { - s.worker(ctx) - }) + if err := stopper.RunAsyncTask(ctx, "raft-worker", s.worker); err != nil { + s.done.Done() + } } } diff --git a/pkg/kv/kvserver/single_key_test.go b/pkg/kv/kvserver/single_key_test.go index 1503c1fe5cba..4e6636aa3b02 100644 --- a/pkg/kv/kvserver/single_key_test.go +++ b/pkg/kv/kvserver/single_key_test.go @@ -104,7 +104,7 @@ func TestSingleKey(t *testing.T) { var results []result for len(results) < num { select { - case <-tc.Stopper().ShouldStop(): + case <-tc.Stopper().ShouldQuiesce(): t.Fatalf("interrupted") case r := <-resultCh: if r.err != nil { diff --git a/pkg/kv/kvserver/store.go b/pkg/kv/kvserver/store.go index 464b69b4b1a1..b9b5f0c5ce14 100644 --- a/pkg/kv/kvserver/store.go +++ b/pkg/kv/kvserver/store.go @@ -1061,7 +1061,7 @@ func (s *Store) SetDraining(drain bool, reporter func(int, redact.SafeString)) { // To prevent this, we add this code here which adds the missing // cancel + wait in the particular case where the stopper is // completing a shutdown while a graceful SetDrain is still ongoing. - ctx, cancelFn := s.stopper.WithCancelOnStop(baseCtx) + ctx, cancelFn := s.stopper.WithCancelOnQuiesce(baseCtx) defer cancelFn() var wg sync.WaitGroup @@ -1564,13 +1564,13 @@ func (s *Store) Start(ctx context.Context, stopper *stop.Stopper) error { // This may trigger splits along structured boundaries, // and update max range bytes. gossipUpdateC := s.cfg.Gossip.RegisterSystemConfigChannel() - s.stopper.RunWorker(ctx, func(context.Context) { + _ = s.stopper.RunAsyncTask(ctx, "syscfg-listener", func(context.Context) { for { select { case <-gossipUpdateC: cfg := s.cfg.Gossip.GetSystemConfig() s.systemGossipUpdate(cfg) - case <-s.stopper.ShouldStop(): + case <-s.stopper.ShouldQuiesce(): return } } @@ -1585,11 +1585,11 @@ func (s *Store) Start(ctx context.Context, stopper *stop.Stopper) error { // Start the scanner. The construction here makes sure that the scanner // only starts after Gossip has connected, and that it does not block Start // from returning (as doing so might prevent Gossip from ever connecting). - s.stopper.RunWorker(ctx, func(context.Context) { + _ = s.stopper.RunAsyncTask(ctx, "scanner", func(context.Context) { select { case <-s.cfg.Gossip.Connected: s.scanner.Start(s.stopper) - case <-s.stopper.ShouldStop(): + case <-s.stopper.ShouldQuiesce(): return } }) @@ -1696,7 +1696,7 @@ func (s *Store) startGossip() { s.initComplete.Add(len(gossipFns)) for _, gossipFn := range gossipFns { gossipFn := gossipFn // per-iteration copy - s.stopper.RunWorker(context.Background(), func(ctx context.Context) { + if err := s.stopper.RunAsyncTask(context.Background(), "store-gossip", func(ctx context.Context) { ticker := time.NewTicker(gossipFn.interval) defer ticker.Stop() for first := true; ; { @@ -1705,7 +1705,7 @@ func (s *Store) startGossip() { // making it impossible to get an epoch-based range lease), in which // case we want to retry quickly. retryOptions := base.DefaultRetryOptions() - retryOptions.Closer = s.stopper.ShouldStop() + retryOptions.Closer = s.stopper.ShouldQuiesce() for r := retry.Start(retryOptions); r.Next(); { if repl := s.LookupReplica(roachpb.RKey(gossipFn.key)); repl != nil { annotatedCtx := repl.AnnotateCtx(ctx) @@ -1724,11 +1724,13 @@ func (s *Store) startGossip() { } select { case <-ticker.C: - case <-s.stopper.ShouldStop(): + case <-s.stopper.ShouldQuiesce(): return } } - }) + }); err != nil { + s.initComplete.Done() + } } } @@ -1742,7 +1744,7 @@ func (s *Store) startGossip() { func (s *Store) startLeaseRenewer(ctx context.Context) { // Start a goroutine that watches and proactively renews certain // expiration-based leases. - s.stopper.RunWorker(ctx, func(ctx context.Context) { + _ = s.stopper.RunAsyncTask(ctx, "lease-renewer", func(ctx context.Context) { repls := make(map[*Replica]struct{}) timer := timeutil.NewTimer() defer timer.Stop() @@ -1775,7 +1777,7 @@ func (s *Store) startLeaseRenewer(ctx context.Context) { case <-s.renewableLeasesSignal: case <-timer.C: timer.Read = true - case <-s.stopper.ShouldStop(): + case <-s.stopper.ShouldQuiesce(): return } } @@ -1797,7 +1799,7 @@ func (s *Store) startClosedTimestampRangefeedSubscriber(ctx context.Context) { return } - s.stopper.RunWorker(ctx, func(ctx context.Context) { + _ = s.stopper.RunAsyncTask(ctx, "ct-subscriber", func(ctx context.Context) { var replIDs []roachpb.RangeID for { select { diff --git a/pkg/kv/kvserver/store_raft.go b/pkg/kv/kvserver/store_raft.go index 5a54ab903003..9c50aa81c8f6 100644 --- a/pkg/kv/kvserver/store_raft.go +++ b/pkg/kv/kvserver/store_raft.go @@ -608,10 +608,12 @@ func (s *Store) processRaft(ctx context.Context) { s.scheduler.Start(ctx, s.stopper) // Wait for the scheduler worker goroutines to finish. - s.stopper.RunWorker(ctx, s.scheduler.Wait) + if err := s.stopper.RunAsyncTask(ctx, "sched-wait", s.scheduler.Wait); err != nil { + s.scheduler.Wait(ctx) + } - s.stopper.RunWorker(ctx, s.raftTickLoop) - s.stopper.RunWorker(ctx, s.coalescedHeartbeatsLoop) + _ = s.stopper.RunAsyncTask(ctx, "sched-tick-loop", s.raftTickLoop) + _ = s.stopper.RunAsyncTask(ctx, "coalesced-hb-loop", s.coalescedHeartbeatsLoop) s.stopper.AddCloser(stop.CloserFn(func() { s.cfg.Transport.Stop(s.StoreID()) })) @@ -646,7 +648,7 @@ func (s *Store) raftTickLoop(ctx context.Context) { s.scheduler.EnqueueRaftTicks(rangeIDs...) s.metrics.RaftTicks.Inc(1) - case <-s.stopper.ShouldStop(): + case <-s.stopper.ShouldQuiesce(): return } } @@ -689,7 +691,7 @@ func (s *Store) coalescedHeartbeatsLoop(ctx context.Context) { select { case <-ticker.C: s.sendQueuedHeartbeats(ctx) - case <-s.stopper.ShouldStop(): + case <-s.stopper.ShouldQuiesce(): return } } diff --git a/pkg/kv/kvserver/store_rebalancer.go b/pkg/kv/kvserver/store_rebalancer.go index 2ac5efc66351..098039cd1a42 100644 --- a/pkg/kv/kvserver/store_rebalancer.go +++ b/pkg/kv/kvserver/store_rebalancer.go @@ -177,7 +177,7 @@ func (sr *StoreRebalancer) Start(ctx context.Context, stopper *stop.Stopper) { // Start a goroutine that watches and proactively renews certain // expiration-based leases. - stopper.RunWorker(ctx, func(ctx context.Context) { + _ = stopper.RunAsyncTask(ctx, "store-rebalancer", func(ctx context.Context) { timer := timeutil.NewTimer() defer timer.Stop() timer.Reset(jitteredInterval(storeRebalancerTimerDuration)) diff --git a/pkg/kv/kvserver/store_snapshot.go b/pkg/kv/kvserver/store_snapshot.go index 5860b36394f3..1717c5685ba2 100644 --- a/pkg/kv/kvserver/store_snapshot.go +++ b/pkg/kv/kvserver/store_snapshot.go @@ -524,7 +524,7 @@ func (s *Store) reserveSnapshot( case s.snapshotApplySem <- struct{}{}: case <-ctx.Done(): return nil, "", ctx.Err() - case <-s.stopper.ShouldStop(): + case <-s.stopper.ShouldQuiesce(): return nil, "", errors.Errorf("stopped") default: return nil, snapshotApplySemBusyMsg, nil @@ -534,7 +534,7 @@ func (s *Store) reserveSnapshot( case s.snapshotApplySem <- struct{}{}: case <-ctx.Done(): return nil, "", ctx.Err() - case <-s.stopper.ShouldStop(): + case <-s.stopper.ShouldQuiesce(): return nil, "", errors.Errorf("stopped") } } diff --git a/pkg/rpc/context.go b/pkg/rpc/context.go index 90e20cfb1a98..ee967c3489b9 100644 --- a/pkg/rpc/context.go +++ b/pkg/rpc/context.go @@ -254,7 +254,7 @@ func (c *Connection) Connect(ctx context.Context) (*grpc.ClientConn, error) { // Wait for initial heartbeat. select { case <-c.initialHeartbeatDone: - case <-c.stopper.ShouldStop(): + case <-c.stopper.ShouldQuiesce(): return nil, errors.Errorf("stopped") case <-ctx.Done(): return nil, ctx.Err() @@ -398,7 +398,7 @@ func NewContext(opts ContextOptions) *Context { ctx.ClusterID.Set(masterCtx, *id) } - ctx.Stopper.RunWorker(ctx.masterCtx, func(context.Context) { + waitQuiesce := func(context.Context) { <-ctx.Stopper.ShouldQuiesce() cancel() @@ -415,7 +415,10 @@ func NewContext(opts ContextOptions) *Context { ctx.removeConn(conn, k.(connKey)) return true }) - }) + } + if err := ctx.Stopper.RunAsyncTask(ctx.masterCtx, "wait-rpcctx-quiesce", waitQuiesce); err != nil { + waitQuiesce(ctx.masterCtx) + } return ctx } @@ -1046,15 +1049,13 @@ func (ctx *Context) grpcDialNodeInternal( var redialChan <-chan struct{} conn.grpcConn, redialChan, conn.dialErr = ctx.grpcDialRaw(target, remoteNodeID, class) if conn.dialErr == nil { - if err := ctx.Stopper.RunTask( + if err := ctx.Stopper.RunAsyncTask( ctx.masterCtx, "rpc.Context: grpc heartbeat", func(masterCtx context.Context) { - ctx.Stopper.RunWorker(masterCtx, func(masterCtx context.Context) { - err := ctx.runHeartbeat(conn, target, redialChan) - if err != nil && !grpcutil.IsClosedConnection(err) { - log.Health.Errorf(masterCtx, "removing connection to %s due to error: %s", target, err) - } - ctx.removeConn(conn, thisConnKeys...) - }) + err := ctx.runHeartbeat(conn, target, redialChan) + if err != nil && !grpcutil.IsClosedConnection(err) { + log.Health.Errorf(masterCtx, "removing connection to %s due to error: %s", target, err) + } + ctx.removeConn(conn, thisConnKeys...) }); err != nil { conn.dialErr = err } diff --git a/pkg/rpc/context_test.go b/pkg/rpc/context_test.go index 031b380d9693..c3d7e93b67d5 100644 --- a/pkg/rpc/context_test.go +++ b/pkg/rpc/context_test.go @@ -328,7 +328,7 @@ func TestHeartbeatHealth(t *testing.T) { } select { - case <-stopper.ShouldStop(): + case <-stopper.ShouldQuiesce(): return case heartbeat.ready <- err: } @@ -599,14 +599,14 @@ func TestHeartbeatHealthTransport(t *testing.T) { }} }() - stopper.RunWorker(ctx, func(context.Context) { + _ = stopper.RunAsyncTask(ctx, "wait-quiesce", func(context.Context) { <-stopper.ShouldQuiesce() netutil.FatalIfUnexpected(ln.Close()) - <-stopper.ShouldStop() + <-stopper.ShouldQuiesce() s.Stop() }) - stopper.RunWorker(ctx, func(context.Context) { + _ = stopper.RunAsyncTask(ctx, "serve", func(context.Context) { netutil.FatalIfUnexpected(s.Serve(ln)) }) diff --git a/pkg/rpc/heartbeat_test.go b/pkg/rpc/heartbeat_test.go index 0b905bc771f8..2ceddf118796 100644 --- a/pkg/rpc/heartbeat_test.go +++ b/pkg/rpc/heartbeat_test.go @@ -25,6 +25,7 @@ import ( "github.com/cockroachdb/cockroach/pkg/util/stop" "github.com/cockroachdb/cockroach/pkg/util/syncutil" "github.com/cockroachdb/cockroach/pkg/util/uuid" + "github.com/cockroachdb/errors" ) func TestRemoteOffsetString(t *testing.T) { @@ -92,7 +93,8 @@ func (mhs *ManualHeartbeatService) Ping( } case <-ctx.Done(): return nil, ctx.Err() - case <-mhs.stopper.ShouldStop(): + case <-mhs.stopper.ShouldQuiesce(): + return nil, errors.New("quiesce") } hs := HeartbeatService{ clock: mhs.clock, diff --git a/pkg/security/certificate_manager.go b/pkg/security/certificate_manager.go index 7bbead1e84db..30930eebe179 100644 --- a/pkg/security/certificate_manager.go +++ b/pkg/security/certificate_manager.go @@ -230,7 +230,7 @@ func (cm *CertificateManager) RegisterSignalHandler(stopper *stop.Stopper) { ch := sysutil.RefreshSignaledChan() for { select { - case <-stopper.ShouldStop(): + case <-stopper.ShouldQuiesce(): return case sig := <-ch: log.Infof(context.Background(), "received signal %q, triggering certificate reload", sig) diff --git a/pkg/server/diagnostics/reporter.go b/pkg/server/diagnostics/reporter.go index 7acd7feedab0..4123f60a5aa3 100644 --- a/pkg/server/diagnostics/reporter.go +++ b/pkg/server/diagnostics/reporter.go @@ -90,7 +90,7 @@ type Reporter struct { // PeriodicallyReportDiagnostics starts a background worker that periodically // phones home to report usage and diagnostics. func (r *Reporter) PeriodicallyReportDiagnostics(ctx context.Context, stopper *stop.Stopper) { - stopper.RunWorker(ctx, func(ctx context.Context) { + _ = stopper.RunAsyncTask(ctx, "diagnostics", func(ctx context.Context) { defer logcrash.RecoverAndReportNonfatalPanic(ctx, &r.Settings.SV) nextReport := r.StartTime diff --git a/pkg/server/diagnostics/update_checker.go b/pkg/server/diagnostics/update_checker.go index a48a4c4cbc7d..d945fe6438f3 100644 --- a/pkg/server/diagnostics/update_checker.go +++ b/pkg/server/diagnostics/update_checker.go @@ -72,7 +72,7 @@ type UpdateChecker struct { // PeriodicallyCheckForUpdates starts a background worker that periodically // phones home to check for updates. func (u *UpdateChecker) PeriodicallyCheckForUpdates(ctx context.Context, stopper *stop.Stopper) { - stopper.RunWorker(ctx, func(ctx context.Context) { + _ = stopper.RunAsyncTask(ctx, "update-checker", func(ctx context.Context) { defer logcrash.RecoverAndReportNonfatalPanic(ctx, &u.Settings.SV) nextUpdateCheck := u.StartTime diff --git a/pkg/server/init.go b/pkg/server/init.go index ce1741f7fea5..7da1de23b474 100644 --- a/pkg/server/init.go +++ b/pkg/server/init.go @@ -215,16 +215,15 @@ func (s *initServer) ServeAndWait( joinCtx, cancelJoin = context.WithCancel(ctx) defer cancelJoin() - err := stopper.RunTask(joinCtx, "init server: join loop", - func(joinCtx context.Context) { - stopper.RunWorker(joinCtx, func(joinCtx context.Context) { - defer wg.Done() - - state, err := s.startJoinLoop(joinCtx, stopper) - joinCh <- joinResult{state: state, err: err} - }) + err := stopper.RunAsyncTask(joinCtx, "init server: join loop", + func(ctx context.Context) { + defer wg.Done() + + state, err := s.startJoinLoop(ctx, stopper) + joinCh <- joinResult{state: state, err: err} }) if err != nil { + wg.Done() return nil, false, err } } diff --git a/pkg/server/node.go b/pkg/server/node.go index f623dd125a76..1358271235d8 100644 --- a/pkg/server/node.go +++ b/pkg/server/node.go @@ -599,7 +599,7 @@ func (n *Node) initializeAdditionalStores( // information. Starts a goroutine to loop until the node is closed. func (n *Node) startGossiping(ctx context.Context, stopper *stop.Stopper) { ctx = n.AnnotateCtx(ctx) - stopper.RunWorker(ctx, func(ctx context.Context) { + _ = stopper.RunAsyncTask(ctx, "start-gossip", func(ctx context.Context) { // Verify we've already gossiped our node descriptor. // // TODO(tbg): see if we really needed to do this earlier already. We @@ -630,7 +630,7 @@ func (n *Node) startGossiping(ctx context.Context, stopper *stop.Stopper) { if err := n.storeCfg.Gossip.SetNodeDescriptor(&n.Descriptor); err != nil { log.Warningf(ctx, "couldn't gossip descriptor for node %d: %s", n.Descriptor.NodeID, err) } - case <-stopper.ShouldStop(): + case <-stopper.ShouldQuiesce(): return } } @@ -651,7 +651,7 @@ func (n *Node) gossipStores(ctx context.Context) { // maintained. func (n *Node) startComputePeriodicMetrics(stopper *stop.Stopper, interval time.Duration) { ctx := n.AnnotateCtx(context.Background()) - stopper.RunWorker(ctx, func(ctx context.Context) { + _ = stopper.RunAsyncTask(ctx, "compute-metrics", func(ctx context.Context) { // Compute periodic stats at the same frequency as metrics are sampled. ticker := time.NewTicker(interval) defer ticker.Stop() @@ -661,7 +661,7 @@ func (n *Node) startComputePeriodicMetrics(stopper *stop.Stopper, interval time. if err := n.computePeriodicMetrics(ctx, tick); err != nil { log.Errorf(ctx, "failed computing periodic metrics: %s", err) } - case <-stopper.ShouldStop(): + case <-stopper.ShouldQuiesce(): return } } @@ -683,13 +683,13 @@ func (n *Node) startGraphiteStatsExporter(st *cluster.Settings) { ctx := logtags.AddTag(n.AnnotateCtx(context.Background()), "graphite stats exporter", nil) pm := metric.MakePrometheusExporter() - n.stopper.RunWorker(ctx, func(ctx context.Context) { + _ = n.stopper.RunAsyncTask(ctx, "graphite-exporter", func(ctx context.Context) { var timer timeutil.Timer defer timer.Stop() for { timer.Reset(graphiteInterval.Get(&st.SV)) select { - case <-n.stopper.ShouldStop(): + case <-n.stopper.ShouldQuiesce(): return case <-timer.C: timer.Read = true @@ -714,7 +714,7 @@ func (n *Node) startWriteNodeStatus(frequency time.Duration) error { if err := n.writeNodeStatus(ctx, 0 /* alertTTL */, false /* mustExist */); err != nil { return errors.Wrap(err, "error recording initial status summaries") } - n.stopper.RunWorker(ctx, func(ctx context.Context) { + return n.stopper.RunAsyncTask(ctx, "write-node-status", func(ctx context.Context) { // Write a status summary immediately; this helps the UI remain // responsive when new nodes are added. ticker := time.NewTicker(frequency) @@ -735,12 +735,11 @@ func (n *Node) startWriteNodeStatus(frequency time.Duration) error { if err := n.writeNodeStatus(ctx, 2*frequency, true /* mustExist */); err != nil { log.Warningf(ctx, "error recording status summaries: %s", err) } - case <-n.stopper.ShouldStop(): + case <-n.stopper.ShouldQuiesce(): return } } }) - return nil } // writeNodeStatus retrieves status summaries from the supplied @@ -814,11 +813,11 @@ func (n *Node) recordJoinEvent(ctx context.Context) { return } - n.stopper.RunWorker(ctx, func(bgCtx context.Context) { + _ = n.stopper.RunAsyncTask(ctx, "record-join", func(bgCtx context.Context) { ctx, span := n.AnnotateCtxWithSpan(bgCtx, "record-join-event") defer span.Finish() retryOpts := base.DefaultRetryOptions() - retryOpts.Closer = n.stopper.ShouldStop() + retryOpts.Closer = n.stopper.ShouldQuiesce() for r := retry.Start(retryOpts); r.Next(); { if err := n.storeCfg.DB.Txn(ctx, func(ctx context.Context, txn *kv.Txn) error { return sql.InsertEventRecord(ctx, n.sqlExec, diff --git a/pkg/server/node_engine_health.go b/pkg/server/node_engine_health.go index ffa128e83af6..efff31544fef 100644 --- a/pkg/server/node_engine_health.go +++ b/pkg/server/node_engine_health.go @@ -29,7 +29,7 @@ func (n *Node) startAssertEngineHealth( ) { maxSyncDuration := storage.MaxSyncDuration.Get(&settings.SV) fatalOnExceeded := storage.MaxSyncDurationFatalOnExceeded.Get(&settings.SV) - n.stopper.RunWorker(ctx, func(ctx context.Context) { + _ = n.stopper.RunAsyncTask(ctx, "engine-health", func(ctx context.Context) { t := timeutil.NewTimer() t.Reset(0) diff --git a/pkg/server/server.go b/pkg/server/server.go index 325290eb0154..b686374270d1 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -1041,15 +1041,16 @@ func (s *Server) startPersistingHLCUpperBound( } } - s.stopper.RunWorker( + _ = s.stopper.RunAsyncTask( ctx, + "persist-hlc-upper-bound", func(context.Context) { periodicallyPersistHLCUpperBound( s.clock, persistHLCUpperBoundIntervalCh, persistHLCUpperBoundFn, tickerFn, - s.stopper.ShouldStop(), + s.stopper.ShouldQuiesce(), nil, /* tick callback */ ) }, @@ -1273,12 +1274,15 @@ func (s *Server) PreStart(ctx context.Context) error { // loopback handles the HTTP <-> RPC loopback connection. loopback := newLoopbackListener(workersCtx, s.stopper) - s.stopper.RunWorker(workersCtx, func(workersCtx context.Context) { + waitQuiesce := func(context.Context) { <-s.stopper.ShouldQuiesce() _ = loopback.Close() - }) + } + if err := s.stopper.RunAsyncTask(workersCtx, "gw-quiesce", waitQuiesce); err != nil { + waitQuiesce(workersCtx) + } - s.stopper.RunWorker(workersCtx, func(context.Context) { + _ = s.stopper.RunAsyncTask(workersCtx, "serve-loopback", func(context.Context) { netutil.FatalIfUnexpected(s.grpc.Serve(loopback)) }) @@ -1310,12 +1314,21 @@ func (s *Server) PreStart(ctx context.Context) error { if err != nil { return err } - s.stopper.RunWorker(workersCtx, func(workersCtx context.Context) { - <-s.stopper.ShouldQuiesce() - if err := conn.Close(); err != nil { - log.Ops.Fatalf(workersCtx, "%v", err) + { + waitQuiesce := func(workersCtx context.Context) { + <-s.stopper.ShouldQuiesce() + // NB: we can't do this as a Closer because (*Server).ServeWith is + // running in a worker and usually sits on accept() which unblocks + // only when the listener closes. In other words, the listener needs + // to close when quiescing starts to allow that worker to shut down. + if err := conn.Close(); err != nil { + log.Ops.Fatalf(workersCtx, "%v", err) + } } - }) + if err := s.stopper.RunAsyncTask(workersCtx, "wait-quiesce", waitQuiesce); err != nil { + waitQuiesce(workersCtx) + } + } for _, gw := range []grpcGatewayServer{s.admin, s.status, s.authentication, s.tsServer} { if err := gw.RegisterGateway(gwCtx, gwMux, conn); err != nil { @@ -1840,12 +1853,20 @@ func (s *Server) startListenRPCAndSQL( } // The SQL listener shutdown worker, which closes everything under // the SQL port when the stopper indicates we are shutting down. - s.stopper.RunWorker(workersCtx, func(workersCtx context.Context) { + waitQuiesce := func(ctx context.Context) { <-s.stopper.ShouldQuiesce() + // NB: we can't do this as a Closer because (*Server).ServeWith is + // running in a worker and usually sits on accept() which unblocks + // only when the listener closes. In other words, the listener needs + // to close when quiescing starts to allow that worker to shut down. if err := pgL.Close(); err != nil { - log.Ops.Fatalf(workersCtx, "%v", err) + log.Ops.Fatalf(ctx, "%v", err) } - }) + } + if err := s.stopper.RunAsyncTask(workersCtx, "wait-quiesce", waitQuiesce); err != nil { + waitQuiesce(workersCtx) + return nil, nil, err + } log.Eventf(ctx, "listening on sql port %s", s.cfg.SQLAddr) } @@ -1876,11 +1897,12 @@ func (s *Server) startListenRPCAndSQL( } // The remainder shutdown worker. - s.stopper.RunWorker(workersCtx, func(context.Context) { + waitForQuiesce := func(context.Context) { <-s.stopper.ShouldQuiesce() // TODO(bdarnell): Do we need to also close the other listeners? netutil.FatalIfUnexpected(anyL.Close()) - <-s.stopper.ShouldStop() + } + s.stopper.AddCloser(stop.CloserFn(func() { s.grpc.Stop() serveOnMux.Do(func() { // The cmux matches don't shut down properly unless serve is called on the @@ -1888,7 +1910,12 @@ func (s *Server) startListenRPCAndSQL( // if we wouldn't otherwise reach the point where we start serving on it. netutil.FatalIfUnexpected(m.Serve()) }) - }) + })) + if err := s.stopper.RunAsyncTask( + workersCtx, "grpc-quiesce", waitForQuiesce, + ); err != nil { + return nil, nil, err + } // startRPCServer starts the RPC server. We do not do this // immediately because we want the cluster to be ready (or ready to @@ -1896,11 +1923,11 @@ func (s *Server) startListenRPCAndSQL( // (Server.Start) will call this at the right moment. startRPCServer = func(ctx context.Context) { // Serve the gRPC endpoint. - s.stopper.RunWorker(workersCtx, func(context.Context) { + _ = s.stopper.RunAsyncTask(workersCtx, "serve-grpc", func(context.Context) { netutil.FatalIfUnexpected(s.grpc.Serve(anyL)) }) - s.stopper.RunWorker(ctx, func(context.Context) { + _ = s.stopper.RunAsyncTask(ctx, "serve-mux", func(context.Context) { serveOnMux.Do(func() { netutil.FatalIfUnexpected(m.Serve()) }) @@ -1921,12 +1948,20 @@ func (s *Server) startServeUI( // The HTTP listener shutdown worker, which closes everything under // the HTTP port when the stopper indicates we are shutting down. - s.stopper.RunWorker(workersCtx, func(workersCtx context.Context) { + waitQuiesce := func(ctx context.Context) { + // NB: we can't do this as a Closer because (*Server).ServeWith is + // running in a worker and usually sits on accept() which unblocks + // only when the listener closes. In other words, the listener needs + // to close when quiescing starts to allow that worker to shut down. <-s.stopper.ShouldQuiesce() if err := httpLn.Close(); err != nil { - log.Ops.Fatalf(workersCtx, "%v", err) + log.Ops.Fatalf(ctx, "%v", err) } - }) + } + if err := s.stopper.RunAsyncTask(workersCtx, "wait-quiesce", waitQuiesce); err != nil { + waitQuiesce(workersCtx) + return err + } if uiTLSConfig != nil { httpMux := cmux.New(httpLn) @@ -1934,15 +1969,17 @@ func (s *Server) startServeUI( tlsL := httpMux.Match(cmux.Any()) // Dispatch incoming requests to either clearL or tlsL. - s.stopper.RunWorker(workersCtx, func(context.Context) { + if err := s.stopper.RunAsyncTask(workersCtx, "serve-ui", func(context.Context) { netutil.FatalIfUnexpected(httpMux.Serve()) - }) + }); err != nil { + return err + } // Serve the plain HTTP (non-TLS) connection over clearL. // This produces a HTTP redirect to the `https` URL for the path /, // handles the request normally (via s.ServeHTTP) for the path /health, // and produces 404 for anything else. - s.stopper.RunWorker(workersCtx, func(context.Context) { + if err := s.stopper.RunAsyncTask(workersCtx, "serve-health", func(context.Context) { mux := http.NewServeMux() mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { http.Redirect(w, r, "https://"+r.Host+r.RequestURI, http.StatusTemporaryRedirect) @@ -1952,7 +1989,9 @@ func (s *Server) startServeUI( plainRedirectServer := netutil.MakeServer(s.stopper, uiTLSConfig, mux) netutil.FatalIfUnexpected(plainRedirectServer.Serve(clearL)) - }) + }); err != nil { + return err + } httpLn = tls.NewListener(tlsL, uiTLSConfig) } @@ -1961,11 +2000,9 @@ func (s *Server) startServeUI( // listening on --http-addr without TLS if uiTLSConfig was // nil, or overridden above if uiTLSConfig was not nil to come from // the TLS negotiation over the HTTP port. - s.stopper.RunWorker(workersCtx, func(context.Context) { + return s.stopper.RunAsyncTask(workersCtx, "server-http", func(context.Context) { netutil.FatalIfUnexpected(connManager.Serve(httpLn)) }) - - return nil } // TODO(tbg): move into server_sql.go. @@ -1984,7 +2021,7 @@ func (s *SQLServer) startServeSQL( tcpKeepAlive: envutil.EnvOrDefaultDuration("COCKROACH_SQL_TCP_KEEP_ALIVE", time.Minute), } - stopper.RunWorker(pgCtx, func(pgCtx context.Context) { + _ = stopper.RunAsyncTask(pgCtx, "serve-conn", func(pgCtx context.Context) { netutil.FatalIfUnexpected(connManager.ServeWith(pgCtx, stopper, pgL, func(conn net.Conn) { connCtx := s.pgServer.AnnotateCtxForIncomingConn(pgCtx, conn) tcpKeepAlive.configure(connCtx, conn) @@ -2005,21 +2042,33 @@ func (s *SQLServer) startServeSQL( return err } - stopper.RunWorker(ctx, func(workersCtx context.Context) { + waitQuiesce := func(ctx context.Context) { <-stopper.ShouldQuiesce() + // NB: we can't do this as a Closer because (*Server).ServeWith is + // running in a worker and usually sits on accept() which unblocks + // only when the listener closes. In other words, the listener needs + // to close when quiescing starts to allow that worker to shut down. if err := unixLn.Close(); err != nil { - log.Ops.Fatalf(workersCtx, "%v", err) + log.Ops.Fatalf(ctx, "%v", err) } - }) + } + if err := stopper.RunAsyncTask(ctx, "unix-ln-close", func(ctx context.Context) { + waitQuiesce(ctx) + }); err != nil { + waitQuiesce(ctx) + return err + } - stopper.RunWorker(pgCtx, func(pgCtx context.Context) { + if err := stopper.RunAsyncTask(pgCtx, "unix-ln-serve", func(pgCtx context.Context) { netutil.FatalIfUnexpected(connManager.ServeWith(pgCtx, stopper, unixLn, func(conn net.Conn) { connCtx := s.pgServer.AnnotateCtxForIncomingConn(pgCtx, conn) if err := s.pgServer.ServeConn(connCtx, conn, pgwire.SocketUnix); err != nil { log.Ops.Errorf(connCtx, "%v", err) } })) - }) + }); err != nil { + return err + } } return nil } @@ -2177,7 +2226,7 @@ func startSampleEnvironment(ctx context.Context, cfg sampleEnvironmentCfg) error } } - cfg.stopper.RunWorker(ctx, func(ctx context.Context) { + return cfg.stopper.RunAsyncTask(ctx, "mem-logger", func(ctx context.Context) { var goMemStats atomic.Value // *status.GoMemStats goMemStats.Store(&status.GoMemStats{}) var collectingMemStats int32 // atomic, 1 when stats call is ongoing @@ -2188,7 +2237,7 @@ func startSampleEnvironment(ctx context.Context, cfg sampleEnvironmentCfg) error for { select { - case <-cfg.stopper.ShouldStop(): + case <-cfg.stopper.ShouldQuiesce(): return case <-timer.C: timer.Read = true @@ -2237,7 +2286,6 @@ func startSampleEnvironment(ctx context.Context, cfg sampleEnvironmentCfg) error } } }) - return nil } // Stop stops the server. diff --git a/pkg/server/server_systemlog_gc.go b/pkg/server/server_systemlog_gc.go index 19f506ccbac3..aaacb0bc6860 100644 --- a/pkg/server/server_systemlog_gc.go +++ b/pkg/server/server_systemlog_gc.go @@ -164,7 +164,7 @@ func (s *Server) startSystemLogsGC(ctx context.Context) { }, } - s.stopper.RunWorker(ctx, func(ctx context.Context) { + _ = s.stopper.RunAsyncTask(ctx, "system-log-gc", func(ctx context.Context) { period := systemLogGCPeriod if storeKnobs, ok := s.cfg.TestingKnobs.Store.(*kvserver.StoreTestingKnobs); ok && storeKnobs.SystemLogsGCPeriod != 0 { period = storeKnobs.SystemLogsGCPeriod @@ -205,12 +205,12 @@ func (s *Server) startSystemLogsGC(ctx context.Context) { if storeKnobs, ok := s.cfg.TestingKnobs.Store.(*kvserver.StoreTestingKnobs); ok && storeKnobs.SystemLogsGCGCDone != nil { select { case storeKnobs.SystemLogsGCGCDone <- struct{}{}: - case <-s.stopper.ShouldStop(): + case <-s.stopper.ShouldQuiesce(): // Test has finished. return } } - case <-s.stopper.ShouldStop(): + case <-s.stopper.ShouldQuiesce(): return } } diff --git a/pkg/server/settingsworker.go b/pkg/server/settingsworker.go index 9d099ee31457..8e435c113c8d 100644 --- a/pkg/server/settingsworker.go +++ b/pkg/server/settingsworker.go @@ -130,7 +130,7 @@ func (s *Server) refreshSettings(initialSettingsKVs []roachpb.KeyValue) error { } } // Setup updater that listens for changes in settings. - s.stopper.RunWorker(ctx, func(ctx context.Context) { + return s.stopper.RunAsyncTask(ctx, "refresh-settings", func(ctx context.Context) { gossipUpdateC := s.gossip.RegisterSystemConfigChannel() // No new settings can be defined beyond this point. for { @@ -141,10 +141,9 @@ func (s *Server) refreshSettings(initialSettingsKVs []roachpb.KeyValue) error { if err := processSystemConfigKVs(ctx, cfg.Values, u, s.engines[0]); err != nil { log.Warningf(ctx, "error processing config KVs: %+v", err) } - case <-s.stopper.ShouldStop(): + case <-s.stopper.ShouldQuiesce(): return } } }) - return nil } diff --git a/pkg/server/status.go b/pkg/server/status.go index 5650ced777a1..a96eb190a932 100644 --- a/pkg/server/status.go +++ b/pkg/server/status.go @@ -1860,7 +1860,7 @@ func (s *statusServer) iterateNodes( // Issue the requests concurrently. sem := quotapool.NewIntPool("node status", maxConcurrentRequests) - ctx, cancel := s.stopper.WithCancelOnStop(ctx) + ctx, cancel := s.stopper.WithCancelOnQuiesce(ctx) defer cancel() for nodeID := range nodeStatuses { nodeID := nodeID // needed to ensure the closure below captures a copy. diff --git a/pkg/server/testserver.go b/pkg/server/testserver.go index 39d79b2f6cce..baf121a452c1 100644 --- a/pkg/server/testserver.go +++ b/pkg/server/testserver.go @@ -699,24 +699,36 @@ func StartTenant( return nil, "", "", err } - args.stopper.RunWorker(ctx, func(ctx context.Context) { - <-args.stopper.ShouldQuiesce() - // NB: we can't do this as a Closer because (*Server).ServeWith is - // running in a worker and usually sits on accept(pgL) which unblocks - // only when pgL closes. In other words, pgL needs to close when - // quiescing starts to allow that worker to shut down. - _ = pgL.Close() - }) + { + waitQuiesce := func(ctx context.Context) { + <-args.stopper.ShouldQuiesce() + // NB: we can't do this as a Closer because (*Server).ServeWith is + // running in a worker and usually sits on accept(pgL) which unblocks + // only when pgL closes. In other words, pgL needs to close when + // quiescing starts to allow that worker to shut down. + _ = pgL.Close() + } + if err := args.stopper.RunAsyncTask(ctx, "wait-quiesce-pgl", waitQuiesce); err != nil { + waitQuiesce(ctx) + return nil, "", "", err + } + } httpL, err := listen(ctx, &args.Config.HTTPAddr, &args.Config.HTTPAdvertiseAddr, "http") if err != nil { return nil, "", "", err } - args.stopper.RunWorker(ctx, func(ctx context.Context) { - <-args.stopper.ShouldQuiesce() - _ = httpL.Close() - }) + { + waitQuiesce := func(ctx context.Context) { + <-args.stopper.ShouldQuiesce() + _ = httpL.Close() + } + if err := args.stopper.RunAsyncTask(ctx, "wait-quiesce-http", waitQuiesce); err != nil { + waitQuiesce(ctx) + return nil, "", "", err + } + } pgLAddr := pgL.Addr().String() httpLAddr := httpL.Addr().String() @@ -729,7 +741,7 @@ func StartTenant( pgLAddr, // sql addr ) - args.stopper.RunWorker(ctx, func(ctx context.Context) { + if err := args.stopper.RunAsyncTask(ctx, "serve-http", func(ctx context.Context) { mux := http.NewServeMux() debugServer := debug.NewServer(args.Settings, s.pgServer.HBADebugFn()) mux.Handle("/", debugServer) @@ -743,7 +755,9 @@ func StartTenant( f := varsHandler{metricSource: args.recorder, st: args.Settings}.handleVars mux.Handle(statusVars, http.HandlerFunc(f)) _ = http.Serve(httpL, mux) - }) + }); err != nil { + return nil, "", "", err + } const ( socketFile = "" // no unix socket diff --git a/pkg/sql/catalog/lease/lease.go b/pkg/sql/catalog/lease/lease.go index ef654e5a11a2..36b2d749e1db 100644 --- a/pkg/sql/catalog/lease/lease.go +++ b/pkg/sql/catalog/lease/lease.go @@ -1652,7 +1652,9 @@ func (m *Manager) findDescriptorState(id descpb.ID, create bool) *descriptorStat func (m *Manager) RefreshLeases( ctx context.Context, s *stop.Stopper, db *kv.DB, g gossip.OptionalGossip, ) { - s.RunWorker(ctx, func(ctx context.Context) { + // TODO(ajwerner): is this task needed? refreshLeases appears to already + // delegate everything to a goroutine. + _ = s.RunAsyncTask(ctx, "refresh-leases", func(ctx context.Context) { m.refreshLeases(ctx, g, db, s) }) } @@ -1662,7 +1664,7 @@ func (m *Manager) refreshLeases( ) { descUpdateCh := make(chan *descpb.Descriptor) m.watchForUpdates(ctx, s, db, g, descUpdateCh) - s.RunWorker(ctx, func(ctx context.Context) { + _ = s.RunAsyncTask(ctx, "refresh-leases", func(ctx context.Context) { for { select { case desc := <-descUpdateCh: @@ -1762,7 +1764,7 @@ func (m *Manager) watchForGossipUpdates( return } - s.RunWorker(ctx, func(ctx context.Context) { + _ = s.RunAsyncTask(ctx, "gossip-updates", func(ctx context.Context) { descKeyPrefix := m.storage.codec.TablePrefix(uint32(systemschema.DescriptorTable.ID)) // TODO(ajwerner): Add a mechanism to unregister this channel upon // return. NB: this call is allowed to bypass OptionalGossip because @@ -1860,9 +1862,11 @@ func (m *Manager) watchForRangefeedUpdates( case descUpdateCh <- &descriptor: } } - s.RunWorker(ctx, func(ctx context.Context) { + _ = s.RunAsyncTask(ctx, "lease-rangefeed", func(ctx context.Context) { for { select { + case <-m.stopper.ShouldQuiesce(): + return case <-ctx.Done(): return case e := <-eventCh: @@ -1945,7 +1949,8 @@ func (m *Manager) waitForRangefeedsToBeUsable(ctx context.Context, s *stop.Stopp upgradeChan := make(chan struct{}) timer := timeutil.NewTimer() timer.Reset(0) - s.RunWorker(ctx, func(ctx context.Context) { + // NB: we intentionally do *not* close upgradeChan if the task never starts. + _ = s.RunAsyncTask(ctx, "wait-rangefeed-version", func(ctx context.Context) { for { select { case <-timer.C: @@ -1995,7 +2000,7 @@ var leaseRefreshLimit = settings.RegisterIntSetting( // traffic immediately. // TODO(vivek): Remove once epoch based table leases are implemented. func (m *Manager) PeriodicallyRefreshSomeLeases(ctx context.Context) { - m.stopper.RunWorker(ctx, func(ctx context.Context) { + _ = m.stopper.RunAsyncTask(ctx, "lease-refresher", func(ctx context.Context) { if m.storage.leaseDuration <= 0 { return } @@ -2074,7 +2079,7 @@ func (m *Manager) DeleteOrphanedLeases(timeThreshold int64) { // Run as async worker to prevent blocking the main server Start method. // Exit after releasing all the orphaned leases. - m.stopper.RunWorker(context.Background(), func(ctx context.Context) { + _ = m.stopper.RunAsyncTask(context.Background(), "del-orphaned-leases", func(ctx context.Context) { // This could have been implemented using DELETE WHERE, but DELETE WHERE // doesn't implement AS OF SYSTEM TIME. @@ -2113,7 +2118,6 @@ SELECT "descID", version, expiration FROM system.public.lease AS OF SYSTEM TIME log.Infof(ctx, "released orphaned lease: %+v", lease) wg.Done() }); err != nil { - log.Warningf(ctx, "did not release orphaned lease: %+v, err = %s", lease, err) wg.Done() } } diff --git a/pkg/sql/conn_executor.go b/pkg/sql/conn_executor.go index 194ca11a738d..0790b6e98259 100644 --- a/pkg/sql/conn_executor.go +++ b/pkg/sql/conn_executor.go @@ -736,7 +736,7 @@ func (s *Server) PeriodicallyClearSQLStats( stats *sqlStats, reset func(ctx context.Context), ) { - stopper.RunWorker(ctx, func(ctx context.Context) { + _ = stopper.RunAsyncTask(ctx, "sql-stats-clearer", func(ctx context.Context) { var timer timeutil.Timer for { s.sqlStats.Lock() diff --git a/pkg/sql/distsql_physical_planner_test.go b/pkg/sql/distsql_physical_planner_test.go index 5c6962bbd7f8..1221c7f2d0a2 100644 --- a/pkg/sql/distsql_physical_planner_test.go +++ b/pkg/sql/distsql_physical_planner_test.go @@ -152,12 +152,12 @@ func TestPlanningDuringSplitsAndMerges(t *testing.T) { ) // Start a worker that continuously performs splits in the background. - tc.Stopper().RunWorker(context.Background(), func(ctx context.Context) { + _ = tc.Stopper().RunAsyncTask(context.Background(), "splitter", func(ctx context.Context) { rng, _ := randutil.NewPseudoRand() cdb := tc.Server(0).DB() for { select { - case <-tc.Stopper().ShouldStop(): + case <-tc.Stopper().ShouldQuiesce(): return default: // Split the table at a random row. diff --git a/pkg/sql/distsql_running.go b/pkg/sql/distsql_running.go index 27111e2f922a..da61cdd68499 100644 --- a/pkg/sql/distsql_running.go +++ b/pkg/sql/distsql_running.go @@ -96,9 +96,9 @@ func (dsp *DistSQLPlanner) initRunners(ctx context.Context) { // requests if a worker is actually there to receive them. dsp.runnerChan = make(chan runnerRequest) for i := 0; i < numRunners; i++ { - dsp.stopper.RunWorker(ctx, func(context.Context) { + _ = dsp.stopper.RunAsyncTask(ctx, "distslq-runner", func(context.Context) { runnerChan := dsp.runnerChan - stopChan := dsp.stopper.ShouldStop() + stopChan := dsp.stopper.ShouldQuiesce() for { select { case req := <-runnerChan: diff --git a/pkg/sql/flowinfra/flow_scheduler.go b/pkg/sql/flowinfra/flow_scheduler.go index e4d4afeea73b..5bd2e090f225 100644 --- a/pkg/sql/flowinfra/flow_scheduler.go +++ b/pkg/sql/flowinfra/flow_scheduler.go @@ -146,7 +146,9 @@ func (fs *FlowScheduler) ScheduleFlow(ctx context.Context, f Flow) error { // Start launches the main loop of the scheduler. func (fs *FlowScheduler) Start() { ctx := fs.AnnotateCtx(context.Background()) - fs.stopper.RunWorker(ctx, func(context.Context) { + // TODO(radu): we may end up with a few flows in the queue that will + // never be processed. Is that an issue? + _ = fs.stopper.RunAsyncTask(ctx, "flow-scheduler", func(context.Context) { stopped := false fs.mu.Lock() defer fs.mu.Unlock() @@ -188,7 +190,7 @@ func (fs *FlowScheduler) Start() { atomic.AddInt32(&fs.atomics.numRunning, -1) } - case <-fs.stopper.ShouldStop(): + case <-fs.stopper.ShouldQuiesce(): fs.mu.Lock() stopped = true } diff --git a/pkg/sql/rowexec/sampler.go b/pkg/sql/rowexec/sampler.go index a9ab39417654..9313398508a5 100644 --- a/pkg/sql/rowexec/sampler.go +++ b/pkg/sql/rowexec/sampler.go @@ -308,7 +308,7 @@ func (s *samplerProcessor) mainLoop(ctx context.Context) (earlyExit bool, err er case <-timer.C: timer.Read = true break - case <-s.flowCtx.Stopper().ShouldStop(): + case <-s.flowCtx.Stopper().ShouldQuiesce(): break } } diff --git a/pkg/sql/stats/automatic_stats.go b/pkg/sql/stats/automatic_stats.go index 591cec854c80..8b011bd6fb64 100644 --- a/pkg/sql/stats/automatic_stats.go +++ b/pkg/sql/stats/automatic_stats.go @@ -244,7 +244,7 @@ func MakeRefresher( func (r *Refresher) Start( ctx context.Context, stopper *stop.Stopper, refreshInterval time.Duration, ) error { - stopper.RunWorker(context.Background(), func(ctx context.Context) { + _ = stopper.RunAsyncTask(context.Background(), "refresher", func(ctx context.Context) { // We always sleep for r.asOfTime at the beginning of each refresh, so // subtract it from the refreshInterval. refreshInterval -= r.asOfTime @@ -306,7 +306,7 @@ func (r *Refresher) Start( case mut := <-r.mutations: r.mutationCounts[mut.tableID] += int64(mut.rowsAffected) - case <-stopper.ShouldStop(): + case <-stopper.ShouldQuiesce(): return } } diff --git a/pkg/sql/temporary_schema.go b/pkg/sql/temporary_schema.go index 7674efc000f1..9c1ffa6d93c3 100644 --- a/pkg/sql/temporary_schema.go +++ b/pkg/sql/temporary_schema.go @@ -585,7 +585,7 @@ func (c *TemporaryObjectCleaner) doTemporaryObjectCleanup( // Start initializes the background thread which periodically cleans up leftover temporary objects. func (c *TemporaryObjectCleaner) Start(ctx context.Context, stopper *stop.Stopper) { - stopper.RunWorker(ctx, func(ctx context.Context) { + _ = stopper.RunAsyncTask(ctx, "object-cleaner", func(ctx context.Context) { nextTick := timeutil.Now() for { nextTickCh := time.After(nextTick.Sub(timeutil.Now())) diff --git a/pkg/sql/tests/monotonic_insert_test.go b/pkg/sql/tests/monotonic_insert_test.go index 3f88864a7347..e3c769668991 100644 --- a/pkg/sql/tests/monotonic_insert_test.go +++ b/pkg/sql/tests/monotonic_insert_test.go @@ -230,7 +230,7 @@ RETURNING val, sts, node, tb`, for { select { case sem <- struct{}{}: - case <-tc.Stopper().ShouldStop(): + case <-tc.Stopper().ShouldQuiesce(): return case <-timer: return diff --git a/pkg/ts/db.go b/pkg/ts/db.go index b653a296b0e4..ffc04058c659 100644 --- a/pkg/ts/db.go +++ b/pkg/ts/db.go @@ -157,7 +157,7 @@ func (db *DB) PollSource( // start begins the goroutine for this poller, which will periodically request // time series data from the DataSource and store it. func (p *poller) start() { - p.stopper.RunWorker(context.TODO(), func(context.Context) { + _ = p.stopper.RunAsyncTask(context.TODO(), "ts-poller", func(context.Context) { // Poll once immediately. p.poll() ticker := time.NewTicker(p.frequency) @@ -166,7 +166,7 @@ func (p *poller) start() { select { case <-ticker.C: p.poll() - case <-p.stopper.ShouldStop(): + case <-p.stopper.ShouldQuiesce(): return } } diff --git a/pkg/util/netutil/net.go b/pkg/util/netutil/net.go index f4f5afdab8e2..ffdfbb09323d 100644 --- a/pkg/util/netutil/net.go +++ b/pkg/util/netutil/net.go @@ -42,16 +42,21 @@ func ListenAndServeGRPC( ctx := context.TODO() - stopper.RunWorker(ctx, func(context.Context) { + stopper.AddCloser(stop.CloserFn(server.Stop)) + waitQuiesce := func(context.Context) { <-stopper.ShouldQuiesce() FatalIfUnexpected(ln.Close()) - <-stopper.ShouldStop() - server.Stop() - }) + } + if err := stopper.RunAsyncTask(ctx, "listen-quiesce", waitQuiesce); err != nil { + waitQuiesce(ctx) + return nil, err + } - stopper.RunWorker(ctx, func(context.Context) { + if err := stopper.RunAsyncTask(ctx, "serve", func(context.Context) { FatalIfUnexpected(server.Serve(ln)) - }) + }); err != nil { + return nil, err + } return ln, nil } @@ -104,15 +109,18 @@ func MakeServer(stopper *stop.Stopper, tlsConfig *tls.Config, handler http.Handl log.Fatalf(ctx, "%v", err) } - stopper.RunWorker(ctx, func(context.Context) { - <-stopper.ShouldStop() + waitQuiesce := func(context.Context) { + <-stopper.ShouldQuiesce() mu.Lock() for conn := range activeConns { conn.Close() } mu.Unlock() - }) + } + if err := stopper.RunAsyncTask(ctx, "http2-wait-quiesce", waitQuiesce); err != nil { + waitQuiesce(ctx) + } return server } diff --git a/pkg/util/stop/stopper.go b/pkg/util/stop/stopper.go index b21e7a823598..1ef13d5439e8 100644 --- a/pkg/util/stop/stopper.go +++ b/pkg/util/stop/stopper.go @@ -114,28 +114,58 @@ func (f CloserFn) Close() { f() } -// A Stopper provides a channel-based mechanism to stop an arbitrary -// array of workers. Each worker is registered with the stopper via -// the RunWorker() method. The system further allows execution of functions -// through RunTask() and RunAsyncTask(). +// A Stopper provides control over the lifecycle of goroutines started +// through it via its RunTask, RunAsyncTask, and other similar methods. // -// Stopping occurs in two phases: the first is the request to stop, which moves -// the stopper into a quiescing phase. While quiescing, calls to RunTask() & -// RunAsyncTask() don't execute the function passed in and return ErrUnavailable. -// When all outstanding tasks have been completed, the stopper -// closes its stopper channel, which signals all live workers that it's safe to -// shut down. When all workers have shutdown, the stopper is complete. +// When Stop is invoked, the Stopper // -// An arbitrary list of objects implementing the Closer interface may -// be added to the stopper via AddCloser(), to be closed after the -// stopper has stopped. +// - it invokes Quiesce, which causes the Stopper to refuse new work +// (that is, its Run* family of methods starts returning ErrUnavailable), +// closes the channel returned by ShouldQuiesce, and blocks until +// until no more tasks are tracked, then +// - it runs all of the methods supplied to AddCloser, then +// - closes the IsStopped channel. +// +// When ErrUnavailable is returned from a task, the caller needs +// to handle it appropriately by terminating any work that it had +// hoped to defer to the task (which is guaranteed to never have been +// invoked). A simple example of this can be seen in the below snippet: +// +// var wg sync.WaitGroup +// wg.Add(1) +// if err := s.RunAsyncTask("foo", func(ctx context.Context) { +// defer wg.Done() +// }); err != nil { +// // Task never ran. +// wg.Done() +// } +// +// To ensure that tasks that do get started are sensitive to Quiesce, +// they need to observe the ShouldQuiesce channel similar to how they +// are expected to observe context cancellation: +// +// func x() { +// select { +// case <-s.ShouldQuiesce: +// return +// case <-ctx.Done(): +// return +// case <-someChan: +// // Do work. +// } +// } +// +// TODO(tbg): many improvements here are possible: +// - propagate quiescing via context cancellation +// - better API around refused tasks +// - all the other things mentioned in: +// https://github.com/cockroachdb/cockroach/issues/58164 type Stopper struct { quiescer chan struct{} // Closed when quiescing - stopper chan struct{} // Closed when stopping stopped chan struct{} // Closed when stopped completely onPanic func(interface{}) // called with recover() on panic on any goroutine - stop sync.WaitGroup // Incremented for outstanding workers - mu struct { + + mu struct { syncutil.Mutex quiesce *sync.Cond // Conditional variable to wait for outstanding tasks quiescing bool // true when Stop() has been called @@ -174,7 +204,6 @@ func OnPanic(handler func(interface{})) Option { func NewStopper(options ...Option) *Stopper { s := &Stopper{ quiescer: make(chan struct{}), - stopper: make(chan struct{}), stopped: make(chan struct{}), } @@ -208,31 +237,18 @@ func (s *Stopper) Recover(ctx context.Context) { } } -// RunWorker runs the supplied function as a "worker" to be stopped -// by the stopper. The function is run in a goroutine. -func (s *Stopper) RunWorker(ctx context.Context, f func(context.Context)) { - s.stop.Add(1) - go func() { - // Remove any associated span; we need to ensure this because the - // worker may run longer than the caller which presumably closes - // any spans it has created. - ctx = tracing.ContextWithSpan(ctx, nil) - defer s.Recover(ctx) - defer s.stop.Done() - f(ctx) - }() -} - // AddCloser adds an object to close after the stopper has been stopped. // // WARNING: memory resources acquired by this method will stay around for // the lifetime of the Stopper. Use with care to avoid leaking memory. +// +// A closer that is added after Stop has already been called will be +// called immediately. func (s *Stopper) AddCloser(c Closer) { s.mu.Lock() defer s.mu.Unlock() select { - case <-s.stopper: - // Close immediately. + case <-s.stopped: c.Close() default: s.mu.closers = append(s.mu.closers, c) @@ -249,16 +265,6 @@ func (s *Stopper) WithCancelOnQuiesce(ctx context.Context) (context.Context, fun return s.withCancel(ctx, s.mu.qCancels, s.quiescer) } -// WithCancelOnStop returns a child context which is canceled when the -// returned cancel function is called or when the Stopper begins to stop, -// whichever happens first. -// -// Canceling this context releases resources associated with it, so code should -// call cancel as soon as the operations running in this Context complete. -func (s *Stopper) WithCancelOnStop(ctx context.Context) (context.Context, func()) { - return s.withCancel(ctx, s.mu.sCancels, s.stopper) -} - func (s *Stopper) withCancel( ctx context.Context, cancels map[int]func(), cancelCh chan struct{}, ) (context.Context, func()) { @@ -469,6 +475,8 @@ func (s *Stopper) runningTasksLocked() TaskMap { // Stop signals all live workers to stop and then waits for each to // confirm it has stopped. +// +// Stop is idempotent; concurrent calls will block on each other. func (s *Stopper) Stop(ctx context.Context) { s.mu.Lock() stopCalled := s.mu.stopCalled @@ -498,7 +506,6 @@ func (s *Stopper) Stop(ctx context.Context) { // panics happen on purpose). if r := recover(); r != nil { go s.Quiesce(ctx) - close(s.stopper) s.mu.Lock() for _, c := range s.mu.closers { go c.Close() @@ -512,10 +519,8 @@ func (s *Stopper) Stop(ctx context.Context) { for _, cancel := range s.mu.sCancels { cancel() } - close(s.stopper) s.mu.Unlock() - s.stop.Wait() s.mu.Lock() defer s.mu.Unlock() for _, c := range s.mu.closers { @@ -533,16 +538,6 @@ func (s *Stopper) ShouldQuiesce() <-chan struct{} { return s.quiescer } -// ShouldStop returns a channel which will be closed when Stop() has been -// invoked and outstanding tasks have quiesced. -func (s *Stopper) ShouldStop() <-chan struct{} { - if s == nil { - // A nil stopper will never signal ShouldStop, but will also never panic. - return nil - } - return s.stopper -} - // IsStopped returns a channel which will be closed after Stop() has // been invoked to full completion, meaning all workers have completed // and all closers have been closed. diff --git a/pkg/util/stop/stopper_test.go b/pkg/util/stop/stopper_test.go index 0b85cb8c5df2..d97aac411de1 100644 --- a/pkg/util/stop/stopper_test.go +++ b/pkg/util/stop/stopper_test.go @@ -38,7 +38,7 @@ func TestStopper(t *testing.T) { cleanup := make(chan struct{}) ctx := context.Background() - s.RunWorker(ctx, func(context.Context) { + _ = s.RunAsyncTask(ctx, "task", func(context.Context) { <-running }) @@ -48,7 +48,7 @@ func TestStopper(t *testing.T) { <-cleanup }() - <-s.ShouldStop() + <-s.ShouldQuiesce() select { case <-waiting: close(cleanup) @@ -91,7 +91,7 @@ func TestStopperIsStopped(t *testing.T) { go s.Stop(context.Background()) select { - case <-s.ShouldStop(): + case <-s.ShouldQuiesce(): case <-time.After(time.Second): t.Fatal("stopper should have finished waiting") } @@ -112,16 +112,16 @@ func TestStopperIsStopped(t *testing.T) { s.Stop(context.Background()) } -func TestStopperMultipleStopees(t *testing.T) { +func TestStopperMultipleTasks(t *testing.T) { defer leaktest.AfterTest(t)() const count = 3 s := stop.NewStopper() ctx := context.Background() for i := 0; i < count; i++ { - s.RunWorker(ctx, func(context.Context) { - <-s.ShouldStop() - }) + require.NoError(t, s.RunAsyncTask(ctx, "task", func(context.Context) { + <-s.ShouldQuiesce() + })) } done := make(chan struct{}) @@ -144,8 +144,8 @@ func TestStopperStartFinishTasks(t *testing.T) { go s.Stop(ctx) select { - case <-s.ShouldStop(): - t.Fatal("expected stopper to be quiesceing") + case <-s.IsStopped(): + t.Fatal("stopper not fully stopped") case <-time.After(100 * time.Millisecond): // Expected. } @@ -153,27 +153,7 @@ func TestStopperStartFinishTasks(t *testing.T) { t.Error(err) } select { - case <-s.ShouldStop(): - // Success. - case <-time.After(time.Second): - t.Fatal("stopper should be ready to stop") - } -} - -func TestStopperRunWorker(t *testing.T) { - defer leaktest.AfterTest(t)() - s := stop.NewStopper() - ctx := context.Background() - s.RunWorker(ctx, func(context.Context) { - <-s.ShouldStop() - }) - closer := make(chan struct{}) - go func() { - s.Stop(ctx) - close(closer) - }() - select { - case <-closer: + case <-s.IsStopped(): // Success. case <-time.After(time.Second): t.Fatal("stopper should be ready to stop") @@ -197,17 +177,17 @@ func TestStopperQuiesce(t *testing.T) { quiesceDone = append(quiesceDone, qc) sc := make(chan struct{}) runTaskDone = append(runTaskDone, sc) - thisStopper.RunWorker(ctx, func(ctx context.Context) { + go func() { // Wait until Quiesce() is called. <-qc - err := thisStopper.RunTask(ctx, "test", func(context.Context) {}) + err := thisStopper.RunTask(ctx, "inner", func(context.Context) {}) if !errors.HasType(err, (*roachpb.NodeUnavailableError)(nil)) { t.Error(err) } // Make the stoppers call Stop(). close(sc) - <-thisStopper.ShouldStop() - }) + <-thisStopper.ShouldQuiesce() + }() } done := make(chan struct{}) @@ -368,9 +348,6 @@ func TestStopperRunTaskPanic(t *testing.T) { func(ctx context.Context) { explode(ctx) }, ) }, - func() { - s.RunWorker(ctx, explode) - }, } { go test() recovered := <-ch @@ -385,50 +362,29 @@ func TestStopperWithCancel(t *testing.T) { s := stop.NewStopper() ctx := context.Background() ctx1, _ := s.WithCancelOnQuiesce(ctx) - ctx2, _ := s.WithCancelOnStop(ctx) ctx3, cancel3 := s.WithCancelOnQuiesce(ctx) - ctx4, cancel4 := s.WithCancelOnStop(ctx) if err := ctx1.Err(); err != nil { t.Fatalf("should not be canceled: %v", err) } - if err := ctx2.Err(); err != nil { - t.Fatalf("should not be canceled: %v", err) - } if err := ctx3.Err(); err != nil { t.Fatalf("should not be canceled: %v", err) } - if err := ctx4.Err(); err != nil { - t.Fatalf("should not be canceled: %v", err) - } cancel3() - cancel4() if err := ctx1.Err(); err != nil { t.Fatalf("should not be canceled: %v", err) } - if err := ctx2.Err(); err != nil { - t.Fatalf("should not be canceled: %v", err) - } if err := ctx3.Err(); !errors.Is(err, context.Canceled) { t.Fatalf("should be canceled: %v", err) } - if err := ctx4.Err(); !errors.Is(err, context.Canceled) { - t.Fatalf("should be canceled: %v", err) - } s.Quiesce(ctx) if err := ctx1.Err(); !errors.Is(err, context.Canceled) { t.Fatalf("should be canceled: %v", err) } - if err := ctx2.Err(); err != nil { - t.Fatalf("should not be canceled: %v", err) - } s.Stop(ctx) - if err := ctx2.Err(); !errors.Is(err, context.Canceled) { - t.Fatalf("should be canceled: %v", err) - } } func TestStopperWithCancelConcurrent(t *testing.T) { @@ -437,22 +393,17 @@ func TestStopperWithCancelConcurrent(t *testing.T) { for i := 0; i < trials; i++ { s := stop.NewStopper() ctx := context.Background() - var ctx1, ctx2 context.Context + var ctx1 context.Context - // Tie two contexts to the Stopper and Stop concurrently. There should + // Tie a context to the Stopper and Stop concurrently. There should // be no circumstance where either Context is not canceled. var wg sync.WaitGroup - wg.Add(3) + wg.Add(2) go func() { defer wg.Done() runtime.Gosched() ctx1, _ = s.WithCancelOnQuiesce(ctx) }() - go func() { - defer wg.Done() - runtime.Gosched() - ctx2, _ = s.WithCancelOnStop(ctx) - }() go func() { defer wg.Done() runtime.Gosched() @@ -463,27 +414,17 @@ func TestStopperWithCancelConcurrent(t *testing.T) { if err := ctx1.Err(); !errors.Is(err, context.Canceled) { t.Errorf("should be canceled: %v", err) } - if err := ctx2.Err(); !errors.Is(err, context.Canceled) { - t.Errorf("should be canceled: %v", err) - } } } func TestStopperShouldQuiesce(t *testing.T) { defer leaktest.AfterTest(t)() s := stop.NewStopper() - running := make(chan struct{}) runningTask := make(chan struct{}) waiting := make(chan struct{}) cleanup := make(chan struct{}) ctx := context.Background() - // Run a worker. A call to stopper.Stop(context.Background()) will not close until all workers - // have completed, and this worker will complete when the "running" channel - // is closed. - s.RunWorker(ctx, func(context.Context) { - <-running - }) // Run an asynchronous task. A stopper which has been Stop()ed will not // close it's ShouldStop() channel until all tasks have completed. This task // will complete when the "runningTask" channel is closed. @@ -502,33 +443,15 @@ func TestStopperShouldQuiesce(t *testing.T) { // The ShouldQuiesce() channel should close as soon as the stopper is // Stop()ed. <-s.ShouldQuiesce() - // However, the ShouldStop() channel should still be blocked because the - // async task started above is still running, meaning we haven't quiesceed - // yet. - select { - case <-s.ShouldStop(): - close(cleanup) - t.Fatal("expected ShouldStop() to block until quiesceing complete") - default: - // Expected. - } // After completing the running task, the ShouldStop() channel should // now close. close(runningTask) - <-s.ShouldStop() - // However, the working running above prevents the call to Stop() from - // returning; it blocks until the runner's goroutine is finished. We - // use the "waiting" channel to detect this. select { - case <-waiting: - close(cleanup) - t.Fatal("expected stopper to have blocked") - default: - // Expected. + case <-s.IsStopped(): + // Good. + case <-time.After(10 * time.Second): + t.Fatal("stopper did not fully stop in time") } - // Finally, close the "running" channel, which should cause the original - // call to Stop() to return. - close(running) <-waiting close(cleanup) }