diff --git a/pkg/disttask/framework/framework_pause_and_resume_test.go b/pkg/disttask/framework/framework_pause_and_resume_test.go index 28ba8b4cefaa3..a0c297d1acc49 100644 --- a/pkg/disttask/framework/framework_pause_and_resume_test.go +++ b/pkg/disttask/framework/framework_pause_and_resume_test.go @@ -34,11 +34,11 @@ func CheckSubtasksState(ctx context.Context, t *testing.T, taskID int64, state p mgr, err := storage.GetTaskManager() require.NoError(t, err) mgr.PrintSubtaskInfo(ctx, taskID) - cnt, err := mgr.GetSubtaskInStatesCnt(ctx, taskID, state) + cntByStates, err := mgr.GetSubtaskCntGroupByStates(ctx, taskID, proto.StepTwo) require.NoError(t, err) historySubTasksCnt, err := storage.GetSubtasksFromHistoryByTaskIDForTest(ctx, mgr, taskID) require.NoError(t, err) - require.Equal(t, expectedCnt, cnt+int64(historySubTasksCnt)) + require.Equal(t, expectedCnt, cntByStates[state]+int64(historySubTasksCnt)) } func TestFrameworkPauseAndResume(t *testing.T) { diff --git a/pkg/disttask/framework/handle/BUILD.bazel b/pkg/disttask/framework/handle/BUILD.bazel index dd0697e8d2f72..a3065ce31eb18 100644 --- a/pkg/disttask/framework/handle/BUILD.bazel +++ b/pkg/disttask/framework/handle/BUILD.bazel @@ -25,6 +25,7 @@ go_test( ":handle", "//pkg/disttask/framework/proto", "//pkg/disttask/framework/storage", + "//pkg/disttask/framework/testutil", "//pkg/testkit", "//pkg/util/backoff", "@com_github_ngaut_pools//:pools", diff --git a/pkg/disttask/framework/handle/handle_test.go b/pkg/disttask/framework/handle/handle_test.go index 65e2eaa52c42b..dc711850610bb 100644 --- a/pkg/disttask/framework/handle/handle_test.go +++ b/pkg/disttask/framework/handle/handle_test.go @@ -27,6 +27,7 @@ import ( "github.com/pingcap/tidb/pkg/disttask/framework/handle" "github.com/pingcap/tidb/pkg/disttask/framework/proto" "github.com/pingcap/tidb/pkg/disttask/framework/storage" + "github.com/pingcap/tidb/pkg/disttask/framework/testutil" "github.com/pingcap/tidb/pkg/testkit" "github.com/pingcap/tidb/pkg/util/backoff" "github.com/stretchr/testify/require" @@ -46,9 +47,23 @@ func TestHandle(t *testing.T) { mgr := storage.NewTaskManager(pool) storage.SetTaskManager(mgr) +<<<<<<< HEAD // no dispatcher registered err := handle.SubmitAndRunGlobalTask(ctx, "1", proto.TaskTypeExample, 2, []byte("byte")) require.Error(t, err) +======= + testutil.WaitNodeRegistered(ctx, t) + + // no scheduler registered + task, err := handle.SubmitTask(ctx, "1", proto.TaskTypeExample, 2, []byte("byte")) + require.NoError(t, err) + waitedTask, err := handle.WaitTask(ctx, task.ID, func(task *proto.Task) bool { + return task.IsDone() + }) + require.NoError(t, err) + require.Equal(t, proto.TaskStateFailed, waitedTask.State) + require.ErrorContains(t, waitedTask.Error, "unknown task type") +>>>>>>> 99f0349bfb6 (disttask: fix failed step is taken as success (#49971)) task, err := mgr.GetGlobalTaskByID(ctx, 1) require.NoError(t, err) diff --git a/pkg/disttask/framework/mock/scheduler_mock.go b/pkg/disttask/framework/mock/scheduler_mock.go index 6f99256c9cede..d5e6cfa4ea673 100644 --- a/pkg/disttask/framework/mock/scheduler_mock.go +++ b/pkg/disttask/framework/mock/scheduler_mock.go @@ -429,7 +429,223 @@ func (mr *MockExtensionMockRecorder) GetSubtaskExecutor(arg0, arg1, arg2 any) *g // IsIdempotent mocks base method. func (m *MockExtension) IsIdempotent(arg0 *proto.Subtask) bool { m.ctrl.T.Helper() +<<<<<<< HEAD ret := m.ctrl.Call(m, "IsIdempotent", arg0) +======= + ret := m.ctrl.Call(m, "DeleteDeadNodes", arg0, arg1) + ret0, _ := ret[0].(error) + return ret0 +} + +// DeleteDeadNodes indicates an expected call of DeleteDeadNodes. +func (mr *MockTaskManagerMockRecorder) DeleteDeadNodes(arg0, arg1 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteDeadNodes", reflect.TypeOf((*MockTaskManager)(nil).DeleteDeadNodes), arg0, arg1) +} + +// FailTask mocks base method. +func (m *MockTaskManager) FailTask(arg0 context.Context, arg1 int64, arg2 proto.TaskState, arg3 error) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "FailTask", arg0, arg1, arg2, arg3) + ret0, _ := ret[0].(error) + return ret0 +} + +// FailTask indicates an expected call of FailTask. +func (mr *MockTaskManagerMockRecorder) FailTask(arg0, arg1, arg2, arg3 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FailTask", reflect.TypeOf((*MockTaskManager)(nil).FailTask), arg0, arg1, arg2, arg3) +} + +// GCSubtasks mocks base method. +func (m *MockTaskManager) GCSubtasks(arg0 context.Context) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GCSubtasks", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// GCSubtasks indicates an expected call of GCSubtasks. +func (mr *MockTaskManagerMockRecorder) GCSubtasks(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GCSubtasks", reflect.TypeOf((*MockTaskManager)(nil).GCSubtasks), arg0) +} + +// GetAllNodes mocks base method. +func (m *MockTaskManager) GetAllNodes(arg0 context.Context) ([]proto.ManagedNode, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetAllNodes", arg0) + ret0, _ := ret[0].([]proto.ManagedNode) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetAllNodes indicates an expected call of GetAllNodes. +func (mr *MockTaskManagerMockRecorder) GetAllNodes(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAllNodes", reflect.TypeOf((*MockTaskManager)(nil).GetAllNodes), arg0) +} + +// GetManagedNodes mocks base method. +func (m *MockTaskManager) GetManagedNodes(arg0 context.Context) ([]proto.ManagedNode, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetManagedNodes", arg0) + ret0, _ := ret[0].([]proto.ManagedNode) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetManagedNodes indicates an expected call of GetManagedNodes. +func (mr *MockTaskManagerMockRecorder) GetManagedNodes(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetManagedNodes", reflect.TypeOf((*MockTaskManager)(nil).GetManagedNodes), arg0) +} + +// GetSubtaskCntGroupByStates mocks base method. +func (m *MockTaskManager) GetSubtaskCntGroupByStates(arg0 context.Context, arg1 int64, arg2 proto.Step) (map[proto.SubtaskState]int64, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetSubtaskCntGroupByStates", arg0, arg1, arg2) + ret0, _ := ret[0].(map[proto.SubtaskState]int64) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetSubtaskCntGroupByStates indicates an expected call of GetSubtaskCntGroupByStates. +func (mr *MockTaskManagerMockRecorder) GetSubtaskCntGroupByStates(arg0, arg1, arg2 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSubtaskCntGroupByStates", reflect.TypeOf((*MockTaskManager)(nil).GetSubtaskCntGroupByStates), arg0, arg1, arg2) +} + +// GetSubtasksByExecIdsAndStepAndState mocks base method. +func (m *MockTaskManager) GetSubtasksByExecIdsAndStepAndState(arg0 context.Context, arg1 []string, arg2 int64, arg3 proto.Step, arg4 proto.SubtaskState) ([]*proto.Subtask, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetSubtasksByExecIdsAndStepAndState", arg0, arg1, arg2, arg3, arg4) + ret0, _ := ret[0].([]*proto.Subtask) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetSubtasksByExecIdsAndStepAndState indicates an expected call of GetSubtasksByExecIdsAndStepAndState. +func (mr *MockTaskManagerMockRecorder) GetSubtasksByExecIdsAndStepAndState(arg0, arg1, arg2, arg3, arg4 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSubtasksByExecIdsAndStepAndState", reflect.TypeOf((*MockTaskManager)(nil).GetSubtasksByExecIdsAndStepAndState), arg0, arg1, arg2, arg3, arg4) +} + +// GetSubtasksByStepAndState mocks base method. +func (m *MockTaskManager) GetSubtasksByStepAndState(arg0 context.Context, arg1 int64, arg2 proto.Step, arg3 proto.TaskState) ([]*proto.Subtask, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetSubtasksByStepAndState", arg0, arg1, arg2, arg3) + ret0, _ := ret[0].([]*proto.Subtask) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetSubtasksByStepAndState indicates an expected call of GetSubtasksByStepAndState. +func (mr *MockTaskManagerMockRecorder) GetSubtasksByStepAndState(arg0, arg1, arg2, arg3 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSubtasksByStepAndState", reflect.TypeOf((*MockTaskManager)(nil).GetSubtasksByStepAndState), arg0, arg1, arg2, arg3) +} + +// GetTaskByID mocks base method. +func (m *MockTaskManager) GetTaskByID(arg0 context.Context, arg1 int64) (*proto.Task, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetTaskByID", arg0, arg1) + ret0, _ := ret[0].(*proto.Task) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetTaskByID indicates an expected call of GetTaskByID. +func (mr *MockTaskManagerMockRecorder) GetTaskByID(arg0, arg1 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTaskByID", reflect.TypeOf((*MockTaskManager)(nil).GetTaskByID), arg0, arg1) +} + +// GetTaskExecutorIDsByTaskID mocks base method. +func (m *MockTaskManager) GetTaskExecutorIDsByTaskID(arg0 context.Context, arg1 int64) ([]string, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetTaskExecutorIDsByTaskID", arg0, arg1) + ret0, _ := ret[0].([]string) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetTaskExecutorIDsByTaskID indicates an expected call of GetTaskExecutorIDsByTaskID. +func (mr *MockTaskManagerMockRecorder) GetTaskExecutorIDsByTaskID(arg0, arg1 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTaskExecutorIDsByTaskID", reflect.TypeOf((*MockTaskManager)(nil).GetTaskExecutorIDsByTaskID), arg0, arg1) +} + +// GetTaskExecutorIDsByTaskIDAndStep mocks base method. +func (m *MockTaskManager) GetTaskExecutorIDsByTaskIDAndStep(arg0 context.Context, arg1 int64, arg2 proto.Step) ([]string, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetTaskExecutorIDsByTaskIDAndStep", arg0, arg1, arg2) + ret0, _ := ret[0].([]string) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetTaskExecutorIDsByTaskIDAndStep indicates an expected call of GetTaskExecutorIDsByTaskIDAndStep. +func (mr *MockTaskManagerMockRecorder) GetTaskExecutorIDsByTaskIDAndStep(arg0, arg1, arg2 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTaskExecutorIDsByTaskIDAndStep", reflect.TypeOf((*MockTaskManager)(nil).GetTaskExecutorIDsByTaskIDAndStep), arg0, arg1, arg2) +} + +// GetTasksInStates mocks base method. +func (m *MockTaskManager) GetTasksInStates(arg0 context.Context, arg1 ...any) ([]*proto.Task, error) { + m.ctrl.T.Helper() + varargs := []any{arg0} + for _, a := range arg1 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "GetTasksInStates", varargs...) + ret0, _ := ret[0].([]*proto.Task) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetTasksInStates indicates an expected call of GetTasksInStates. +func (mr *MockTaskManagerMockRecorder) GetTasksInStates(arg0 any, arg1 ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]any{arg0}, arg1...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTasksInStates", reflect.TypeOf((*MockTaskManager)(nil).GetTasksInStates), varargs...) +} + +// GetTopUnfinishedTasks mocks base method. +func (m *MockTaskManager) GetTopUnfinishedTasks(arg0 context.Context) ([]*proto.Task, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetTopUnfinishedTasks", arg0) + ret0, _ := ret[0].([]*proto.Task) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetTopUnfinishedTasks indicates an expected call of GetTopUnfinishedTasks. +func (mr *MockTaskManagerMockRecorder) GetTopUnfinishedTasks(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTopUnfinishedTasks", reflect.TypeOf((*MockTaskManager)(nil).GetTopUnfinishedTasks), arg0) +} + +// GetUsedSlotsOnNodes mocks base method. +func (m *MockTaskManager) GetUsedSlotsOnNodes(arg0 context.Context) (map[string]int, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetUsedSlotsOnNodes", arg0) + ret0, _ := ret[0].(map[string]int) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetUsedSlotsOnNodes indicates an expected call of GetUsedSlotsOnNodes. +func (mr *MockTaskManagerMockRecorder) GetUsedSlotsOnNodes(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUsedSlotsOnNodes", reflect.TypeOf((*MockTaskManager)(nil).GetUsedSlotsOnNodes), arg0) +} + +// PauseTask mocks base method. +func (m *MockTaskManager) PauseTask(arg0 context.Context, arg1 string) (bool, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "PauseTask", arg0, arg1) +>>>>>>> 99f0349bfb6 (disttask: fix failed step is taken as success (#49971)) ret0, _ := ret[0].(bool) return ret0 } diff --git a/pkg/disttask/framework/planner/BUILD.bazel b/pkg/disttask/framework/planner/BUILD.bazel index 3f28cede80197..8ea3e7c63e3b7 100644 --- a/pkg/disttask/framework/planner/BUILD.bazel +++ b/pkg/disttask/framework/planner/BUILD.bazel @@ -27,6 +27,7 @@ go_test( ":planner", "//pkg/disttask/framework/mock", "//pkg/disttask/framework/storage", + "//pkg/disttask/framework/testutil", "//pkg/kv", "//pkg/testkit", "@com_github_ngaut_pools//:pools", diff --git a/pkg/disttask/framework/planner/planner_test.go b/pkg/disttask/framework/planner/planner_test.go index e515c3e0e266f..0801ce6ecdb7a 100644 --- a/pkg/disttask/framework/planner/planner_test.go +++ b/pkg/disttask/framework/planner/planner_test.go @@ -23,6 +23,7 @@ import ( "github.com/pingcap/tidb/pkg/disttask/framework/mock" "github.com/pingcap/tidb/pkg/disttask/framework/planner" "github.com/pingcap/tidb/pkg/disttask/framework/storage" + "github.com/pingcap/tidb/pkg/disttask/framework/testutil" "github.com/pingcap/tidb/pkg/kv" "github.com/pingcap/tidb/pkg/testkit" "github.com/stretchr/testify/require" @@ -45,7 +46,7 @@ func TestPlanner(t *testing.T) { defer pool.Close() mgr := storage.NewTaskManager(pool) storage.SetTaskManager(mgr) - + testutil.WaitNodeRegistered(ctx, t) p := &planner.Planner{} pCtx := planner.PlanCtx{ Ctx: ctx, diff --git a/pkg/disttask/framework/scheduler/BUILD.bazel b/pkg/disttask/framework/scheduler/BUILD.bazel index 9ff3449a5411d..cb39089ed0a78 100644 --- a/pkg/disttask/framework/scheduler/BUILD.bazel +++ b/pkg/disttask/framework/scheduler/BUILD.bazel @@ -40,14 +40,27 @@ go_test( name = "scheduler_test", timeout = "short", srcs = [ +<<<<<<< HEAD "manager_test.go", "register_test.go", +======= + "main_test.go", + "nodes_test.go", + "rebalance_test.go", + "scheduler_manager_test.go", + "scheduler_nokit_test.go", +>>>>>>> 99f0349bfb6 (disttask: fix failed step is taken as success (#49971)) "scheduler_test.go", ], embed = [":scheduler"], flaky = True, +<<<<<<< HEAD race = "on", shard_count = 8, +======= + race = "off", + shard_count = 26, +>>>>>>> 99f0349bfb6 (disttask: fix failed step is taken as success (#49971)) deps = [ "//pkg/disttask/framework/mock", "//pkg/disttask/framework/mock/execute", diff --git a/pkg/disttask/framework/scheduler/interface.go b/pkg/disttask/framework/scheduler/interface.go index dc2c4aa9375af..0468c2d75da96 100644 --- a/pkg/disttask/framework/scheduler/interface.go +++ b/pkg/disttask/framework/scheduler/interface.go @@ -21,10 +21,65 @@ import ( "github.com/pingcap/tidb/pkg/disttask/framework/scheduler/execute" ) +<<<<<<< HEAD // TaskTable defines the interface to access task table. type TaskTable interface { GetGlobalTasksInStates(ctx context.Context, states ...interface{}) (task []*proto.Task, err error) GetGlobalTaskByID(ctx context.Context, taskID int64) (task *proto.Task, err error) +======= +// TaskManager defines the interface to access task table. +type TaskManager interface { + // GetTopUnfinishedTasks returns unfinished tasks, limited by MaxConcurrentTask*2, + // to make sure lower priority tasks can be scheduled if resource is enough. + // The returned tasks are sorted by task order, see proto.Task, and only contains + // some fields, see row2TaskBasic. + GetTopUnfinishedTasks(ctx context.Context) ([]*proto.Task, error) + GetTasksInStates(ctx context.Context, states ...interface{}) (task []*proto.Task, err error) + GetTaskByID(ctx context.Context, taskID int64) (task *proto.Task, err error) + UpdateTaskAndAddSubTasks(ctx context.Context, task *proto.Task, subtasks []*proto.Subtask, prevState proto.TaskState) (bool, error) + GCSubtasks(ctx context.Context) error + GetAllNodes(ctx context.Context) ([]proto.ManagedNode, error) + DeleteDeadNodes(ctx context.Context, nodes []string) error + TransferTasks2History(ctx context.Context, tasks []*proto.Task) error + CancelTask(ctx context.Context, taskID int64) error + // FailTask updates task state to Failed and updates task error. + FailTask(ctx context.Context, taskID int64, currentState proto.TaskState, taskErr error) error + PauseTask(ctx context.Context, taskKey string) (bool, error) + // SwitchTaskStep switches the task to the next step and add subtasks in one + // transaction. It will change task state too if we're switch from InitStep to + // next step. + SwitchTaskStep(ctx context.Context, task *proto.Task, nextState proto.TaskState, nextStep proto.Step, subtasks []*proto.Subtask) error + // SwitchTaskStepInBatch similar to SwitchTaskStep, but it will insert subtasks + // in batch, and task step change will be in a separate transaction. + // Note: subtasks of this step must be stable, i.e. count, order and content + // should be the same on each try, else the subtasks inserted might be messed up. + // And each subtask of this step must be different, to handle the network + // partition or owner change. + SwitchTaskStepInBatch(ctx context.Context, task *proto.Task, nextState proto.TaskState, nextStep proto.Step, subtasks []*proto.Subtask) error + // SucceedTask updates a task to success state. + SucceedTask(ctx context.Context, taskID int64) error + // GetUsedSlotsOnNodes returns the used slots on nodes that have subtask scheduled. + // subtasks of each task on one node is only accounted once as we don't support + // running them concurrently. + // we only consider pending/running subtasks, subtasks related to revert are + // not considered. + GetUsedSlotsOnNodes(ctx context.Context) (map[string]int, error) + // GetSubtaskCntGroupByStates returns the count of subtasks of some step group by state. + GetSubtaskCntGroupByStates(ctx context.Context, taskID int64, step proto.Step) (map[proto.SubtaskState]int64, error) + ResumeSubtasks(ctx context.Context, taskID int64) error + CollectSubTaskError(ctx context.Context, taskID int64) ([]error, error) + TransferSubTasks2History(ctx context.Context, taskID int64) error + UpdateSubtasksExecIDs(ctx context.Context, taskID int64, subtasks []*proto.Subtask) error + // GetManagedNodes returns the nodes managed by dist framework and can be used + // to execute tasks. If there are any nodes with background role, we use them, + // else we use nodes without role. + // returned nodes are sorted by node id(host:port). + GetManagedNodes(ctx context.Context) ([]proto.ManagedNode, error) + GetTaskExecutorIDsByTaskID(ctx context.Context, taskID int64) ([]string, error) + GetSubtasksByStepAndState(ctx context.Context, taskID int64, step proto.Step, state proto.TaskState) ([]*proto.Subtask, error) + GetSubtasksByExecIdsAndStepAndState(ctx context.Context, tidbIDs []string, taskID int64, step proto.Step, state proto.SubtaskState) ([]*proto.Subtask, error) + GetTaskExecutorIDsByTaskIDAndStep(ctx context.Context, taskID int64, step proto.Step) ([]string, error) +>>>>>>> 99f0349bfb6 (disttask: fix failed step is taken as success (#49971)) GetSubtasksInStates(ctx context.Context, tidbID string, taskID int64, step proto.Step, states ...interface{}) ([]*proto.Subtask, error) GetFirstSubtaskInStates(ctx context.Context, instanceID string, taskID int64, step proto.Step, states ...interface{}) (*proto.Subtask, error) diff --git a/pkg/disttask/framework/scheduler/scheduler.go b/pkg/disttask/framework/scheduler/scheduler.go index 6ad8835abeb4e..2e9accf189a3d 100644 --- a/pkg/disttask/framework/scheduler/scheduler.go +++ b/pkg/disttask/framework/scheduler/scheduler.go @@ -517,10 +517,28 @@ func (s *BaseScheduler) markErrorHandled() { s.mu.handled = true } +<<<<<<< HEAD func (s *BaseScheduler) getError() error { s.mu.RLock() defer s.mu.RUnlock() return s.mu.err +======= +// handle task in pausing state, cancel all running subtasks. +func (s *BaseScheduler) onPausing() error { + logutil.Logger(s.logCtx).Info("on pausing state", zap.Stringer("state", s.Task.State), zap.Int64("step", int64(s.Task.Step))) + cntByStates, err := s.taskMgr.GetSubtaskCntGroupByStates(s.ctx, s.Task.ID, s.Task.Step) + if err != nil { + logutil.Logger(s.logCtx).Warn("check task failed", zap.Error(err)) + return err + } + runningPendingCnt := cntByStates[proto.SubtaskStateRunning] + cntByStates[proto.SubtaskStatePending] + if runningPendingCnt == 0 { + logutil.Logger(s.logCtx).Info("all running subtasks paused, update the task to paused state") + return s.updateTask(proto.TaskStatePaused, nil, RetrySQLTimes) + } + logutil.Logger(s.logCtx).Debug("on pausing state, this task keeps current state", zap.Stringer("state", s.Task.State)) + return nil +>>>>>>> 99f0349bfb6 (disttask: fix failed step is taken as success (#49971)) } func (s *BaseScheduler) resetError() { @@ -530,6 +548,7 @@ func (s *BaseScheduler) resetError() { s.mu.handled = false } +<<<<<<< HEAD func (s *BaseScheduler) startSubtaskAndUpdateState(ctx context.Context, subtask *proto.Subtask) { metrics.DecDistTaskSubTaskCnt(subtask) metrics.EndDistTaskSubTask(subtask) @@ -544,6 +563,368 @@ func (s *BaseScheduler) updateSubtaskStateAndErrorImpl(ctx context.Context, tidb logger := logutil.Logger(s.logCtx) backoffer := backoff.NewExponential(dispatcher.RetrySQLInterval, 2, dispatcher.RetrySQLMaxInterval) err := handle.RunWithRetry(ctx, dispatcher.RetrySQLTimes, backoffer, logger, +======= +// TestSyncChan is used to sync the test. +var TestSyncChan = make(chan struct{}) + +// handle task in resuming state. +func (s *BaseScheduler) onResuming() error { + logutil.Logger(s.logCtx).Info("on resuming state", zap.Stringer("state", s.Task.State), zap.Int64("step", int64(s.Task.Step))) + cntByStates, err := s.taskMgr.GetSubtaskCntGroupByStates(s.ctx, s.Task.ID, s.Task.Step) + if err != nil { + logutil.Logger(s.logCtx).Warn("check task failed", zap.Error(err)) + return err + } + if cntByStates[proto.SubtaskStatePaused] == 0 { + // Finish the resuming process. + logutil.Logger(s.logCtx).Info("all paused tasks converted to pending state, update the task to running state") + err := s.updateTask(proto.TaskStateRunning, nil, RetrySQLTimes) + failpoint.Inject("syncAfterResume", func() { + TestSyncChan <- struct{}{} + }) + return err + } + + return s.taskMgr.ResumeSubtasks(s.ctx, s.Task.ID) +} + +// handle task in reverting state, check all revert subtasks finishes. +func (s *BaseScheduler) onReverting() error { + logutil.Logger(s.logCtx).Debug("on reverting state", zap.Stringer("state", s.Task.State), zap.Int64("step", int64(s.Task.Step))) + cntByStates, err := s.taskMgr.GetSubtaskCntGroupByStates(s.ctx, s.Task.ID, s.Task.Step) + if err != nil { + logutil.Logger(s.logCtx).Warn("check task failed", zap.Error(err)) + return err + } + activeRevertCnt := cntByStates[proto.SubtaskStateRevertPending] + cntByStates[proto.SubtaskStateReverting] + if activeRevertCnt == 0 { + if err = s.OnDone(s.ctx, s, s.Task); err != nil { + return errors.Trace(err) + } + return s.updateTask(proto.TaskStateReverted, nil, RetrySQLTimes) + } + // Wait all subtasks in this step finishes. + s.OnTick(s.ctx, s.Task) + logutil.Logger(s.logCtx).Debug("on reverting state, this task keeps current state", zap.Stringer("state", s.Task.State)) + return nil +} + +// handle task in pending state, schedule subtasks. +func (s *BaseScheduler) onPending() error { + logutil.Logger(s.logCtx).Debug("on pending state", zap.Stringer("state", s.Task.State), zap.Int64("step", int64(s.Task.Step))) + return s.switch2NextStep() +} + +// handle task in running state, check all running subtasks finishes. +// If subtasks finished, run into the next step. +func (s *BaseScheduler) onRunning() error { + logutil.Logger(s.logCtx).Debug("on running state", + zap.Stringer("state", s.Task.State), + zap.Int64("step", int64(s.Task.Step))) + // check current step finishes. + cntByStates, err := s.taskMgr.GetSubtaskCntGroupByStates(s.ctx, s.Task.ID, s.Task.Step) + if err != nil { + logutil.Logger(s.logCtx).Warn("check task failed", zap.Error(err)) + return err + } + if cntByStates[proto.SubtaskStateFailed] > 0 || cntByStates[proto.SubtaskStateCanceled] > 0 { + subTaskErrs, err := s.taskMgr.CollectSubTaskError(s.ctx, s.Task.ID) + if err != nil { + logutil.Logger(s.logCtx).Warn("collect subtask error failed", zap.Error(err)) + return err + } + if len(subTaskErrs) > 0 { + logutil.Logger(s.logCtx).Warn("subtasks encounter errors") + return s.onErrHandlingStage(subTaskErrs) + } + } else if s.isStepSucceed(cntByStates) { + return s.switch2NextStep() + } + + if err := s.balanceSubtasks(); err != nil { + return err + } + // Wait all subtasks in this step finishes. + s.OnTick(s.ctx, s.Task) + logutil.Logger(s.logCtx).Debug("on running state, this task keeps current state", zap.Stringer("state", s.Task.State)) + return nil +} + +func (s *BaseScheduler) onFinished() error { + metrics.UpdateMetricsForFinishTask(s.Task) + logutil.Logger(s.logCtx).Debug("schedule task, task is finished", zap.Stringer("state", s.Task.State)) + return s.taskMgr.TransferSubTasks2History(s.ctx, s.Task.ID) +} + +// balanceSubtasks check the liveNode num every liveNodeFetchInterval then rebalance subtasks. +func (s *BaseScheduler) balanceSubtasks() error { + if len(s.TaskNodes) == 0 { + var err error + s.TaskNodes, err = s.taskMgr.GetTaskExecutorIDsByTaskIDAndStep(s.ctx, s.Task.ID, s.Task.Step) + if err != nil { + return err + } + } + s.balanceSubtaskTick++ + if s.balanceSubtaskTick == defaultBalanceSubtaskTicks { + s.balanceSubtaskTick = 0 + eligibleNodes, err := s.getEligibleNodes() + if err != nil { + return err + } + if len(eligibleNodes) > 0 { + return s.doBalanceSubtasks(eligibleNodes) + } + } + return nil +} + +// DoBalanceSubtasks make count of subtasks on each liveNodes balanced and clean up subtasks on dead nodes. +// TODO(ywqzzy): refine to make it easier for testing. +func (s *BaseScheduler) doBalanceSubtasks(eligibleNodes []string) error { + eligibleNodeMap := make(map[string]struct{}, len(eligibleNodes)) + for _, n := range eligibleNodes { + eligibleNodeMap[n] = struct{}{} + } + // 1. find out nodes need to clean subtasks. + deadNodes := make([]string, 0) + deadNodesMap := make(map[string]bool, 0) + for _, node := range s.TaskNodes { + if _, ok := eligibleNodeMap[node]; !ok { + deadNodes = append(deadNodes, node) + deadNodesMap[node] = true + } + } + // 2. get subtasks for each node before rebalance. + subtasks, err := s.taskMgr.GetSubtasksByStepAndState(s.ctx, s.Task.ID, s.Task.Step, proto.TaskStatePending) + if err != nil { + return err + } + if len(deadNodes) != 0 { + /// get subtask from deadNodes, since there might be some running subtasks on deadNodes. + /// In this case, all subtasks on deadNodes are in running/pending state. + subtasksOnDeadNodes, err := s.taskMgr.GetSubtasksByExecIdsAndStepAndState( + s.ctx, + deadNodes, + s.Task.ID, + s.Task.Step, + proto.SubtaskStateRunning) + if err != nil { + return err + } + subtasks = append(subtasks, subtasksOnDeadNodes...) + } + // 3. group subtasks for each task executor. + subtasksOnTaskExecutor := make(map[string][]*proto.Subtask, len(eligibleNodes)+len(deadNodes)) + for _, node := range eligibleNodes { + subtasksOnTaskExecutor[node] = make([]*proto.Subtask, 0) + } + for _, subtask := range subtasks { + subtasksOnTaskExecutor[subtask.ExecID] = append( + subtasksOnTaskExecutor[subtask.ExecID], + subtask) + } + // 4. prepare subtasks that need to rebalance to other nodes. + averageSubtaskCnt := len(subtasks) / len(eligibleNodes) + rebalanceSubtasks := make([]*proto.Subtask, 0) + for k, v := range subtasksOnTaskExecutor { + if ok := deadNodesMap[k]; ok { + rebalanceSubtasks = append(rebalanceSubtasks, v...) + continue + } + // When no tidb scale-in/out and averageSubtaskCnt*len(eligibleNodes) < len(subtasks), + // no need to send subtask to other nodes. + // eg: tidb1 with 3 subtasks, tidb2 with 2 subtasks, subtasks are balanced now. + if averageSubtaskCnt*len(eligibleNodes) < len(subtasks) && len(s.TaskNodes) == len(eligibleNodes) { + if len(v) > averageSubtaskCnt+1 { + rebalanceSubtasks = append(rebalanceSubtasks, v[0:len(v)-averageSubtaskCnt]...) + } + continue + } + if len(v) > averageSubtaskCnt { + rebalanceSubtasks = append(rebalanceSubtasks, v[0:len(v)-averageSubtaskCnt]...) + } + } + // 5. skip rebalance. + if len(rebalanceSubtasks) == 0 { + return nil + } + // 6.rebalance subtasks to other nodes. + rebalanceIdx := 0 + for k, v := range subtasksOnTaskExecutor { + if ok := deadNodesMap[k]; !ok { + if len(v) < averageSubtaskCnt { + for i := 0; i < averageSubtaskCnt-len(v) && rebalanceIdx < len(rebalanceSubtasks); i++ { + rebalanceSubtasks[rebalanceIdx].ExecID = k + rebalanceIdx++ + } + } + } + } + // 7. rebalance rest subtasks evenly to liveNodes. + liveNodeIdx := 0 + for rebalanceIdx < len(rebalanceSubtasks) { + rebalanceSubtasks[rebalanceIdx].ExecID = eligibleNodes[liveNodeIdx] + rebalanceIdx++ + liveNodeIdx++ + } + + // 8. update subtasks and do clean up logic. + if err = s.taskMgr.UpdateSubtasksExecIDs(s.ctx, s.Task.ID, subtasks); err != nil { + return err + } + logutil.Logger(s.logCtx).Info("balance subtasks", + zap.Stringers("subtasks-rebalanced", subtasks)) + s.TaskNodes = append([]string{}, eligibleNodes...) + return nil +} + +// updateTask update the task in tidb_global_task table. +func (s *BaseScheduler) updateTask(taskState proto.TaskState, newSubTasks []*proto.Subtask, retryTimes int) (err error) { + prevState := s.Task.State + s.Task.State = taskState + logutil.BgLogger().Info("task state transform", zap.Stringer("from", prevState), zap.Stringer("to", taskState)) + if !VerifyTaskStateTransform(prevState, taskState) { + return errors.Errorf("invalid task state transform, from %s to %s", prevState, taskState) + } + + var retryable bool + for i := 0; i < retryTimes; i++ { + retryable, err = s.taskMgr.UpdateTaskAndAddSubTasks(s.ctx, s.Task, newSubTasks, prevState) + if err == nil || !retryable { + break + } + if err1 := s.ctx.Err(); err1 != nil { + return err1 + } + if i%10 == 0 { + logutil.Logger(s.logCtx).Warn("updateTask first failed", zap.Stringer("from", prevState), zap.Stringer("to", s.Task.State), + zap.Int("retry times", i), zap.Error(err)) + } + time.Sleep(RetrySQLInterval) + } + if err != nil && retryTimes != nonRetrySQLTime { + logutil.Logger(s.logCtx).Warn("updateTask failed", + zap.Stringer("from", prevState), zap.Stringer("to", s.Task.State), zap.Int("retry times", retryTimes), zap.Error(err)) + } + return err +} + +func (s *BaseScheduler) onErrHandlingStage(receiveErrs []error) error { + // we only store the first error. + s.Task.Error = receiveErrs[0] + + var subTasks []*proto.Subtask + // when step of task is `StepInit`, no need to do revert + if s.Task.Step != proto.StepInit { + instanceIDs, err := s.GetAllTaskExecutorIDs(s.ctx, s.Task) + if err != nil { + logutil.Logger(s.logCtx).Warn("get task's all instances failed", zap.Error(err)) + return err + } + + subTasks = make([]*proto.Subtask, 0, len(instanceIDs)) + for _, id := range instanceIDs { + // reverting subtasks belong to the same step as current active step. + subTasks = append(subTasks, proto.NewSubtask( + s.Task.Step, s.Task.ID, s.Task.Type, id, + s.Task.Concurrency, proto.EmptyMeta, 0)) + } + } + return s.updateTask(proto.TaskStateReverting, subTasks, RetrySQLTimes) +} + +func (s *BaseScheduler) switch2NextStep() (err error) { + nextStep := s.GetNextStep(s.Task) + logutil.Logger(s.logCtx).Info("on next step", + zap.Int64("current-step", int64(s.Task.Step)), + zap.Int64("next-step", int64(nextStep))) + + if nextStep == proto.StepDone { + s.Task.Step = nextStep + s.Task.StateUpdateTime = time.Now().UTC() + if err = s.OnDone(s.ctx, s, s.Task); err != nil { + return errors.Trace(err) + } + return s.taskMgr.SucceedTask(s.ctx, s.Task.ID) + } + + serverNodes, err := s.getEligibleNodes() + if err != nil { + return err + } + logutil.Logger(s.logCtx).Info("eligible instances", zap.Int("num", len(serverNodes))) + if len(serverNodes) == 0 { + return errors.New("no available TiDB node to dispatch subtasks") + } + + metas, err := s.OnNextSubtasksBatch(s.ctx, s, s.Task, serverNodes, nextStep) + if err != nil { + logutil.Logger(s.logCtx).Warn("generate part of subtasks failed", zap.Error(err)) + return s.handlePlanErr(err) + } + + return s.scheduleSubTask(nextStep, metas, serverNodes) +} + +// getEligibleNodes returns the eligible(live) nodes for the task. +// if the task can only be scheduled to some specific nodes, return them directly, +// we don't care liveliness of them. +func (s *BaseScheduler) getEligibleNodes() ([]string, error) { + serverNodes, err := s.GetEligibleInstances(s.ctx, s.Task) + if err != nil { + return nil, err + } + logutil.Logger(s.logCtx).Debug("eligible instances", zap.Int("num", len(serverNodes))) + if len(serverNodes) == 0 { + serverNodes = append([]string{}, s.nodeMgr.getManagedNodes()...) + } + return serverNodes, nil +} + +func (s *BaseScheduler) scheduleSubTask( + subtaskStep proto.Step, + metas [][]byte, + serverNodes []string) error { + logutil.Logger(s.logCtx).Info("schedule subtasks", + zap.Stringer("state", s.Task.State), + zap.Int64("step", int64(s.Task.Step)), + zap.Int("concurrency", s.Task.Concurrency), + zap.Int("subtasks", len(metas))) + s.TaskNodes = serverNodes + var size uint64 + subTasks := make([]*proto.Subtask, 0, len(metas)) + for i, meta := range metas { + // we assign the subtask to the instance in a round-robin way. + // TODO: assign the subtask to the instance according to the system load of each nodes + pos := i % len(serverNodes) + instanceID := serverNodes[pos] + logutil.Logger(s.logCtx).Debug("create subtasks", zap.String("instanceID", instanceID)) + subTasks = append(subTasks, proto.NewSubtask( + subtaskStep, s.Task.ID, s.Task.Type, instanceID, s.Task.Concurrency, meta, i+1)) + + size += uint64(len(meta)) + } + failpoint.Inject("cancelBeforeUpdateTask", func() { + _ = s.taskMgr.CancelTask(s.ctx, s.Task.ID) + }) + + // as other fields and generated key and index KV takes space too, we limit + // the size of subtasks to 80% of the transaction limit. + limit := max(uint64(float64(kv.TxnTotalSizeLimit.Load())*0.8), 1) + fn := s.taskMgr.SwitchTaskStep + if size >= limit { + // On default, transaction size limit is controlled by tidb_mem_quota_query + // which is 1G on default, so it's unlikely to reach this limit, but in + // case user set txn-total-size-limit explicitly, we insert in batch. + logutil.Logger(s.logCtx).Info("subtasks size exceed limit, will insert in batch", + zap.Uint64("size", size), zap.Uint64("limit", limit)) + fn = s.taskMgr.SwitchTaskStepInBatch + } + + backoffer := backoff.NewExponential(RetrySQLInterval, 2, RetrySQLMaxInterval) + return handle.RunWithRetry(s.ctx, RetrySQLTimes, backoffer, logutil.Logger(s.logCtx), +>>>>>>> 99f0349bfb6 (disttask: fix failed step is taken as success (#49971)) func(ctx context.Context) (bool, error) { return true, s.taskTable.UpdateSubtaskStateAndError(ctx, tidbID, subtaskID, state, subTaskErr) }, @@ -646,5 +1027,33 @@ func (s *BaseScheduler) updateErrorToSubtask(ctx context.Context, taskID int64, if err1 == nil { logger.Warn("update error to subtask success", zap.Error(err)) } +<<<<<<< HEAD return err1 +======= + previousSubtaskMetas := make([][]byte, 0, len(previousSubtasks)) + for _, subtask := range previousSubtasks { + previousSubtaskMetas = append(previousSubtaskMetas, subtask.Meta) + } + return previousSubtaskMetas, nil +} + +// WithNewSession executes the function with a new session. +func (s *BaseScheduler) WithNewSession(fn func(se sessionctx.Context) error) error { + return s.taskMgr.WithNewSession(fn) +} + +// WithNewTxn executes the fn in a new transaction. +func (s *BaseScheduler) WithNewTxn(ctx context.Context, fn func(se sessionctx.Context) error) error { + return s.taskMgr.WithNewTxn(ctx, fn) +} + +func (*BaseScheduler) isStepSucceed(cntByStates map[proto.SubtaskState]int64) bool { + _, ok := cntByStates[proto.SubtaskStateSucceed] + return len(cntByStates) == 0 || (len(cntByStates) == 1 && ok) +} + +// IsCancelledErr checks if the error is a cancelled error. +func IsCancelledErr(err error) bool { + return strings.Contains(err.Error(), taskCancelMsg) +>>>>>>> 99f0349bfb6 (disttask: fix failed step is taken as success (#49971)) } diff --git a/pkg/disttask/framework/scheduler/scheduler_manager_test.go b/pkg/disttask/framework/scheduler/scheduler_manager_test.go new file mode 100644 index 0000000000000..3fb7804244683 --- /dev/null +++ b/pkg/disttask/framework/scheduler/scheduler_manager_test.go @@ -0,0 +1,84 @@ +// Copyright 2023 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package scheduler_test + +import ( + "context" + "testing" + "time" + + "github.com/ngaut/pools" + "github.com/pingcap/tidb/pkg/disttask/framework/mock" + "github.com/pingcap/tidb/pkg/disttask/framework/proto" + "github.com/pingcap/tidb/pkg/disttask/framework/testutil" + "github.com/pingcap/tidb/pkg/testkit" + "github.com/stretchr/testify/require" + "github.com/tikv/client-go/v2/util" + "go.uber.org/mock/gomock" +) + +func TestCleanUpRoutine(t *testing.T) { + store := testkit.CreateMockStore(t) + gtk := testkit.NewTestKit(t, store) + pool := pools.NewResourcePool(func() (pools.Resource, error) { + return gtk.Session(), nil + }, 1, 1, time.Second) + defer pool.Close() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + ctx := context.Background() + ctx = util.WithInternalSourceType(ctx, "scheduler_manager") + mockCleanupRoutine := mock.NewMockCleanUpRoutine(ctrl) + + sch, mgr := MockSchedulerManager(t, ctrl, pool, getNumberExampleSchedulerExt(ctrl), mockCleanupRoutine) + mockCleanupRoutine.EXPECT().CleanUp(gomock.Any(), gomock.Any()).Return(nil).AnyTimes() + sch.Start() + defer sch.Stop() + testutil.WaitNodeRegistered(ctx, t) + taskID, err := mgr.CreateTask(ctx, "test", proto.TaskTypeExample, 1, nil) + require.NoError(t, err) + + checkTaskRunningCnt := func() []*proto.Task { + var tasks []*proto.Task + require.Eventually(t, func() bool { + var err error + tasks, err = mgr.GetTasksInStates(ctx, proto.TaskStateRunning) + require.NoError(t, err) + return len(tasks) == 1 + }, time.Second, 50*time.Millisecond) + return tasks + } + + checkSubtaskCnt := func(tasks []*proto.Task, taskID int64) { + require.Eventually(t, func() bool { + cntByStates, err := mgr.GetSubtaskCntGroupByStates(ctx, taskID, proto.StepOne) + require.NoError(t, err) + return int64(subtaskCnt) == cntByStates[proto.SubtaskStatePending] + }, time.Second, 50*time.Millisecond) + } + + tasks := checkTaskRunningCnt() + checkSubtaskCnt(tasks, taskID) + for i := 1; i <= subtaskCnt; i++ { + err = mgr.UpdateSubtaskStateAndError(ctx, ":4000", int64(i), proto.SubtaskStateSucceed, nil) + require.NoError(t, err) + } + sch.DoCleanUpRoutine() + require.Eventually(t, func() bool { + tasks, err := mgr.GetTasksFromHistoryInStates(ctx, proto.TaskStateSucceed) + require.NoError(t, err) + return len(tasks) != 0 + }, time.Second*10, time.Millisecond*300) +} diff --git a/pkg/disttask/framework/scheduler/scheduler_nokit_test.go b/pkg/disttask/framework/scheduler/scheduler_nokit_test.go new file mode 100644 index 0000000000000..e6320ded9ceec --- /dev/null +++ b/pkg/disttask/framework/scheduler/scheduler_nokit_test.go @@ -0,0 +1,40 @@ +// Copyright 2024 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package scheduler + +import ( + "testing" + + "github.com/pingcap/tidb/pkg/disttask/framework/proto" + "github.com/stretchr/testify/require" +) + +func TestSchedulerIsStepSucceed(t *testing.T) { + s := &BaseScheduler{} + require.True(t, s.isStepSucceed(nil)) + require.True(t, s.isStepSucceed(map[proto.SubtaskState]int64{})) + require.True(t, s.isStepSucceed(map[proto.SubtaskState]int64{ + proto.SubtaskStateSucceed: 1, + })) + for _, state := range []proto.SubtaskState{ + proto.SubtaskStateCanceled, + proto.SubtaskStateFailed, + proto.SubtaskStateReverting, + } { + require.False(t, s.isStepSucceed(map[proto.SubtaskState]int64{ + state: 1, + })) + } +} diff --git a/pkg/disttask/framework/scheduler/scheduler_test.go b/pkg/disttask/framework/scheduler/scheduler_test.go index 508cdcfae1a13..ca54c59be84cc 100644 --- a/pkg/disttask/framework/scheduler/scheduler_test.go +++ b/pkg/disttask/framework/scheduler/scheduler_test.go @@ -208,6 +208,7 @@ func TestSchedulerRun(t *testing.T) { err = scheduler.Run(runCtx, task) require.EqualError(t, err, context.Canceled.Error()) +<<<<<<< HEAD // 8. grpc cancel mockSubtaskTable.EXPECT().GetSubtasksInStates(gomock.Any(), "id", taskID, proto.StepOne, unfinishedNormalSubtaskStates...).Return([]*proto.Subtask{{ @@ -237,6 +238,439 @@ func TestSchedulerRun(t *testing.T) { grpcErr, " %s", "test annotate", +======= + testutil.WaitNodeRegistered(ctx, t) + + // unknown task type + taskID, err := mgr.CreateTask(ctx, "test", "test-type", 1, nil) + require.NoError(t, err) + require.Eventually(t, func() bool { + task, err := mgr.GetTaskByID(ctx, taskID) + require.NoError(t, err) + return task.State == proto.TaskStateFailed && + strings.Contains(task.Error.Error(), "unknown task type") + }, time.Second*10, time.Millisecond*300) + + // scheduler init error + taskID, err = mgr.CreateTask(ctx, "test2", proto.TaskTypeExample, 1, nil) + require.NoError(t, err) + require.Eventually(t, func() bool { + task, err := mgr.GetTaskByID(ctx, taskID) + require.NoError(t, err) + return task.State == proto.TaskStateFailed && + strings.Contains(task.Error.Error(), "mock scheduler init error") + }, time.Second*10, time.Millisecond*300) +} + +func checkDispatch(t *testing.T, taskCnt int, isSucc, isCancel, isSubtaskCancel, isPauseAndResume bool) { + require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/pkg/domain/MockDisableDistTask", "return(true)")) + defer func() { + require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/pkg/domain/MockDisableDistTask")) + }() + // test DispatchTaskLoop + // test parallelism control + var originalConcurrency int + if taskCnt == 1 { + originalConcurrency = proto.MaxConcurrentTask + proto.MaxConcurrentTask = 1 + } + + store := testkit.CreateMockStore(t) + gtk := testkit.NewTestKit(t, store) + pool := pools.NewResourcePool(func() (pools.Resource, error) { + return gtk.Session(), nil + }, 1, 1, time.Second) + defer pool.Close() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + ctx := context.Background() + ctx = util.WithInternalSourceType(ctx, "scheduler") + + sch, mgr := MockSchedulerManager(t, ctrl, pool, getNumberExampleSchedulerExt(ctrl), nil) + sch.Start() + defer func() { + sch.Stop() + // make data race happy + if taskCnt == 1 { + proto.MaxConcurrentTask = originalConcurrency + } + }() + + require.NoError(t, mgr.StartManager(ctx, ":4000", "background")) + + // 3s + cnt := 60 + checkGetRunningTaskCnt := func(expected int) { + require.Eventually(t, func() bool { + return sch.GetRunningTaskCnt() == expected + }, time.Second, 50*time.Millisecond) + } + + checkTaskRunningCnt := func() []*proto.Task { + var tasks []*proto.Task + require.Eventually(t, func() bool { + var err error + tasks, err = mgr.GetTasksInStates(ctx, proto.TaskStateRunning) + require.NoError(t, err) + return len(tasks) == taskCnt + }, time.Second, 50*time.Millisecond) + return tasks + } + + checkSubtaskCnt := func(tasks []*proto.Task, taskIDs []int64) { + for i, taskID := range taskIDs { + require.Equal(t, int64(i+1), tasks[i].ID) + require.Eventually(t, func() bool { + cntByStates, err := mgr.GetSubtaskCntGroupByStates(ctx, taskID, proto.StepOne) + require.NoError(t, err) + return int64(subtaskCnt) == cntByStates[proto.SubtaskStatePending] + }, time.Second, 50*time.Millisecond) + } + } + + // Mock add tasks. + taskIDs := make([]int64, 0, taskCnt) + for i := 0; i < taskCnt; i++ { + taskID, err := mgr.CreateTask(ctx, fmt.Sprintf("%d", i), proto.TaskTypeExample, 0, nil) + require.NoError(t, err) + taskIDs = append(taskIDs, taskID) + } + // test OnNextSubtasksBatch. + checkGetRunningTaskCnt(taskCnt) + tasks := checkTaskRunningCnt() + checkSubtaskCnt(tasks, taskIDs) + // test parallelism control + if taskCnt == 1 { + taskID, err := mgr.CreateTask(ctx, fmt.Sprintf("%d", taskCnt), proto.TaskTypeExample, 0, nil) + require.NoError(t, err) + checkGetRunningTaskCnt(taskCnt) + // Clean the task. + deleteTasks(t, store, taskID) + sch.DelRunningTask(taskID) + } + + // test DetectTaskLoop + checkGetTaskState := func(expectedState proto.TaskState) { + i := 0 + for ; i < cnt; i++ { + tasks, err := mgr.GetTasksInStates(ctx, expectedState) + require.NoError(t, err) + if len(tasks) == taskCnt { + break + } + historyTasks, err := mgr.GetTasksFromHistoryInStates(ctx, expectedState) + require.NoError(t, err) + if len(tasks)+len(historyTasks) == taskCnt { + break + } + time.Sleep(time.Millisecond * 50) + } + require.Less(t, i, cnt) + } + // Test all subtasks are successful. + var err error + if isSucc { + // Mock subtasks succeed. + for i := 1; i <= subtaskCnt*taskCnt; i++ { + err = mgr.UpdateSubtaskStateAndError(ctx, ":4000", int64(i), proto.SubtaskStateSucceed, nil) + require.NoError(t, err) + } + checkGetTaskState(proto.TaskStateSucceed) + require.Len(t, tasks, taskCnt) + + checkGetRunningTaskCnt(0) + return + } + + if isCancel { + for i := 1; i <= taskCnt; i++ { + err = mgr.CancelTask(ctx, int64(i)) + require.NoError(t, err) + } + } else if isPauseAndResume { + for i := 0; i < taskCnt; i++ { + found, err := mgr.PauseTask(ctx, fmt.Sprintf("%d", i)) + require.Equal(t, true, found) + require.NoError(t, err) + } + for i := 1; i <= subtaskCnt*taskCnt; i++ { + err = mgr.UpdateSubtaskStateAndError(ctx, ":4000", int64(i), proto.SubtaskStatePaused, nil) + require.NoError(t, err) + } + checkGetTaskState(proto.TaskStatePaused) + for i := 0; i < taskCnt; i++ { + found, err := mgr.ResumeTask(ctx, fmt.Sprintf("%d", i)) + require.Equal(t, true, found) + require.NoError(t, err) + } + + // Mock subtasks succeed. + for i := 1; i <= subtaskCnt*taskCnt; i++ { + err = mgr.UpdateSubtaskStateAndError(ctx, ":4000", int64(i), proto.SubtaskStateSucceed, nil) + require.NoError(t, err) + } + checkGetTaskState(proto.TaskStateSucceed) + return + } else { + // Test each task has a subtask failed. + require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/pkg/disttask/framework/storage/MockUpdateTaskErr", "1*return(true)")) + defer func() { + require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/pkg/disttask/framework/storage/MockUpdateTaskErr")) + }() + + if isSubtaskCancel { + // Mock a subtask canceled + for i := 1; i <= subtaskCnt*taskCnt; i += subtaskCnt { + err = mgr.UpdateSubtaskStateAndError(ctx, ":4000", int64(i), proto.SubtaskStateCanceled, nil) + require.NoError(t, err) + } + } else { + // Mock a subtask fails. + for i := 1; i <= subtaskCnt*taskCnt; i += subtaskCnt { + err = mgr.UpdateSubtaskStateAndError(ctx, ":4000", int64(i), proto.SubtaskStateFailed, nil) + require.NoError(t, err) + } + } + } + + checkGetTaskState(proto.TaskStateReverting) + require.Len(t, tasks, taskCnt) + // Mock all subtask reverted. + start := subtaskCnt * taskCnt + for i := start; i <= start+subtaskCnt*taskCnt; i++ { + err = mgr.UpdateSubtaskStateAndError(ctx, ":4000", int64(i), proto.SubtaskStateReverted, nil) + require.NoError(t, err) + } + checkGetTaskState(proto.TaskStateReverted) + require.Len(t, tasks, taskCnt) +} + +func TestSimple(t *testing.T) { + checkDispatch(t, 1, true, false, false, false) +} + +func TestSimpleErrStage(t *testing.T) { + checkDispatch(t, 1, false, false, false, false) +} + +func TestSimpleCancel(t *testing.T) { + checkDispatch(t, 1, false, true, false, false) +} + +func TestSimpleSubtaskCancel(t *testing.T) { + checkDispatch(t, 1, false, false, true, false) +} + +func TestParallel(t *testing.T) { + checkDispatch(t, 3, true, false, false, false) +} + +func TestParallelErrStage(t *testing.T) { + checkDispatch(t, 3, false, false, false, false) +} + +func TestParallelCancel(t *testing.T) { + checkDispatch(t, 3, false, true, false, false) +} + +func TestParallelSubtaskCancel(t *testing.T) { + checkDispatch(t, 3, false, false, true, false) +} + +func TestPause(t *testing.T) { + checkDispatch(t, 1, false, false, false, true) +} + +func TestParallelPause(t *testing.T) { + checkDispatch(t, 3, false, false, false, true) +} + +func TestVerifyTaskStateTransform(t *testing.T) { + testCases := []struct { + oldState proto.TaskState + newState proto.TaskState + expect bool + }{ + {proto.TaskStateRunning, proto.TaskStateRunning, true}, + {proto.TaskStatePending, proto.TaskStateRunning, true}, + {proto.TaskStatePending, proto.TaskStateReverting, false}, + {proto.TaskStateRunning, proto.TaskStateReverting, true}, + {proto.TaskStateReverting, proto.TaskStateReverted, true}, + {proto.TaskStateReverting, proto.TaskStateSucceed, false}, + {proto.TaskStateRunning, proto.TaskStatePausing, true}, + {proto.TaskStateRunning, proto.TaskStateResuming, false}, + {proto.TaskStateCancelling, proto.TaskStateRunning, false}, + {proto.TaskStateCanceled, proto.TaskStateRunning, false}, + } + for _, tc := range testCases { + require.Equal(t, tc.expect, scheduler.VerifyTaskStateTransform(tc.oldState, tc.newState)) + } +} + +func TestIsCancelledErr(t *testing.T) { + require.False(t, scheduler.IsCancelledErr(errors.New("some err"))) + require.False(t, scheduler.IsCancelledErr(context.Canceled)) + require.True(t, scheduler.IsCancelledErr(errors.New("cancelled by user"))) +} + +func TestDispatcherOnNextStage(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + taskMgr := mock.NewMockTaskManager(ctrl) + schExt := mockDispatch.NewMockExtension(ctrl) + + ctx := context.Background() + ctx = util.WithInternalSourceType(ctx, "dispatcher") + task := proto.Task{ + ID: 1, + State: proto.TaskStatePending, + Step: proto.StepInit, + } + cloneTask := task + nodeMgr := scheduler.NewNodeManager() + sch := scheduler.NewBaseScheduler(ctx, taskMgr, nodeMgr, &cloneTask) + sch.Extension = schExt + + // test next step is done + schExt.EXPECT().GetNextStep(gomock.Any()).Return(proto.StepDone) + schExt.EXPECT().OnDone(gomock.Any(), gomock.Any(), gomock.Any()).Return(errors.New("done err")) + require.ErrorContains(t, sch.Switch2NextStep(), "done err") + require.True(t, ctrl.Satisfied()) + // we update task step before OnDone + require.Equal(t, proto.StepDone, sch.Task.Step) + *sch.Task = task + schExt.EXPECT().GetNextStep(gomock.Any()).Return(proto.StepDone) + schExt.EXPECT().OnDone(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil) + taskMgr.EXPECT().SucceedTask(gomock.Any(), gomock.Any()).Return(nil) + require.NoError(t, sch.Switch2NextStep()) + require.True(t, ctrl.Satisfied()) + + // GetEligibleInstances err + schExt.EXPECT().GetNextStep(gomock.Any()).Return(proto.StepOne) + schExt.EXPECT().GetEligibleInstances(gomock.Any(), gomock.Any()).Return(nil, errors.New("GetEligibleInstances err")) + require.ErrorContains(t, sch.Switch2NextStep(), "GetEligibleInstances err") + require.True(t, ctrl.Satisfied()) + // GetEligibleInstances no instance + schExt.EXPECT().GetNextStep(gomock.Any()).Return(proto.StepOne) + schExt.EXPECT().GetEligibleInstances(gomock.Any(), gomock.Any()).Return(nil, nil) + require.ErrorContains(t, sch.Switch2NextStep(), "no available TiDB node to dispatch subtasks") + require.True(t, ctrl.Satisfied()) + + serverNodes := []string{":4000"} + // OnNextSubtasksBatch err + schExt.EXPECT().GetNextStep(gomock.Any()).Return(proto.StepOne) + schExt.EXPECT().GetEligibleInstances(gomock.Any(), gomock.Any()).Return(serverNodes, nil) + schExt.EXPECT().OnNextSubtasksBatch(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + Return(nil, errors.New("OnNextSubtasksBatch err")) + schExt.EXPECT().IsRetryableErr(gomock.Any()).Return(true) + require.ErrorContains(t, sch.Switch2NextStep(), "OnNextSubtasksBatch err") + require.True(t, ctrl.Satisfied()) + + bak := kv.TxnTotalSizeLimit.Load() + t.Cleanup(func() { + kv.TxnTotalSizeLimit.Store(bak) + }) + + // dispatch in batch + subtaskMetas := [][]byte{ + []byte(`{"xx": "1"}`), + []byte(`{"xx": "2"}`), + } + schExt.EXPECT().GetNextStep(gomock.Any()).Return(proto.StepOne) + schExt.EXPECT().GetEligibleInstances(gomock.Any(), gomock.Any()).Return(serverNodes, nil) + schExt.EXPECT().OnNextSubtasksBatch(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + Return(subtaskMetas, nil) + taskMgr.EXPECT().SwitchTaskStepInBatch(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil) + kv.TxnTotalSizeLimit.Store(1) + require.NoError(t, sch.Switch2NextStep()) + require.True(t, ctrl.Satisfied()) + // met unstable subtasks + schExt.EXPECT().GetNextStep(gomock.Any()).Return(proto.StepOne) + schExt.EXPECT().GetEligibleInstances(gomock.Any(), gomock.Any()).Return(serverNodes, nil) + schExt.EXPECT().OnNextSubtasksBatch(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + Return(subtaskMetas, nil) + taskMgr.EXPECT().SwitchTaskStepInBatch(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + Return(errors.Annotatef(storage.ErrUnstableSubtasks, "expected %d, got %d", + 2, 100)) + kv.TxnTotalSizeLimit.Store(1) + startTime := time.Now() + err := sch.Switch2NextStep() + require.ErrorIs(t, err, storage.ErrUnstableSubtasks) + require.ErrorContains(t, err, "expected 2, got 100") + require.WithinDuration(t, startTime, time.Now(), 10*time.Second) + require.True(t, ctrl.Satisfied()) + + // dispatch in one txn + schExt.EXPECT().GetNextStep(gomock.Any()).Return(proto.StepOne) + schExt.EXPECT().GetEligibleInstances(gomock.Any(), gomock.Any()).Return(serverNodes, nil) + schExt.EXPECT().OnNextSubtasksBatch(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + Return(subtaskMetas, nil) + taskMgr.EXPECT().SwitchTaskStep(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil) + kv.TxnTotalSizeLimit.Store(config.DefTxnTotalSizeLimit) + require.NoError(t, sch.Switch2NextStep()) + require.True(t, ctrl.Satisfied()) +} + +func TestManagerDispatchLoop(t *testing.T) { + // Mock 16 cpu node. + require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/pkg/util/cpu/mockNumCpu", "return(16)")) + t.Cleanup(func() { + require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/pkg/util/cpu/mockNumCpu")) + }) + ctrl := gomock.NewController(t) + defer ctrl.Finish() + mockScheduler := mock.NewMockScheduler(ctrl) + + _ = testkit.CreateMockStore(t) + require.Eventually(t, func() bool { + taskMgr, err := storage.GetTaskManager() + return err == nil && taskMgr != nil + }, 10*time.Second, 100*time.Millisecond) + + ctx := context.Background() + ctx = util.WithInternalSourceType(ctx, "scheduler") + taskMgr, err := storage.GetTaskManager() + require.NoError(t, err) + require.NotNil(t, taskMgr) + + // in this test, we only test scheduler manager, so we add a subtask takes 16 + // slots to avoid reserve by slots, and make sure below test cases works. + serverInfos, err := infosync.GetAllServerInfo(ctx) + require.NoError(t, err) + for _, s := range serverInfos { + execID := disttaskutil.GenerateExecID(s) + testutil.InsertSubtask(t, taskMgr, 1000000, proto.StepOne, execID, []byte(""), proto.SubtaskStatePending, proto.TaskTypeExample, 16) + } + concurrencies := []int{4, 6, 16, 2, 4, 4} + waitChannels := make([]chan struct{}, len(concurrencies)) + for i := range waitChannels { + waitChannels[i] = make(chan struct{}) + } + var counter atomic.Int32 + scheduler.RegisterSchedulerFactory(proto.TaskTypeExample, + func(ctx context.Context, taskMgr scheduler.TaskManager, nodeMgr *scheduler.NodeManager, task *proto.Task) scheduler.Scheduler { + idx := counter.Load() + mockScheduler = mock.NewMockScheduler(ctrl) + mockScheduler.EXPECT().Init().Return(nil) + mockScheduler.EXPECT().ScheduleTask().Do(func() { + require.NoError(t, taskMgr.WithNewSession(func(se sessionctx.Context) error { + _, err := sqlexec.ExecSQL(ctx, se, "update mysql.tidb_global_task set state=%?, step=%? where id=%?", + proto.TaskStateRunning, proto.StepOne, task.ID) + return err + })) + <-waitChannels[idx] + require.NoError(t, taskMgr.WithNewSession(func(se sessionctx.Context) error { + _, err := sqlexec.ExecSQL(ctx, se, "update mysql.tidb_global_task set state=%?, step=%? where id=%?", + proto.TaskStateSucceed, proto.StepDone, task.ID) + return err + })) + }) + mockScheduler.EXPECT().Close() + counter.Add(1) + return mockScheduler + }, +>>>>>>> 99f0349bfb6 (disttask: fix failed step is taken as success (#49971)) ) mockSubtaskExecutor.EXPECT().RunSubtask(gomock.Any(), gomock.Any()).Return(annotatedError) mockSubtaskExecutor.EXPECT().Cleanup(gomock.Any()).Return(nil) diff --git a/pkg/disttask/framework/storage/BUILD.bazel b/pkg/disttask/framework/storage/BUILD.bazel index bc9f7331912d9..817d3ce767d8b 100644 --- a/pkg/disttask/framework/storage/BUILD.bazel +++ b/pkg/disttask/framework/storage/BUILD.bazel @@ -31,7 +31,11 @@ go_test( srcs = ["table_test.go"], flaky = True, race = "on", +<<<<<<< HEAD shard_count = 8, +======= + shard_count = 14, +>>>>>>> 99f0349bfb6 (disttask: fix failed step is taken as success (#49971)) deps = [ ":storage", "//pkg/disttask/framework/proto", diff --git a/pkg/disttask/framework/storage/table_test.go b/pkg/disttask/framework/storage/table_test.go index 260e08d98ed42..a0ba8b4de4340 100644 --- a/pkg/disttask/framework/storage/table_test.go +++ b/pkg/disttask/framework/storage/table_test.go @@ -126,6 +126,274 @@ func TestGlobalTaskTable(t *testing.T) { cancelling, err = gm.IsGlobalTaskCancelling(ctx, id) require.NoError(t, err) require.True(t, cancelling) +<<<<<<< HEAD +======= + + id, err = gm.CreateTask(ctx, "key-fail", "test2", 4, []byte("test2")) + require.NoError(t, err) + // state not right, update nothing + require.NoError(t, gm.FailTask(ctx, id, proto.TaskStateRunning, errors.New("test error"))) + task, err = gm.GetTaskByID(ctx, id) + require.NoError(t, err) + require.Equal(t, proto.TaskStatePending, task.State) + require.Nil(t, task.Error) + require.NoError(t, gm.FailTask(ctx, id, proto.TaskStatePending, errors.New("test error"))) + task, err = gm.GetTaskByID(ctx, id) + require.NoError(t, err) + require.Equal(t, proto.TaskStateFailed, task.State) + require.ErrorContains(t, task.Error, "test error") + + // succeed a pending task, no effect + id, err = gm.CreateTask(ctx, "key-success", "test", 4, []byte("test")) + require.NoError(t, err) + require.NoError(t, gm.SucceedTask(ctx, id)) + task, err = gm.GetTaskByID(ctx, id) + require.NoError(t, err) + checkTaskStateStep(t, task, proto.TaskStatePending, proto.StepInit) + // succeed a running task + require.NoError(t, gm.SwitchTaskStep(ctx, task, proto.TaskStateRunning, proto.StepOne, nil)) + task, err = gm.GetTaskByID(ctx, id) + require.NoError(t, err) + checkTaskStateStep(t, task, proto.TaskStateRunning, proto.StepOne) + startTime := time.Unix(time.Now().Unix(), 0) + require.NoError(t, gm.SucceedTask(ctx, id)) + task, err = gm.GetTaskByID(ctx, id) + require.NoError(t, err) + checkTaskStateStep(t, task, proto.TaskStateSucceed, proto.StepDone) + require.GreaterOrEqual(t, task.StateUpdateTime, startTime) +} + +func checkAfterSwitchStep(t *testing.T, startTime time.Time, task *proto.Task, subtasks []*proto.Subtask, step proto.Step) { + tm, err := storage.GetTaskManager() + require.NoError(t, err) + ctx := context.Background() + ctx = util.WithInternalSourceType(ctx, "table_test") + + checkTaskStateStep(t, task, proto.TaskStateRunning, step) + require.GreaterOrEqual(t, task.StartTime, startTime) + require.GreaterOrEqual(t, task.StateUpdateTime, startTime) + gotSubtasks, err := tm.GetSubtasksByStepAndState(ctx, task.ID, task.Step, proto.TaskStatePending) + require.NoError(t, err) + require.Len(t, gotSubtasks, len(subtasks)) + sort.Slice(gotSubtasks, func(i, j int) bool { + return gotSubtasks[i].Ordinal < gotSubtasks[j].Ordinal + }) + for i := 0; i < len(gotSubtasks); i++ { + subtask := gotSubtasks[i] + require.Equal(t, []byte(fmt.Sprintf("%d", i)), subtask.Meta) + require.Equal(t, i+1, subtask.Ordinal) + require.Equal(t, 11, subtask.Concurrency) + require.Equal(t, ":4000", subtask.ExecID) + require.Equal(t, proto.TaskTypeExample, subtask.Type) + require.GreaterOrEqual(t, subtask.CreateTime, startTime) + } +} + +func TestSwitchTaskStep(t *testing.T) { + store, tm, ctx := testutil.InitTableTest(t) + tk := testkit.NewTestKit(t, store) + + require.NoError(t, tm.StartManager(ctx, ":4000", "")) + taskID, err := tm.CreateTask(ctx, "key1", "test", 4, []byte("test")) + require.NoError(t, err) + task, err := tm.GetTaskByID(ctx, taskID) + require.NoError(t, err) + checkTaskStateStep(t, task, proto.TaskStatePending, proto.StepInit) + subtasksStepOne := make([]*proto.Subtask, 3) + for i := 0; i < len(subtasksStepOne); i++ { + subtasksStepOne[i] = proto.NewSubtask(proto.StepOne, taskID, proto.TaskTypeExample, + ":4000", 11, []byte(fmt.Sprintf("%d", i)), i+1) + } + startTime := time.Unix(time.Now().Unix(), 0) + task.Meta = []byte("changed meta") + require.NoError(t, tm.SwitchTaskStep(ctx, task, proto.TaskStateRunning, proto.StepOne, subtasksStepOne)) + task, err = tm.GetTaskByID(ctx, taskID) + require.NoError(t, err) + require.Equal(t, []byte("changed meta"), task.Meta) + checkAfterSwitchStep(t, startTime, task, subtasksStepOne, proto.StepOne) + // switch step again, no effect + require.NoError(t, tm.SwitchTaskStep(ctx, task, proto.TaskStateRunning, proto.StepOne, subtasksStepOne)) + task, err = tm.GetTaskByID(ctx, taskID) + require.NoError(t, err) + checkAfterSwitchStep(t, startTime, task, subtasksStepOne, proto.StepOne) + // switch step to step two + time.Sleep(time.Second) + taskStartTime := task.StartTime + subtasksStepTwo := make([]*proto.Subtask, 3) + for i := 0; i < len(subtasksStepTwo); i++ { + subtasksStepTwo[i] = proto.NewSubtask(proto.StepTwo, taskID, proto.TaskTypeExample, + ":4000", 11, []byte(fmt.Sprintf("%d", i)), i+1) + } + require.NoError(t, tk.Session().GetSessionVars().SetSystemVar(variable.TiDBMemQuotaQuery, "1024")) + require.NoError(t, tm.SwitchTaskStep(ctx, task, proto.TaskStateRunning, proto.StepTwo, subtasksStepTwo)) + value, ok := tk.Session().GetSessionVars().GetSystemVar(variable.TiDBMemQuotaQuery) + require.True(t, ok) + require.Equal(t, "1024", value) + task, err = tm.GetTaskByID(ctx, taskID) + require.NoError(t, err) + // start time should not change + require.Equal(t, taskStartTime, task.StartTime) + checkAfterSwitchStep(t, startTime, task, subtasksStepTwo, proto.StepTwo) +} + +func TestSwitchTaskStepInBatch(t *testing.T) { + store, tm, ctx := testutil.InitTableTest(t) + tk := testkit.NewTestKit(t, store) + + require.NoError(t, tm.StartManager(ctx, ":4000", "")) + // normal flow + prepare := func(taskKey string) (*proto.Task, []*proto.Subtask) { + taskID, err := tm.CreateTask(ctx, taskKey, "test", 4, []byte("test")) + require.NoError(t, err) + task, err := tm.GetTaskByID(ctx, taskID) + require.NoError(t, err) + checkTaskStateStep(t, task, proto.TaskStatePending, proto.StepInit) + subtasks := make([]*proto.Subtask, 3) + for i := 0; i < len(subtasks); i++ { + subtasks[i] = proto.NewSubtask(proto.StepOne, taskID, proto.TaskTypeExample, + ":4000", 11, []byte(fmt.Sprintf("%d", i)), i+1) + } + return task, subtasks + } + startTime := time.Unix(time.Now().Unix(), 0) + task1, subtasks1 := prepare("key1") + require.NoError(t, tm.SwitchTaskStepInBatch(ctx, task1, proto.TaskStateRunning, proto.StepOne, subtasks1)) + task1, err := tm.GetTaskByID(ctx, task1.ID) + require.NoError(t, err) + checkAfterSwitchStep(t, startTime, task1, subtasks1, proto.StepOne) + + // mock another dispatcher inserted some subtasks + require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/pkg/disttask/framework/storage/waitBeforeInsertSubtasks", `1*return(true)`)) + t.Cleanup(func() { + require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/pkg/disttask/framework/storage/waitBeforeInsertSubtasks")) + }) + task2, subtasks2 := prepare("key2") + go func() { + storage.TestChannel <- struct{}{} + tk2 := testkit.NewTestKit(t, store) + subtask := subtasks2[0] + _, err = sqlexec.ExecSQL(ctx, tk2.Session(), ` + insert into mysql.tidb_background_subtask( + step, task_key, exec_id, meta, state, type, concurrency, ordinal, create_time, checkpoint, summary) + values (%?, %?, %?, %?, %?, %?, %?, %?, CURRENT_TIMESTAMP(), '{}', '{}')`, + subtask.Step, subtask.TaskID, subtask.ExecID, subtask.Meta, + proto.TaskStatePending, proto.Type2Int(subtask.Type), subtask.Concurrency, subtask.Ordinal) + require.NoError(t, err) + storage.TestChannel <- struct{}{} + }() + err = tm.SwitchTaskStepInBatch(ctx, task2, proto.TaskStateRunning, proto.StepOne, subtasks2) + require.ErrorIs(t, err, kv.ErrKeyExists) + task2, err = tm.GetTaskByID(ctx, task2.ID) + require.NoError(t, err) + checkTaskStateStep(t, task2, proto.TaskStatePending, proto.StepInit) + gotSubtasks, err := tm.GetSubtasksByStepAndState(ctx, task2.ID, proto.StepOne, proto.TaskStatePending) + require.NoError(t, err) + require.Len(t, gotSubtasks, 1) + // run again, should success + require.NoError(t, tm.SwitchTaskStepInBatch(ctx, task2, proto.TaskStateRunning, proto.StepOne, subtasks2)) + task2, err = tm.GetTaskByID(ctx, task2.ID) + require.NoError(t, err) + checkAfterSwitchStep(t, startTime, task2, subtasks2, proto.StepOne) + + // mock subtasks unstable + task3, subtasks3 := prepare("key3") + for i := 0; i < 2; i++ { + subtask := subtasks3[i] + _, err = sqlexec.ExecSQL(ctx, tk.Session(), ` + insert into mysql.tidb_background_subtask( + step, task_key, exec_id, meta, state, type, concurrency, ordinal, create_time, checkpoint, summary) + values (%?, %?, %?, %?, %?, %?, %?, %?, CURRENT_TIMESTAMP(), '{}', '{}')`, + subtask.Step, subtask.TaskID, subtask.ExecID, subtask.Meta, + proto.TaskStatePending, proto.Type2Int(subtask.Type), subtask.Concurrency, subtask.Ordinal) + require.NoError(t, err) + } + err = tm.SwitchTaskStepInBatch(ctx, task3, proto.TaskStateRunning, proto.StepOne, subtasks3[:1]) + require.ErrorIs(t, err, storage.ErrUnstableSubtasks) + require.ErrorContains(t, err, "expected 1, got 2") +} + +func TestGetTopUnfinishedTasks(t *testing.T) { + _, gm, ctx := testutil.InitTableTest(t) + + require.NoError(t, gm.StartManager(ctx, ":4000", "")) + taskStates := []proto.TaskState{ + proto.TaskStateSucceed, + proto.TaskStatePending, + proto.TaskStateRunning, + proto.TaskStateReverting, + proto.TaskStateCancelling, + proto.TaskStatePausing, + proto.TaskStateResuming, + proto.TaskStateFailed, + proto.TaskStatePending, + proto.TaskStatePending, + proto.TaskStatePending, + proto.TaskStatePending, + } + for i, state := range taskStates { + taskKey := fmt.Sprintf("key/%d", i) + _, err := gm.CreateTask(ctx, taskKey, "test", 4, []byte("test")) + require.NoError(t, err) + require.NoError(t, gm.WithNewSession(func(se sessionctx.Context) error { + _, err := se.(sqlexec.SQLExecutor).ExecuteInternal(ctx, ` + update mysql.tidb_global_task set state = %? where task_key = %?`, + state, taskKey) + return err + })) + } + // adjust task order + require.NoError(t, gm.WithNewSession(func(se sessionctx.Context) error { + _, err := se.(sqlexec.SQLExecutor).ExecuteInternal(ctx, ` + update mysql.tidb_global_task set create_time = current_timestamp`) + return err + })) + require.NoError(t, gm.WithNewSession(func(se sessionctx.Context) error { + _, err := se.(sqlexec.SQLExecutor).ExecuteInternal(ctx, ` + update mysql.tidb_global_task + set create_time = timestampadd(minute, -10, current_timestamp) + where task_key = 'key/5'`) + return err + })) + require.NoError(t, gm.WithNewSession(func(se sessionctx.Context) error { + _, err := se.(sqlexec.SQLExecutor).ExecuteInternal(ctx, ` + update mysql.tidb_global_task set priority = 100 where task_key = 'key/6'`) + return err + })) + require.NoError(t, gm.WithNewSession(func(se sessionctx.Context) error { + rs, err := sqlexec.ExecSQL(ctx, se, ` + select count(1) from mysql.tidb_global_task`) + require.Len(t, rs, 1) + require.Equal(t, int64(12), rs[0].GetInt64(0)) + return err + })) + tasks, err := gm.GetTopUnfinishedTasks(ctx) + require.NoError(t, err) + require.Len(t, tasks, 8) + taskKeys := make([]string, 0, len(tasks)) + for _, task := range tasks { + taskKeys = append(taskKeys, task.Key) + // not filled + require.Empty(t, task.Meta) + } + require.Equal(t, []string{"key/6", "key/5", "key/1", "key/2", "key/3", "key/4", "key/8", "key/9"}, taskKeys) +} + +func TestGetUsedSlotsOnNodes(t *testing.T) { + _, sm, ctx := testutil.InitTableTest(t) + + testutil.InsertSubtask(t, sm, 1, proto.StepOne, "tidb-1", []byte(""), proto.SubtaskStateRunning, "test", 12) + testutil.InsertSubtask(t, sm, 1, proto.StepOne, "tidb-2", []byte(""), proto.SubtaskStatePending, "test", 12) + testutil.InsertSubtask(t, sm, 2, proto.StepOne, "tidb-2", []byte(""), proto.SubtaskStatePending, "test", 8) + testutil.InsertSubtask(t, sm, 3, proto.StepOne, "tidb-3", []byte(""), proto.SubtaskStatePending, "test", 8) + testutil.InsertSubtask(t, sm, 4, proto.StepOne, "tidb-3", []byte(""), proto.SubtaskStateFailed, "test", 8) + slotsOnNodes, err := sm.GetUsedSlotsOnNodes(ctx) + require.NoError(t, err) + require.Equal(t, map[string]int{ + "tidb-1": 12, + "tidb-2": 20, + "tidb-3": 8, + }, slotsOnNodes) +>>>>>>> 99f0349bfb6 (disttask: fix failed step is taken as success (#49971)) } func TestSubTaskTable(t *testing.T) { @@ -167,13 +435,21 @@ func TestSubTaskTable(t *testing.T) { require.NoError(t, err) require.Len(t, ids, 0) +<<<<<<< HEAD cnt, err := sm.GetSubtaskInStatesCnt(ctx, 1, proto.TaskStatePending) +======= + cntByStates, err := sm.GetSubtaskCntGroupByStates(ctx, 1, proto.StepInit) +>>>>>>> 99f0349bfb6 (disttask: fix failed step is taken as success (#49971)) require.NoError(t, err) - require.Equal(t, int64(1), cnt) + require.Equal(t, int64(1), cntByStates[proto.SubtaskStatePending]) +<<<<<<< HEAD cnt, err = sm.GetSubtaskInStatesCnt(ctx, 1, proto.TaskStatePending, proto.TaskStateRevertPending) +======= + cntByStates, err = sm.GetSubtaskCntGroupByStates(ctx, 1, proto.StepInit) +>>>>>>> 99f0349bfb6 (disttask: fix failed step is taken as success (#49971)) require.NoError(t, err) - require.Equal(t, int64(1), cnt) + require.Equal(t, int64(1), cntByStates[proto.SubtaskStatePending]+cntByStates[proto.SubtaskStateRevertPending]) ok, err := sm.HasSubtasksInStates(ctx, "tidb1", 1, proto.StepInit, proto.TaskStatePending) require.NoError(t, err) @@ -205,9 +481,13 @@ func TestSubTaskTable(t *testing.T) { require.Equal(t, proto.TaskStateCancelling, subtask2.State) require.Greater(t, subtask2.UpdateTime, subtask.UpdateTime) +<<<<<<< HEAD cnt, err = sm.GetSubtaskInStatesCnt(ctx, 1, proto.TaskStatePending) +======= + cntByStates, err = sm.GetSubtaskCntGroupByStates(ctx, 1, proto.StepInit) +>>>>>>> 99f0349bfb6 (disttask: fix failed step is taken as success (#49971)) require.NoError(t, err) - require.Equal(t, int64(0), cnt) + require.Equal(t, int64(0), cntByStates[proto.SubtaskStatePending]) ok, err = sm.HasSubtasksInStates(ctx, "tidb1", 1, proto.StepInit, proto.TaskStatePending) require.NoError(t, err) @@ -223,9 +503,13 @@ func TestSubTaskTable(t *testing.T) { err = sm.AddNewSubTask(ctx, 2, proto.StepInit, "tidb1", []byte("test"), proto.TaskTypeExample, true) require.NoError(t, err) +<<<<<<< HEAD cnt, err = sm.GetSubtaskInStatesCnt(ctx, 2, proto.TaskStateRevertPending) +======= + cntByStates, err = sm.GetSubtaskCntGroupByStates(ctx, 2, proto.StepInit) +>>>>>>> 99f0349bfb6 (disttask: fix failed step is taken as success (#49971)) require.NoError(t, err) - require.Equal(t, int64(1), cnt) + require.Equal(t, int64(1), cntByStates[proto.SubtaskStateRevertPending]) subtasks, err := sm.GetSucceedSubtasksByStep(ctx, 2, proto.StepInit) require.NoError(t, err) @@ -354,9 +638,14 @@ func TestBothGlobalAndSubTaskTable(t *testing.T) { require.Equal(t, proto.TaskTypeExample, subtask2.Type) require.Equal(t, []byte("m2"), subtask2.Meta) +<<<<<<< HEAD cnt, err := sm.GetSubtaskInStatesCnt(ctx, 1, proto.TaskStatePending) +======= + cntByStates, err := sm.GetSubtaskCntGroupByStates(ctx, 1, proto.StepInit) +>>>>>>> 99f0349bfb6 (disttask: fix failed step is taken as success (#49971)) require.NoError(t, err) - require.Equal(t, int64(2), cnt) + require.Len(t, cntByStates, 1) + require.Equal(t, int64(2), cntByStates[proto.SubtaskStatePending]) // isSubTaskRevert: true prevState = task.State @@ -395,9 +684,13 @@ func TestBothGlobalAndSubTaskTable(t *testing.T) { require.Equal(t, proto.TaskTypeExample, subtask2.Type) require.Equal(t, []byte("m4"), subtask2.Meta) +<<<<<<< HEAD cnt, err = sm.GetSubtaskInStatesCnt(ctx, 1, proto.TaskStateRevertPending) +======= + cntByStates, err = sm.GetSubtaskCntGroupByStates(ctx, 1, proto.StepInit) +>>>>>>> 99f0349bfb6 (disttask: fix failed step is taken as success (#49971)) require.NoError(t, err) - require.Equal(t, int64(2), cnt) + require.Equal(t, int64(2), cntByStates[proto.SubtaskStateRevertPending]) // test transactional require.NoError(t, sm.DeleteSubtasksByTaskID(ctx, 1)) @@ -415,9 +708,34 @@ func TestBothGlobalAndSubTaskTable(t *testing.T) { require.NoError(t, err) require.Equal(t, proto.TaskStateReverting, task.State) +<<<<<<< HEAD cnt, err = sm.GetSubtaskInStatesCnt(ctx, 1, proto.TaskStateRevertPending) +======= + cntByStates, err = sm.GetSubtaskCntGroupByStates(ctx, 1, proto.StepInit) +>>>>>>> 99f0349bfb6 (disttask: fix failed step is taken as success (#49971)) require.NoError(t, err) - require.Equal(t, int64(0), cnt) + require.Equal(t, int64(0), cntByStates[proto.SubtaskStateRevertPending]) +} + +func TestGetSubtaskCntByStates(t *testing.T) { + _, sm, ctx := testutil.InitTableTest(t) + testutil.InsertSubtask(t, sm, 1, proto.StepOne, "tidb1", nil, proto.SubtaskStatePending, "test", 1) + testutil.InsertSubtask(t, sm, 1, proto.StepOne, "tidb1", nil, proto.SubtaskStatePending, "test", 1) + testutil.InsertSubtask(t, sm, 1, proto.StepOne, "tidb1", nil, proto.SubtaskStateRunning, "test", 1) + testutil.InsertSubtask(t, sm, 1, proto.StepOne, "tidb1", nil, proto.SubtaskStateSucceed, "test", 1) + testutil.InsertSubtask(t, sm, 1, proto.StepOne, "tidb1", nil, proto.SubtaskStateFailed, "test", 1) + testutil.InsertSubtask(t, sm, 1, proto.StepTwo, "tidb1", nil, proto.SubtaskStateFailed, "test", 1) + cntByStates, err := sm.GetSubtaskCntGroupByStates(ctx, 1, proto.StepOne) + require.NoError(t, err) + require.Len(t, cntByStates, 4) + require.Equal(t, int64(2), cntByStates[proto.SubtaskStatePending]) + require.Equal(t, int64(1), cntByStates[proto.SubtaskStateRunning]) + require.Equal(t, int64(1), cntByStates[proto.SubtaskStateSucceed]) + require.Equal(t, int64(1), cntByStates[proto.SubtaskStateFailed]) + cntByStates, err = sm.GetSubtaskCntGroupByStates(ctx, 1, proto.StepTwo) + require.NoError(t, err) + require.Len(t, cntByStates, 1) + require.Equal(t, int64(1), cntByStates[proto.SubtaskStateFailed]) } func TestDistFrameworkMeta(t *testing.T) { @@ -584,26 +902,42 @@ func TestPauseAndResume(t *testing.T) { require.NoError(t, sm.AddNewSubTask(ctx, 1, proto.StepInit, "tidb1", []byte("test"), proto.TaskTypeExample, false)) // 1.1 pause all subtasks. require.NoError(t, sm.PauseSubtasks(ctx, "tidb1", 1)) +<<<<<<< HEAD cnt, err := sm.GetSubtaskInStatesCnt(ctx, 1, proto.TaskStatePaused) +======= + cntByStates, err := sm.GetSubtaskCntGroupByStates(ctx, 1, proto.StepInit) +>>>>>>> 99f0349bfb6 (disttask: fix failed step is taken as success (#49971)) require.NoError(t, err) - require.Equal(t, int64(3), cnt) + require.Equal(t, int64(3), cntByStates[proto.SubtaskStatePaused]) // 1.2 resume all subtasks. require.NoError(t, sm.ResumeSubtasks(ctx, 1)) +<<<<<<< HEAD cnt, err = sm.GetSubtaskInStatesCnt(ctx, 1, proto.TaskStatePending) +======= + cntByStates, err = sm.GetSubtaskCntGroupByStates(ctx, 1, proto.StepInit) +>>>>>>> 99f0349bfb6 (disttask: fix failed step is taken as success (#49971)) require.NoError(t, err) - require.Equal(t, int64(3), cnt) + require.Equal(t, int64(3), cntByStates[proto.SubtaskStatePending]) // 2.1 pause 2 subtasks. sm.UpdateSubtaskStateAndError(ctx, "tidb1", 1, proto.TaskStateSucceed, nil) require.NoError(t, sm.PauseSubtasks(ctx, "tidb1", 1)) +<<<<<<< HEAD cnt, err = sm.GetSubtaskInStatesCnt(ctx, 1, proto.TaskStatePaused) +======= + cntByStates, err = sm.GetSubtaskCntGroupByStates(ctx, 1, proto.StepInit) +>>>>>>> 99f0349bfb6 (disttask: fix failed step is taken as success (#49971)) require.NoError(t, err) - require.Equal(t, int64(2), cnt) + require.Equal(t, int64(2), cntByStates[proto.SubtaskStatePaused]) // 2.2 resume 2 subtasks. require.NoError(t, sm.ResumeSubtasks(ctx, 1)) +<<<<<<< HEAD cnt, err = sm.GetSubtaskInStatesCnt(ctx, 1, proto.TaskStatePending) +======= + cntByStates, err = sm.GetSubtaskCntGroupByStates(ctx, 1, proto.StepInit) +>>>>>>> 99f0349bfb6 (disttask: fix failed step is taken as success (#49971)) require.NoError(t, err) - require.Equal(t, int64(2), cnt) + require.Equal(t, int64(2), cntByStates[proto.SubtaskStatePending]) } func TestCancelAndExecIdChanged(t *testing.T) { diff --git a/pkg/disttask/framework/storage/task_table.go b/pkg/disttask/framework/storage/task_table.go index d5205ffce2254..cbfa0b1ec532a 100644 --- a/pkg/disttask/framework/storage/task_table.go +++ b/pkg/disttask/framework/storage/task_table.go @@ -488,24 +488,41 @@ func (stm *TaskManager) UpdateSubtaskRowCount(ctx context.Context, subtaskID int return err } +<<<<<<< HEAD // GetSubtaskInStatesCnt gets the subtask count in the states. func (stm *TaskManager) GetSubtaskInStatesCnt(ctx context.Context, taskID int64, states ...interface{}) (int64, error) { args := []interface{}{taskID} args = append(args, states...) rs, err := stm.executeSQLWithNewSession(ctx, `select count(*) from mysql.tidb_background_subtask where task_key = %? and state in (`+strings.Repeat("%?,", len(states)-1)+"%?)", args...) +======= +// GetSubtaskCntGroupByStates gets the subtask count by states. +func (stm *TaskManager) GetSubtaskCntGroupByStates(ctx context.Context, taskID int64, step proto.Step) (map[proto.SubtaskState]int64, error) { + rs, err := stm.executeSQLWithNewSession(ctx, ` + select state, count(*) + from mysql.tidb_background_subtask + where task_key = %? and step = %? + group by state`, + taskID, step) +>>>>>>> 99f0349bfb6 (disttask: fix failed step is taken as success (#49971)) if err != nil { - return 0, err + return nil, err } - return rs[0].GetInt64(0), nil + res := make(map[proto.SubtaskState]int64, len(rs)) + for _, r := range rs { + state := proto.SubtaskState(r.GetString(0)) + res[state] = r.GetInt64(1) + } + + return res, nil } // CollectSubTaskError collects the subtask error. func (stm *TaskManager) CollectSubTaskError(ctx context.Context, taskID int64) ([]error, error) { rs, err := stm.executeSQLWithNewSession(ctx, `select error from mysql.tidb_background_subtask - where task_key = %? AND state in (%?, %?)`, taskID, proto.TaskStateFailed, proto.TaskStateCanceled) + where task_key = %? AND state in (%?, %?)`, taskID, proto.SubtaskStateFailed, proto.SubtaskStateCanceled) if err != nil { return nil, err } diff --git a/pkg/disttask/framework/taskexecutor/task_executor.go b/pkg/disttask/framework/taskexecutor/task_executor.go new file mode 100644 index 0000000000000..fd585a259daef --- /dev/null +++ b/pkg/disttask/framework/taskexecutor/task_executor.go @@ -0,0 +1,653 @@ +// Copyright 2023 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package taskexecutor + +import ( + "context" + "sync" + "time" + + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/tidb/br/pkg/lightning/common" + "github.com/pingcap/tidb/br/pkg/lightning/log" + "github.com/pingcap/tidb/pkg/disttask/framework/handle" + "github.com/pingcap/tidb/pkg/disttask/framework/proto" + "github.com/pingcap/tidb/pkg/disttask/framework/scheduler" + "github.com/pingcap/tidb/pkg/disttask/framework/storage" + "github.com/pingcap/tidb/pkg/disttask/framework/taskexecutor/execute" + "github.com/pingcap/tidb/pkg/domain/infosync" + "github.com/pingcap/tidb/pkg/metrics" + "github.com/pingcap/tidb/pkg/util/backoff" + "github.com/pingcap/tidb/pkg/util/gctuner" + "github.com/pingcap/tidb/pkg/util/logutil" + "github.com/pingcap/tidb/pkg/util/memory" + "go.uber.org/zap" +) + +const ( + // DefaultCheckSubtaskCanceledInterval is the default check interval for cancel cancelled subtasks. + DefaultCheckSubtaskCanceledInterval = 2 * time.Second +) + +var ( + // ErrCancelSubtask is the cancel cause when cancelling subtasks. + ErrCancelSubtask = errors.New("cancel subtasks") + // ErrFinishSubtask is the cancel cause when TaskExecutor successfully processed subtasks. + ErrFinishSubtask = errors.New("finish subtasks") + // ErrFinishRollback is the cancel cause when TaskExecutor rollback successfully. + ErrFinishRollback = errors.New("finish rollback") + + // TestSyncChan is used to sync the test. + TestSyncChan = make(chan struct{}) +) + +// BaseTaskExecutor is the base implementation of TaskExecutor. +type BaseTaskExecutor struct { + // id, it's the same as server id now, i.e. host:port. + id string + taskID int64 + taskTable TaskTable + logCtx context.Context + // ctx from manager + ctx context.Context + Extension + + mu struct { + sync.RWMutex + err error + // handled indicates whether the error has been updated to one of the subtask. + handled bool + // runtimeCancel is used to cancel the Run/Rollback when error occurs. + runtimeCancel context.CancelCauseFunc + } +} + +// NewBaseTaskExecutor creates a new BaseTaskExecutor. +func NewBaseTaskExecutor(ctx context.Context, id string, taskID int64, taskTable TaskTable) *BaseTaskExecutor { + taskExecutorImpl := &BaseTaskExecutor{ + id: id, + taskID: taskID, + taskTable: taskTable, + ctx: ctx, + logCtx: logutil.WithFields(context.Background(), zap.Int64("task-id", taskID)), + } + return taskExecutorImpl +} + +func (s *BaseTaskExecutor) startCancelCheck(ctx context.Context, wg *sync.WaitGroup, cancelFn context.CancelCauseFunc) { + wg.Add(1) + go func() { + defer wg.Done() + ticker := time.NewTicker(DefaultCheckSubtaskCanceledInterval) + defer ticker.Stop() + for { + select { + case <-ctx.Done(): + logutil.Logger(s.logCtx).Info("task executor exits") + return + case <-ticker.C: + canceled, err := s.taskTable.IsTaskExecutorCanceled(ctx, s.id, s.taskID) + if err != nil { + continue + } + if canceled { + logutil.Logger(s.logCtx).Info("taskExecutor canceled") + if cancelFn != nil { + // subtask transferred to other tidb, don't mark subtask as canceled. + // Should not change the subtask's state. + cancelFn(nil) + } + } + } + } + }() +} + +// Init implements the TaskExecutor interface. +func (*BaseTaskExecutor) Init(_ context.Context) error { + return nil +} + +// Run start to fetch and run all subtasks of the task on the node. +func (s *BaseTaskExecutor) Run(ctx context.Context, task *proto.Task) (err error) { + defer func() { + if r := recover(); r != nil { + logutil.Logger(s.logCtx).Error("BaseTaskExecutor panicked", zap.Any("recover", r), zap.Stack("stack")) + err4Panic := errors.Errorf("%v", r) + err1 := s.updateErrorToSubtask(ctx, task.ID, err4Panic) + if err == nil { + err = err1 + } + } + }() + err = s.run(ctx, task) + if s.mu.handled { + return err + } + if err == nil { + return nil + } + return s.updateErrorToSubtask(ctx, task.ID, err) +} + +func (s *BaseTaskExecutor) run(ctx context.Context, task *proto.Task) (resErr error) { + if ctx.Err() != nil { + s.onError(ctx.Err()) + return s.getError() + } + runCtx, runCancel := context.WithCancelCause(ctx) + defer runCancel(ErrFinishSubtask) + s.registerCancelFunc(runCancel) + s.resetError() + stepLogger := log.BeginTask(logutil.Logger(s.logCtx).With( + zap.Any("step", task.Step), + zap.Int("concurrency", task.Concurrency), + zap.Float64("mem-limit-percent", gctuner.GlobalMemoryLimitTuner.GetPercentage()), + zap.String("server-mem-limit", memory.ServerMemoryLimitOriginText.Load()), + ), "execute task") + // log as info level, subtask might be cancelled, let caller check it. + defer func() { + stepLogger.End(zap.InfoLevel, resErr) + }() + + summary, cleanup, err := runSummaryCollectLoop(ctx, task, s.taskTable) + if err != nil { + s.onError(err) + return s.getError() + } + defer cleanup() + subtaskExecutor, err := s.GetSubtaskExecutor(ctx, task, summary) + if err != nil { + s.onError(err) + return s.getError() + } + + failpoint.Inject("mockExecSubtaskInitEnvErr", func() { + failpoint.Return(errors.New("mockExecSubtaskInitEnvErr")) + }) + if err := subtaskExecutor.Init(runCtx); err != nil { + s.onError(err) + return s.getError() + } + + var wg sync.WaitGroup + cancelCtx, checkCancel := context.WithCancel(ctx) + s.startCancelCheck(cancelCtx, &wg, runCancel) + + defer func() { + err := subtaskExecutor.Cleanup(runCtx) + if err != nil { + logutil.Logger(s.logCtx).Error("cleanup subtask exec env failed", zap.Error(err)) + } + checkCancel() + wg.Wait() + }() + + subtasks, err := s.taskTable.GetSubtasksByStepAndStates( + runCtx, s.id, task.ID, task.Step, + proto.SubtaskStatePending, proto.SubtaskStateRunning) + if err != nil { + s.onError(err) + if common.IsRetryableError(err) { + logutil.Logger(s.logCtx).Warn("met retryable error", zap.Error(err)) + return nil + } + return s.getError() + } + for _, subtask := range subtasks { + metrics.IncDistTaskSubTaskCnt(subtask) + metrics.StartDistTaskSubTask(subtask) + } + + for { + // check if any error occurs. + if err := s.getError(); err != nil { + break + } + if runCtx.Err() != nil { + logutil.Logger(s.logCtx).Info("taskExecutor runSubtask loop exit") + break + } + + subtask, err := s.taskTable.GetFirstSubtaskInStates(runCtx, s.id, task.ID, task.Step, + proto.SubtaskStatePending, proto.SubtaskStateRunning) + if err != nil { + logutil.Logger(s.logCtx).Warn("GetFirstSubtaskInStates meets error", zap.Error(err)) + continue + } + if subtask == nil { + failpoint.Inject("breakInTaskExecutorUT", func() { + failpoint.Break() + }) + newTask, err := s.taskTable.GetTaskByID(runCtx, task.ID) + if err != nil { + logutil.Logger(s.logCtx).Warn("GetTaskByID meets error", zap.Error(err)) + continue + } + // When the task move to next step or task state changes, the TaskExecutor should exit. + if newTask.Step != task.Step || newTask.State != task.State { + break + } + time.Sleep(checkTime) + continue + } + + if subtask.State == proto.SubtaskStateRunning { + if !s.IsIdempotent(subtask) { + logutil.Logger(s.logCtx).Info("subtask in running state and is not idempotent, fail it", + zap.Int64("subtask-id", subtask.ID)) + subtaskErr := errors.New("subtask in running state and is not idempotent") + s.onError(subtaskErr) + s.updateSubtaskStateAndError(runCtx, subtask, proto.SubtaskStateFailed, subtaskErr) + s.markErrorHandled() + break + } + } else { + // subtask.State == proto.TaskStatePending + err := s.startSubtaskAndUpdateState(runCtx, subtask) + if err != nil { + logutil.Logger(s.logCtx).Warn("startSubtaskAndUpdateState meets error", zap.Error(err)) + // should ignore ErrSubtaskNotFound + // since the err only indicate that the subtask not owned by current task executor. + if err == storage.ErrSubtaskNotFound { + continue + } + s.onError(err) + continue + } + } + + failpoint.Inject("mockCleanExecutor", func() { + v, ok := testContexts.Load(s.id) + if ok { + if v.(*TestContext).mockDown.Load() { + failpoint.Break() + } + } + }) + + failpoint.Inject("cancelBeforeRunSubtask", func() { + runCancel(nil) + }) + + s.runSubtask(runCtx, subtaskExecutor, subtask) + } + return s.getError() +} + +func (s *BaseTaskExecutor) runSubtask(ctx context.Context, subtaskExecutor execute.SubtaskExecutor, subtask *proto.Subtask) { + err := subtaskExecutor.RunSubtask(ctx, subtask) + failpoint.Inject("MockRunSubtaskCancel", func(val failpoint.Value) { + if val.(bool) { + err = ErrCancelSubtask + } + }) + + failpoint.Inject("MockRunSubtaskContextCanceled", func(val failpoint.Value) { + if val.(bool) { + err = context.Canceled + } + }) + + if err != nil { + s.onError(err) + } + + finished := s.markSubTaskCanceledOrFailed(ctx, subtask) + if finished { + return + } + + failpoint.Inject("mockTiDBDown", func(val failpoint.Value) { + logutil.Logger(s.logCtx).Info("trigger mockTiDBDown") + if s.id == val.(string) || s.id == ":4001" || s.id == ":4002" { + v, ok := testContexts.Load(s.id) + if ok { + v.(*TestContext).TestSyncSubtaskRun <- struct{}{} + v.(*TestContext).mockDown.Store(true) + logutil.Logger(s.logCtx).Info("mockTiDBDown") + time.Sleep(2 * time.Second) + failpoint.Return() + } + } + }) + failpoint.Inject("mockTiDBDown2", func() { + if s.id == ":4003" && subtask.Step == proto.StepTwo { + v, ok := testContexts.Load(s.id) + if ok { + v.(*TestContext).TestSyncSubtaskRun <- struct{}{} + v.(*TestContext).mockDown.Store(true) + time.Sleep(2 * time.Second) + return + } + } + }) + + failpoint.Inject("mockTiDBPartitionThenResume", func(val failpoint.Value) { + if val.(bool) && (s.id == ":4000" || s.id == ":4001" || s.id == ":4002") { + infosync.MockGlobalServerInfoManagerEntry.DeleteByExecID(s.id) + time.Sleep(20 * time.Second) + } + }) + + failpoint.Inject("MockExecutorRunErr", func(val failpoint.Value) { + if val.(bool) { + s.onError(errors.New("MockExecutorRunErr")) + } + }) + failpoint.Inject("MockExecutorRunCancel", func(val failpoint.Value) { + if taskID, ok := val.(int); ok { + mgr, err := storage.GetTaskManager() + if err != nil { + logutil.BgLogger().Error("get task manager failed", zap.Error(err)) + } else { + err = mgr.CancelTask(ctx, int64(taskID)) + if err != nil { + logutil.BgLogger().Error("cancel task failed", zap.Error(err)) + } + } + } + }) + s.onSubtaskFinished(ctx, subtaskExecutor, subtask) +} + +func (s *BaseTaskExecutor) onSubtaskFinished(ctx context.Context, executor execute.SubtaskExecutor, subtask *proto.Subtask) { + if err := s.getError(); err == nil { + if err = executor.OnFinished(ctx, subtask); err != nil { + s.onError(err) + } + } + failpoint.Inject("MockSubtaskFinishedCancel", func(val failpoint.Value) { + if val.(bool) { + s.onError(ErrCancelSubtask) + } + }) + + finished := s.markSubTaskCanceledOrFailed(ctx, subtask) + if finished { + return + } + + s.finishSubtaskAndUpdateState(ctx, subtask) + + finished = s.markSubTaskCanceledOrFailed(ctx, subtask) + if finished { + return + } + + failpoint.Inject("syncAfterSubtaskFinish", func() { + TestSyncChan <- struct{}{} + <-TestSyncChan + }) +} + +// Rollback rollbacks the subtask. +func (s *BaseTaskExecutor) Rollback(ctx context.Context, task *proto.Task) error { + rollbackCtx, rollbackCancel := context.WithCancelCause(ctx) + defer rollbackCancel(ErrFinishRollback) + s.registerCancelFunc(rollbackCancel) + + s.resetError() + logutil.Logger(s.logCtx).Info("taskExecutor rollback a step", zap.Any("step", task.Step)) + + // We should cancel all subtasks before rolling back + for { + subtask, err := s.taskTable.GetFirstSubtaskInStates(ctx, s.id, task.ID, task.Step, + proto.SubtaskStatePending, proto.SubtaskStateRunning) + if err != nil { + s.onError(err) + return s.getError() + } + + if subtask == nil { + break + } + + s.updateSubtaskStateAndError(ctx, subtask, proto.SubtaskStateCanceled, nil) + if err = s.getError(); err != nil { + return err + } + } + + executor, err := s.GetSubtaskExecutor(ctx, task, nil) + if err != nil { + s.onError(err) + return s.getError() + } + subtask, err := s.taskTable.GetFirstSubtaskInStates(ctx, s.id, task.ID, task.Step, + proto.SubtaskStateRevertPending, proto.SubtaskStateReverting) + if err != nil { + s.onError(err) + return s.getError() + } + if subtask == nil { + logutil.BgLogger().Warn("taskExecutor rollback a step, but no subtask in revert_pending state", zap.Any("step", task.Step)) + return nil + } + if subtask.State == proto.SubtaskStateRevertPending { + s.updateSubtaskStateAndError(ctx, subtask, proto.SubtaskStateReverting, nil) + } + if err := s.getError(); err != nil { + return err + } + + // right now all impl of Rollback is empty, so we don't check idempotent here. + // will try to remove this rollback completely in the future. + err = executor.Rollback(rollbackCtx) + if err != nil { + s.updateSubtaskStateAndError(ctx, subtask, proto.SubtaskStateRevertFailed, nil) + s.onError(err) + } else { + s.updateSubtaskStateAndError(ctx, subtask, proto.SubtaskStateReverted, nil) + } + return s.getError() +} + +// Pause pause the TaskExecutor's subtasks. +func (s *BaseTaskExecutor) Pause(ctx context.Context, task *proto.Task) error { + logutil.Logger(s.logCtx).Info("taskExecutor pause subtasks") + // pause all running subtasks. + if err := s.taskTable.PauseSubtasks(ctx, s.id, task.ID); err != nil { + s.onError(err) + return s.getError() + } + return nil +} + +// Close closes the TaskExecutor when all the subtasks are complete. +func (*BaseTaskExecutor) Close() { +} + +func runSummaryCollectLoop( + ctx context.Context, + task *proto.Task, + taskTable TaskTable, +) (summary *execute.Summary, cleanup func(), err error) { + taskMgr, ok := taskTable.(*storage.TaskManager) + if !ok { + return nil, func() {}, nil + } + opt, ok := taskTypes[task.Type] + if !ok { + return nil, func() {}, errors.Errorf("taskExecutor option for type %s not found", task.Type) + } + if opt.Summary != nil { + go opt.Summary.UpdateRowCountLoop(ctx, taskMgr) + return opt.Summary, func() { + opt.Summary.PersistRowCount(ctx, taskMgr) + }, nil + } + return nil, func() {}, nil +} + +func (s *BaseTaskExecutor) registerCancelFunc(cancel context.CancelCauseFunc) { + s.mu.Lock() + defer s.mu.Unlock() + s.mu.runtimeCancel = cancel +} + +func (s *BaseTaskExecutor) onError(err error) { + if err == nil { + return + } + err = errors.Trace(err) + logutil.Logger(s.logCtx).Error("onError", zap.Error(err), zap.Stack("stack")) + s.mu.Lock() + defer s.mu.Unlock() + + if s.mu.err == nil { + s.mu.err = err + logutil.Logger(s.logCtx).Error("taskExecutor met first error", zap.Error(err)) + } + + if s.mu.runtimeCancel != nil { + s.mu.runtimeCancel(err) + } +} + +func (s *BaseTaskExecutor) markErrorHandled() { + s.mu.Lock() + defer s.mu.Unlock() + s.mu.handled = true +} + +func (s *BaseTaskExecutor) getError() error { + s.mu.RLock() + defer s.mu.RUnlock() + return s.mu.err +} + +func (s *BaseTaskExecutor) resetError() { + s.mu.Lock() + defer s.mu.Unlock() + s.mu.err = nil + s.mu.handled = false +} + +func (s *BaseTaskExecutor) startSubtaskAndUpdateState(ctx context.Context, subtask *proto.Subtask) error { + err := s.startSubtask(ctx, subtask.ID) + if err == nil { + metrics.DecDistTaskSubTaskCnt(subtask) + metrics.EndDistTaskSubTask(subtask) + subtask.State = proto.SubtaskStateRunning + metrics.IncDistTaskSubTaskCnt(subtask) + metrics.StartDistTaskSubTask(subtask) + } + return err +} + +func (s *BaseTaskExecutor) updateSubtaskStateAndErrorImpl(ctx context.Context, execID string, subtaskID int64, state proto.SubtaskState, subTaskErr error) { + // retry for 3+6+12+24+(30-4)*30 ~= 825s ~= 14 minutes + logger := logutil.Logger(s.logCtx) + backoffer := backoff.NewExponential(scheduler.RetrySQLInterval, 2, scheduler.RetrySQLMaxInterval) + err := handle.RunWithRetry(ctx, scheduler.RetrySQLTimes, backoffer, logger, + func(ctx context.Context) (bool, error) { + return true, s.taskTable.UpdateSubtaskStateAndError(ctx, execID, subtaskID, state, subTaskErr) + }, + ) + if err != nil { + s.onError(err) + } +} + +// startSubtask try to change the state of the subtask to running. +// If the subtask is not owned by the task executor, +// the update will fail and task executor should not run the subtask. +func (s *BaseTaskExecutor) startSubtask(ctx context.Context, subtaskID int64) error { + // retry for 3+6+12+24+(30-4)*30 ~= 825s ~= 14 minutes + logger := logutil.Logger(s.logCtx) + backoffer := backoff.NewExponential(scheduler.RetrySQLInterval, 2, scheduler.RetrySQLMaxInterval) + return handle.RunWithRetry(ctx, scheduler.RetrySQLTimes, backoffer, logger, + func(ctx context.Context) (bool, error) { + err := s.taskTable.StartSubtask(ctx, subtaskID, s.id) + if err == storage.ErrSubtaskNotFound { + // No need to retry. + return false, err + } + return true, err + }, + ) +} + +func (s *BaseTaskExecutor) finishSubtask(ctx context.Context, subtask *proto.Subtask) { + logger := logutil.Logger(s.logCtx) + backoffer := backoff.NewExponential(scheduler.RetrySQLInterval, 2, scheduler.RetrySQLMaxInterval) + err := handle.RunWithRetry(ctx, scheduler.RetrySQLTimes, backoffer, logger, + func(ctx context.Context) (bool, error) { + return true, s.taskTable.FinishSubtask(ctx, subtask.ExecID, subtask.ID, subtask.Meta) + }, + ) + if err != nil { + s.onError(err) + } +} + +func (s *BaseTaskExecutor) updateSubtaskStateAndError(ctx context.Context, subtask *proto.Subtask, state proto.SubtaskState, subTaskErr error) { + metrics.DecDistTaskSubTaskCnt(subtask) + metrics.EndDistTaskSubTask(subtask) + s.updateSubtaskStateAndErrorImpl(ctx, subtask.ExecID, subtask.ID, state, subTaskErr) + subtask.State = state + metrics.IncDistTaskSubTaskCnt(subtask) + if !subtask.IsDone() { + metrics.StartDistTaskSubTask(subtask) + } +} + +func (s *BaseTaskExecutor) finishSubtaskAndUpdateState(ctx context.Context, subtask *proto.Subtask) { + metrics.DecDistTaskSubTaskCnt(subtask) + metrics.EndDistTaskSubTask(subtask) + s.finishSubtask(ctx, subtask) + subtask.State = proto.SubtaskStateSucceed + metrics.IncDistTaskSubTaskCnt(subtask) +} + +// markSubTaskCanceledOrFailed check the error type and decide the subtasks' state. +// 1. Only cancel subtasks when meet ErrCancelSubtask. +// 2. Only fail subtasks when meet non retryable error. +// 3. When meet other errors, don't change subtasks' state. +func (s *BaseTaskExecutor) markSubTaskCanceledOrFailed(ctx context.Context, subtask *proto.Subtask) bool { + if err := s.getError(); err != nil { + err := errors.Cause(err) + if ctx.Err() != nil && context.Cause(ctx) == ErrCancelSubtask { + logutil.Logger(s.logCtx).Warn("subtask canceled", zap.Error(err)) + s.updateSubtaskStateAndError(s.ctx, subtask, proto.SubtaskStateCanceled, nil) + } else if s.IsRetryableError(err) { + logutil.Logger(s.logCtx).Warn("met retryable error", zap.Error(err)) + } else if common.IsContextCanceledError(err) { + logutil.Logger(s.logCtx).Info("met context canceled for gracefully shutdown", zap.Error(err)) + } else { + logutil.Logger(s.logCtx).Warn("subtask failed", zap.Error(err)) + s.updateSubtaskStateAndError(s.ctx, subtask, proto.SubtaskStateFailed, err) + } + s.markErrorHandled() + return true + } + return false +} + +func (s *BaseTaskExecutor) updateErrorToSubtask(ctx context.Context, taskID int64, err error) error { + logger := logutil.Logger(s.logCtx) + backoffer := backoff.NewExponential(scheduler.RetrySQLInterval, 2, scheduler.RetrySQLMaxInterval) + err1 := handle.RunWithRetry(s.logCtx, scheduler.RetrySQLTimes, backoffer, logger, + func(_ context.Context) (bool, error) { + return true, s.taskTable.UpdateErrorToSubtask(ctx, s.id, taskID, err) + }, + ) + if err1 == nil { + logger.Warn("update error to subtask success", zap.Error(err)) + } + return err1 +} diff --git a/pkg/disttask/framework/testutil/context.go b/pkg/disttask/framework/testutil/context.go new file mode 100644 index 0000000000000..0c7ac9c58ff81 --- /dev/null +++ b/pkg/disttask/framework/testutil/context.go @@ -0,0 +1,100 @@ +// Copyright 2023 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package testutil + +import ( + "context" + "fmt" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/pingcap/failpoint" + "github.com/pingcap/tidb/pkg/disttask/framework/proto" + "github.com/pingcap/tidb/pkg/disttask/framework/storage" + "github.com/pingcap/tidb/pkg/testkit" + "github.com/stretchr/testify/require" + "github.com/tikv/client-go/v2/util" + "go.uber.org/mock/gomock" +) + +// TestContext defines shared variables for disttask tests. +type TestContext struct { + sync.RWMutex + // taskID/step -> subtask map. + subtasksHasRun map[string]map[int64]struct{} + // for rollback tests. + RollbackCnt atomic.Int32 + // for plan err handling tests. + CallTime int +} + +// InitTestContext inits test context for disttask tests. +func InitTestContext(t *testing.T, nodeNum int) (context.Context, *gomock.Controller, *TestContext, *testkit.DistExecutionContext) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + ctx := context.Background() + ctx = util.WithInternalSourceType(ctx, "dispatcher") + require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/pkg/util/cpu/mockNumCpu", "return(8)")) + t.Cleanup(func() { + require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/pkg/util/cpu/mockNumCpu")) + }) + + executionContext := testkit.NewDistExecutionContext(t, nodeNum) + WaitNodeRegistered(ctx, t) + testCtx := &TestContext{ + subtasksHasRun: make(map[string]map[int64]struct{}), + } + return ctx, ctrl, testCtx, executionContext +} + +// CollectSubtask collects subtask info +func (c *TestContext) CollectSubtask(subtask *proto.Subtask) { + key := getTaskStepKey(subtask.TaskID, subtask.Step) + c.Lock() + defer c.Unlock() + m, ok := c.subtasksHasRun[key] + if !ok { + m = make(map[int64]struct{}) + c.subtasksHasRun[key] = m + } + m[subtask.ID] = struct{}{} +} + +// CollectedSubtaskCnt returns the collected subtask count. +func (c *TestContext) CollectedSubtaskCnt(taskID int64, step proto.Step) int { + key := getTaskStepKey(taskID, step) + c.RLock() + defer c.RUnlock() + return len(c.subtasksHasRun[key]) +} + +// getTaskStepKey returns the key of a task step. +func getTaskStepKey(id int64, step proto.Step) string { + return fmt.Sprintf("%d/%d", id, step) +} + +// WaitNodeRegistered waits until some node is registered. +func WaitNodeRegistered(ctx context.Context, t *testing.T) { + // wait until some node is registered. + require.Eventually(t, func() bool { + taskMgr, err := storage.GetTaskManager() + require.NoError(t, err) + nodes, err := taskMgr.GetAllNodes(ctx) + require.NoError(t, err) + return len(nodes) > 0 + }, 5*time.Second, 100*time.Millisecond) +} diff --git a/pkg/disttask/framework/testutil/task_util.go b/pkg/disttask/framework/testutil/task_util.go new file mode 100644 index 0000000000000..0f8d92deac04e --- /dev/null +++ b/pkg/disttask/framework/testutil/task_util.go @@ -0,0 +1,50 @@ +// Copyright 2023 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package testutil + +import ( + "context" + "testing" + + "github.com/pingcap/tidb/pkg/disttask/framework/proto" + "github.com/pingcap/tidb/pkg/disttask/framework/storage" + "github.com/pingcap/tidb/pkg/sessionctx" + "github.com/pingcap/tidb/pkg/util/sqlexec" + "github.com/stretchr/testify/require" + "github.com/tikv/client-go/v2/util" +) + +// CreateSubTask adds a new task to subtask table. +// used for testing. +func CreateSubTask(t *testing.T, gm *storage.TaskManager, taskID int64, step proto.Step, execID string, meta []byte, tp proto.TaskType, concurrency int, isRevert bool) { + state := proto.SubtaskStatePending + if isRevert { + state = proto.SubtaskStateRevertPending + } + InsertSubtask(t, gm, taskID, step, execID, meta, state, tp, concurrency) +} + +// InsertSubtask adds a new subtask of any state to subtask table. +func InsertSubtask(t *testing.T, gm *storage.TaskManager, taskID int64, step proto.Step, execID string, meta []byte, state proto.SubtaskState, tp proto.TaskType, concurrency int) { + ctx := context.Background() + ctx = util.WithInternalSourceType(ctx, "table_test") + require.NoError(t, gm.WithNewSession(func(se sessionctx.Context) error { + _, err := sqlexec.ExecSQL(ctx, se, ` + insert into mysql.tidb_background_subtask(`+storage.InsertSubtaskColumns+`) values`+ + `(%?, %?, %?, %?, %?, %?, %?, NULL, CURRENT_TIMESTAMP(), '{}', '{}')`, + step, taskID, execID, meta, state, proto.Type2Int(tp), concurrency) + return err + })) +}