diff --git a/pkg/sql/sql_activity_update_job.go b/pkg/sql/sql_activity_update_job.go index a3ef9c59e8eb..c53630d32b87 100644 --- a/pkg/sql/sql_activity_update_job.go +++ b/pkg/sql/sql_activity_update_job.go @@ -92,14 +92,12 @@ func (j *sqlActivityUpdateJob) Resume(ctx context.Context, execCtxI interface{}) flushDoneSignal := make(chan struct{}) defer func() { - statsFlush.SetFlushDoneCallback(nil) + statsFlush.SetFlushDoneSignalCh(nil) close(flushDoneSignal) }() + statsFlush.SetFlushDoneSignalCh(flushDoneSignal) for { - statsFlush.SetFlushDoneCallback(func() { - flushDoneSignal <- struct{}{} - }) select { case <-flushDoneSignal: // A flush was done. Set the timer and wait for it to complete. diff --git a/pkg/sql/sqlstats/persistedsqlstats/provider.go b/pkg/sql/sqlstats/persistedsqlstats/provider.go index d77c5502f098..574fdc275f12 100644 --- a/pkg/sql/sqlstats/persistedsqlstats/provider.go +++ b/pkg/sql/sqlstats/persistedsqlstats/provider.go @@ -69,8 +69,10 @@ type PersistedSQLStats struct { memoryPressureSignal chan struct{} // Used to signal the flush completed. - flushDoneCallback func() - flushMutex syncutil.Mutex + flushDoneMu struct { + syncutil.Mutex + signalCh chan<- struct{} + } lastFlushStarted time.Time jobMonitor jobMonitor @@ -94,7 +96,6 @@ func New(cfg *Config, memSQLStats *sslocal.SQLStats) *PersistedSQLStats { cfg: cfg, memoryPressureSignal: make(chan struct{}), drain: make(chan struct{}), - flushDoneCallback: nil, } p.jobMonitor = jobMonitor{ @@ -134,10 +135,11 @@ func (s *PersistedSQLStats) Stop(ctx context.Context) { s.tasksDoneWG.Wait() } -func (s *PersistedSQLStats) SetFlushDoneCallback(callBackFunc func()) { - s.flushMutex.Lock() - defer s.flushMutex.Unlock() - s.flushDoneCallback = callBackFunc +// SetFlushDoneSignalCh sets the channel to signal each time a flush has been completed. +func (s *PersistedSQLStats) SetFlushDoneSignalCh(sigCh chan<- struct{}) { + s.flushDoneMu.Lock() + defer s.flushDoneMu.Unlock() + s.flushDoneMu.signalCh = sigCh } // GetController returns the controller of the PersistedSQLStats. @@ -186,13 +188,21 @@ func (s *PersistedSQLStats) startSQLStatsFlushLoop(ctx context.Context, stopper s.Flush(ctx) - func() { - s.flushMutex.Lock() - defer s.flushMutex.Unlock() - if s.flushDoneCallback != nil { - s.flushDoneCallback() + // Tell the local activity translator job, if any, that we've + // performed a round of flush. + if sigCh := func() chan<- struct{} { + s.flushDoneMu.Lock() + defer s.flushDoneMu.Unlock() + return s.flushDoneMu.signalCh + }(); sigCh != nil { + select { + case sigCh <- struct{}{}: + case <-stopper.ShouldQuiesce(): + return + case <-s.drain: + return } - }() + } } }) if err != nil {