diff --git a/ttl/ttlworker/job_manager.go b/ttl/ttlworker/job_manager.go index 43d3f7fbfca88..f71f79bd2e78c 100644 --- a/ttl/ttlworker/job_manager.go +++ b/ttl/ttlworker/job_manager.go @@ -184,39 +184,73 @@ func (m *JobManager) reportMetrics() { func (m *JobManager) resizeScanWorkers(count int) error { var err error - m.scanWorkers, err = m.resizeWorkers(m.scanWorkers, count, func() worker { + var canceledWorkers []worker + m.scanWorkers, canceledWorkers, err = m.resizeWorkers(m.scanWorkers, count, func() worker { return newScanWorker(m.delCh, m.notifyStateCh, m.sessPool) }) + for _, w := range canceledWorkers { + s := w.(scanWorker) + + var tableID int64 + var scanErr error + result := s.PollTaskResult() + if result != nil { + tableID = result.task.tbl.ID + scanErr = result.err + } else { + // if the scan worker failed to poll the task, it's possible that the `WaitStopped` has timeout + // we still consider the scan task as finished + curTask := s.CurrentTask() + if curTask == nil { + continue + } + tableID = curTask.tbl.ID + scanErr = errors.New("timeout to cancel scan task") + } + + job := findJobWithTableID(m.runningJobs, tableID) + if job == nil { + logutil.Logger(m.ctx).Warn("task state changed but job not found", zap.Int64("tableID", tableID)) + continue + } + logutil.Logger(m.ctx).Debug("scan task finished", zap.String("jobID", job.id)) + job.finishedScanTaskCounter += 1 + job.scanTaskErr = multierr.Append(job.scanTaskErr, scanErr) + } return err } func (m *JobManager) resizeDelWorkers(count int) error { var err error - m.delWorkers, err = m.resizeWorkers(m.delWorkers, count, func() worker { + m.delWorkers, _, err = m.resizeWorkers(m.delWorkers, count, func() worker { return newDeleteWorker(m.delCh, m.sessPool) }) return err } -func (m *JobManager) resizeWorkers(workers []worker, count int, factory func() worker) ([]worker, error) { +// resizeWorkers scales the worker, and returns the full set of workers as the first return value. If there are workers +// stopped, return the stopped worker in the second return value +func (m *JobManager) resizeWorkers(workers []worker, count int, factory func() worker) ([]worker, []worker, error) { if count < len(workers) { logutil.Logger(m.ctx).Info("shrink ttl worker", zap.Int("originalCount", len(workers)), zap.Int("newCount", count)) for _, w := range workers[count:] { w.Stop() } + var errs error + ctx, cancel := context.WithTimeout(m.ctx, 30*time.Second) for _, w := range workers[count:] { - err := w.WaitStopped(m.ctx, 30*time.Second) + err := w.WaitStopped(ctx, 30*time.Second) if err != nil { logutil.Logger(m.ctx).Warn("fail to stop ttl worker", zap.Error(err)) errs = multierr.Append(errs, err) } } + cancel() // remove the existing workers, and keep the left workers - workers = workers[:count] - return workers, errs + return workers[:count], workers[count:], errs } if count > len(workers) { @@ -227,10 +261,10 @@ func (m *JobManager) resizeWorkers(workers []worker, count int, factory func() w w.Start() workers = append(workers, w) } - return workers, nil + return workers, nil, nil } - return workers, nil + return workers, nil, nil } // updateTaskState polls the result from scan worker and returns whether there are result polled @@ -238,12 +272,14 @@ func (m *JobManager) updateTaskState() bool { results := m.pollScanWorkerResults() for _, result := range results { job := findJobWithTableID(m.runningJobs, result.task.tbl.ID) - if job != nil { - logutil.Logger(m.ctx).Debug("scan task state changed", zap.String("jobID", job.id)) - - job.finishedScanTaskCounter += 1 - job.scanTaskErr = multierr.Append(job.scanTaskErr, result.err) + if job == nil { + logutil.Logger(m.ctx).Warn("task state changed but job not found", zap.Int64("tableID", result.task.tbl.ID)) + continue } + logutil.Logger(m.ctx).Debug("scan task finished", zap.String("jobID", job.id)) + + job.finishedScanTaskCounter += 1 + job.scanTaskErr = multierr.Append(job.scanTaskErr, result.err) } return len(results) > 0 @@ -252,7 +288,7 @@ func (m *JobManager) updateTaskState() bool { func (m *JobManager) pollScanWorkerResults() []*ttlScanTaskExecResult { results := make([]*ttlScanTaskExecResult, 0, len(m.scanWorkers)) for _, w := range m.scanWorkers { - worker := w.(*ttlScanWorker) + worker := w.(scanWorker) result := worker.PollTaskResult() if result != nil { results = append(results, result) diff --git a/ttl/ttlworker/job_manager_test.go b/ttl/ttlworker/job_manager_test.go index 7261eb2edf8f7..f1566b2bc02bf 100644 --- a/ttl/ttlworker/job_manager_test.go +++ b/ttl/ttlworker/job_manager_test.go @@ -305,7 +305,7 @@ func TestResizeWorkers(t *testing.T) { m.SetScanWorkers4Test([]worker{ scanWorker1, }) - newWorkers, err := m.resizeWorkers(m.scanWorkers, 2, func() worker { + newWorkers, _, err := m.resizeWorkers(m.scanWorkers, 2, func() worker { return scanWorker2 }) assert.NoError(t, err) @@ -327,6 +327,24 @@ func TestResizeWorkers(t *testing.T) { assert.NoError(t, m.resizeScanWorkers(1)) scanWorker2.checkWorkerStatus(workerStatusStopped, false, nil) + + // shrink scan workers after job is run + scanWorker1 = newMockScanWorker(t) + scanWorker1.Start() + scanWorker2 = newMockScanWorker(t) + scanWorker2.Start() + + m = NewJobManager("test-id", newMockSessionPool(t, tbl), nil) + m.SetScanWorkers4Test([]worker{ + scanWorker1, + scanWorker2, + }) + m.runningJobs = append(m.runningJobs, &ttlJob{tbl: tbl}) + + scanWorker2.curTaskResult = &ttlScanTaskExecResult{task: &ttlScanTask{tbl: tbl}} + assert.NoError(t, m.resizeScanWorkers(1)) + scanWorker2.checkWorkerStatus(workerStatusStopped, false, nil) + assert.Equal(t, m.runningJobs[0].finishedScanTaskCounter, 1) } func TestLocalJobs(t *testing.T) { diff --git a/ttl/ttlworker/scan.go b/ttl/ttlworker/scan.go index 7e997425b851d..38a4fd544535d 100644 --- a/ttl/ttlworker/scan.go +++ b/ttl/ttlworker/scan.go @@ -334,4 +334,6 @@ type scanWorker interface { Idle() bool Schedule(*ttlScanTask) error + PollTaskResult() *ttlScanTaskExecResult + CurrentTask() *ttlScanTask } diff --git a/ttl/ttlworker/worker.go b/ttl/ttlworker/worker.go index a04110373cdbf..783384862cacf 100644 --- a/ttl/ttlworker/worker.go +++ b/ttl/ttlworker/worker.go @@ -96,6 +96,12 @@ func (w *baseWorker) Error() error { } func (w *baseWorker) WaitStopped(ctx context.Context, timeout time.Duration) error { + // consider the situation when the worker has stopped, but the context has also stopped. We should + // return without error + if w.Status() == workerStatusStopped { + return nil + } + ctx, cancel := context.WithTimeout(ctx, timeout) go func() { w.wg.Wait()