diff --git a/pkg/ddl/backfilling_dist_scheduler_test.go b/pkg/ddl/backfilling_dist_scheduler_test.go index 7e3a3f60e0918..12384bda39de5 100644 --- a/pkg/ddl/backfilling_dist_scheduler_test.go +++ b/pkg/ddl/backfilling_dist_scheduler_test.go @@ -170,15 +170,16 @@ func TestBackfillingSchedulerGlobalSortMode(t *testing.T) { subtaskMetas, err := sch.OnNextSubtasksBatch(ctx, sch, task, execIDs, sch.GetNextStep(task)) require.NoError(t, err) require.Len(t, subtaskMetas, 1) - task.Step = ext.GetNextStep(task) - require.Equal(t, proto.BackfillStepReadIndex, task.Step) + nextStep := ext.GetNextStep(task) + require.Equal(t, proto.BackfillStepReadIndex, nextStep) // update task/subtask, and finish subtask, so we can go to next stage subtasks := make([]*proto.Subtask, 0, len(subtaskMetas)) - for _, m := range subtaskMetas { - subtasks = append(subtasks, proto.NewSubtask(task.Step, task.ID, task.Type, "", 1, m, 0)) + for i, m := range subtaskMetas { + subtasks = append(subtasks, proto.NewSubtask(nextStep, task.ID, task.Type, "", 1, m, i+1)) } - _, err = mgr.UpdateTaskAndAddSubTasks(ctx, task, subtasks, proto.TaskStatePending) + err = mgr.SwitchTaskStep(ctx, task, proto.TaskStatePending, nextStep, subtasks) require.NoError(t, err) + task.Step = nextStep gotSubtasks, err := mgr.GetSubtasksWithHistory(ctx, taskID, proto.BackfillStepReadIndex) require.NoError(t, err) @@ -210,16 +211,17 @@ func TestBackfillingSchedulerGlobalSortMode(t *testing.T) { subtaskMetas, err = ext.OnNextSubtasksBatch(ctx, sch, task, execIDs, ext.GetNextStep(task)) require.NoError(t, err) require.Len(t, subtaskMetas, 1) - task.Step = ext.GetNextStep(task) - require.Equal(t, proto.BackfillStepMergeSort, task.Step) + nextStep = ext.GetNextStep(task) + require.Equal(t, proto.BackfillStepMergeSort, nextStep) // update meta, same as import into. subtasks = make([]*proto.Subtask, 0, len(subtaskMetas)) - for _, m := range subtaskMetas { - subtasks = append(subtasks, proto.NewSubtask(task.Step, task.ID, task.Type, "", 1, m, 0)) + for i, m := range subtaskMetas { + subtasks = append(subtasks, proto.NewSubtask(nextStep, task.ID, task.Type, "", 1, m, i+1)) } - _, err = mgr.UpdateTaskAndAddSubTasks(ctx, task, subtasks, proto.TaskStatePending) + err = mgr.SwitchTaskStepInBatch(ctx, task, proto.TaskStatePending, nextStep, subtasks) require.NoError(t, err) + task.Step = nextStep gotSubtasks, err = mgr.GetSubtasksWithHistory(ctx, taskID, task.Step) require.NoError(t, err) mergeSortStepMeta := &ddl.BackfillSubTaskMeta{ diff --git a/pkg/ddl/backfilling_import_cloud.go b/pkg/ddl/backfilling_import_cloud.go index b5f0e70749ec7..77003923cc9f8 100644 --- a/pkg/ddl/backfilling_import_cloud.go +++ b/pkg/ddl/backfilling_import_cloud.go @@ -107,8 +107,3 @@ func (*cloudImportExecutor) OnFinished(ctx context.Context, _ *proto.Subtask) er logutil.Logger(ctx).Info("cloud import executor finish subtask") return nil } - -func (*cloudImportExecutor) Rollback(ctx context.Context) error { - logutil.Logger(ctx).Info("cloud import executor rollback subtask") - return nil -} diff --git a/pkg/ddl/backfilling_merge_sort.go b/pkg/ddl/backfilling_merge_sort.go index da8954d966060..e93443285fe3b 100644 --- a/pkg/ddl/backfilling_merge_sort.go +++ b/pkg/ddl/backfilling_merge_sort.go @@ -135,8 +135,3 @@ func (m *mergeSortExecutor) OnFinished(ctx context.Context, subtask *proto.Subta subtask.Meta = newMeta return nil } - -func (*mergeSortExecutor) Rollback(ctx context.Context) error { - logutil.Logger(ctx).Info("merge sort executor rollback backfill add index task") - return nil -} diff --git a/pkg/ddl/backfilling_read_index.go b/pkg/ddl/backfilling_read_index.go index e238abfaf19ab..ce207c53941f4 100644 --- a/pkg/ddl/backfilling_read_index.go +++ b/pkg/ddl/backfilling_read_index.go @@ -194,12 +194,6 @@ func (r *readIndexExecutor) OnFinished(ctx context.Context, subtask *proto.Subta return nil } -func (r *readIndexExecutor) Rollback(ctx context.Context) error { - logutil.Logger(ctx).Info("read index executor rollback backfill add index task", - zap.String("category", "ddl"), zap.Int64("jobID", r.job.ID)) - return nil -} - func (r *readIndexExecutor) getTableStartEndKey(sm *BackfillSubTaskMeta) ( start, end kv.Key, tbl table.PhysicalTable, err error) { currentVer, err1 := getValidCurrentVersion(r.d.store) diff --git a/pkg/disttask/framework/framework_rollback_test.go b/pkg/disttask/framework/framework_rollback_test.go index 39d4bf3ace8bc..78e84dc1b29d7 100644 --- a/pkg/disttask/framework/framework_rollback_test.go +++ b/pkg/disttask/framework/framework_rollback_test.go @@ -34,7 +34,5 @@ func TestFrameworkRollback(t *testing.T) { task := testutil.SubmitAndWaitTask(ctx, t, "key1") require.Equal(t, proto.TaskStateReverted, task.State) - require.Equal(t, int32(2), testContext.RollbackCnt.Load()) - testContext.RollbackCnt.Store(0) distContext.Close() } diff --git a/pkg/disttask/framework/mock/execute/execute_mock.go b/pkg/disttask/framework/mock/execute/execute_mock.go index 4aaeab1e2b77c..cd6ca330a3afa 100644 --- a/pkg/disttask/framework/mock/execute/execute_mock.go +++ b/pkg/disttask/framework/mock/execute/execute_mock.go @@ -81,20 +81,6 @@ func (mr *MockStepExecutorMockRecorder) OnFinished(arg0, arg1 any) *gomock.Call return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnFinished", reflect.TypeOf((*MockStepExecutor)(nil).OnFinished), arg0, arg1) } -// Rollback mocks base method. -func (m *MockStepExecutor) Rollback(arg0 context.Context) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Rollback", arg0) - ret0, _ := ret[0].(error) - return ret0 -} - -// Rollback indicates an expected call of Rollback. -func (mr *MockStepExecutorMockRecorder) Rollback(arg0 any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Rollback", reflect.TypeOf((*MockStepExecutor)(nil).Rollback), arg0) -} - // RunSubtask mocks base method. func (m *MockStepExecutor) RunSubtask(arg0 context.Context, arg1 *proto.Subtask) error { m.ctrl.T.Helper() diff --git a/pkg/disttask/framework/mock/scheduler_mock.go b/pkg/disttask/framework/mock/scheduler_mock.go index 57002fbd44fcd..ad578ebd3249f 100644 --- a/pkg/disttask/framework/mock/scheduler_mock.go +++ b/pkg/disttask/framework/mock/scheduler_mock.go @@ -398,21 +398,6 @@ func (mr *MockTaskManagerMockRecorder) GetTaskByID(arg0, arg1 any) *gomock.Call return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTaskByID", reflect.TypeOf((*MockTaskManager)(nil).GetTaskByID), arg0, arg1) } -// GetTaskExecutorIDsByTaskID mocks base method. -func (m *MockTaskManager) GetTaskExecutorIDsByTaskID(arg0 context.Context, arg1 int64) ([]string, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetTaskExecutorIDsByTaskID", arg0, arg1) - ret0, _ := ret[0].([]string) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// GetTaskExecutorIDsByTaskID indicates an expected call of GetTaskExecutorIDsByTaskID. -func (mr *MockTaskManagerMockRecorder) GetTaskExecutorIDsByTaskID(arg0, arg1 any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTaskExecutorIDsByTaskID", reflect.TypeOf((*MockTaskManager)(nil).GetTaskExecutorIDsByTaskID), arg0, arg1) -} - // GetTasksInStates mocks base method. func (m *MockTaskManager) GetTasksInStates(arg0 context.Context, arg1 ...any) ([]*proto.Task, error) { m.ctrl.T.Helper() @@ -506,6 +491,34 @@ func (mr *MockTaskManagerMockRecorder) ResumeSubtasks(arg0, arg1 any) *gomock.Ca return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ResumeSubtasks", reflect.TypeOf((*MockTaskManager)(nil).ResumeSubtasks), arg0, arg1) } +// ResumedTask mocks base method. +func (m *MockTaskManager) ResumedTask(arg0 context.Context, arg1 int64) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ResumedTask", arg0, arg1) + ret0, _ := ret[0].(error) + return ret0 +} + +// ResumedTask indicates an expected call of ResumedTask. +func (mr *MockTaskManagerMockRecorder) ResumedTask(arg0, arg1 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ResumedTask", reflect.TypeOf((*MockTaskManager)(nil).ResumedTask), arg0, arg1) +} + +// RevertTask mocks base method. +func (m *MockTaskManager) RevertTask(arg0 context.Context, arg1 int64, arg2 proto.TaskState, arg3 error) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "RevertTask", arg0, arg1, arg2, arg3) + ret0, _ := ret[0].(error) + return ret0 +} + +// RevertTask indicates an expected call of RevertTask. +func (mr *MockTaskManagerMockRecorder) RevertTask(arg0, arg1, arg2, arg3 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RevertTask", reflect.TypeOf((*MockTaskManager)(nil).RevertTask), arg0, arg1, arg2, arg3) +} + // RevertedTask mocks base method. func (m *MockTaskManager) RevertedTask(arg0 context.Context, arg1 int64) error { m.ctrl.T.Helper() @@ -590,21 +603,6 @@ func (mr *MockTaskManagerMockRecorder) UpdateSubtasksExecIDs(arg0, arg1 any) *go return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateSubtasksExecIDs", reflect.TypeOf((*MockTaskManager)(nil).UpdateSubtasksExecIDs), arg0, arg1) } -// UpdateTaskAndAddSubTasks mocks base method. -func (m *MockTaskManager) UpdateTaskAndAddSubTasks(arg0 context.Context, arg1 *proto.Task, arg2 []*proto.Subtask, arg3 proto.TaskState) (bool, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "UpdateTaskAndAddSubTasks", arg0, arg1, arg2, arg3) - ret0, _ := ret[0].(bool) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// UpdateTaskAndAddSubTasks indicates an expected call of UpdateTaskAndAddSubTasks. -func (mr *MockTaskManagerMockRecorder) UpdateTaskAndAddSubTasks(arg0, arg1, arg2, arg3 any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateTaskAndAddSubTasks", reflect.TypeOf((*MockTaskManager)(nil).UpdateTaskAndAddSubTasks), arg0, arg1, arg2, arg3) -} - // WithNewSession mocks base method. func (m *MockTaskManager) WithNewSession(arg0 func(sessionctx.Context) error) error { m.ctrl.T.Helper() diff --git a/pkg/disttask/framework/proto/subtask.go b/pkg/disttask/framework/proto/subtask.go index e825aa3202300..401888ac67ab2 100644 --- a/pkg/disttask/framework/proto/subtask.go +++ b/pkg/disttask/framework/proto/subtask.go @@ -42,30 +42,13 @@ import ( // │ ┌────────┐ // └───────►│canceled│ // └────────┘ -// -// for reverting subtask: -// -// ┌──────────────┐ ┌─────────┐ ┌─────────┐ -// │revert_pending├───►│reverting├──►│ reverted│ -// └──────────────┘ └────┬────┘ └─────────┘ -// │ ┌─────────────┐ -// └────────►│revert_failed│ -// └─────────────┘ -// 1. succeed/failed: pending -> running -> succeed/failed -// 2. canceled: pending -> running -> canceled -// 3. rollback: revert_pending -> reverting -> reverted/revert_failed -// 4. pause/resume: pending -> running -> paused -> running const ( - SubtaskStatePending SubtaskState = "pending" - SubtaskStateRunning SubtaskState = "running" - SubtaskStateSucceed SubtaskState = "succeed" - SubtaskStateFailed SubtaskState = "failed" - SubtaskStateCanceled SubtaskState = "canceled" - SubtaskStatePaused SubtaskState = "paused" - SubtaskStateRevertPending SubtaskState = "revert_pending" - SubtaskStateReverting SubtaskState = "reverting" - SubtaskStateReverted SubtaskState = "reverted" - SubtaskStateRevertFailed SubtaskState = "revert_failed" + SubtaskStatePending SubtaskState = "pending" + SubtaskStateRunning SubtaskState = "running" + SubtaskStateSucceed SubtaskState = "succeed" + SubtaskStateFailed SubtaskState = "failed" + SubtaskStateCanceled SubtaskState = "canceled" + SubtaskStatePaused SubtaskState = "paused" ) type ( @@ -117,8 +100,8 @@ func (t *Subtask) String() string { // IsDone checks if the subtask is done. func (t *Subtask) IsDone() bool { - return t.State == SubtaskStateSucceed || t.State == SubtaskStateReverted || t.State == SubtaskStateCanceled || - t.State == SubtaskStateFailed || t.State == SubtaskStateRevertFailed + return t.State == SubtaskStateSucceed || t.State == SubtaskStateCanceled || + t.State == SubtaskStateFailed } // NewSubtask create a new subtask. diff --git a/pkg/disttask/framework/proto/subtask_test.go b/pkg/disttask/framework/proto/subtask_test.go index 09e01111428e4..22380a1410b47 100644 --- a/pkg/disttask/framework/proto/subtask_test.go +++ b/pkg/disttask/framework/proto/subtask_test.go @@ -31,12 +31,8 @@ func TestSubtaskIsDone(t *testing.T) { {SubtaskStatePending, false}, {SubtaskStateRunning, false}, {SubtaskStateSucceed, true}, - {SubtaskStateReverting, false}, - {SubtaskStateRevertPending, false}, {SubtaskStateFailed, true}, - {SubtaskStateRevertFailed, true}, {SubtaskStatePaused, false}, - {SubtaskStateReverted, true}, {SubtaskStateCanceled, true}, } for _, c := range cases { diff --git a/pkg/disttask/framework/proto/task.go b/pkg/disttask/framework/proto/task.go index 4a27d731e8d8e..cfa4c3a7555c2 100644 --- a/pkg/disttask/framework/proto/task.go +++ b/pkg/disttask/framework/proto/task.go @@ -20,40 +20,37 @@ import ( // task state machine // -// ┌────────┐ -// ┌───────────│resuming│◄────────┐ -// │ └────────┘ │ -// ┌──────┐ │ ┌───────┐ ┌──┴───┐ -// │failed│ │ ┌────────►│pausing├──────►│paused│ -// └──────┘ │ │ └───────┘ └──────┘ -// ▲ ▼ │ -// ┌──┴────┐ ┌───┴───┐ ┌────────┐ -// │pending├────►│running├────►│succeed │ -// └──┬────┘ └──┬┬───┘ └────────┘ -// │ ││ ┌─────────┐ ┌────────┐ -// │ │└────────►│reverting├────►│reverted│ -// │ ▼ └────┬────┘ └────────┘ -// │ ┌──────────┐ ▲ │ ┌─────────────┐ -// └─────────►│cancelling├────┘ └─────────►│revert_failed│ -// └──────────┘ └─────────────┘ -// 1. succeed: pending -> running -> succeed -// 2. failed: pending -> running -> reverting -> reverted/revert_failed, pending -> failed -// 3. canceled: pending -> running -> cancelling -> reverting -> reverted/revert_failed -// 4. pause/resume: pending -> running -> pausing -> paused -> running +// Note: if a task fails during running, it will end with `reverted` state. +// The `failed` state is used to mean the framework cannot run the task, such as +// invalid task type, scheduler init error(fatal), etc. // -// TODO: we don't have revert_failed task for now. +// ┌────────┐ +// ┌───────────│resuming│◄────────┐ +// │ └────────┘ │ +// ┌──────┐ │ ┌───────┐ ┌──┴───┐ +// │failed│ │ ┌────────►│pausing├──────►│paused│ +// └──────┘ │ │ └───────┘ └──────┘ +// ▲ ▼ │ +// ┌──┴────┐ ┌───┴───┐ ┌────────┐ +// │pending├────►│running├────►│succeed │ +// └──┬────┘ └──┬┬───┘ └────────┘ +// │ ││ ┌─────────┐ ┌────────┐ +// │ │└────────►│reverting├────►│reverted│ +// │ ▼ └─────────┘ └────────┘ +// │ ┌──────────┐ ▲ +// └─────────►│cancelling├────┘ +// └──────────┘ const ( - TaskStatePending TaskState = "pending" - TaskStateRunning TaskState = "running" - TaskStateSucceed TaskState = "succeed" - TaskStateFailed TaskState = "failed" - TaskStateReverting TaskState = "reverting" - TaskStateReverted TaskState = "reverted" - TaskStateRevertFailed TaskState = "revert_failed" - TaskStateCancelling TaskState = "cancelling" - TaskStatePausing TaskState = "pausing" - TaskStatePaused TaskState = "paused" - TaskStateResuming TaskState = "resuming" + TaskStatePending TaskState = "pending" + TaskStateRunning TaskState = "running" + TaskStateSucceed TaskState = "succeed" + TaskStateFailed TaskState = "failed" + TaskStateReverting TaskState = "reverting" + TaskStateReverted TaskState = "reverted" + TaskStateCancelling TaskState = "cancelling" + TaskStatePausing TaskState = "pausing" + TaskStatePaused TaskState = "paused" + TaskStateResuming TaskState = "resuming" ) type ( diff --git a/pkg/disttask/framework/proto/task_test.go b/pkg/disttask/framework/proto/task_test.go index a59981da333dd..845e65aa0d0f6 100644 --- a/pkg/disttask/framework/proto/task_test.go +++ b/pkg/disttask/framework/proto/task_test.go @@ -37,7 +37,6 @@ func TestTaskIsDone(t *testing.T) { {TaskStateSucceed, true}, {TaskStateReverting, false}, {TaskStateFailed, true}, - {TaskStateRevertFailed, false}, {TaskStateCancelling, false}, {TaskStatePausing, false}, {TaskStatePaused, false}, diff --git a/pkg/disttask/framework/scheduler/BUILD.bazel b/pkg/disttask/framework/scheduler/BUILD.bazel index c721ddd939d19..0c678390d6283 100644 --- a/pkg/disttask/framework/scheduler/BUILD.bazel +++ b/pkg/disttask/framework/scheduler/BUILD.bazel @@ -53,7 +53,7 @@ go_test( embed = [":scheduler"], flaky = True, race = "off", - shard_count = 30, + shard_count = 29, deps = [ "//pkg/config", "//pkg/disttask/framework/mock", diff --git a/pkg/disttask/framework/scheduler/interface.go b/pkg/disttask/framework/scheduler/interface.go index ec9b8f6d6992f..a1fea8c2126c8 100644 --- a/pkg/disttask/framework/scheduler/interface.go +++ b/pkg/disttask/framework/scheduler/interface.go @@ -32,22 +32,25 @@ type TaskManager interface { GetTopUnfinishedTasks(ctx context.Context) ([]*proto.Task, error) GetTasksInStates(ctx context.Context, states ...interface{}) (task []*proto.Task, err error) GetTaskByID(ctx context.Context, taskID int64) (task *proto.Task, err error) - UpdateTaskAndAddSubTasks(ctx context.Context, task *proto.Task, subtasks []*proto.Subtask, prevState proto.TaskState) (bool, error) GCSubtasks(ctx context.Context) error GetAllNodes(ctx context.Context) ([]proto.ManagedNode, error) DeleteDeadNodes(ctx context.Context, nodes []string) error - // TransferTask2History transfer tasks and it's related subtasks to history tables. + // TransferTasks2History transfer tasks, and it's related subtasks to history tables. TransferTasks2History(ctx context.Context, tasks []*proto.Task) error // CancelTask updated task state to canceling. CancelTask(ctx context.Context, taskID int64) error // FailTask updates task state to Failed and updates task error. FailTask(ctx context.Context, taskID int64, currentState proto.TaskState, taskErr error) error + // RevertTask updates task state to reverting, and task error. + RevertTask(ctx context.Context, taskID int64, taskState proto.TaskState, taskErr error) error // RevertedTask updates task state to reverted. RevertedTask(ctx context.Context, taskID int64) error // PauseTask updated task state to pausing. PauseTask(ctx context.Context, taskKey string) (bool, error) // PausedTask updated task state to paused. PausedTask(ctx context.Context, taskID int64) error + // ResumedTask updated task state from resuming to running. + ResumedTask(ctx context.Context, taskID int64) error // SucceedTask updates a task to success state. SucceedTask(ctx context.Context, taskID int64) error // SwitchTaskStep switches the task to the next step and add subtasks in one @@ -80,8 +83,6 @@ type TaskManager interface { // else we use nodes without role. // returned nodes are sorted by node id(host:port). GetManagedNodes(ctx context.Context) ([]proto.ManagedNode, error) - // GetTaskExecutorIDsByTaskID gets the task executor IDs of the given task ID. - GetTaskExecutorIDsByTaskID(ctx context.Context, taskID int64) ([]string, error) // GetAllSubtasksByStepAndState gets all subtasks by given states for one step. GetAllSubtasksByStepAndState(ctx context.Context, taskID int64, step proto.Step, state proto.SubtaskState) ([]*proto.Subtask, error) diff --git a/pkg/disttask/framework/scheduler/scheduler.go b/pkg/disttask/framework/scheduler/scheduler.go index 69d312aa72e74..fa10adbcb11af 100644 --- a/pkg/disttask/framework/scheduler/scheduler.go +++ b/pkg/disttask/framework/scheduler/scheduler.go @@ -32,7 +32,6 @@ import ( "github.com/pingcap/tidb/pkg/metrics" "github.com/pingcap/tidb/pkg/sessionctx" "github.com/pingcap/tidb/pkg/util/backoff" - disttaskutil "github.com/pingcap/tidb/pkg/util/disttask" "github.com/pingcap/tidb/pkg/util/intest" "github.com/pingcap/tidb/pkg/util/logutil" "go.uber.org/zap" @@ -285,7 +284,7 @@ func (s *BaseScheduler) onResuming() error { if cntByStates[proto.SubtaskStatePaused] == 0 { // Finish the resuming process. s.logger.Info("all paused tasks converted to pending state, update the task to running state") - err := s.updateTask(proto.TaskStateRunning, nil, RetrySQLTimes) + err := s.taskMgr.ResumedTask(s.ctx, task.ID) failpoint.Inject("syncAfterResume", func() { TestSyncChan <- struct{}{} }) @@ -304,7 +303,7 @@ func (s *BaseScheduler) onReverting() error { s.logger.Warn("check task failed", zap.Error(err)) return err } - activeRevertCnt := cntByStates[proto.SubtaskStateRevertPending] + cntByStates[proto.SubtaskStateReverting] + activeRevertCnt := cntByStates[proto.SubtaskStatePending] + cntByStates[proto.SubtaskStateRunning] if activeRevertCnt == 0 { if err = s.OnDone(s.ctx, s, task); err != nil { return errors.Trace(err) @@ -363,63 +362,13 @@ func (s *BaseScheduler) onFinished() { s.logger.Debug("schedule task, task is finished", zap.Stringer("state", task.State)) } -// updateTask update the task in tidb_global_task table. -func (s *BaseScheduler) updateTask(taskState proto.TaskState, newSubTasks []*proto.Subtask, retryTimes int) (err error) { - task := *s.GetTask() - prevState := task.State - task.State = taskState - s.task.Store(&task) - logutil.BgLogger().Info("task state transform", zap.Stringer("from", prevState), zap.Stringer("to", taskState)) - if !VerifyTaskStateTransform(prevState, taskState) { - return errors.Errorf("invalid task state transform, from %s to %s", prevState, taskState) - } - - var retryable bool - for i := 0; i < retryTimes; i++ { - retryable, err = s.taskMgr.UpdateTaskAndAddSubTasks(s.ctx, &task, newSubTasks, prevState) - if err == nil || !retryable { - break - } - if err1 := s.ctx.Err(); err1 != nil { - return err1 - } - if i%10 == 0 { - s.logger.Warn("updateTask first failed", zap.Stringer("from", prevState), zap.Stringer("to", task.State), - zap.Int("retry times", i), zap.Error(err)) - } - time.Sleep(RetrySQLInterval) - } - if err != nil && retryTimes != nonRetrySQLTime { - s.logger.Warn("updateTask failed", - zap.Stringer("from", prevState), zap.Stringer("to", task.State), zap.Int("retry times", retryTimes), zap.Error(err)) - } - return err -} - func (s *BaseScheduler) onErrHandlingStage(receiveErrs []error) error { task := *s.GetTask() // we only store the first error. task.Error = receiveErrs[0] s.task.Store(&task) - var subTasks []*proto.Subtask - // when step of task is `StepInit`, no need to do revert - if task.Step != proto.StepInit { - instanceIDs, err := s.GetAllTaskExecutorIDs(s.ctx, &task) - if err != nil { - s.logger.Warn("get task's all instances failed", zap.Error(err)) - return err - } - - subTasks = make([]*proto.Subtask, 0, len(instanceIDs)) - for _, id := range instanceIDs { - // reverting subtasks belong to the same step as current active step. - subTasks = append(subTasks, proto.NewSubtask( - task.Step, task.ID, task.Type, id, - task.Concurrency, proto.EmptyMeta, 0)) - } - } - return s.updateTask(proto.TaskStateReverting, subTasks, RetrySQLTimes) + return s.taskMgr.RevertTask(s.ctx, task.ID, task.State, task.Error) } func (s *BaseScheduler) switch2NextStep() (err error) { @@ -532,6 +481,7 @@ func (s *BaseScheduler) handlePlanErr(err error) error { return errors.Trace(err) } + // TODO: to reverting state? return s.taskMgr.FailTask(s.ctx, task.ID, task.State, task.Error) } @@ -564,31 +514,6 @@ func GenerateTaskExecutorNodes(ctx context.Context) (serverNodes []*infosync.Ser return serverNodes, nil } -// GetAllTaskExecutorIDs gets all the task executor IDs. -func (s *BaseScheduler) GetAllTaskExecutorIDs(ctx context.Context, task *proto.Task) ([]string, error) { - // We get all servers instead of eligible servers here - // because eligible servers may change during the task execution. - serverInfos, err := GenerateTaskExecutorNodes(ctx) - if err != nil { - return nil, err - } - if len(serverInfos) == 0 { - return nil, nil - } - - executorIDs, err := s.taskMgr.GetTaskExecutorIDsByTaskID(s.ctx, task.ID) - if err != nil { - return nil, err - } - ids := make([]string, 0, len(executorIDs)) - for _, id := range executorIDs { - if ok := disttaskutil.MatchServerInfo(serverInfos, id); ok { - ids = append(ids, id) - } - } - return ids, nil -} - // GetPreviousSubtaskMetas get subtask metas from specific step. func (s *BaseScheduler) GetPreviousSubtaskMetas(taskID int64, step proto.Step) ([][]byte, error) { previousSubtasks, err := s.taskMgr.GetAllSubtasksByStepAndState(s.ctx, taskID, step, proto.SubtaskStateSucceed) diff --git a/pkg/disttask/framework/scheduler/scheduler_nokit_test.go b/pkg/disttask/framework/scheduler/scheduler_nokit_test.go index 2e67331c90e1b..2b2558b24dcb4 100644 --- a/pkg/disttask/framework/scheduler/scheduler_nokit_test.go +++ b/pkg/disttask/framework/scheduler/scheduler_nokit_test.go @@ -211,7 +211,6 @@ func TestSchedulerIsStepSucceed(t *testing.T) { for _, state := range []proto.SubtaskState{ proto.SubtaskStateCanceled, proto.SubtaskStateFailed, - proto.SubtaskStateReverting, } { require.False(t, s.isStepSucceed(map[proto.SubtaskState]int64{ state: 1, diff --git a/pkg/disttask/framework/scheduler/scheduler_test.go b/pkg/disttask/framework/scheduler/scheduler_test.go index 752d499832f84..c263c2bbb3775 100644 --- a/pkg/disttask/framework/scheduler/scheduler_test.go +++ b/pkg/disttask/framework/scheduler/scheduler_test.go @@ -132,76 +132,6 @@ func deleteTasks(t *testing.T, store kv.Storage, taskID int64) { tk.MustExec(fmt.Sprintf("delete from mysql.tidb_global_task where id = %d", taskID)) } -func TestGetInstance(t *testing.T) { - ctx := context.Background() - ctx = util.WithInternalSourceType(ctx, "scheduler") - - store := testkit.CreateMockStore(t) - gtk := testkit.NewTestKit(t, store) - pool := pools.NewResourcePool(func() (pools.Resource, error) { - return gtk.Session(), nil - }, 1, 1, time.Second) - defer pool.Close() - ctrl := gomock.NewController(t) - defer ctrl.Finish() - require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/pkg/disttask/framework/scheduler/mockTaskExecutorNodes", "return()")) - schManager, mgr := MockSchedulerManager(t, ctrl, pool, getTestSchedulerExt(ctrl), nil) - // test no server - task := &proto.Task{ID: 1, Type: proto.TaskTypeExample} - sch := schManager.MockScheduler(task) - sch.Extension = getTestSchedulerExt(ctrl) - instanceIDs, err := sch.GetAllTaskExecutorIDs(ctx, task) - require.Lenf(t, instanceIDs, 0, "GetAllTaskExecutorIDs when there's no subtask") - require.NoError(t, err) - - // test 2 servers - // server ids: uuid0, uuid1 - // subtask instance ids: nil - uuids := []string{"ddl_id_1", "ddl_id_2"} - serverIDs := []string{"10.123.124.10:32457", "[ABCD:EF01:2345:6789:ABCD:EF01:2345:6789]:65535"} - - scheduler.MockServerInfo = []*infosync.ServerInfo{ - { - ID: uuids[0], - IP: "10.123.124.10", - Port: 32457, - }, - { - ID: uuids[1], - IP: "ABCD:EF01:2345:6789:ABCD:EF01:2345:6789", - Port: 65535, - }, - } - instanceIDs, err = sch.GetAllTaskExecutorIDs(ctx, task) - require.Lenf(t, instanceIDs, 0, "GetAllTaskExecutorIDs") - require.NoError(t, err) - - // server ids: uuid0, uuid1 - // subtask instance ids: uuid1 - subtask := &proto.Subtask{ - Type: proto.TaskTypeExample, - TaskID: task.ID, - ExecID: serverIDs[1], - } - testutil.CreateSubTask(t, mgr, task.ID, proto.StepInit, subtask.ExecID, nil, subtask.Type, 11, true) - instanceIDs, err = sch.GetAllTaskExecutorIDs(ctx, task) - require.NoError(t, err) - require.Equal(t, []string{serverIDs[1]}, instanceIDs) - // server ids: uuid0, uuid1 - // subtask instance ids: uuid0, uuid1 - subtask = &proto.Subtask{ - Type: proto.TaskTypeExample, - TaskID: task.ID, - ExecID: serverIDs[0], - } - testutil.CreateSubTask(t, mgr, task.ID, proto.StepInit, subtask.ExecID, nil, subtask.Type, 11, true) - instanceIDs, err = sch.GetAllTaskExecutorIDs(ctx, task) - require.NoError(t, err) - require.Len(t, instanceIDs, len(serverIDs)) - require.ElementsMatch(t, instanceIDs, serverIDs) - require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/pkg/disttask/framework/scheduler/mockTaskExecutorNodes")) -} - func TestTaskFailInManager(t *testing.T) { store := testkit.CreateMockStore(t) gtk := testkit.NewTestKit(t, store) @@ -395,12 +325,6 @@ func checkDispatch(t *testing.T, taskCnt int, isSucc, isCancel, isSubtaskCancel, checkGetTaskState(proto.TaskStateSucceed) return } else { - // Test each task has a subtask failed. - require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/pkg/disttask/framework/storage/MockUpdateTaskErr", "1*return(true)")) - defer func() { - require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/pkg/disttask/framework/storage/MockUpdateTaskErr")) - }() - if isSubtaskCancel { // Mock a subtask canceled for i := 1; i <= subtaskCnt*taskCnt; i += subtaskCnt { @@ -418,11 +342,14 @@ func checkDispatch(t *testing.T, taskCnt int, isSucc, isCancel, isSubtaskCancel, checkGetTaskState(proto.TaskStateReverting) require.Len(t, tasks, taskCnt) - // Mock all subtask reverted. - start := subtaskCnt * taskCnt - for i := start; i <= start+subtaskCnt*taskCnt; i++ { - err = mgr.UpdateSubtaskStateAndError(ctx, ":4000", int64(i), proto.SubtaskStateReverted, nil) + for _, task := range tasks { + subtasks, err := mgr.GetSubtasksByExecIDAndStepAndStates( + ctx, ":4000", task.ID, task.Step, + proto.SubtaskStatePending, proto.SubtaskStateRunning) require.NoError(t, err) + for _, subtask := range subtasks { + require.NoError(t, mgr.UpdateSubtaskStateAndError(ctx, ":4000", subtask.ID, proto.SubtaskStateCanceled, nil)) + } } checkGetTaskState(proto.TaskStateReverted) require.Len(t, tasks, taskCnt) diff --git a/pkg/disttask/framework/scheduler/state_transform.go b/pkg/disttask/framework/scheduler/state_transform.go index 1faaea344d396..94ea5f3f24680 100644 --- a/pkg/disttask/framework/scheduler/state_transform.go +++ b/pkg/disttask/framework/scheduler/state_transform.go @@ -41,8 +41,7 @@ func VerifyTaskStateTransform(from, to proto.TaskState) bool { // no revert_failed now // proto.TaskStateRevertFailed, }, - proto.TaskStateFailed: {}, - proto.TaskStateRevertFailed: {}, + proto.TaskStateFailed: {}, proto.TaskStateCancelling: { proto.TaskStateReverting, }, diff --git a/pkg/disttask/framework/storage/BUILD.bazel b/pkg/disttask/framework/storage/BUILD.bazel index d68d574ff4c97..10428f4a2b472 100644 --- a/pkg/disttask/framework/storage/BUILD.bazel +++ b/pkg/disttask/framework/storage/BUILD.bazel @@ -19,7 +19,6 @@ go_library( "//pkg/sessionctx/variable", "//pkg/util/chunk", "//pkg/util/cpu", - "//pkg/util/intest", "//pkg/util/logutil", "//pkg/util/sqlescape", "//pkg/util/sqlexec", @@ -54,7 +53,6 @@ go_test( "//pkg/testkit", "//pkg/testkit/testsetup", "//pkg/util/sqlexec", - "@com_github_ngaut_pools//:pools", "@com_github_pingcap_errors//:errors", "@com_github_pingcap_failpoint//:failpoint", "@com_github_stretchr_testify//require", diff --git a/pkg/disttask/framework/storage/table_test.go b/pkg/disttask/framework/storage/table_test.go index 5e0050048f700..6610240d65154 100644 --- a/pkg/disttask/framework/storage/table_test.go +++ b/pkg/disttask/framework/storage/table_test.go @@ -22,7 +22,6 @@ import ( "testing" "time" - "github.com/ngaut/pools" "github.com/pingcap/errors" "github.com/pingcap/failpoint" "github.com/pingcap/tidb/pkg/disttask/framework/proto" @@ -37,14 +36,6 @@ import ( "github.com/tikv/client-go/v2/util" ) -func GetTaskManager(t *testing.T, pool *pools.ResourcePool) *storage.TaskManager { - manager := storage.NewTaskManager(pool) - storage.SetTaskManager(manager) - manager, err := storage.GetTaskManager() - require.NoError(t, err) - return manager -} - func checkTaskStateStep(t *testing.T, task *proto.Task, state proto.TaskState, step proto.Step) { require.Equal(t, state, task.State) require.Equal(t, step, task.Step) @@ -93,12 +84,10 @@ func TestTaskTable(t *testing.T) { require.Equal(t, task, task4[0]) require.GreaterOrEqual(t, task4[0].StateUpdateTime, task.StateUpdateTime) - prevState := task.State - task.State = proto.TaskStateRunning - retryable, err := gm.UpdateTaskAndAddSubTasks(ctx, task, nil, prevState) + err = gm.SwitchTaskStep(ctx, task, proto.TaskStateRunning, proto.StepOne, nil) require.NoError(t, err) - require.True(t, retryable) + task.State = proto.TaskStateRunning task5, err := gm.GetTasksInStates(ctx, proto.TaskStateRunning) require.NoError(t, err) require.Len(t, task5, 1) @@ -172,18 +161,23 @@ func TestTaskTable(t *testing.T) { require.NoError(t, err) checkTaskStateStep(t, task, proto.TaskStatePending, proto.StepInit) // reverted a reverting task + require.NoError(t, gm.SwitchTaskStep(ctx, task, proto.TaskStateRunning, proto.StepOne, nil)) + task, err = gm.GetTaskByID(ctx, id) + require.NoError(t, err) + checkTaskStateStep(t, task, proto.TaskStateRunning, proto.StepOne) task.State = proto.TaskStateReverting - _, err = gm.UpdateTaskAndAddSubTasks(ctx, task, nil, proto.TaskStatePending) + err = gm.RevertTask(ctx, task.ID, proto.TaskStateRunning, errors.New("test error")) require.NoError(t, err) task, err = gm.GetTaskByID(ctx, id) require.NoError(t, err) require.Equal(t, proto.TaskStateReverting, task.State) + require.ErrorContains(t, task.Error, "test error") require.NoError(t, gm.RevertedTask(ctx, task.ID)) task, err = gm.GetTaskByID(ctx, id) require.NoError(t, err) require.Equal(t, proto.TaskStateReverted, task.State) - // paused + // paused id, err = gm.CreateTask(ctx, "key-paused", "test", 4, []byte("test")) require.NoError(t, err) require.NoError(t, gm.PausedTask(ctx, id)) @@ -192,8 +186,9 @@ func TestTaskTable(t *testing.T) { checkTaskStateStep(t, task, proto.TaskStatePending, proto.StepInit) // reverted a reverting task task.State = proto.TaskStatePausing - _, err = gm.UpdateTaskAndAddSubTasks(ctx, task, nil, proto.TaskStatePending) + found, err := gm.PauseTask(ctx, task.Key) require.NoError(t, err) + require.True(t, found) task, err = gm.GetTaskByID(ctx, id) require.NoError(t, err) require.Equal(t, proto.TaskStatePausing, task.State) @@ -476,32 +471,22 @@ func TestSubTaskTable(t *testing.T) { id, err := sm.CreateTask(ctx, "key1", "test", 4, []byte("test")) require.NoError(t, err) require.Equal(t, int64(1), id) - _, err = sm.UpdateTaskAndAddSubTasks( + err = sm.SwitchTaskStep( ctx, - &proto.Task{ - ID: 1, - State: proto.TaskStateRunning, - }, - []*proto.Subtask{ - { - Step: proto.StepInit, - Type: proto.TaskTypeExample, - Concurrency: 11, - ExecID: "tidb1", - Meta: []byte("test"), - Ordinal: 1, - }, - }, proto.TaskStatePending, + &proto.Task{ID: 1, State: proto.TaskStatePending, Step: proto.StepInit}, + proto.TaskStateRunning, + proto.StepOne, + []*proto.Subtask{proto.NewSubtask(proto.StepOne, 1, proto.TaskTypeExample, "tidb1", 11, []byte("test"), 1)}, ) require.NoError(t, err) - nilTask, err := sm.GetFirstSubtaskInStates(ctx, "tidb2", 1, proto.StepInit, proto.SubtaskStatePending) + nilTask, err := sm.GetFirstSubtaskInStates(ctx, "tidb2", 1, proto.StepOne, proto.SubtaskStatePending) require.NoError(t, err) require.Nil(t, nilTask) - subtask, err := sm.GetFirstSubtaskInStates(ctx, "tidb1", 1, proto.StepInit, proto.SubtaskStatePending) + subtask, err := sm.GetFirstSubtaskInStates(ctx, "tidb1", 1, proto.StepOne, proto.SubtaskStatePending) require.NoError(t, err) - require.Equal(t, proto.StepInit, subtask.Step) + require.Equal(t, proto.StepOne, subtask.Step) require.Equal(t, proto.TaskTypeExample, subtask.Type) require.Equal(t, int64(1), subtask.TaskID) require.Equal(t, proto.SubtaskStatePending, subtask.State) @@ -509,33 +494,21 @@ func TestSubTaskTable(t *testing.T) { require.Equal(t, []byte("test"), subtask.Meta) require.Equal(t, 11, subtask.Concurrency) require.GreaterOrEqual(t, subtask.CreateTime, timeBeforeCreate) - require.Equal(t, 0, subtask.Ordinal) + require.Equal(t, 1, subtask.Ordinal) require.Zero(t, subtask.StartTime) require.Zero(t, subtask.UpdateTime) require.Equal(t, "{}", subtask.Summary) - subtask2, err := sm.GetFirstSubtaskInStates(ctx, "tidb1", 1, proto.StepInit, proto.SubtaskStatePending, proto.SubtaskStateReverted) + subtask2, err := sm.GetFirstSubtaskInStates(ctx, "tidb1", 1, proto.StepOne, proto.SubtaskStatePending) require.NoError(t, err) require.Equal(t, subtask, subtask2) - ids, err := sm.GetTaskExecutorIDsByTaskID(ctx, 1) - require.NoError(t, err) - require.Len(t, ids, 1) - require.Equal(t, "tidb1", ids[0]) - - ids, err = sm.GetTaskExecutorIDsByTaskID(ctx, 3) - require.NoError(t, err) - require.Len(t, ids, 0) - - cntByStates, err := sm.GetSubtaskCntGroupByStates(ctx, 1, proto.StepInit) + cntByStates, err := sm.GetSubtaskCntGroupByStates(ctx, 1, proto.StepOne) require.NoError(t, err) + require.Len(t, cntByStates, 1) require.Equal(t, int64(1), cntByStates[proto.SubtaskStatePending]) - cntByStates, err = sm.GetSubtaskCntGroupByStates(ctx, 1, proto.StepInit) - require.NoError(t, err) - require.Equal(t, int64(1), cntByStates[proto.SubtaskStatePending]+cntByStates[proto.SubtaskStateRevertPending]) - - ok, err := sm.HasSubtasksInStates(ctx, "tidb1", 1, proto.StepInit, proto.SubtaskStatePending) + ok, err := sm.HasSubtasksInStates(ctx, "tidb1", 1, proto.StepOne, proto.SubtaskStatePending) require.NoError(t, err) require.True(t, ok) @@ -546,11 +519,11 @@ func TestSubTaskTable(t *testing.T) { err = sm.StartSubtask(ctx, 1, "tidb2") require.Error(t, storage.ErrSubtaskNotFound, err) - subtask, err = sm.GetFirstSubtaskInStates(ctx, "tidb1", 1, proto.StepInit, proto.SubtaskStatePending) + subtask, err = sm.GetFirstSubtaskInStates(ctx, "tidb1", 1, proto.StepOne, proto.SubtaskStatePending) require.NoError(t, err) require.Nil(t, subtask) - subtask, err = sm.GetFirstSubtaskInStates(ctx, "tidb1", 1, proto.StepInit, proto.SubtaskStateRunning) + subtask, err = sm.GetFirstSubtaskInStates(ctx, "tidb1", 1, proto.StepOne, proto.SubtaskStateRunning) require.NoError(t, err) require.Equal(t, proto.TaskTypeExample, subtask.Type) require.Equal(t, int64(1), subtask.TaskID) @@ -562,85 +535,81 @@ func TestSubTaskTable(t *testing.T) { // check update time after state change to cancel time.Sleep(time.Second) - require.NoError(t, sm.UpdateSubtaskStateAndError(ctx, "tidb1", 1, proto.SubtaskStateReverting, nil)) - subtask2, err = sm.GetFirstSubtaskInStates(ctx, "tidb1", 1, proto.StepInit, proto.SubtaskStateReverting) + require.NoError(t, sm.FailSubtask(ctx, "tidb1", 1, errors.New("mock err"))) + subtask2, err = sm.GetFirstSubtaskInStates(ctx, "tidb1", 1, proto.StepOne, proto.SubtaskStateFailed) require.NoError(t, err) - require.Equal(t, proto.SubtaskStateReverting, subtask2.State) + require.Equal(t, proto.SubtaskStateFailed, subtask2.State) require.Greater(t, subtask2.UpdateTime, subtask.UpdateTime) - cntByStates, err = sm.GetSubtaskCntGroupByStates(ctx, 1, proto.StepInit) + cntByStates, err = sm.GetSubtaskCntGroupByStates(ctx, 1, proto.StepOne) require.NoError(t, err) require.Equal(t, int64(0), cntByStates[proto.SubtaskStatePending]) - ok, err = sm.HasSubtasksInStates(ctx, "tidb1", 1, proto.StepInit, proto.SubtaskStatePending) + ok, err = sm.HasSubtasksInStates(ctx, "tidb1", 1, proto.StepOne, proto.SubtaskStatePending) require.NoError(t, err) require.False(t, ok) require.NoError(t, testutil.DeleteSubtasksByTaskID(ctx, sm, 1)) - ok, err = sm.HasSubtasksInStates(ctx, "tidb1", 1, proto.StepInit, proto.SubtaskStatePending, proto.SubtaskStateRunning) + ok, err = sm.HasSubtasksInStates(ctx, "tidb1", 1, proto.StepOne, proto.SubtaskStatePending, proto.SubtaskStateRunning) require.NoError(t, err) require.False(t, ok) - testutil.CreateSubTask(t, sm, 2, proto.StepInit, "tidb1", []byte("test"), proto.TaskTypeExample, 11, true) + testutil.CreateSubTask(t, sm, 2, proto.StepOne, "tidb1", []byte("test"), proto.TaskTypeExample, 11) - cntByStates, err = sm.GetSubtaskCntGroupByStates(ctx, 2, proto.StepInit) - require.NoError(t, err) - require.Equal(t, int64(1), cntByStates[proto.SubtaskStateRevertPending]) - - subtasks, err := sm.GetAllSubtasksByStepAndState(ctx, 2, proto.StepInit, proto.SubtaskStateSucceed) + subtasks, err := sm.GetAllSubtasksByStepAndState(ctx, 2, proto.StepOne, proto.SubtaskStateSucceed) require.NoError(t, err) require.Len(t, subtasks, 0) require.NoError(t, sm.FinishSubtask(ctx, "tidb1", 2, []byte{})) - subtasks, err = sm.GetAllSubtasksByStepAndState(ctx, 2, proto.StepInit, proto.SubtaskStateSucceed) + subtasks, err = sm.GetAllSubtasksByStepAndState(ctx, 2, proto.StepOne, proto.SubtaskStateSucceed) require.NoError(t, err) require.Len(t, subtasks, 1) - rowCount, err := sm.GetSubtaskRowCount(ctx, 2, proto.StepInit) + rowCount, err := sm.GetSubtaskRowCount(ctx, 2, proto.StepOne) require.NoError(t, err) require.Equal(t, int64(0), rowCount) require.NoError(t, sm.UpdateSubtaskRowCount(ctx, 2, 100)) - rowCount, err = sm.GetSubtaskRowCount(ctx, 2, proto.StepInit) + rowCount, err = sm.GetSubtaskRowCount(ctx, 2, proto.StepOne) require.NoError(t, err) require.Equal(t, int64(100), rowCount) // test UpdateSubtasksExecIDs // 1. update one subtask - testutil.CreateSubTask(t, sm, 5, proto.StepInit, "tidb1", []byte("test"), proto.TaskTypeExample, 11, false) - subtasks, err = sm.GetAllSubtasksByStepAndState(ctx, 5, proto.StepInit, proto.SubtaskStatePending) + testutil.CreateSubTask(t, sm, 5, proto.StepOne, "tidb1", []byte("test"), proto.TaskTypeExample, 11) + subtasks, err = sm.GetAllSubtasksByStepAndState(ctx, 5, proto.StepOne, proto.SubtaskStatePending) require.NoError(t, err) subtasks[0].ExecID = "tidb2" require.NoError(t, sm.UpdateSubtasksExecIDs(ctx, subtasks)) - subtasks, err = sm.GetAllSubtasksByStepAndState(ctx, 5, proto.StepInit, proto.SubtaskStatePending) + subtasks, err = sm.GetAllSubtasksByStepAndState(ctx, 5, proto.StepOne, proto.SubtaskStatePending) require.NoError(t, err) require.Equal(t, "tidb2", subtasks[0].ExecID) // 2. update 2 subtasks - testutil.CreateSubTask(t, sm, 5, proto.StepInit, "tidb1", []byte("test"), proto.TaskTypeExample, 11, false) - subtasks, err = sm.GetAllSubtasksByStepAndState(ctx, 5, proto.StepInit, proto.SubtaskStatePending) + testutil.CreateSubTask(t, sm, 5, proto.StepOne, "tidb1", []byte("test"), proto.TaskTypeExample, 11) + subtasks, err = sm.GetAllSubtasksByStepAndState(ctx, 5, proto.StepOne, proto.SubtaskStatePending) require.NoError(t, err) subtasks[0].ExecID = "tidb3" require.NoError(t, sm.UpdateSubtasksExecIDs(ctx, subtasks)) - subtasks, err = sm.GetAllSubtasksByStepAndState(ctx, 5, proto.StepInit, proto.SubtaskStatePending) + subtasks, err = sm.GetAllSubtasksByStepAndState(ctx, 5, proto.StepOne, proto.SubtaskStatePending) require.NoError(t, err) require.Equal(t, "tidb3", subtasks[0].ExecID) require.Equal(t, "tidb1", subtasks[1].ExecID) // update fail require.NoError(t, sm.UpdateSubtaskStateAndError(ctx, "tidb1", subtasks[0].ID, proto.SubtaskStateRunning, nil)) - subtasks, err = sm.GetAllSubtasksByStepAndState(ctx, 5, proto.StepInit, proto.SubtaskStatePending) + subtasks, err = sm.GetAllSubtasksByStepAndState(ctx, 5, proto.StepOne, proto.SubtaskStatePending) require.NoError(t, err) require.Equal(t, "tidb3", subtasks[0].ExecID) subtasks[0].ExecID = "tidb2" // update success require.NoError(t, sm.UpdateSubtasksExecIDs(ctx, subtasks)) - subtasks, err = sm.GetAllSubtasksByStepAndState(ctx, 5, proto.StepInit, proto.SubtaskStatePending) + subtasks, err = sm.GetAllSubtasksByStepAndState(ctx, 5, proto.StepOne, proto.SubtaskStatePending) require.NoError(t, err) require.Equal(t, "tidb2", subtasks[0].ExecID) // test GetSubtaskErrors - testutil.CreateSubTask(t, sm, 7, proto.StepInit, "tidb1", []byte("test"), proto.TaskTypeExample, 11, false) - subtasks, err = sm.GetAllSubtasksByStepAndState(ctx, 7, proto.StepInit, proto.SubtaskStatePending) + testutil.CreateSubTask(t, sm, 7, proto.StepOne, "tidb1", []byte("test"), proto.TaskTypeExample, 11) + subtasks, err = sm.GetAllSubtasksByStepAndState(ctx, 7, proto.StepOne, proto.SubtaskStatePending) require.NoError(t, err) require.Equal(t, 1, len(subtasks)) require.NoError(t, sm.UpdateSubtaskStateAndError(ctx, "tidb1", subtasks[0].ID, proto.SubtaskStateFailed, errors.New("test err"))) @@ -662,107 +631,33 @@ func TestBothTaskAndSubTaskTable(t *testing.T) { require.Equal(t, proto.TaskStatePending, task.State) // isSubTaskRevert: false - prevState := task.State - task.State = proto.TaskStateRunning subTasks := []*proto.Subtask{ - { - Step: proto.StepInit, - Type: proto.TaskTypeExample, - ExecID: "instance1", - Meta: []byte("m1"), - }, - { - Step: proto.StepInit, - Type: proto.TaskTypeExample, - ExecID: "instance2", - Meta: []byte("m2"), - }, + proto.NewSubtask(proto.StepOne, task.ID, proto.TaskTypeExample, "instance1", 1, []byte("m1"), 1), + proto.NewSubtask(proto.StepOne, task.ID, proto.TaskTypeExample, "instance2", 1, []byte("m2"), 2), } - retryable, err := sm.UpdateTaskAndAddSubTasks(ctx, task, subTasks, prevState) + err = sm.SwitchTaskStep(ctx, task, proto.TaskStateRunning, proto.StepOne, subTasks) require.NoError(t, err) - require.True(t, retryable) task, err = sm.GetTaskByID(ctx, 1) require.NoError(t, err) require.Equal(t, proto.TaskStateRunning, task.State) - subtask1, err := sm.GetFirstSubtaskInStates(ctx, "instance1", 1, proto.StepInit, proto.SubtaskStatePending) + subtask1, err := sm.GetFirstSubtaskInStates(ctx, "instance1", 1, proto.StepOne, proto.SubtaskStatePending) require.NoError(t, err) require.Equal(t, int64(1), subtask1.ID) require.Equal(t, proto.TaskTypeExample, subtask1.Type) require.Equal(t, []byte("m1"), subtask1.Meta) - subtask2, err := sm.GetFirstSubtaskInStates(ctx, "instance2", 1, proto.StepInit, proto.SubtaskStatePending) + subtask2, err := sm.GetFirstSubtaskInStates(ctx, "instance2", 1, proto.StepOne, proto.SubtaskStatePending) require.NoError(t, err) require.Equal(t, int64(2), subtask2.ID) require.Equal(t, proto.TaskTypeExample, subtask2.Type) require.Equal(t, []byte("m2"), subtask2.Meta) - cntByStates, err := sm.GetSubtaskCntGroupByStates(ctx, 1, proto.StepInit) + cntByStates, err := sm.GetSubtaskCntGroupByStates(ctx, 1, proto.StepOne) require.NoError(t, err) require.Len(t, cntByStates, 1) require.Equal(t, int64(2), cntByStates[proto.SubtaskStatePending]) - - // isSubTaskRevert: true - prevState = task.State - task.State = proto.TaskStateReverting - subTasks = []*proto.Subtask{ - { - Step: proto.StepInit, - Type: proto.TaskTypeExample, - ExecID: "instance3", - Meta: []byte("m3"), - }, - { - Step: proto.StepInit, - Type: proto.TaskTypeExample, - ExecID: "instance4", - Meta: []byte("m4"), - }, - } - retryable, err = sm.UpdateTaskAndAddSubTasks(ctx, task, subTasks, prevState) - require.NoError(t, err) - require.True(t, retryable) - - task, err = sm.GetTaskByID(ctx, 1) - require.NoError(t, err) - require.Equal(t, proto.TaskStateReverting, task.State) - - subtask1, err = sm.GetFirstSubtaskInStates(ctx, "instance3", 1, proto.StepInit, proto.SubtaskStateRevertPending) - require.NoError(t, err) - require.Equal(t, int64(3), subtask1.ID) - require.Equal(t, proto.TaskTypeExample, subtask1.Type) - require.Equal(t, []byte("m3"), subtask1.Meta) - - subtask2, err = sm.GetFirstSubtaskInStates(ctx, "instance4", 1, proto.StepInit, proto.SubtaskStateRevertPending) - require.NoError(t, err) - require.Equal(t, int64(4), subtask2.ID) - require.Equal(t, proto.TaskTypeExample, subtask2.Type) - require.Equal(t, []byte("m4"), subtask2.Meta) - - cntByStates, err = sm.GetSubtaskCntGroupByStates(ctx, 1, proto.StepInit) - require.NoError(t, err) - require.Equal(t, int64(2), cntByStates[proto.SubtaskStateRevertPending]) - - // test transactional - require.NoError(t, testutil.DeleteSubtasksByTaskID(ctx, sm, 1)) - require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/pkg/disttask/framework/storage/MockUpdateTaskErr", "1*return(true)")) - defer func() { - require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/pkg/disttask/framework/storage/MockUpdateTaskErr")) - }() - prevState = task.State - task.State = proto.TaskStateFailed - retryable, err = sm.UpdateTaskAndAddSubTasks(ctx, task, subTasks, prevState) - require.EqualError(t, err, "updateTaskErr") - require.True(t, retryable) - - task, err = sm.GetTaskByID(ctx, 1) - require.NoError(t, err) - require.Equal(t, proto.TaskStateReverting, task.State) - - cntByStates, err = sm.GetSubtaskCntGroupByStates(ctx, 1, proto.StepInit) - require.NoError(t, err) - require.Equal(t, int64(0), cntByStates[proto.SubtaskStateRevertPending]) } func TestGetSubtaskCntByStates(t *testing.T) { @@ -885,11 +780,11 @@ func TestSubtaskHistoryTable(t *testing.T) { finishedMeta = "finished" ) - testutil.CreateSubTask(t, sm, taskID, proto.StepInit, tidb1, []byte(meta), proto.TaskTypeExample, 11, false) + testutil.CreateSubTask(t, sm, taskID, proto.StepInit, tidb1, []byte(meta), proto.TaskTypeExample, 11) require.NoError(t, sm.FinishSubtask(ctx, tidb1, subTask1, []byte(finishedMeta))) - testutil.CreateSubTask(t, sm, taskID, proto.StepInit, tidb2, []byte(meta), proto.TaskTypeExample, 11, false) + testutil.CreateSubTask(t, sm, taskID, proto.StepInit, tidb2, []byte(meta), proto.TaskTypeExample, 11) require.NoError(t, sm.UpdateSubtaskStateAndError(ctx, tidb2, subTask2, proto.SubtaskStateCanceled, nil)) - testutil.CreateSubTask(t, sm, taskID, proto.StepInit, tidb3, []byte(meta), proto.TaskTypeExample, 11, false) + testutil.CreateSubTask(t, sm, taskID, proto.StepInit, tidb3, []byte(meta), proto.TaskTypeExample, 11) require.NoError(t, sm.UpdateSubtaskStateAndError(ctx, tidb3, subTask3, proto.SubtaskStateFailed, nil)) subTasks, err := testutil.GetSubtasksByTaskID(ctx, sm, taskID) @@ -922,7 +817,7 @@ func TestSubtaskHistoryTable(t *testing.T) { }() time.Sleep(2 * time.Second) - testutil.CreateSubTask(t, sm, taskID2, proto.StepInit, tidb1, []byte(meta), proto.TaskTypeExample, 11, false) + testutil.CreateSubTask(t, sm, taskID2, proto.StepInit, tidb1, []byte(meta), proto.TaskTypeExample, 11) require.NoError(t, sm.UpdateSubtaskStateAndError(ctx, tidb1, subTask4, proto.SubtaskStateFailed, nil)) require.NoError(t, testutil.TransferSubTasks2History(ctx, sm, taskID2)) @@ -987,9 +882,9 @@ func TestTaskHistoryTable(t *testing.T) { func TestPauseAndResume(t *testing.T) { _, sm, ctx := testutil.InitTableTest(t) - testutil.CreateSubTask(t, sm, 1, proto.StepInit, "tidb1", []byte("test"), proto.TaskTypeExample, 11, false) - testutil.CreateSubTask(t, sm, 1, proto.StepInit, "tidb1", []byte("test"), proto.TaskTypeExample, 11, false) - testutil.CreateSubTask(t, sm, 1, proto.StepInit, "tidb1", []byte("test"), proto.TaskTypeExample, 11, false) + testutil.CreateSubTask(t, sm, 1, proto.StepInit, "tidb1", []byte("test"), proto.TaskTypeExample, 11) + testutil.CreateSubTask(t, sm, 1, proto.StepInit, "tidb1", []byte("test"), proto.TaskTypeExample, 11) + testutil.CreateSubTask(t, sm, 1, proto.StepInit, "tidb1", []byte("test"), proto.TaskTypeExample, 11) // 1.1 pause all subtasks. require.NoError(t, sm.PauseSubtasks(ctx, "tidb1", 1)) cntByStates, err := sm.GetSubtaskCntGroupByStates(ctx, 1, proto.StepInit) @@ -1017,7 +912,7 @@ func TestPauseAndResume(t *testing.T) { func TestCancelAndExecIdChanged(t *testing.T) { sm, ctx, cancel := testutil.InitTableTestWithCancel(t) - testutil.CreateSubTask(t, sm, 1, proto.StepInit, "tidb1", []byte("test"), proto.TaskTypeExample, 11, false) + testutil.CreateSubTask(t, sm, 1, proto.StepInit, "tidb1", []byte("test"), proto.TaskTypeExample, 11) subtask, err := sm.GetFirstSubtaskInStates(ctx, "tidb1", 1, proto.StepInit, proto.SubtaskStatePending) require.NoError(t, err) // 1. cancel the ctx, then update subtask state. @@ -1161,7 +1056,7 @@ func TestSubtasksState(t *testing.T) { ts := time.Now() time.Sleep(1 * time.Second) // 1. test FailSubtask do update start/update time - testutil.CreateSubTask(t, sm, 3, proto.StepInit, "for_test", []byte("test"), proto.TaskTypeExample, 11, false) + testutil.CreateSubTask(t, sm, 3, proto.StepInit, "for_test", []byte("test"), proto.TaskTypeExample, 11) require.NoError(t, sm.FailSubtask(ctx, "for_test", 3, errors.New("fail"))) subtask, err := sm.GetFirstSubtaskInStates(ctx, "for_test", 3, proto.StepInit, proto.SubtaskStateFailed) require.NoError(t, err) @@ -1174,7 +1069,7 @@ func TestSubtasksState(t *testing.T) { require.Greater(t, endTime, ts) // 2. test FinishSubtask do update update time - testutil.CreateSubTask(t, sm, 4, proto.StepInit, "for_test1", []byte("test"), proto.TaskTypeExample, 11, false) + testutil.CreateSubTask(t, sm, 4, proto.StepInit, "for_test1", []byte("test"), proto.TaskTypeExample, 11) subtask, err = sm.GetFirstSubtaskInStates(ctx, "for_test1", 4, proto.StepInit, proto.SubtaskStatePending) require.NoError(t, err) err = sm.StartSubtask(ctx, subtask.ID, "for_test1") @@ -1196,7 +1091,7 @@ func TestSubtasksState(t *testing.T) { require.Greater(t, endTime, ts) // 3. test CancelSubtask - testutil.CreateSubTask(t, sm, 3, proto.StepInit, "for_test", []byte("test"), proto.TaskTypeExample, 11, false) + testutil.CreateSubTask(t, sm, 3, proto.StepInit, "for_test", []byte("test"), proto.TaskTypeExample, 11) require.NoError(t, sm.CancelSubtask(ctx, "for_test", 3)) subtask, err = sm.GetFirstSubtaskInStates(ctx, "for_test", 3, proto.StepInit, proto.SubtaskStateCanceled) require.NoError(t, err) diff --git a/pkg/disttask/framework/storage/task_state.go b/pkg/disttask/framework/storage/task_state.go index 72761aaeb1b5f..e3b743eadf88f 100644 --- a/pkg/disttask/framework/storage/task_state.go +++ b/pkg/disttask/framework/storage/task_state.go @@ -59,6 +59,19 @@ func (mgr *TaskManager) FailTask(ctx context.Context, taskID int64, currentState return err } +// RevertTask implements the scheduler.TaskManager interface. +func (mgr *TaskManager) RevertTask(ctx context.Context, taskID int64, taskState proto.TaskState, taskErr error) error { + _, err := mgr.ExecuteSQLWithNewSession(ctx, ` + update mysql.tidb_global_task + set state = %?, + error = %?, + state_update_time = CURRENT_TIMESTAMP() + where id = %? and state = %?`, + proto.TaskStateReverting, serializeErr(taskErr), taskID, taskState, + ) + return err +} + // RevertedTask implements the scheduler.TaskManager interface. func (mgr *TaskManager) RevertedTask(ctx context.Context, taskID int64) error { _, err := mgr.ExecuteSQLWithNewSession(ctx, @@ -102,8 +115,7 @@ func (mgr *TaskManager) PausedTask(ctx context.Context, taskID int64) error { _, err := mgr.ExecuteSQLWithNewSession(ctx, `update mysql.tidb_global_task set state = %?, - state_update_time = CURRENT_TIMESTAMP(), - end_time = CURRENT_TIMESTAMP() + state_update_time = CURRENT_TIMESTAMP() where id = %? and state = %?`, proto.TaskStatePaused, taskID, proto.TaskStatePausing, ) @@ -135,6 +147,18 @@ func (mgr *TaskManager) ResumeTask(ctx context.Context, taskKey string) (bool, e return found, nil } +// ResumedTask implements the scheduler.TaskManager interface. +func (mgr *TaskManager) ResumedTask(ctx context.Context, taskID int64) error { + _, err := mgr.ExecuteSQLWithNewSession(ctx, ` + update mysql.tidb_global_task + set state = %?, + state_update_time = CURRENT_TIMESTAMP() + where id = %? and state = %?`, + proto.TaskStateRunning, taskID, proto.TaskStateResuming, + ) + return err +} + // SucceedTask update task state from running to succeed. func (mgr *TaskManager) SucceedTask(ctx context.Context, taskID int64) error { return mgr.WithNewSession(func(se sessionctx.Context) error { diff --git a/pkg/disttask/framework/storage/task_state_test.go b/pkg/disttask/framework/storage/task_state_test.go index 777d5477d1abe..8819521d6e4f1 100644 --- a/pkg/disttask/framework/storage/task_state_test.go +++ b/pkg/disttask/framework/storage/task_state_test.go @@ -38,7 +38,7 @@ func TestTaskState(t *testing.T) { require.NoError(t, gm.CancelTask(ctx, id)) task, err := gm.GetTaskByID(ctx, id) require.NoError(t, err) - require.Equal(t, proto.TaskStateCancelling, task.State) + checkTaskStateStep(t, task, proto.TaskStateCancelling, proto.StepInit) // 2. cancel task by key session id, err = gm.CreateTask(ctx, "key2", "test", 4, []byte("test")) @@ -50,7 +50,7 @@ func TestTaskState(t *testing.T) { })) task, err = gm.GetTaskByID(ctx, id) require.NoError(t, err) - require.Equal(t, proto.TaskStateCancelling, task.State) + checkTaskStateStep(t, task, proto.TaskStateCancelling, proto.StepInit) // 3. fail task id, err = gm.CreateTask(ctx, "key3", "test", 4, []byte("test")) @@ -60,7 +60,7 @@ func TestTaskState(t *testing.T) { require.NoError(t, gm.FailTask(ctx, id, proto.TaskStatePending, failedErr)) task, err = gm.GetTaskByID(ctx, id) require.NoError(t, err) - require.Equal(t, proto.TaskStateFailed, task.State) + checkTaskStateStep(t, task, proto.TaskStateFailed, proto.StepInit) require.ErrorContains(t, task.Error, "test err") // 4. Reverted task @@ -69,15 +69,17 @@ func TestTaskState(t *testing.T) { require.Equal(t, int64(4), id) task, err = gm.GetTaskByID(ctx, 4) require.NoError(t, err) - task.State = proto.TaskStateReverting - retryable, err := gm.UpdateTaskAndAddSubTasks(ctx, task, nil, proto.TaskStatePending) + checkTaskStateStep(t, task, proto.TaskStatePending, proto.StepInit) + err = gm.RevertTask(ctx, task.ID, proto.TaskStatePending, nil) require.NoError(t, err) - require.True(t, retryable) + task, err = gm.GetTaskByID(ctx, 4) + require.NoError(t, err) + checkTaskStateStep(t, task, proto.TaskStateReverting, proto.StepInit) require.NoError(t, gm.RevertedTask(ctx, id)) task, err = gm.GetTaskByID(ctx, id) require.NoError(t, err) - require.Equal(t, proto.TaskStateReverted, task.State) + checkTaskStateStep(t, task, proto.TaskStateReverted, proto.StepInit) // 5. pause task id, err = gm.CreateTask(ctx, "key5", "test", 4, []byte("test")) @@ -103,6 +105,10 @@ func TestTaskState(t *testing.T) { task, err = gm.GetTaskByID(ctx, id) require.NoError(t, err) require.Equal(t, proto.TaskStateResuming, task.State) + require.NoError(t, gm.ResumedTask(ctx, id)) + task, err = gm.GetTaskByID(ctx, id) + require.NoError(t, err) + require.Equal(t, proto.TaskStateRunning, task.State) // 8. succeed task id, err = gm.CreateTask(ctx, "key6", "test", 4, []byte("test")) @@ -110,13 +116,13 @@ func TestTaskState(t *testing.T) { require.Equal(t, int64(6), id) task, err = gm.GetTaskByID(ctx, 6) require.NoError(t, err) - task.State = proto.TaskStateRunning - retryable, err = gm.UpdateTaskAndAddSubTasks(ctx, task, nil, proto.TaskStatePending) + checkTaskStateStep(t, task, proto.TaskStatePending, proto.StepInit) + require.NoError(t, gm.SwitchTaskStep(ctx, task, proto.TaskStateRunning, proto.StepOne, nil)) + task, err = gm.GetTaskByID(ctx, 6) require.NoError(t, err) - require.True(t, retryable) + checkTaskStateStep(t, task, proto.TaskStateRunning, proto.StepOne) require.NoError(t, gm.SucceedTask(ctx, id)) task, err = gm.GetTaskByID(ctx, id) require.NoError(t, err) - require.Equal(t, proto.TaskStateSucceed, task.State) - require.Equal(t, proto.StepDone, task.Step) + checkTaskStateStep(t, task, proto.TaskStateSucceed, proto.StepDone) } diff --git a/pkg/disttask/framework/storage/task_table.go b/pkg/disttask/framework/storage/task_table.go index 524212b903c95..3b8ee181d8254 100644 --- a/pkg/disttask/framework/storage/task_table.go +++ b/pkg/disttask/framework/storage/task_table.go @@ -29,8 +29,6 @@ import ( "github.com/pingcap/tidb/pkg/sessionctx" "github.com/pingcap/tidb/pkg/sessionctx/variable" "github.com/pingcap/tidb/pkg/util/chunk" - "github.com/pingcap/tidb/pkg/util/intest" - "github.com/pingcap/tidb/pkg/util/sqlescape" "github.com/pingcap/tidb/pkg/util/sqlexec" "github.com/tikv/client-go/v2/util" ) @@ -483,26 +481,6 @@ func (mgr *TaskManager) HasSubtasksInStates(ctx context.Context, tidbID string, return len(rs) > 0, nil } -// GetTaskExecutorIDsByTaskID gets the task executor IDs of the given task ID. -func (mgr *TaskManager) GetTaskExecutorIDsByTaskID(ctx context.Context, taskID int64) ([]string, error) { - rs, err := mgr.ExecuteSQLWithNewSession(ctx, `select distinct(exec_id) from mysql.tidb_background_subtask - where task_key = %?`, taskID) - if err != nil { - return nil, err - } - if len(rs) == 0 { - return nil, nil - } - - instanceIDs := make([]string, 0, len(rs)) - for _, r := range rs { - id := r.GetString(0) - instanceIDs = append(instanceIDs, id) - } - - return instanceIDs, nil -} - // UpdateSubtasksExecIDs update subtasks' execID. func (mgr *TaskManager) UpdateSubtasksExecIDs(ctx context.Context, subtasks []*proto.Subtask) error { // skip the update process. @@ -659,82 +637,6 @@ func (*TaskManager) splitSubtasks(subtasks []*proto.Subtask) [][]*proto.Subtask return res } -// UpdateTaskAndAddSubTasks update the task and add new subtasks -// TODO: remove this when we remove reverting subtasks. -func (mgr *TaskManager) UpdateTaskAndAddSubTasks(ctx context.Context, task *proto.Task, subtasks []*proto.Subtask, prevState proto.TaskState) (bool, error) { - retryable := true - err := mgr.WithNewTxn(ctx, func(se sessionctx.Context) error { - _, err := sqlexec.ExecSQL(ctx, se, "update mysql.tidb_global_task "+ - "set state = %?, dispatcher_id = %?, step = %?, concurrency = %?, meta = %?, error = %?, state_update_time = CURRENT_TIMESTAMP()"+ - "where id = %? and state = %?", - task.State, task.SchedulerID, task.Step, task.Concurrency, task.Meta, serializeErr(task.Error), task.ID, prevState) - if err != nil { - return err - } - // When AffectedRows == 0, means other admin command have changed the task state, it's illegal to schedule subtasks. - if se.GetSessionVars().StmtCtx.AffectedRows() == 0 { - if !intest.InTest { - // task state have changed by other admin command - retryable = false - return errors.New("invalid task state transform, state already changed") - } - // TODO: remove it, when OnNextSubtasksBatch returns subtasks, just insert subtasks without updating tidb_global_task. - // Currently the business running on distributed task framework will update proto.Task in OnNextSubtasksBatch. - // So when scheduling subtasks, framework needs to update task and insert subtasks in one Txn. - // - // In future, it's needed to restrict changes of task in OnNextSubtasksBatch. - // If OnNextSubtasksBatch won't update any fields in proto.Task, we can insert subtasks only. - // - // For now, we update nothing in proto.Task in UT's OnNextSubtasksBatch, so the AffectedRows will be 0. So UT can't fully compatible - // with current UpdateTaskAndAddSubTasks implementation. - rs, err := sqlexec.ExecSQL(ctx, se, "select id from mysql.tidb_global_task where id = %? and state = %?", task.ID, prevState) - if err != nil { - return err - } - // state have changed. - if len(rs) == 0 { - retryable = false - return errors.New("invalid task state transform, state already changed") - } - } - - failpoint.Inject("MockUpdateTaskErr", func(val failpoint.Value) { - if val.(bool) { - failpoint.Return(errors.New("updateTaskErr")) - } - }) - if len(subtasks) > 0 { - subtaskState := proto.SubtaskStatePending - if task.State == proto.TaskStateReverting { - subtaskState = proto.SubtaskStateRevertPending - } - - sql := new(strings.Builder) - if err := sqlescape.FormatSQL(sql, `insert into mysql.tidb_background_subtask(`+InsertSubtaskColumns+`) values`); err != nil { - return err - } - for i, subtask := range subtasks { - if i != 0 { - if err := sqlescape.FormatSQL(sql, ","); err != nil { - return err - } - } - if err := sqlescape.FormatSQL(sql, "(%?, %?, %?, %?, %?, %?, %?, NULL, CURRENT_TIMESTAMP(), '{}', '{}')", - subtask.Step, task.ID, subtask.ExecID, subtask.Meta, subtaskState, proto.Type2Int(subtask.Type), subtask.Concurrency); err != nil { - return err - } - } - _, err := sqlexec.ExecSQL(ctx, se, sql.String()) - if err != nil { - return nil - } - } - return nil - }) - - return retryable, err -} - func serializeErr(err error) []byte { if err == nil { return nil diff --git a/pkg/disttask/framework/taskexecutor/execute/interface.go b/pkg/disttask/framework/taskexecutor/execute/interface.go index 7bce5d2f49421..beaef2b86a49a 100644 --- a/pkg/disttask/framework/taskexecutor/execute/interface.go +++ b/pkg/disttask/framework/taskexecutor/execute/interface.go @@ -21,17 +21,21 @@ import ( ) // StepExecutor defines the executor of a subtask. +// the calling sequence is: +// +// Init +// for every subtask of this step: +// if RunSubtask failed then break +// else OnFinished +// Cleanup type StepExecutor interface { // Init is used to initialize the environment for the subtask executor. Init(context.Context) error // RunSubtask is used to run the subtask. RunSubtask(ctx context.Context, subtask *proto.Subtask) error - // Cleanup is used to clean up the environment for the subtask executor. - Cleanup(context.Context) error // OnFinished is used to handle the subtask when it is finished. // The subtask meta can be updated in place. OnFinished(ctx context.Context, subtask *proto.Subtask) error - // Rollback is used to roll back all subtasks. - // TODO: right now all impl of Rollback is empty, maybe we can remove it. - Rollback(context.Context) error + // Cleanup is used to clean up the environment for the subtask executor. + Cleanup(context.Context) error } diff --git a/pkg/disttask/framework/taskexecutor/interface.go b/pkg/disttask/framework/taskexecutor/interface.go index bc405a160f683..9aed883bd992c 100644 --- a/pkg/disttask/framework/taskexecutor/interface.go +++ b/pkg/disttask/framework/taskexecutor/interface.go @@ -117,8 +117,3 @@ func (*EmptyStepExecutor) Cleanup(context.Context) error { func (*EmptyStepExecutor) OnFinished(_ context.Context, _ *proto.Subtask) error { return nil } - -// Rollback implements the StepExecutor interface. -func (*EmptyStepExecutor) Rollback(context.Context) error { - return nil -} diff --git a/pkg/disttask/framework/taskexecutor/manager.go b/pkg/disttask/framework/taskexecutor/manager.go index f32f39533155b..5388e5831f1b2 100644 --- a/pkg/disttask/framework/taskexecutor/manager.go +++ b/pkg/disttask/framework/taskexecutor/manager.go @@ -44,9 +44,8 @@ var ( retrySQLTimes = 30 retrySQLInterval = 500 * time.Millisecond unfinishedSubtaskStates = []proto.SubtaskState{ - proto.SubtaskStatePending, proto.SubtaskStateRevertPending, - // for the case that the tidb is restarted when the subtask is running. - proto.SubtaskStateRunning, proto.SubtaskStateReverting, + proto.SubtaskStatePending, + proto.SubtaskStateRunning, } ) diff --git a/pkg/disttask/framework/taskexecutor/task_executor.go b/pkg/disttask/framework/taskexecutor/task_executor.go index 3262f850b8830..f23a2425fb025 100644 --- a/pkg/disttask/framework/taskexecutor/task_executor.go +++ b/pkg/disttask/framework/taskexecutor/task_executor.go @@ -51,8 +51,6 @@ var ( ErrCancelSubtask = errors.New("cancel subtasks") // ErrFinishSubtask is the cancel cause when TaskExecutor successfully processed subtasks. ErrFinishSubtask = errors.New("finish subtasks") - // ErrFinishRollback is the cancel cause when TaskExecutor rollback successfully. - ErrFinishRollback = errors.New("finish rollback") // ErrNonIdempotentSubtask means the subtask is left in running state and is not idempotent, // so cannot be run again. ErrNonIdempotentSubtask = errors.New("subtask in running state and is not idempotent") @@ -444,18 +442,18 @@ func (e *BaseTaskExecutor) onSubtaskFinished(ctx context.Context, executor execu } // Rollback rollbacks the subtask. +// TODO no need to start executor to do it, refactor it later. func (e *BaseTaskExecutor) Rollback(ctx context.Context, task *proto.Task) error { // TODO: we can centralized this when we move handleExecutableTask loop here. e.task.Store(task) - rollbackCtx, rollbackCancel := context.WithCancelCause(ctx) - defer rollbackCancel(ErrFinishRollback) - e.registerCancelFunc(rollbackCancel) e.resetError() e.logger.Info("taskExecutor rollback a step", zap.String("step", proto.Step2Str(task.Type, task.Step))) // We should cancel all subtasks before rolling back for { + // TODO we can update them using one sql, but requires change the metric + // gathering logic. subtask, err := e.taskTable.GetFirstSubtaskInStates(ctx, e.id, task.ID, task.Step, proto.SubtaskStatePending, proto.SubtaskStateRunning) if err != nil { @@ -472,38 +470,6 @@ func (e *BaseTaskExecutor) Rollback(ctx context.Context, task *proto.Task) error return err } } - - executor, err := e.GetStepExecutor(ctx, task, nil, nil) - if err != nil { - e.onError(err) - return e.getError() - } - subtask, err := e.taskTable.GetFirstSubtaskInStates(ctx, e.id, task.ID, task.Step, - proto.SubtaskStateRevertPending, proto.SubtaskStateReverting) - if err != nil { - e.onError(err) - return e.getError() - } - if subtask == nil { - logutil.BgLogger().Warn("taskExecutor rollback a step, but no subtask in revert_pending state") - return nil - } - if subtask.State == proto.SubtaskStateRevertPending { - e.updateSubtaskStateAndError(ctx, subtask, proto.SubtaskStateReverting, nil) - } - if err := e.getError(); err != nil { - return err - } - - // right now all impl of Rollback is empty, so we don't check idempotent here. - // will try to remove this rollback completely in the future. - err = executor.Rollback(rollbackCtx) - if err != nil { - e.updateSubtaskStateAndError(ctx, subtask, proto.SubtaskStateRevertFailed, nil) - e.onError(err) - } else { - e.updateSubtaskStateAndError(ctx, subtask, proto.SubtaskStateReverted, nil) - } return e.getError() } diff --git a/pkg/disttask/framework/taskexecutor/task_executor_test.go b/pkg/disttask/framework/taskexecutor/task_executor_test.go index dbcad18762840..8894994d6d1b6 100644 --- a/pkg/disttask/framework/taskexecutor/task_executor_test.go +++ b/pkg/disttask/framework/taskexecutor/task_executor_test.go @@ -35,9 +35,6 @@ var ( unfinishedNormalSubtaskStates = []interface{}{ proto.SubtaskStatePending, proto.SubtaskStateRunning, } - unfinishedRevertSubtaskStates = []interface{}{ - proto.SubtaskStateRevertPending, proto.SubtaskStateReverting, - } ) func TestTaskExecutorRun(t *testing.T) { @@ -292,14 +289,8 @@ func TestTaskExecutorRollback(t *testing.T) { // 1. no taskExecutor constructor task1 := &proto.Task{Step: proto.StepOne, ID: 1, Type: tp} - taskExecutorRegisterErr := errors.Errorf("constructor of taskExecutor for key not found") - mockExtension.EXPECT().GetStepExecutor(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, taskExecutorRegisterErr) taskExecutor := NewBaseTaskExecutor(ctx, "id", task1, mockSubtaskTable) taskExecutor.Extension = mockExtension - mockSubtaskTable.EXPECT().GetFirstSubtaskInStates(runCtx, "id", int64(1), proto.StepOne, - unfinishedNormalSubtaskStates...).Return(nil, nil) - err := taskExecutor.Rollback(runCtx, task1) - require.EqualError(t, err, taskExecutorRegisterErr.Error()) mockExtension.EXPECT().GetStepExecutor(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(mockStepExecutor, nil).AnyTimes() @@ -307,32 +298,16 @@ func TestTaskExecutorRollback(t *testing.T) { getSubtaskErr := errors.New("get subtask error") var taskID int64 = 1 mockSubtaskTable.EXPECT().GetFirstSubtaskInStates(runCtx, "id", taskID, proto.StepOne, - unfinishedNormalSubtaskStates...).Return(nil, nil) - mockSubtaskTable.EXPECT().GetFirstSubtaskInStates(runCtx, "id", taskID, proto.StepOne, - unfinishedRevertSubtaskStates...).Return(nil, getSubtaskErr) - err = taskExecutor.Rollback(runCtx, &proto.Task{Step: proto.StepOne, Type: tp, ID: taskID}) + unfinishedNormalSubtaskStates...).Return(nil, getSubtaskErr) + err := taskExecutor.Rollback(runCtx, &proto.Task{Step: proto.StepOne, Type: tp, ID: taskID}) require.EqualError(t, err, getSubtaskErr.Error()) // 3. no subtask mockSubtaskTable.EXPECT().GetFirstSubtaskInStates(runCtx, "id", taskID, proto.StepOne, unfinishedNormalSubtaskStates...).Return(nil, nil) - mockSubtaskTable.EXPECT().GetFirstSubtaskInStates(runCtx, "id", taskID, proto.StepOne, - unfinishedRevertSubtaskStates...).Return(nil, nil) err = taskExecutor.Rollback(runCtx, &proto.Task{Step: proto.StepOne, Type: tp, ID: taskID}) require.NoError(t, err) - // 4. rollback failed - rollbackErr := errors.New("rollback error") - mockSubtaskTable.EXPECT().GetFirstSubtaskInStates(runCtx, "id", taskID, proto.StepOne, - unfinishedNormalSubtaskStates...).Return(nil, nil) - mockSubtaskTable.EXPECT().GetFirstSubtaskInStates(runCtx, "id", taskID, proto.StepOne, - unfinishedRevertSubtaskStates...).Return(&proto.Subtask{ID: 1, State: proto.SubtaskStateRevertPending, ExecID: "id"}, nil) - mockSubtaskTable.EXPECT().UpdateSubtaskStateAndError(runCtx, "id", taskID, proto.SubtaskStateReverting, nil).Return(nil) - mockStepExecutor.EXPECT().Rollback(gomock.Any()).Return(rollbackErr) - mockSubtaskTable.EXPECT().UpdateSubtaskStateAndError(runCtx, "id", taskID, proto.SubtaskStateRevertFailed, nil).Return(nil) - err = taskExecutor.Rollback(runCtx, &proto.Task{Step: proto.StepOne, Type: tp, ID: taskID}) - require.EqualError(t, err, rollbackErr.Error()) - // 5. rollback success mockSubtaskTable.EXPECT().GetFirstSubtaskInStates(runCtx, "id", taskID, proto.StepOne, unfinishedNormalSubtaskStates...).Return(&proto.Subtask{ID: 1, ExecID: "id"}, nil) @@ -342,21 +317,12 @@ func TestTaskExecutorRollback(t *testing.T) { mockSubtaskTable.EXPECT().UpdateSubtaskStateAndError(runCtx, "id", int64(2), proto.SubtaskStateCanceled, nil).Return(nil) mockSubtaskTable.EXPECT().GetFirstSubtaskInStates(runCtx, "id", taskID, proto.StepOne, unfinishedNormalSubtaskStates...).Return(nil, nil) - mockSubtaskTable.EXPECT().GetFirstSubtaskInStates(runCtx, "id", taskID, proto.StepOne, - unfinishedRevertSubtaskStates...).Return(&proto.Subtask{ID: 3, State: proto.SubtaskStateRevertPending, ExecID: "id"}, nil) - mockSubtaskTable.EXPECT().UpdateSubtaskStateAndError(runCtx, "id", int64(3), proto.SubtaskStateReverting, nil).Return(nil) - mockStepExecutor.EXPECT().Rollback(gomock.Any()).Return(nil) - mockSubtaskTable.EXPECT().UpdateSubtaskStateAndError(runCtx, "id", int64(3), proto.SubtaskStateReverted, nil).Return(nil) err = taskExecutor.Rollback(runCtx, &proto.Task{Step: proto.StepOne, Type: tp, ID: taskID}) require.NoError(t, err) // rollback again for previous left subtask in TaskStateReverting state mockSubtaskTable.EXPECT().GetFirstSubtaskInStates(runCtx, "id", taskID, proto.StepOne, unfinishedNormalSubtaskStates...).Return(nil, nil) - mockSubtaskTable.EXPECT().GetFirstSubtaskInStates(runCtx, "id", taskID, proto.StepOne, - unfinishedRevertSubtaskStates...).Return(&proto.Subtask{ID: 3, State: proto.SubtaskStateReverting, ExecID: "id"}, nil) - mockStepExecutor.EXPECT().Rollback(gomock.Any()).Return(nil) - mockSubtaskTable.EXPECT().UpdateSubtaskStateAndError(runCtx, "id", int64(3), proto.SubtaskStateReverted, nil).Return(nil) err = taskExecutor.Rollback(runCtx, &proto.Task{Step: proto.StepOne, Type: tp, ID: taskID}) require.NoError(t, err) } @@ -405,11 +371,6 @@ func TestTaskExecutor(t *testing.T) { // 2. rollback success. mockSubtaskTable.EXPECT().GetFirstSubtaskInStates(runCtx, "id", taskID, proto.StepOne, unfinishedNormalSubtaskStates...).Return(nil, nil) - mockSubtaskTable.EXPECT().GetFirstSubtaskInStates(runCtx, "id", taskID, proto.StepOne, - unfinishedRevertSubtaskStates...).Return(&proto.Subtask{ID: 1, Type: tp, State: proto.SubtaskStateRevertPending, ExecID: "id"}, nil) - mockSubtaskTable.EXPECT().UpdateSubtaskStateAndError(runCtx, "id", taskID, proto.SubtaskStateReverting, nil).Return(nil) - mockStepExecutor.EXPECT().Rollback(gomock.Any()).Return(nil) - mockSubtaskTable.EXPECT().UpdateSubtaskStateAndError(runCtx, "id", taskID, proto.SubtaskStateReverted, nil).Return(nil) err = taskExecutor.Rollback(runCtx, &proto.Task{Step: proto.StepOne, Type: tp, ID: taskID}) require.NoError(t, err) require.True(t, ctrl.Satisfied()) diff --git a/pkg/disttask/framework/taskexecutor/task_executor_testkit_test.go b/pkg/disttask/framework/taskexecutor/task_executor_testkit_test.go index cfef0c527490e..1bb6a9dd8338a 100644 --- a/pkg/disttask/framework/taskexecutor/task_executor_testkit_test.go +++ b/pkg/disttask/framework/taskexecutor/task_executor_testkit_test.go @@ -40,25 +40,22 @@ func runOneTask(ctx context.Context, t *testing.T, mgr *storage.TaskManager, tas task, err := mgr.GetTaskByID(ctx, taskID) require.NoError(t, err) // 1. stepOne - task.Step = proto.StepOne - task.State = proto.TaskStateRunning - _, err = mgr.UpdateTaskAndAddSubTasks(ctx, task, nil, proto.TaskStatePending) + err = mgr.SwitchTaskStep(ctx, task, proto.TaskStateRunning, proto.StepOne, nil) require.NoError(t, err) for i := 0; i < subtaskCnt; i++ { - testutil.CreateSubTask(t, mgr, taskID, proto.StepOne, "test", nil, proto.TaskTypeExample, 11, false) + testutil.CreateSubTask(t, mgr, taskID, proto.StepOne, ":4000", nil, proto.TaskTypeExample, 11) } task, err = mgr.GetTaskByID(ctx, taskID) require.NoError(t, err) factory := taskexecutor.GetTaskExecutorFactory(task.Type) require.NotNil(t, factory) - executor := factory(ctx, "test", task, mgr) + executor := factory(ctx, ":4000", task, mgr) require.NoError(t, executor.RunStep(ctx, task, nil)) // 2. stepTwo - task.Step = proto.StepTwo - _, err = mgr.UpdateTaskAndAddSubTasks(ctx, task, nil, proto.TaskStateRunning) + err = mgr.SwitchTaskStep(ctx, task, proto.TaskStateRunning, proto.StepTwo, nil) require.NoError(t, err) for i := 0; i < subtaskCnt; i++ { - testutil.CreateSubTask(t, mgr, taskID, proto.StepTwo, "test", nil, proto.TaskTypeExample, 11, false) + testutil.CreateSubTask(t, mgr, taskID, proto.StepTwo, ":4000", nil, proto.TaskTypeExample, 11) } task, err = mgr.GetTaskByID(ctx, taskID) require.NoError(t, err) diff --git a/pkg/disttask/framework/testutil/context.go b/pkg/disttask/framework/testutil/context.go index 0abf2fe5021f0..d02e5b41ff698 100644 --- a/pkg/disttask/framework/testutil/context.go +++ b/pkg/disttask/framework/testutil/context.go @@ -18,7 +18,6 @@ import ( "context" "fmt" "sync" - "sync/atomic" "testing" "github.com/pingcap/failpoint" @@ -34,8 +33,6 @@ type TestContext struct { sync.RWMutex // taskID/step -> subtask map. subtasksHasRun map[string]map[int64]struct{} - // for rollback tests. - RollbackCnt atomic.Int32 // for plan err handling tests. CallTime int } diff --git a/pkg/disttask/framework/testutil/disttest_util.go b/pkg/disttask/framework/testutil/disttest_util.go index df22f3f05a5d6..b92526823c926 100644 --- a/pkg/disttask/framework/testutil/disttest_util.go +++ b/pkg/disttask/framework/testutil/disttest_util.go @@ -90,12 +90,6 @@ func RegisterRollbackTaskMeta(t *testing.T, ctrl *gomock.Controller, mockSchedul mockCleanupRountine.EXPECT().CleanUp(gomock.Any(), gomock.Any()).Return(nil).AnyTimes() mockExecutor.EXPECT().Init(gomock.Any()).Return(nil).AnyTimes() mockExecutor.EXPECT().Cleanup(gomock.Any()).Return(nil).AnyTimes() - mockExecutor.EXPECT().Rollback(gomock.Any()).DoAndReturn( - func(_ context.Context) error { - testContext.RollbackCnt.Add(1) - return nil - }, - ).AnyTimes() mockExecutor.EXPECT().RunSubtask(gomock.Any(), gomock.Any()).DoAndReturn( func(_ context.Context, subtask *proto.Subtask) error { testContext.CollectSubtask(subtask) @@ -107,7 +101,6 @@ func RegisterRollbackTaskMeta(t *testing.T, ctrl *gomock.Controller, mockSchedul mockExtension.EXPECT().IsRetryableError(gomock.Any()).Return(false).AnyTimes() registerTaskMetaInner(t, proto.TaskTypeExample, mockExtension, mockCleanupRountine, mockScheduler) - testContext.RollbackCnt.Store(0) } // SubmitAndWaitTask schedule one task. diff --git a/pkg/disttask/framework/testutil/executor_util.go b/pkg/disttask/framework/testutil/executor_util.go index be34f9d04dc63..c36ab746862ac 100644 --- a/pkg/disttask/framework/testutil/executor_util.go +++ b/pkg/disttask/framework/testutil/executor_util.go @@ -29,7 +29,6 @@ func GetMockStepExecutor(ctrl *gomock.Controller) *mockexecute.MockStepExecutor executor := mockexecute.NewMockStepExecutor(ctrl) executor.EXPECT().Init(gomock.Any()).Return(nil).AnyTimes() executor.EXPECT().Cleanup(gomock.Any()).Return(nil).AnyTimes() - executor.EXPECT().Rollback(gomock.Any()).Return(nil).AnyTimes() executor.EXPECT().OnFinished(gomock.Any(), gomock.Any()).Return(nil).AnyTimes() return executor } diff --git a/pkg/disttask/framework/testutil/task_util.go b/pkg/disttask/framework/testutil/task_util.go index 0f8d92deac04e..44eecf857deb8 100644 --- a/pkg/disttask/framework/testutil/task_util.go +++ b/pkg/disttask/framework/testutil/task_util.go @@ -28,12 +28,8 @@ import ( // CreateSubTask adds a new task to subtask table. // used for testing. -func CreateSubTask(t *testing.T, gm *storage.TaskManager, taskID int64, step proto.Step, execID string, meta []byte, tp proto.TaskType, concurrency int, isRevert bool) { - state := proto.SubtaskStatePending - if isRevert { - state = proto.SubtaskStateRevertPending - } - InsertSubtask(t, gm, taskID, step, execID, meta, state, tp, concurrency) +func CreateSubTask(t *testing.T, gm *storage.TaskManager, taskID int64, step proto.Step, execID string, meta []byte, tp proto.TaskType, concurrency int) { + InsertSubtask(t, gm, taskID, step, execID, meta, proto.SubtaskStatePending, tp, concurrency) } // InsertSubtask adds a new subtask of any state to subtask table. diff --git a/pkg/disttask/importinto/job_testkit_test.go b/pkg/disttask/importinto/job_testkit_test.go index c26f4fb6d8c51..f165afa6d8a54 100644 --- a/pkg/disttask/importinto/job_testkit_test.go +++ b/pkg/disttask/importinto/job_testkit_test.go @@ -71,7 +71,7 @@ func TestGetTaskImportedRows(t *testing.T) { bytes, err := json.Marshal(m) require.NoError(t, err) testutil.CreateSubTask(t, manager, taskID, proto.ImportStepImport, - "", bytes, proto.ImportInto, 11, false) + "", bytes, proto.ImportInto, 11) } rows, err := importinto.GetTaskImportedRows(ctx, 111) require.NoError(t, err) @@ -103,7 +103,7 @@ func TestGetTaskImportedRows(t *testing.T) { bytes, err := json.Marshal(m) require.NoError(t, err) testutil.CreateSubTask(t, manager, taskID, proto.ImportStepWriteAndIngest, - "", bytes, proto.ImportInto, 11, false) + "", bytes, proto.ImportInto, 11) } rows, err = importinto.GetTaskImportedRows(ctx, 222) require.NoError(t, err) diff --git a/pkg/disttask/importinto/scheduler_testkit_test.go b/pkg/disttask/importinto/scheduler_testkit_test.go index 37037ba6e1800..7fc90c05633cd 100644 --- a/pkg/disttask/importinto/scheduler_testkit_test.go +++ b/pkg/disttask/importinto/scheduler_testkit_test.go @@ -95,18 +95,19 @@ func TestSchedulerExtLocalSort(t *testing.T) { subtaskMetas, err := ext.OnNextSubtasksBatch(ctx, d, task, []string{":4000"}, ext.GetNextStep(task)) require.NoError(t, err) require.Len(t, subtaskMetas, 1) - task.Step = ext.GetNextStep(task) - require.Equal(t, proto.ImportStepImport, task.Step) + nextStep := ext.GetNextStep(task) + require.Equal(t, proto.ImportStepImport, nextStep) gotJobInfo, err = importer.GetJob(ctx, conn, jobID, "root", true) require.NoError(t, err) require.Equal(t, "running", gotJobInfo.Status) // update task/subtask, and finish subtask, so we can go to next stage subtasks := make([]*proto.Subtask, 0, len(subtaskMetas)) - for _, m := range subtaskMetas { - subtasks = append(subtasks, proto.NewSubtask(task.Step, task.ID, task.Type, "", 1, m, 0)) + for i, m := range subtaskMetas { + subtasks = append(subtasks, proto.NewSubtask(nextStep, task.ID, task.Type, "", 1, m, i+1)) } - _, err = manager.UpdateTaskAndAddSubTasks(ctx, task, subtasks, proto.TaskStatePending) + err = manager.SwitchTaskStep(ctx, task, proto.TaskStateRunning, nextStep, subtasks) require.NoError(t, err) + task.Step = nextStep gotSubtasks, err := manager.GetSubtasksWithHistory(ctx, taskID, proto.ImportStepImport) require.NoError(t, err) for _, s := range gotSubtasks { @@ -240,18 +241,19 @@ func TestSchedulerExtGlobalSort(t *testing.T) { subtaskMetas, err := ext.OnNextSubtasksBatch(ctx, d, task, []string{":4000"}, ext.GetNextStep(task)) require.NoError(t, err) require.Len(t, subtaskMetas, 2) - task.Step = ext.GetNextStep(task) - require.Equal(t, proto.ImportStepEncodeAndSort, task.Step) + nextStep := ext.GetNextStep(task) + require.Equal(t, proto.ImportStepEncodeAndSort, nextStep) gotJobInfo, err = importer.GetJob(ctx, conn, jobID, "root", true) require.NoError(t, err) require.Equal(t, "running", gotJobInfo.Status) require.Equal(t, "global-sorting", gotJobInfo.Step) // update task/subtask, and finish subtask, so we can go to next stage subtasks := make([]*proto.Subtask, 0, len(subtaskMetas)) - for _, m := range subtaskMetas { - subtasks = append(subtasks, proto.NewSubtask(task.Step, task.ID, task.Type, "", 1, m, 0)) + for i, m := range subtaskMetas { + subtasks = append(subtasks, proto.NewSubtask(nextStep, task.ID, task.Type, "", 1, m, i+1)) } - _, err = manager.UpdateTaskAndAddSubTasks(ctx, task, subtasks, proto.TaskStatePending) + err = manager.SwitchTaskStep(ctx, task, proto.TaskStatePending, nextStep, subtasks) + task.Step = nextStep require.NoError(t, err) gotSubtasks, err := manager.GetSubtasksWithHistory(ctx, taskID, task.Step) require.NoError(t, err) @@ -297,19 +299,20 @@ func TestSchedulerExtGlobalSort(t *testing.T) { subtaskMetas, err = ext.OnNextSubtasksBatch(ctx, d, task, []string{":4000"}, ext.GetNextStep(task)) require.NoError(t, err) require.Len(t, subtaskMetas, 1) - task.Step = ext.GetNextStep(task) - require.Equal(t, proto.ImportStepMergeSort, task.Step) + nextStep = ext.GetNextStep(task) + require.Equal(t, proto.ImportStepMergeSort, nextStep) gotJobInfo, err = importer.GetJob(ctx, conn, jobID, "root", true) require.NoError(t, err) require.Equal(t, "running", gotJobInfo.Status) require.Equal(t, "global-sorting", gotJobInfo.Step) // update task/subtask, and finish subtask, so we can go to next stage subtasks = make([]*proto.Subtask, 0, len(subtaskMetas)) - for _, m := range subtaskMetas { - subtasks = append(subtasks, proto.NewSubtask(task.Step, task.ID, task.Type, "", 1, m, 0)) + for i, m := range subtaskMetas { + subtasks = append(subtasks, proto.NewSubtask(nextStep, task.ID, task.Type, "", 1, m, i+1)) } - _, err = manager.UpdateTaskAndAddSubTasks(ctx, task, subtasks, proto.TaskStatePending) + err = manager.SwitchTaskStep(ctx, task, proto.TaskStatePending, nextStep, subtasks) require.NoError(t, err) + task.Step = nextStep gotSubtasks, err = manager.GetSubtasksWithHistory(ctx, taskID, task.Step) require.NoError(t, err) mergeSortStepMeta := &importinto.MergeSortStepMeta{ diff --git a/pkg/disttask/importinto/task_executor.go b/pkg/disttask/importinto/task_executor.go index df872d093a398..41ff289b65432 100644 --- a/pkg/disttask/importinto/task_executor.go +++ b/pkg/disttask/importinto/task_executor.go @@ -251,12 +251,6 @@ func (s *importStepExecutor) Cleanup(_ context.Context) (err error) { return s.tableImporter.Close() } -func (s *importStepExecutor) Rollback(context.Context) error { - // TODO: add rollback - s.logger.Info("rollback") - return nil -} - type mergeSortStepExecutor struct { taskexecutor.EmptyStepExecutor taskID int64 @@ -424,11 +418,6 @@ func (e *writeAndIngestStepExecutor) Cleanup(_ context.Context) (err error) { return e.tableImporter.Close() } -func (e *writeAndIngestStepExecutor) Rollback(context.Context) error { - e.logger.Info("rollback") - return nil -} - type postProcessStepExecutor struct { taskexecutor.EmptyStepExecutor taskID int64 diff --git a/tests/realtikvtest/importintotest/job_test.go b/tests/realtikvtest/importintotest/job_test.go index 81fdf99597ad3..8cc6e1ce2f590 100644 --- a/tests/realtikvtest/importintotest/job_test.go +++ b/tests/realtikvtest/importintotest/job_test.go @@ -510,7 +510,7 @@ func (s *mockGCSSuite) TestCancelJob() { s.NoError(err2) subtasks, err2 := taskManager.GetSubtasksWithHistory(ctx, task2.ID, proto.ImportStepPostProcess) s.NoError(err2) - s.Len(subtasks, 2) // framework will generate a subtask when canceling + s.Len(subtasks, 1) var cancelled bool for _, st := range subtasks { if st.State == proto.SubtaskStateCanceled {