From 226558d82a53b8b5c8194f44735dc0067c86be50 Mon Sep 17 00:00:00 2001 From: D3Hunter Date: Wed, 7 Feb 2024 10:42:23 +0800 Subject: [PATCH] Revert "This is an automated cherry-pick of #49971" This reverts commit c6b98281635678c0def60a19619fcb8a9b27c1f0. --- .../framework_pause_and_resume_test.go | 4 +- pkg/disttask/framework/handle/BUILD.bazel | 1 - pkg/disttask/framework/handle/handle_test.go | 15 - pkg/disttask/framework/mock/scheduler_mock.go | 216 ------ pkg/disttask/framework/planner/BUILD.bazel | 1 - .../framework/planner/planner_test.go | 3 +- pkg/disttask/framework/scheduler/BUILD.bazel | 13 - pkg/disttask/framework/scheduler/interface.go | 55 -- pkg/disttask/framework/scheduler/scheduler.go | 409 ----------- .../scheduler/scheduler_manager_test.go | 84 --- .../scheduler/scheduler_nokit_test.go | 40 -- .../framework/scheduler/scheduler_test.go | 434 ------------ pkg/disttask/framework/storage/BUILD.bazel | 4 - pkg/disttask/framework/storage/table_test.go | 356 +--------- pkg/disttask/framework/storage/task_table.go | 23 +- .../framework/taskexecutor/task_executor.go | 653 ------------------ pkg/disttask/framework/testutil/context.go | 100 --- pkg/disttask/framework/testutil/task_util.go | 50 -- 18 files changed, 17 insertions(+), 2444 deletions(-) delete mode 100644 pkg/disttask/framework/scheduler/scheduler_manager_test.go delete mode 100644 pkg/disttask/framework/scheduler/scheduler_nokit_test.go delete mode 100644 pkg/disttask/framework/taskexecutor/task_executor.go delete mode 100644 pkg/disttask/framework/testutil/context.go delete mode 100644 pkg/disttask/framework/testutil/task_util.go diff --git a/pkg/disttask/framework/framework_pause_and_resume_test.go b/pkg/disttask/framework/framework_pause_and_resume_test.go index a0c297d1acc49..28ba8b4cefaa3 100644 --- a/pkg/disttask/framework/framework_pause_and_resume_test.go +++ b/pkg/disttask/framework/framework_pause_and_resume_test.go @@ -34,11 +34,11 @@ func CheckSubtasksState(ctx context.Context, t *testing.T, taskID int64, state p mgr, err := storage.GetTaskManager() require.NoError(t, err) mgr.PrintSubtaskInfo(ctx, taskID) - cntByStates, err := mgr.GetSubtaskCntGroupByStates(ctx, taskID, proto.StepTwo) + cnt, err := mgr.GetSubtaskInStatesCnt(ctx, taskID, state) require.NoError(t, err) historySubTasksCnt, err := storage.GetSubtasksFromHistoryByTaskIDForTest(ctx, mgr, taskID) require.NoError(t, err) - require.Equal(t, expectedCnt, cntByStates[state]+int64(historySubTasksCnt)) + require.Equal(t, expectedCnt, cnt+int64(historySubTasksCnt)) } func TestFrameworkPauseAndResume(t *testing.T) { diff --git a/pkg/disttask/framework/handle/BUILD.bazel b/pkg/disttask/framework/handle/BUILD.bazel index a3065ce31eb18..dd0697e8d2f72 100644 --- a/pkg/disttask/framework/handle/BUILD.bazel +++ b/pkg/disttask/framework/handle/BUILD.bazel @@ -25,7 +25,6 @@ go_test( ":handle", "//pkg/disttask/framework/proto", "//pkg/disttask/framework/storage", - "//pkg/disttask/framework/testutil", "//pkg/testkit", "//pkg/util/backoff", "@com_github_ngaut_pools//:pools", diff --git a/pkg/disttask/framework/handle/handle_test.go b/pkg/disttask/framework/handle/handle_test.go index dc711850610bb..65e2eaa52c42b 100644 --- a/pkg/disttask/framework/handle/handle_test.go +++ b/pkg/disttask/framework/handle/handle_test.go @@ -27,7 +27,6 @@ import ( "github.com/pingcap/tidb/pkg/disttask/framework/handle" "github.com/pingcap/tidb/pkg/disttask/framework/proto" "github.com/pingcap/tidb/pkg/disttask/framework/storage" - "github.com/pingcap/tidb/pkg/disttask/framework/testutil" "github.com/pingcap/tidb/pkg/testkit" "github.com/pingcap/tidb/pkg/util/backoff" "github.com/stretchr/testify/require" @@ -47,23 +46,9 @@ func TestHandle(t *testing.T) { mgr := storage.NewTaskManager(pool) storage.SetTaskManager(mgr) -<<<<<<< HEAD // no dispatcher registered err := handle.SubmitAndRunGlobalTask(ctx, "1", proto.TaskTypeExample, 2, []byte("byte")) require.Error(t, err) -======= - testutil.WaitNodeRegistered(ctx, t) - - // no scheduler registered - task, err := handle.SubmitTask(ctx, "1", proto.TaskTypeExample, 2, []byte("byte")) - require.NoError(t, err) - waitedTask, err := handle.WaitTask(ctx, task.ID, func(task *proto.Task) bool { - return task.IsDone() - }) - require.NoError(t, err) - require.Equal(t, proto.TaskStateFailed, waitedTask.State) - require.ErrorContains(t, waitedTask.Error, "unknown task type") ->>>>>>> 99f0349bfb6 (disttask: fix failed step is taken as success (#49971)) task, err := mgr.GetGlobalTaskByID(ctx, 1) require.NoError(t, err) diff --git a/pkg/disttask/framework/mock/scheduler_mock.go b/pkg/disttask/framework/mock/scheduler_mock.go index d5e6cfa4ea673..6f99256c9cede 100644 --- a/pkg/disttask/framework/mock/scheduler_mock.go +++ b/pkg/disttask/framework/mock/scheduler_mock.go @@ -429,223 +429,7 @@ func (mr *MockExtensionMockRecorder) GetSubtaskExecutor(arg0, arg1, arg2 any) *g // IsIdempotent mocks base method. func (m *MockExtension) IsIdempotent(arg0 *proto.Subtask) bool { m.ctrl.T.Helper() -<<<<<<< HEAD ret := m.ctrl.Call(m, "IsIdempotent", arg0) -======= - ret := m.ctrl.Call(m, "DeleteDeadNodes", arg0, arg1) - ret0, _ := ret[0].(error) - return ret0 -} - -// DeleteDeadNodes indicates an expected call of DeleteDeadNodes. -func (mr *MockTaskManagerMockRecorder) DeleteDeadNodes(arg0, arg1 any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteDeadNodes", reflect.TypeOf((*MockTaskManager)(nil).DeleteDeadNodes), arg0, arg1) -} - -// FailTask mocks base method. -func (m *MockTaskManager) FailTask(arg0 context.Context, arg1 int64, arg2 proto.TaskState, arg3 error) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "FailTask", arg0, arg1, arg2, arg3) - ret0, _ := ret[0].(error) - return ret0 -} - -// FailTask indicates an expected call of FailTask. -func (mr *MockTaskManagerMockRecorder) FailTask(arg0, arg1, arg2, arg3 any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FailTask", reflect.TypeOf((*MockTaskManager)(nil).FailTask), arg0, arg1, arg2, arg3) -} - -// GCSubtasks mocks base method. -func (m *MockTaskManager) GCSubtasks(arg0 context.Context) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GCSubtasks", arg0) - ret0, _ := ret[0].(error) - return ret0 -} - -// GCSubtasks indicates an expected call of GCSubtasks. -func (mr *MockTaskManagerMockRecorder) GCSubtasks(arg0 any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GCSubtasks", reflect.TypeOf((*MockTaskManager)(nil).GCSubtasks), arg0) -} - -// GetAllNodes mocks base method. -func (m *MockTaskManager) GetAllNodes(arg0 context.Context) ([]proto.ManagedNode, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetAllNodes", arg0) - ret0, _ := ret[0].([]proto.ManagedNode) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// GetAllNodes indicates an expected call of GetAllNodes. -func (mr *MockTaskManagerMockRecorder) GetAllNodes(arg0 any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAllNodes", reflect.TypeOf((*MockTaskManager)(nil).GetAllNodes), arg0) -} - -// GetManagedNodes mocks base method. -func (m *MockTaskManager) GetManagedNodes(arg0 context.Context) ([]proto.ManagedNode, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetManagedNodes", arg0) - ret0, _ := ret[0].([]proto.ManagedNode) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// GetManagedNodes indicates an expected call of GetManagedNodes. -func (mr *MockTaskManagerMockRecorder) GetManagedNodes(arg0 any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetManagedNodes", reflect.TypeOf((*MockTaskManager)(nil).GetManagedNodes), arg0) -} - -// GetSubtaskCntGroupByStates mocks base method. -func (m *MockTaskManager) GetSubtaskCntGroupByStates(arg0 context.Context, arg1 int64, arg2 proto.Step) (map[proto.SubtaskState]int64, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetSubtaskCntGroupByStates", arg0, arg1, arg2) - ret0, _ := ret[0].(map[proto.SubtaskState]int64) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// GetSubtaskCntGroupByStates indicates an expected call of GetSubtaskCntGroupByStates. -func (mr *MockTaskManagerMockRecorder) GetSubtaskCntGroupByStates(arg0, arg1, arg2 any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSubtaskCntGroupByStates", reflect.TypeOf((*MockTaskManager)(nil).GetSubtaskCntGroupByStates), arg0, arg1, arg2) -} - -// GetSubtasksByExecIdsAndStepAndState mocks base method. -func (m *MockTaskManager) GetSubtasksByExecIdsAndStepAndState(arg0 context.Context, arg1 []string, arg2 int64, arg3 proto.Step, arg4 proto.SubtaskState) ([]*proto.Subtask, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetSubtasksByExecIdsAndStepAndState", arg0, arg1, arg2, arg3, arg4) - ret0, _ := ret[0].([]*proto.Subtask) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// GetSubtasksByExecIdsAndStepAndState indicates an expected call of GetSubtasksByExecIdsAndStepAndState. -func (mr *MockTaskManagerMockRecorder) GetSubtasksByExecIdsAndStepAndState(arg0, arg1, arg2, arg3, arg4 any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSubtasksByExecIdsAndStepAndState", reflect.TypeOf((*MockTaskManager)(nil).GetSubtasksByExecIdsAndStepAndState), arg0, arg1, arg2, arg3, arg4) -} - -// GetSubtasksByStepAndState mocks base method. -func (m *MockTaskManager) GetSubtasksByStepAndState(arg0 context.Context, arg1 int64, arg2 proto.Step, arg3 proto.TaskState) ([]*proto.Subtask, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetSubtasksByStepAndState", arg0, arg1, arg2, arg3) - ret0, _ := ret[0].([]*proto.Subtask) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// GetSubtasksByStepAndState indicates an expected call of GetSubtasksByStepAndState. -func (mr *MockTaskManagerMockRecorder) GetSubtasksByStepAndState(arg0, arg1, arg2, arg3 any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSubtasksByStepAndState", reflect.TypeOf((*MockTaskManager)(nil).GetSubtasksByStepAndState), arg0, arg1, arg2, arg3) -} - -// GetTaskByID mocks base method. -func (m *MockTaskManager) GetTaskByID(arg0 context.Context, arg1 int64) (*proto.Task, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetTaskByID", arg0, arg1) - ret0, _ := ret[0].(*proto.Task) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// GetTaskByID indicates an expected call of GetTaskByID. -func (mr *MockTaskManagerMockRecorder) GetTaskByID(arg0, arg1 any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTaskByID", reflect.TypeOf((*MockTaskManager)(nil).GetTaskByID), arg0, arg1) -} - -// GetTaskExecutorIDsByTaskID mocks base method. -func (m *MockTaskManager) GetTaskExecutorIDsByTaskID(arg0 context.Context, arg1 int64) ([]string, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetTaskExecutorIDsByTaskID", arg0, arg1) - ret0, _ := ret[0].([]string) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// GetTaskExecutorIDsByTaskID indicates an expected call of GetTaskExecutorIDsByTaskID. -func (mr *MockTaskManagerMockRecorder) GetTaskExecutorIDsByTaskID(arg0, arg1 any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTaskExecutorIDsByTaskID", reflect.TypeOf((*MockTaskManager)(nil).GetTaskExecutorIDsByTaskID), arg0, arg1) -} - -// GetTaskExecutorIDsByTaskIDAndStep mocks base method. -func (m *MockTaskManager) GetTaskExecutorIDsByTaskIDAndStep(arg0 context.Context, arg1 int64, arg2 proto.Step) ([]string, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetTaskExecutorIDsByTaskIDAndStep", arg0, arg1, arg2) - ret0, _ := ret[0].([]string) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// GetTaskExecutorIDsByTaskIDAndStep indicates an expected call of GetTaskExecutorIDsByTaskIDAndStep. -func (mr *MockTaskManagerMockRecorder) GetTaskExecutorIDsByTaskIDAndStep(arg0, arg1, arg2 any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTaskExecutorIDsByTaskIDAndStep", reflect.TypeOf((*MockTaskManager)(nil).GetTaskExecutorIDsByTaskIDAndStep), arg0, arg1, arg2) -} - -// GetTasksInStates mocks base method. -func (m *MockTaskManager) GetTasksInStates(arg0 context.Context, arg1 ...any) ([]*proto.Task, error) { - m.ctrl.T.Helper() - varargs := []any{arg0} - for _, a := range arg1 { - varargs = append(varargs, a) - } - ret := m.ctrl.Call(m, "GetTasksInStates", varargs...) - ret0, _ := ret[0].([]*proto.Task) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// GetTasksInStates indicates an expected call of GetTasksInStates. -func (mr *MockTaskManagerMockRecorder) GetTasksInStates(arg0 any, arg1 ...any) *gomock.Call { - mr.mock.ctrl.T.Helper() - varargs := append([]any{arg0}, arg1...) - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTasksInStates", reflect.TypeOf((*MockTaskManager)(nil).GetTasksInStates), varargs...) -} - -// GetTopUnfinishedTasks mocks base method. -func (m *MockTaskManager) GetTopUnfinishedTasks(arg0 context.Context) ([]*proto.Task, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetTopUnfinishedTasks", arg0) - ret0, _ := ret[0].([]*proto.Task) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// GetTopUnfinishedTasks indicates an expected call of GetTopUnfinishedTasks. -func (mr *MockTaskManagerMockRecorder) GetTopUnfinishedTasks(arg0 any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTopUnfinishedTasks", reflect.TypeOf((*MockTaskManager)(nil).GetTopUnfinishedTasks), arg0) -} - -// GetUsedSlotsOnNodes mocks base method. -func (m *MockTaskManager) GetUsedSlotsOnNodes(arg0 context.Context) (map[string]int, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetUsedSlotsOnNodes", arg0) - ret0, _ := ret[0].(map[string]int) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// GetUsedSlotsOnNodes indicates an expected call of GetUsedSlotsOnNodes. -func (mr *MockTaskManagerMockRecorder) GetUsedSlotsOnNodes(arg0 any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUsedSlotsOnNodes", reflect.TypeOf((*MockTaskManager)(nil).GetUsedSlotsOnNodes), arg0) -} - -// PauseTask mocks base method. -func (m *MockTaskManager) PauseTask(arg0 context.Context, arg1 string) (bool, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "PauseTask", arg0, arg1) ->>>>>>> 99f0349bfb6 (disttask: fix failed step is taken as success (#49971)) ret0, _ := ret[0].(bool) return ret0 } diff --git a/pkg/disttask/framework/planner/BUILD.bazel b/pkg/disttask/framework/planner/BUILD.bazel index 8ea3e7c63e3b7..3f28cede80197 100644 --- a/pkg/disttask/framework/planner/BUILD.bazel +++ b/pkg/disttask/framework/planner/BUILD.bazel @@ -27,7 +27,6 @@ go_test( ":planner", "//pkg/disttask/framework/mock", "//pkg/disttask/framework/storage", - "//pkg/disttask/framework/testutil", "//pkg/kv", "//pkg/testkit", "@com_github_ngaut_pools//:pools", diff --git a/pkg/disttask/framework/planner/planner_test.go b/pkg/disttask/framework/planner/planner_test.go index 0801ce6ecdb7a..e515c3e0e266f 100644 --- a/pkg/disttask/framework/planner/planner_test.go +++ b/pkg/disttask/framework/planner/planner_test.go @@ -23,7 +23,6 @@ import ( "github.com/pingcap/tidb/pkg/disttask/framework/mock" "github.com/pingcap/tidb/pkg/disttask/framework/planner" "github.com/pingcap/tidb/pkg/disttask/framework/storage" - "github.com/pingcap/tidb/pkg/disttask/framework/testutil" "github.com/pingcap/tidb/pkg/kv" "github.com/pingcap/tidb/pkg/testkit" "github.com/stretchr/testify/require" @@ -46,7 +45,7 @@ func TestPlanner(t *testing.T) { defer pool.Close() mgr := storage.NewTaskManager(pool) storage.SetTaskManager(mgr) - testutil.WaitNodeRegistered(ctx, t) + p := &planner.Planner{} pCtx := planner.PlanCtx{ Ctx: ctx, diff --git a/pkg/disttask/framework/scheduler/BUILD.bazel b/pkg/disttask/framework/scheduler/BUILD.bazel index cb39089ed0a78..9ff3449a5411d 100644 --- a/pkg/disttask/framework/scheduler/BUILD.bazel +++ b/pkg/disttask/framework/scheduler/BUILD.bazel @@ -40,27 +40,14 @@ go_test( name = "scheduler_test", timeout = "short", srcs = [ -<<<<<<< HEAD "manager_test.go", "register_test.go", -======= - "main_test.go", - "nodes_test.go", - "rebalance_test.go", - "scheduler_manager_test.go", - "scheduler_nokit_test.go", ->>>>>>> 99f0349bfb6 (disttask: fix failed step is taken as success (#49971)) "scheduler_test.go", ], embed = [":scheduler"], flaky = True, -<<<<<<< HEAD race = "on", shard_count = 8, -======= - race = "off", - shard_count = 26, ->>>>>>> 99f0349bfb6 (disttask: fix failed step is taken as success (#49971)) deps = [ "//pkg/disttask/framework/mock", "//pkg/disttask/framework/mock/execute", diff --git a/pkg/disttask/framework/scheduler/interface.go b/pkg/disttask/framework/scheduler/interface.go index 0468c2d75da96..dc2c4aa9375af 100644 --- a/pkg/disttask/framework/scheduler/interface.go +++ b/pkg/disttask/framework/scheduler/interface.go @@ -21,65 +21,10 @@ import ( "github.com/pingcap/tidb/pkg/disttask/framework/scheduler/execute" ) -<<<<<<< HEAD // TaskTable defines the interface to access task table. type TaskTable interface { GetGlobalTasksInStates(ctx context.Context, states ...interface{}) (task []*proto.Task, err error) GetGlobalTaskByID(ctx context.Context, taskID int64) (task *proto.Task, err error) -======= -// TaskManager defines the interface to access task table. -type TaskManager interface { - // GetTopUnfinishedTasks returns unfinished tasks, limited by MaxConcurrentTask*2, - // to make sure lower priority tasks can be scheduled if resource is enough. - // The returned tasks are sorted by task order, see proto.Task, and only contains - // some fields, see row2TaskBasic. - GetTopUnfinishedTasks(ctx context.Context) ([]*proto.Task, error) - GetTasksInStates(ctx context.Context, states ...interface{}) (task []*proto.Task, err error) - GetTaskByID(ctx context.Context, taskID int64) (task *proto.Task, err error) - UpdateTaskAndAddSubTasks(ctx context.Context, task *proto.Task, subtasks []*proto.Subtask, prevState proto.TaskState) (bool, error) - GCSubtasks(ctx context.Context) error - GetAllNodes(ctx context.Context) ([]proto.ManagedNode, error) - DeleteDeadNodes(ctx context.Context, nodes []string) error - TransferTasks2History(ctx context.Context, tasks []*proto.Task) error - CancelTask(ctx context.Context, taskID int64) error - // FailTask updates task state to Failed and updates task error. - FailTask(ctx context.Context, taskID int64, currentState proto.TaskState, taskErr error) error - PauseTask(ctx context.Context, taskKey string) (bool, error) - // SwitchTaskStep switches the task to the next step and add subtasks in one - // transaction. It will change task state too if we're switch from InitStep to - // next step. - SwitchTaskStep(ctx context.Context, task *proto.Task, nextState proto.TaskState, nextStep proto.Step, subtasks []*proto.Subtask) error - // SwitchTaskStepInBatch similar to SwitchTaskStep, but it will insert subtasks - // in batch, and task step change will be in a separate transaction. - // Note: subtasks of this step must be stable, i.e. count, order and content - // should be the same on each try, else the subtasks inserted might be messed up. - // And each subtask of this step must be different, to handle the network - // partition or owner change. - SwitchTaskStepInBatch(ctx context.Context, task *proto.Task, nextState proto.TaskState, nextStep proto.Step, subtasks []*proto.Subtask) error - // SucceedTask updates a task to success state. - SucceedTask(ctx context.Context, taskID int64) error - // GetUsedSlotsOnNodes returns the used slots on nodes that have subtask scheduled. - // subtasks of each task on one node is only accounted once as we don't support - // running them concurrently. - // we only consider pending/running subtasks, subtasks related to revert are - // not considered. - GetUsedSlotsOnNodes(ctx context.Context) (map[string]int, error) - // GetSubtaskCntGroupByStates returns the count of subtasks of some step group by state. - GetSubtaskCntGroupByStates(ctx context.Context, taskID int64, step proto.Step) (map[proto.SubtaskState]int64, error) - ResumeSubtasks(ctx context.Context, taskID int64) error - CollectSubTaskError(ctx context.Context, taskID int64) ([]error, error) - TransferSubTasks2History(ctx context.Context, taskID int64) error - UpdateSubtasksExecIDs(ctx context.Context, taskID int64, subtasks []*proto.Subtask) error - // GetManagedNodes returns the nodes managed by dist framework and can be used - // to execute tasks. If there are any nodes with background role, we use them, - // else we use nodes without role. - // returned nodes are sorted by node id(host:port). - GetManagedNodes(ctx context.Context) ([]proto.ManagedNode, error) - GetTaskExecutorIDsByTaskID(ctx context.Context, taskID int64) ([]string, error) - GetSubtasksByStepAndState(ctx context.Context, taskID int64, step proto.Step, state proto.TaskState) ([]*proto.Subtask, error) - GetSubtasksByExecIdsAndStepAndState(ctx context.Context, tidbIDs []string, taskID int64, step proto.Step, state proto.SubtaskState) ([]*proto.Subtask, error) - GetTaskExecutorIDsByTaskIDAndStep(ctx context.Context, taskID int64, step proto.Step) ([]string, error) ->>>>>>> 99f0349bfb6 (disttask: fix failed step is taken as success (#49971)) GetSubtasksInStates(ctx context.Context, tidbID string, taskID int64, step proto.Step, states ...interface{}) ([]*proto.Subtask, error) GetFirstSubtaskInStates(ctx context.Context, instanceID string, taskID int64, step proto.Step, states ...interface{}) (*proto.Subtask, error) diff --git a/pkg/disttask/framework/scheduler/scheduler.go b/pkg/disttask/framework/scheduler/scheduler.go index 2e9accf189a3d..6ad8835abeb4e 100644 --- a/pkg/disttask/framework/scheduler/scheduler.go +++ b/pkg/disttask/framework/scheduler/scheduler.go @@ -517,28 +517,10 @@ func (s *BaseScheduler) markErrorHandled() { s.mu.handled = true } -<<<<<<< HEAD func (s *BaseScheduler) getError() error { s.mu.RLock() defer s.mu.RUnlock() return s.mu.err -======= -// handle task in pausing state, cancel all running subtasks. -func (s *BaseScheduler) onPausing() error { - logutil.Logger(s.logCtx).Info("on pausing state", zap.Stringer("state", s.Task.State), zap.Int64("step", int64(s.Task.Step))) - cntByStates, err := s.taskMgr.GetSubtaskCntGroupByStates(s.ctx, s.Task.ID, s.Task.Step) - if err != nil { - logutil.Logger(s.logCtx).Warn("check task failed", zap.Error(err)) - return err - } - runningPendingCnt := cntByStates[proto.SubtaskStateRunning] + cntByStates[proto.SubtaskStatePending] - if runningPendingCnt == 0 { - logutil.Logger(s.logCtx).Info("all running subtasks paused, update the task to paused state") - return s.updateTask(proto.TaskStatePaused, nil, RetrySQLTimes) - } - logutil.Logger(s.logCtx).Debug("on pausing state, this task keeps current state", zap.Stringer("state", s.Task.State)) - return nil ->>>>>>> 99f0349bfb6 (disttask: fix failed step is taken as success (#49971)) } func (s *BaseScheduler) resetError() { @@ -548,7 +530,6 @@ func (s *BaseScheduler) resetError() { s.mu.handled = false } -<<<<<<< HEAD func (s *BaseScheduler) startSubtaskAndUpdateState(ctx context.Context, subtask *proto.Subtask) { metrics.DecDistTaskSubTaskCnt(subtask) metrics.EndDistTaskSubTask(subtask) @@ -563,368 +544,6 @@ func (s *BaseScheduler) updateSubtaskStateAndErrorImpl(ctx context.Context, tidb logger := logutil.Logger(s.logCtx) backoffer := backoff.NewExponential(dispatcher.RetrySQLInterval, 2, dispatcher.RetrySQLMaxInterval) err := handle.RunWithRetry(ctx, dispatcher.RetrySQLTimes, backoffer, logger, -======= -// TestSyncChan is used to sync the test. -var TestSyncChan = make(chan struct{}) - -// handle task in resuming state. -func (s *BaseScheduler) onResuming() error { - logutil.Logger(s.logCtx).Info("on resuming state", zap.Stringer("state", s.Task.State), zap.Int64("step", int64(s.Task.Step))) - cntByStates, err := s.taskMgr.GetSubtaskCntGroupByStates(s.ctx, s.Task.ID, s.Task.Step) - if err != nil { - logutil.Logger(s.logCtx).Warn("check task failed", zap.Error(err)) - return err - } - if cntByStates[proto.SubtaskStatePaused] == 0 { - // Finish the resuming process. - logutil.Logger(s.logCtx).Info("all paused tasks converted to pending state, update the task to running state") - err := s.updateTask(proto.TaskStateRunning, nil, RetrySQLTimes) - failpoint.Inject("syncAfterResume", func() { - TestSyncChan <- struct{}{} - }) - return err - } - - return s.taskMgr.ResumeSubtasks(s.ctx, s.Task.ID) -} - -// handle task in reverting state, check all revert subtasks finishes. -func (s *BaseScheduler) onReverting() error { - logutil.Logger(s.logCtx).Debug("on reverting state", zap.Stringer("state", s.Task.State), zap.Int64("step", int64(s.Task.Step))) - cntByStates, err := s.taskMgr.GetSubtaskCntGroupByStates(s.ctx, s.Task.ID, s.Task.Step) - if err != nil { - logutil.Logger(s.logCtx).Warn("check task failed", zap.Error(err)) - return err - } - activeRevertCnt := cntByStates[proto.SubtaskStateRevertPending] + cntByStates[proto.SubtaskStateReverting] - if activeRevertCnt == 0 { - if err = s.OnDone(s.ctx, s, s.Task); err != nil { - return errors.Trace(err) - } - return s.updateTask(proto.TaskStateReverted, nil, RetrySQLTimes) - } - // Wait all subtasks in this step finishes. - s.OnTick(s.ctx, s.Task) - logutil.Logger(s.logCtx).Debug("on reverting state, this task keeps current state", zap.Stringer("state", s.Task.State)) - return nil -} - -// handle task in pending state, schedule subtasks. -func (s *BaseScheduler) onPending() error { - logutil.Logger(s.logCtx).Debug("on pending state", zap.Stringer("state", s.Task.State), zap.Int64("step", int64(s.Task.Step))) - return s.switch2NextStep() -} - -// handle task in running state, check all running subtasks finishes. -// If subtasks finished, run into the next step. -func (s *BaseScheduler) onRunning() error { - logutil.Logger(s.logCtx).Debug("on running state", - zap.Stringer("state", s.Task.State), - zap.Int64("step", int64(s.Task.Step))) - // check current step finishes. - cntByStates, err := s.taskMgr.GetSubtaskCntGroupByStates(s.ctx, s.Task.ID, s.Task.Step) - if err != nil { - logutil.Logger(s.logCtx).Warn("check task failed", zap.Error(err)) - return err - } - if cntByStates[proto.SubtaskStateFailed] > 0 || cntByStates[proto.SubtaskStateCanceled] > 0 { - subTaskErrs, err := s.taskMgr.CollectSubTaskError(s.ctx, s.Task.ID) - if err != nil { - logutil.Logger(s.logCtx).Warn("collect subtask error failed", zap.Error(err)) - return err - } - if len(subTaskErrs) > 0 { - logutil.Logger(s.logCtx).Warn("subtasks encounter errors") - return s.onErrHandlingStage(subTaskErrs) - } - } else if s.isStepSucceed(cntByStates) { - return s.switch2NextStep() - } - - if err := s.balanceSubtasks(); err != nil { - return err - } - // Wait all subtasks in this step finishes. - s.OnTick(s.ctx, s.Task) - logutil.Logger(s.logCtx).Debug("on running state, this task keeps current state", zap.Stringer("state", s.Task.State)) - return nil -} - -func (s *BaseScheduler) onFinished() error { - metrics.UpdateMetricsForFinishTask(s.Task) - logutil.Logger(s.logCtx).Debug("schedule task, task is finished", zap.Stringer("state", s.Task.State)) - return s.taskMgr.TransferSubTasks2History(s.ctx, s.Task.ID) -} - -// balanceSubtasks check the liveNode num every liveNodeFetchInterval then rebalance subtasks. -func (s *BaseScheduler) balanceSubtasks() error { - if len(s.TaskNodes) == 0 { - var err error - s.TaskNodes, err = s.taskMgr.GetTaskExecutorIDsByTaskIDAndStep(s.ctx, s.Task.ID, s.Task.Step) - if err != nil { - return err - } - } - s.balanceSubtaskTick++ - if s.balanceSubtaskTick == defaultBalanceSubtaskTicks { - s.balanceSubtaskTick = 0 - eligibleNodes, err := s.getEligibleNodes() - if err != nil { - return err - } - if len(eligibleNodes) > 0 { - return s.doBalanceSubtasks(eligibleNodes) - } - } - return nil -} - -// DoBalanceSubtasks make count of subtasks on each liveNodes balanced and clean up subtasks on dead nodes. -// TODO(ywqzzy): refine to make it easier for testing. -func (s *BaseScheduler) doBalanceSubtasks(eligibleNodes []string) error { - eligibleNodeMap := make(map[string]struct{}, len(eligibleNodes)) - for _, n := range eligibleNodes { - eligibleNodeMap[n] = struct{}{} - } - // 1. find out nodes need to clean subtasks. - deadNodes := make([]string, 0) - deadNodesMap := make(map[string]bool, 0) - for _, node := range s.TaskNodes { - if _, ok := eligibleNodeMap[node]; !ok { - deadNodes = append(deadNodes, node) - deadNodesMap[node] = true - } - } - // 2. get subtasks for each node before rebalance. - subtasks, err := s.taskMgr.GetSubtasksByStepAndState(s.ctx, s.Task.ID, s.Task.Step, proto.TaskStatePending) - if err != nil { - return err - } - if len(deadNodes) != 0 { - /// get subtask from deadNodes, since there might be some running subtasks on deadNodes. - /// In this case, all subtasks on deadNodes are in running/pending state. - subtasksOnDeadNodes, err := s.taskMgr.GetSubtasksByExecIdsAndStepAndState( - s.ctx, - deadNodes, - s.Task.ID, - s.Task.Step, - proto.SubtaskStateRunning) - if err != nil { - return err - } - subtasks = append(subtasks, subtasksOnDeadNodes...) - } - // 3. group subtasks for each task executor. - subtasksOnTaskExecutor := make(map[string][]*proto.Subtask, len(eligibleNodes)+len(deadNodes)) - for _, node := range eligibleNodes { - subtasksOnTaskExecutor[node] = make([]*proto.Subtask, 0) - } - for _, subtask := range subtasks { - subtasksOnTaskExecutor[subtask.ExecID] = append( - subtasksOnTaskExecutor[subtask.ExecID], - subtask) - } - // 4. prepare subtasks that need to rebalance to other nodes. - averageSubtaskCnt := len(subtasks) / len(eligibleNodes) - rebalanceSubtasks := make([]*proto.Subtask, 0) - for k, v := range subtasksOnTaskExecutor { - if ok := deadNodesMap[k]; ok { - rebalanceSubtasks = append(rebalanceSubtasks, v...) - continue - } - // When no tidb scale-in/out and averageSubtaskCnt*len(eligibleNodes) < len(subtasks), - // no need to send subtask to other nodes. - // eg: tidb1 with 3 subtasks, tidb2 with 2 subtasks, subtasks are balanced now. - if averageSubtaskCnt*len(eligibleNodes) < len(subtasks) && len(s.TaskNodes) == len(eligibleNodes) { - if len(v) > averageSubtaskCnt+1 { - rebalanceSubtasks = append(rebalanceSubtasks, v[0:len(v)-averageSubtaskCnt]...) - } - continue - } - if len(v) > averageSubtaskCnt { - rebalanceSubtasks = append(rebalanceSubtasks, v[0:len(v)-averageSubtaskCnt]...) - } - } - // 5. skip rebalance. - if len(rebalanceSubtasks) == 0 { - return nil - } - // 6.rebalance subtasks to other nodes. - rebalanceIdx := 0 - for k, v := range subtasksOnTaskExecutor { - if ok := deadNodesMap[k]; !ok { - if len(v) < averageSubtaskCnt { - for i := 0; i < averageSubtaskCnt-len(v) && rebalanceIdx < len(rebalanceSubtasks); i++ { - rebalanceSubtasks[rebalanceIdx].ExecID = k - rebalanceIdx++ - } - } - } - } - // 7. rebalance rest subtasks evenly to liveNodes. - liveNodeIdx := 0 - for rebalanceIdx < len(rebalanceSubtasks) { - rebalanceSubtasks[rebalanceIdx].ExecID = eligibleNodes[liveNodeIdx] - rebalanceIdx++ - liveNodeIdx++ - } - - // 8. update subtasks and do clean up logic. - if err = s.taskMgr.UpdateSubtasksExecIDs(s.ctx, s.Task.ID, subtasks); err != nil { - return err - } - logutil.Logger(s.logCtx).Info("balance subtasks", - zap.Stringers("subtasks-rebalanced", subtasks)) - s.TaskNodes = append([]string{}, eligibleNodes...) - return nil -} - -// updateTask update the task in tidb_global_task table. -func (s *BaseScheduler) updateTask(taskState proto.TaskState, newSubTasks []*proto.Subtask, retryTimes int) (err error) { - prevState := s.Task.State - s.Task.State = taskState - logutil.BgLogger().Info("task state transform", zap.Stringer("from", prevState), zap.Stringer("to", taskState)) - if !VerifyTaskStateTransform(prevState, taskState) { - return errors.Errorf("invalid task state transform, from %s to %s", prevState, taskState) - } - - var retryable bool - for i := 0; i < retryTimes; i++ { - retryable, err = s.taskMgr.UpdateTaskAndAddSubTasks(s.ctx, s.Task, newSubTasks, prevState) - if err == nil || !retryable { - break - } - if err1 := s.ctx.Err(); err1 != nil { - return err1 - } - if i%10 == 0 { - logutil.Logger(s.logCtx).Warn("updateTask first failed", zap.Stringer("from", prevState), zap.Stringer("to", s.Task.State), - zap.Int("retry times", i), zap.Error(err)) - } - time.Sleep(RetrySQLInterval) - } - if err != nil && retryTimes != nonRetrySQLTime { - logutil.Logger(s.logCtx).Warn("updateTask failed", - zap.Stringer("from", prevState), zap.Stringer("to", s.Task.State), zap.Int("retry times", retryTimes), zap.Error(err)) - } - return err -} - -func (s *BaseScheduler) onErrHandlingStage(receiveErrs []error) error { - // we only store the first error. - s.Task.Error = receiveErrs[0] - - var subTasks []*proto.Subtask - // when step of task is `StepInit`, no need to do revert - if s.Task.Step != proto.StepInit { - instanceIDs, err := s.GetAllTaskExecutorIDs(s.ctx, s.Task) - if err != nil { - logutil.Logger(s.logCtx).Warn("get task's all instances failed", zap.Error(err)) - return err - } - - subTasks = make([]*proto.Subtask, 0, len(instanceIDs)) - for _, id := range instanceIDs { - // reverting subtasks belong to the same step as current active step. - subTasks = append(subTasks, proto.NewSubtask( - s.Task.Step, s.Task.ID, s.Task.Type, id, - s.Task.Concurrency, proto.EmptyMeta, 0)) - } - } - return s.updateTask(proto.TaskStateReverting, subTasks, RetrySQLTimes) -} - -func (s *BaseScheduler) switch2NextStep() (err error) { - nextStep := s.GetNextStep(s.Task) - logutil.Logger(s.logCtx).Info("on next step", - zap.Int64("current-step", int64(s.Task.Step)), - zap.Int64("next-step", int64(nextStep))) - - if nextStep == proto.StepDone { - s.Task.Step = nextStep - s.Task.StateUpdateTime = time.Now().UTC() - if err = s.OnDone(s.ctx, s, s.Task); err != nil { - return errors.Trace(err) - } - return s.taskMgr.SucceedTask(s.ctx, s.Task.ID) - } - - serverNodes, err := s.getEligibleNodes() - if err != nil { - return err - } - logutil.Logger(s.logCtx).Info("eligible instances", zap.Int("num", len(serverNodes))) - if len(serverNodes) == 0 { - return errors.New("no available TiDB node to dispatch subtasks") - } - - metas, err := s.OnNextSubtasksBatch(s.ctx, s, s.Task, serverNodes, nextStep) - if err != nil { - logutil.Logger(s.logCtx).Warn("generate part of subtasks failed", zap.Error(err)) - return s.handlePlanErr(err) - } - - return s.scheduleSubTask(nextStep, metas, serverNodes) -} - -// getEligibleNodes returns the eligible(live) nodes for the task. -// if the task can only be scheduled to some specific nodes, return them directly, -// we don't care liveliness of them. -func (s *BaseScheduler) getEligibleNodes() ([]string, error) { - serverNodes, err := s.GetEligibleInstances(s.ctx, s.Task) - if err != nil { - return nil, err - } - logutil.Logger(s.logCtx).Debug("eligible instances", zap.Int("num", len(serverNodes))) - if len(serverNodes) == 0 { - serverNodes = append([]string{}, s.nodeMgr.getManagedNodes()...) - } - return serverNodes, nil -} - -func (s *BaseScheduler) scheduleSubTask( - subtaskStep proto.Step, - metas [][]byte, - serverNodes []string) error { - logutil.Logger(s.logCtx).Info("schedule subtasks", - zap.Stringer("state", s.Task.State), - zap.Int64("step", int64(s.Task.Step)), - zap.Int("concurrency", s.Task.Concurrency), - zap.Int("subtasks", len(metas))) - s.TaskNodes = serverNodes - var size uint64 - subTasks := make([]*proto.Subtask, 0, len(metas)) - for i, meta := range metas { - // we assign the subtask to the instance in a round-robin way. - // TODO: assign the subtask to the instance according to the system load of each nodes - pos := i % len(serverNodes) - instanceID := serverNodes[pos] - logutil.Logger(s.logCtx).Debug("create subtasks", zap.String("instanceID", instanceID)) - subTasks = append(subTasks, proto.NewSubtask( - subtaskStep, s.Task.ID, s.Task.Type, instanceID, s.Task.Concurrency, meta, i+1)) - - size += uint64(len(meta)) - } - failpoint.Inject("cancelBeforeUpdateTask", func() { - _ = s.taskMgr.CancelTask(s.ctx, s.Task.ID) - }) - - // as other fields and generated key and index KV takes space too, we limit - // the size of subtasks to 80% of the transaction limit. - limit := max(uint64(float64(kv.TxnTotalSizeLimit.Load())*0.8), 1) - fn := s.taskMgr.SwitchTaskStep - if size >= limit { - // On default, transaction size limit is controlled by tidb_mem_quota_query - // which is 1G on default, so it's unlikely to reach this limit, but in - // case user set txn-total-size-limit explicitly, we insert in batch. - logutil.Logger(s.logCtx).Info("subtasks size exceed limit, will insert in batch", - zap.Uint64("size", size), zap.Uint64("limit", limit)) - fn = s.taskMgr.SwitchTaskStepInBatch - } - - backoffer := backoff.NewExponential(RetrySQLInterval, 2, RetrySQLMaxInterval) - return handle.RunWithRetry(s.ctx, RetrySQLTimes, backoffer, logutil.Logger(s.logCtx), ->>>>>>> 99f0349bfb6 (disttask: fix failed step is taken as success (#49971)) func(ctx context.Context) (bool, error) { return true, s.taskTable.UpdateSubtaskStateAndError(ctx, tidbID, subtaskID, state, subTaskErr) }, @@ -1027,33 +646,5 @@ func (s *BaseScheduler) updateErrorToSubtask(ctx context.Context, taskID int64, if err1 == nil { logger.Warn("update error to subtask success", zap.Error(err)) } -<<<<<<< HEAD return err1 -======= - previousSubtaskMetas := make([][]byte, 0, len(previousSubtasks)) - for _, subtask := range previousSubtasks { - previousSubtaskMetas = append(previousSubtaskMetas, subtask.Meta) - } - return previousSubtaskMetas, nil -} - -// WithNewSession executes the function with a new session. -func (s *BaseScheduler) WithNewSession(fn func(se sessionctx.Context) error) error { - return s.taskMgr.WithNewSession(fn) -} - -// WithNewTxn executes the fn in a new transaction. -func (s *BaseScheduler) WithNewTxn(ctx context.Context, fn func(se sessionctx.Context) error) error { - return s.taskMgr.WithNewTxn(ctx, fn) -} - -func (*BaseScheduler) isStepSucceed(cntByStates map[proto.SubtaskState]int64) bool { - _, ok := cntByStates[proto.SubtaskStateSucceed] - return len(cntByStates) == 0 || (len(cntByStates) == 1 && ok) -} - -// IsCancelledErr checks if the error is a cancelled error. -func IsCancelledErr(err error) bool { - return strings.Contains(err.Error(), taskCancelMsg) ->>>>>>> 99f0349bfb6 (disttask: fix failed step is taken as success (#49971)) } diff --git a/pkg/disttask/framework/scheduler/scheduler_manager_test.go b/pkg/disttask/framework/scheduler/scheduler_manager_test.go deleted file mode 100644 index 3fb7804244683..0000000000000 --- a/pkg/disttask/framework/scheduler/scheduler_manager_test.go +++ /dev/null @@ -1,84 +0,0 @@ -// Copyright 2023 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package scheduler_test - -import ( - "context" - "testing" - "time" - - "github.com/ngaut/pools" - "github.com/pingcap/tidb/pkg/disttask/framework/mock" - "github.com/pingcap/tidb/pkg/disttask/framework/proto" - "github.com/pingcap/tidb/pkg/disttask/framework/testutil" - "github.com/pingcap/tidb/pkg/testkit" - "github.com/stretchr/testify/require" - "github.com/tikv/client-go/v2/util" - "go.uber.org/mock/gomock" -) - -func TestCleanUpRoutine(t *testing.T) { - store := testkit.CreateMockStore(t) - gtk := testkit.NewTestKit(t, store) - pool := pools.NewResourcePool(func() (pools.Resource, error) { - return gtk.Session(), nil - }, 1, 1, time.Second) - defer pool.Close() - ctrl := gomock.NewController(t) - defer ctrl.Finish() - ctx := context.Background() - ctx = util.WithInternalSourceType(ctx, "scheduler_manager") - mockCleanupRoutine := mock.NewMockCleanUpRoutine(ctrl) - - sch, mgr := MockSchedulerManager(t, ctrl, pool, getNumberExampleSchedulerExt(ctrl), mockCleanupRoutine) - mockCleanupRoutine.EXPECT().CleanUp(gomock.Any(), gomock.Any()).Return(nil).AnyTimes() - sch.Start() - defer sch.Stop() - testutil.WaitNodeRegistered(ctx, t) - taskID, err := mgr.CreateTask(ctx, "test", proto.TaskTypeExample, 1, nil) - require.NoError(t, err) - - checkTaskRunningCnt := func() []*proto.Task { - var tasks []*proto.Task - require.Eventually(t, func() bool { - var err error - tasks, err = mgr.GetTasksInStates(ctx, proto.TaskStateRunning) - require.NoError(t, err) - return len(tasks) == 1 - }, time.Second, 50*time.Millisecond) - return tasks - } - - checkSubtaskCnt := func(tasks []*proto.Task, taskID int64) { - require.Eventually(t, func() bool { - cntByStates, err := mgr.GetSubtaskCntGroupByStates(ctx, taskID, proto.StepOne) - require.NoError(t, err) - return int64(subtaskCnt) == cntByStates[proto.SubtaskStatePending] - }, time.Second, 50*time.Millisecond) - } - - tasks := checkTaskRunningCnt() - checkSubtaskCnt(tasks, taskID) - for i := 1; i <= subtaskCnt; i++ { - err = mgr.UpdateSubtaskStateAndError(ctx, ":4000", int64(i), proto.SubtaskStateSucceed, nil) - require.NoError(t, err) - } - sch.DoCleanUpRoutine() - require.Eventually(t, func() bool { - tasks, err := mgr.GetTasksFromHistoryInStates(ctx, proto.TaskStateSucceed) - require.NoError(t, err) - return len(tasks) != 0 - }, time.Second*10, time.Millisecond*300) -} diff --git a/pkg/disttask/framework/scheduler/scheduler_nokit_test.go b/pkg/disttask/framework/scheduler/scheduler_nokit_test.go deleted file mode 100644 index e6320ded9ceec..0000000000000 --- a/pkg/disttask/framework/scheduler/scheduler_nokit_test.go +++ /dev/null @@ -1,40 +0,0 @@ -// Copyright 2024 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package scheduler - -import ( - "testing" - - "github.com/pingcap/tidb/pkg/disttask/framework/proto" - "github.com/stretchr/testify/require" -) - -func TestSchedulerIsStepSucceed(t *testing.T) { - s := &BaseScheduler{} - require.True(t, s.isStepSucceed(nil)) - require.True(t, s.isStepSucceed(map[proto.SubtaskState]int64{})) - require.True(t, s.isStepSucceed(map[proto.SubtaskState]int64{ - proto.SubtaskStateSucceed: 1, - })) - for _, state := range []proto.SubtaskState{ - proto.SubtaskStateCanceled, - proto.SubtaskStateFailed, - proto.SubtaskStateReverting, - } { - require.False(t, s.isStepSucceed(map[proto.SubtaskState]int64{ - state: 1, - })) - } -} diff --git a/pkg/disttask/framework/scheduler/scheduler_test.go b/pkg/disttask/framework/scheduler/scheduler_test.go index ca54c59be84cc..508cdcfae1a13 100644 --- a/pkg/disttask/framework/scheduler/scheduler_test.go +++ b/pkg/disttask/framework/scheduler/scheduler_test.go @@ -208,7 +208,6 @@ func TestSchedulerRun(t *testing.T) { err = scheduler.Run(runCtx, task) require.EqualError(t, err, context.Canceled.Error()) -<<<<<<< HEAD // 8. grpc cancel mockSubtaskTable.EXPECT().GetSubtasksInStates(gomock.Any(), "id", taskID, proto.StepOne, unfinishedNormalSubtaskStates...).Return([]*proto.Subtask{{ @@ -238,439 +237,6 @@ func TestSchedulerRun(t *testing.T) { grpcErr, " %s", "test annotate", -======= - testutil.WaitNodeRegistered(ctx, t) - - // unknown task type - taskID, err := mgr.CreateTask(ctx, "test", "test-type", 1, nil) - require.NoError(t, err) - require.Eventually(t, func() bool { - task, err := mgr.GetTaskByID(ctx, taskID) - require.NoError(t, err) - return task.State == proto.TaskStateFailed && - strings.Contains(task.Error.Error(), "unknown task type") - }, time.Second*10, time.Millisecond*300) - - // scheduler init error - taskID, err = mgr.CreateTask(ctx, "test2", proto.TaskTypeExample, 1, nil) - require.NoError(t, err) - require.Eventually(t, func() bool { - task, err := mgr.GetTaskByID(ctx, taskID) - require.NoError(t, err) - return task.State == proto.TaskStateFailed && - strings.Contains(task.Error.Error(), "mock scheduler init error") - }, time.Second*10, time.Millisecond*300) -} - -func checkDispatch(t *testing.T, taskCnt int, isSucc, isCancel, isSubtaskCancel, isPauseAndResume bool) { - require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/pkg/domain/MockDisableDistTask", "return(true)")) - defer func() { - require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/pkg/domain/MockDisableDistTask")) - }() - // test DispatchTaskLoop - // test parallelism control - var originalConcurrency int - if taskCnt == 1 { - originalConcurrency = proto.MaxConcurrentTask - proto.MaxConcurrentTask = 1 - } - - store := testkit.CreateMockStore(t) - gtk := testkit.NewTestKit(t, store) - pool := pools.NewResourcePool(func() (pools.Resource, error) { - return gtk.Session(), nil - }, 1, 1, time.Second) - defer pool.Close() - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - ctx := context.Background() - ctx = util.WithInternalSourceType(ctx, "scheduler") - - sch, mgr := MockSchedulerManager(t, ctrl, pool, getNumberExampleSchedulerExt(ctrl), nil) - sch.Start() - defer func() { - sch.Stop() - // make data race happy - if taskCnt == 1 { - proto.MaxConcurrentTask = originalConcurrency - } - }() - - require.NoError(t, mgr.StartManager(ctx, ":4000", "background")) - - // 3s - cnt := 60 - checkGetRunningTaskCnt := func(expected int) { - require.Eventually(t, func() bool { - return sch.GetRunningTaskCnt() == expected - }, time.Second, 50*time.Millisecond) - } - - checkTaskRunningCnt := func() []*proto.Task { - var tasks []*proto.Task - require.Eventually(t, func() bool { - var err error - tasks, err = mgr.GetTasksInStates(ctx, proto.TaskStateRunning) - require.NoError(t, err) - return len(tasks) == taskCnt - }, time.Second, 50*time.Millisecond) - return tasks - } - - checkSubtaskCnt := func(tasks []*proto.Task, taskIDs []int64) { - for i, taskID := range taskIDs { - require.Equal(t, int64(i+1), tasks[i].ID) - require.Eventually(t, func() bool { - cntByStates, err := mgr.GetSubtaskCntGroupByStates(ctx, taskID, proto.StepOne) - require.NoError(t, err) - return int64(subtaskCnt) == cntByStates[proto.SubtaskStatePending] - }, time.Second, 50*time.Millisecond) - } - } - - // Mock add tasks. - taskIDs := make([]int64, 0, taskCnt) - for i := 0; i < taskCnt; i++ { - taskID, err := mgr.CreateTask(ctx, fmt.Sprintf("%d", i), proto.TaskTypeExample, 0, nil) - require.NoError(t, err) - taskIDs = append(taskIDs, taskID) - } - // test OnNextSubtasksBatch. - checkGetRunningTaskCnt(taskCnt) - tasks := checkTaskRunningCnt() - checkSubtaskCnt(tasks, taskIDs) - // test parallelism control - if taskCnt == 1 { - taskID, err := mgr.CreateTask(ctx, fmt.Sprintf("%d", taskCnt), proto.TaskTypeExample, 0, nil) - require.NoError(t, err) - checkGetRunningTaskCnt(taskCnt) - // Clean the task. - deleteTasks(t, store, taskID) - sch.DelRunningTask(taskID) - } - - // test DetectTaskLoop - checkGetTaskState := func(expectedState proto.TaskState) { - i := 0 - for ; i < cnt; i++ { - tasks, err := mgr.GetTasksInStates(ctx, expectedState) - require.NoError(t, err) - if len(tasks) == taskCnt { - break - } - historyTasks, err := mgr.GetTasksFromHistoryInStates(ctx, expectedState) - require.NoError(t, err) - if len(tasks)+len(historyTasks) == taskCnt { - break - } - time.Sleep(time.Millisecond * 50) - } - require.Less(t, i, cnt) - } - // Test all subtasks are successful. - var err error - if isSucc { - // Mock subtasks succeed. - for i := 1; i <= subtaskCnt*taskCnt; i++ { - err = mgr.UpdateSubtaskStateAndError(ctx, ":4000", int64(i), proto.SubtaskStateSucceed, nil) - require.NoError(t, err) - } - checkGetTaskState(proto.TaskStateSucceed) - require.Len(t, tasks, taskCnt) - - checkGetRunningTaskCnt(0) - return - } - - if isCancel { - for i := 1; i <= taskCnt; i++ { - err = mgr.CancelTask(ctx, int64(i)) - require.NoError(t, err) - } - } else if isPauseAndResume { - for i := 0; i < taskCnt; i++ { - found, err := mgr.PauseTask(ctx, fmt.Sprintf("%d", i)) - require.Equal(t, true, found) - require.NoError(t, err) - } - for i := 1; i <= subtaskCnt*taskCnt; i++ { - err = mgr.UpdateSubtaskStateAndError(ctx, ":4000", int64(i), proto.SubtaskStatePaused, nil) - require.NoError(t, err) - } - checkGetTaskState(proto.TaskStatePaused) - for i := 0; i < taskCnt; i++ { - found, err := mgr.ResumeTask(ctx, fmt.Sprintf("%d", i)) - require.Equal(t, true, found) - require.NoError(t, err) - } - - // Mock subtasks succeed. - for i := 1; i <= subtaskCnt*taskCnt; i++ { - err = mgr.UpdateSubtaskStateAndError(ctx, ":4000", int64(i), proto.SubtaskStateSucceed, nil) - require.NoError(t, err) - } - checkGetTaskState(proto.TaskStateSucceed) - return - } else { - // Test each task has a subtask failed. - require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/pkg/disttask/framework/storage/MockUpdateTaskErr", "1*return(true)")) - defer func() { - require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/pkg/disttask/framework/storage/MockUpdateTaskErr")) - }() - - if isSubtaskCancel { - // Mock a subtask canceled - for i := 1; i <= subtaskCnt*taskCnt; i += subtaskCnt { - err = mgr.UpdateSubtaskStateAndError(ctx, ":4000", int64(i), proto.SubtaskStateCanceled, nil) - require.NoError(t, err) - } - } else { - // Mock a subtask fails. - for i := 1; i <= subtaskCnt*taskCnt; i += subtaskCnt { - err = mgr.UpdateSubtaskStateAndError(ctx, ":4000", int64(i), proto.SubtaskStateFailed, nil) - require.NoError(t, err) - } - } - } - - checkGetTaskState(proto.TaskStateReverting) - require.Len(t, tasks, taskCnt) - // Mock all subtask reverted. - start := subtaskCnt * taskCnt - for i := start; i <= start+subtaskCnt*taskCnt; i++ { - err = mgr.UpdateSubtaskStateAndError(ctx, ":4000", int64(i), proto.SubtaskStateReverted, nil) - require.NoError(t, err) - } - checkGetTaskState(proto.TaskStateReverted) - require.Len(t, tasks, taskCnt) -} - -func TestSimple(t *testing.T) { - checkDispatch(t, 1, true, false, false, false) -} - -func TestSimpleErrStage(t *testing.T) { - checkDispatch(t, 1, false, false, false, false) -} - -func TestSimpleCancel(t *testing.T) { - checkDispatch(t, 1, false, true, false, false) -} - -func TestSimpleSubtaskCancel(t *testing.T) { - checkDispatch(t, 1, false, false, true, false) -} - -func TestParallel(t *testing.T) { - checkDispatch(t, 3, true, false, false, false) -} - -func TestParallelErrStage(t *testing.T) { - checkDispatch(t, 3, false, false, false, false) -} - -func TestParallelCancel(t *testing.T) { - checkDispatch(t, 3, false, true, false, false) -} - -func TestParallelSubtaskCancel(t *testing.T) { - checkDispatch(t, 3, false, false, true, false) -} - -func TestPause(t *testing.T) { - checkDispatch(t, 1, false, false, false, true) -} - -func TestParallelPause(t *testing.T) { - checkDispatch(t, 3, false, false, false, true) -} - -func TestVerifyTaskStateTransform(t *testing.T) { - testCases := []struct { - oldState proto.TaskState - newState proto.TaskState - expect bool - }{ - {proto.TaskStateRunning, proto.TaskStateRunning, true}, - {proto.TaskStatePending, proto.TaskStateRunning, true}, - {proto.TaskStatePending, proto.TaskStateReverting, false}, - {proto.TaskStateRunning, proto.TaskStateReverting, true}, - {proto.TaskStateReverting, proto.TaskStateReverted, true}, - {proto.TaskStateReverting, proto.TaskStateSucceed, false}, - {proto.TaskStateRunning, proto.TaskStatePausing, true}, - {proto.TaskStateRunning, proto.TaskStateResuming, false}, - {proto.TaskStateCancelling, proto.TaskStateRunning, false}, - {proto.TaskStateCanceled, proto.TaskStateRunning, false}, - } - for _, tc := range testCases { - require.Equal(t, tc.expect, scheduler.VerifyTaskStateTransform(tc.oldState, tc.newState)) - } -} - -func TestIsCancelledErr(t *testing.T) { - require.False(t, scheduler.IsCancelledErr(errors.New("some err"))) - require.False(t, scheduler.IsCancelledErr(context.Canceled)) - require.True(t, scheduler.IsCancelledErr(errors.New("cancelled by user"))) -} - -func TestDispatcherOnNextStage(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - taskMgr := mock.NewMockTaskManager(ctrl) - schExt := mockDispatch.NewMockExtension(ctrl) - - ctx := context.Background() - ctx = util.WithInternalSourceType(ctx, "dispatcher") - task := proto.Task{ - ID: 1, - State: proto.TaskStatePending, - Step: proto.StepInit, - } - cloneTask := task - nodeMgr := scheduler.NewNodeManager() - sch := scheduler.NewBaseScheduler(ctx, taskMgr, nodeMgr, &cloneTask) - sch.Extension = schExt - - // test next step is done - schExt.EXPECT().GetNextStep(gomock.Any()).Return(proto.StepDone) - schExt.EXPECT().OnDone(gomock.Any(), gomock.Any(), gomock.Any()).Return(errors.New("done err")) - require.ErrorContains(t, sch.Switch2NextStep(), "done err") - require.True(t, ctrl.Satisfied()) - // we update task step before OnDone - require.Equal(t, proto.StepDone, sch.Task.Step) - *sch.Task = task - schExt.EXPECT().GetNextStep(gomock.Any()).Return(proto.StepDone) - schExt.EXPECT().OnDone(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil) - taskMgr.EXPECT().SucceedTask(gomock.Any(), gomock.Any()).Return(nil) - require.NoError(t, sch.Switch2NextStep()) - require.True(t, ctrl.Satisfied()) - - // GetEligibleInstances err - schExt.EXPECT().GetNextStep(gomock.Any()).Return(proto.StepOne) - schExt.EXPECT().GetEligibleInstances(gomock.Any(), gomock.Any()).Return(nil, errors.New("GetEligibleInstances err")) - require.ErrorContains(t, sch.Switch2NextStep(), "GetEligibleInstances err") - require.True(t, ctrl.Satisfied()) - // GetEligibleInstances no instance - schExt.EXPECT().GetNextStep(gomock.Any()).Return(proto.StepOne) - schExt.EXPECT().GetEligibleInstances(gomock.Any(), gomock.Any()).Return(nil, nil) - require.ErrorContains(t, sch.Switch2NextStep(), "no available TiDB node to dispatch subtasks") - require.True(t, ctrl.Satisfied()) - - serverNodes := []string{":4000"} - // OnNextSubtasksBatch err - schExt.EXPECT().GetNextStep(gomock.Any()).Return(proto.StepOne) - schExt.EXPECT().GetEligibleInstances(gomock.Any(), gomock.Any()).Return(serverNodes, nil) - schExt.EXPECT().OnNextSubtasksBatch(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). - Return(nil, errors.New("OnNextSubtasksBatch err")) - schExt.EXPECT().IsRetryableErr(gomock.Any()).Return(true) - require.ErrorContains(t, sch.Switch2NextStep(), "OnNextSubtasksBatch err") - require.True(t, ctrl.Satisfied()) - - bak := kv.TxnTotalSizeLimit.Load() - t.Cleanup(func() { - kv.TxnTotalSizeLimit.Store(bak) - }) - - // dispatch in batch - subtaskMetas := [][]byte{ - []byte(`{"xx": "1"}`), - []byte(`{"xx": "2"}`), - } - schExt.EXPECT().GetNextStep(gomock.Any()).Return(proto.StepOne) - schExt.EXPECT().GetEligibleInstances(gomock.Any(), gomock.Any()).Return(serverNodes, nil) - schExt.EXPECT().OnNextSubtasksBatch(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). - Return(subtaskMetas, nil) - taskMgr.EXPECT().SwitchTaskStepInBatch(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil) - kv.TxnTotalSizeLimit.Store(1) - require.NoError(t, sch.Switch2NextStep()) - require.True(t, ctrl.Satisfied()) - // met unstable subtasks - schExt.EXPECT().GetNextStep(gomock.Any()).Return(proto.StepOne) - schExt.EXPECT().GetEligibleInstances(gomock.Any(), gomock.Any()).Return(serverNodes, nil) - schExt.EXPECT().OnNextSubtasksBatch(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). - Return(subtaskMetas, nil) - taskMgr.EXPECT().SwitchTaskStepInBatch(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). - Return(errors.Annotatef(storage.ErrUnstableSubtasks, "expected %d, got %d", - 2, 100)) - kv.TxnTotalSizeLimit.Store(1) - startTime := time.Now() - err := sch.Switch2NextStep() - require.ErrorIs(t, err, storage.ErrUnstableSubtasks) - require.ErrorContains(t, err, "expected 2, got 100") - require.WithinDuration(t, startTime, time.Now(), 10*time.Second) - require.True(t, ctrl.Satisfied()) - - // dispatch in one txn - schExt.EXPECT().GetNextStep(gomock.Any()).Return(proto.StepOne) - schExt.EXPECT().GetEligibleInstances(gomock.Any(), gomock.Any()).Return(serverNodes, nil) - schExt.EXPECT().OnNextSubtasksBatch(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). - Return(subtaskMetas, nil) - taskMgr.EXPECT().SwitchTaskStep(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil) - kv.TxnTotalSizeLimit.Store(config.DefTxnTotalSizeLimit) - require.NoError(t, sch.Switch2NextStep()) - require.True(t, ctrl.Satisfied()) -} - -func TestManagerDispatchLoop(t *testing.T) { - // Mock 16 cpu node. - require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/pkg/util/cpu/mockNumCpu", "return(16)")) - t.Cleanup(func() { - require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/pkg/util/cpu/mockNumCpu")) - }) - ctrl := gomock.NewController(t) - defer ctrl.Finish() - mockScheduler := mock.NewMockScheduler(ctrl) - - _ = testkit.CreateMockStore(t) - require.Eventually(t, func() bool { - taskMgr, err := storage.GetTaskManager() - return err == nil && taskMgr != nil - }, 10*time.Second, 100*time.Millisecond) - - ctx := context.Background() - ctx = util.WithInternalSourceType(ctx, "scheduler") - taskMgr, err := storage.GetTaskManager() - require.NoError(t, err) - require.NotNil(t, taskMgr) - - // in this test, we only test scheduler manager, so we add a subtask takes 16 - // slots to avoid reserve by slots, and make sure below test cases works. - serverInfos, err := infosync.GetAllServerInfo(ctx) - require.NoError(t, err) - for _, s := range serverInfos { - execID := disttaskutil.GenerateExecID(s) - testutil.InsertSubtask(t, taskMgr, 1000000, proto.StepOne, execID, []byte(""), proto.SubtaskStatePending, proto.TaskTypeExample, 16) - } - concurrencies := []int{4, 6, 16, 2, 4, 4} - waitChannels := make([]chan struct{}, len(concurrencies)) - for i := range waitChannels { - waitChannels[i] = make(chan struct{}) - } - var counter atomic.Int32 - scheduler.RegisterSchedulerFactory(proto.TaskTypeExample, - func(ctx context.Context, taskMgr scheduler.TaskManager, nodeMgr *scheduler.NodeManager, task *proto.Task) scheduler.Scheduler { - idx := counter.Load() - mockScheduler = mock.NewMockScheduler(ctrl) - mockScheduler.EXPECT().Init().Return(nil) - mockScheduler.EXPECT().ScheduleTask().Do(func() { - require.NoError(t, taskMgr.WithNewSession(func(se sessionctx.Context) error { - _, err := sqlexec.ExecSQL(ctx, se, "update mysql.tidb_global_task set state=%?, step=%? where id=%?", - proto.TaskStateRunning, proto.StepOne, task.ID) - return err - })) - <-waitChannels[idx] - require.NoError(t, taskMgr.WithNewSession(func(se sessionctx.Context) error { - _, err := sqlexec.ExecSQL(ctx, se, "update mysql.tidb_global_task set state=%?, step=%? where id=%?", - proto.TaskStateSucceed, proto.StepDone, task.ID) - return err - })) - }) - mockScheduler.EXPECT().Close() - counter.Add(1) - return mockScheduler - }, ->>>>>>> 99f0349bfb6 (disttask: fix failed step is taken as success (#49971)) ) mockSubtaskExecutor.EXPECT().RunSubtask(gomock.Any(), gomock.Any()).Return(annotatedError) mockSubtaskExecutor.EXPECT().Cleanup(gomock.Any()).Return(nil) diff --git a/pkg/disttask/framework/storage/BUILD.bazel b/pkg/disttask/framework/storage/BUILD.bazel index 817d3ce767d8b..bc9f7331912d9 100644 --- a/pkg/disttask/framework/storage/BUILD.bazel +++ b/pkg/disttask/framework/storage/BUILD.bazel @@ -31,11 +31,7 @@ go_test( srcs = ["table_test.go"], flaky = True, race = "on", -<<<<<<< HEAD shard_count = 8, -======= - shard_count = 14, ->>>>>>> 99f0349bfb6 (disttask: fix failed step is taken as success (#49971)) deps = [ ":storage", "//pkg/disttask/framework/proto", diff --git a/pkg/disttask/framework/storage/table_test.go b/pkg/disttask/framework/storage/table_test.go index a0ba8b4de4340..260e08d98ed42 100644 --- a/pkg/disttask/framework/storage/table_test.go +++ b/pkg/disttask/framework/storage/table_test.go @@ -126,274 +126,6 @@ func TestGlobalTaskTable(t *testing.T) { cancelling, err = gm.IsGlobalTaskCancelling(ctx, id) require.NoError(t, err) require.True(t, cancelling) -<<<<<<< HEAD -======= - - id, err = gm.CreateTask(ctx, "key-fail", "test2", 4, []byte("test2")) - require.NoError(t, err) - // state not right, update nothing - require.NoError(t, gm.FailTask(ctx, id, proto.TaskStateRunning, errors.New("test error"))) - task, err = gm.GetTaskByID(ctx, id) - require.NoError(t, err) - require.Equal(t, proto.TaskStatePending, task.State) - require.Nil(t, task.Error) - require.NoError(t, gm.FailTask(ctx, id, proto.TaskStatePending, errors.New("test error"))) - task, err = gm.GetTaskByID(ctx, id) - require.NoError(t, err) - require.Equal(t, proto.TaskStateFailed, task.State) - require.ErrorContains(t, task.Error, "test error") - - // succeed a pending task, no effect - id, err = gm.CreateTask(ctx, "key-success", "test", 4, []byte("test")) - require.NoError(t, err) - require.NoError(t, gm.SucceedTask(ctx, id)) - task, err = gm.GetTaskByID(ctx, id) - require.NoError(t, err) - checkTaskStateStep(t, task, proto.TaskStatePending, proto.StepInit) - // succeed a running task - require.NoError(t, gm.SwitchTaskStep(ctx, task, proto.TaskStateRunning, proto.StepOne, nil)) - task, err = gm.GetTaskByID(ctx, id) - require.NoError(t, err) - checkTaskStateStep(t, task, proto.TaskStateRunning, proto.StepOne) - startTime := time.Unix(time.Now().Unix(), 0) - require.NoError(t, gm.SucceedTask(ctx, id)) - task, err = gm.GetTaskByID(ctx, id) - require.NoError(t, err) - checkTaskStateStep(t, task, proto.TaskStateSucceed, proto.StepDone) - require.GreaterOrEqual(t, task.StateUpdateTime, startTime) -} - -func checkAfterSwitchStep(t *testing.T, startTime time.Time, task *proto.Task, subtasks []*proto.Subtask, step proto.Step) { - tm, err := storage.GetTaskManager() - require.NoError(t, err) - ctx := context.Background() - ctx = util.WithInternalSourceType(ctx, "table_test") - - checkTaskStateStep(t, task, proto.TaskStateRunning, step) - require.GreaterOrEqual(t, task.StartTime, startTime) - require.GreaterOrEqual(t, task.StateUpdateTime, startTime) - gotSubtasks, err := tm.GetSubtasksByStepAndState(ctx, task.ID, task.Step, proto.TaskStatePending) - require.NoError(t, err) - require.Len(t, gotSubtasks, len(subtasks)) - sort.Slice(gotSubtasks, func(i, j int) bool { - return gotSubtasks[i].Ordinal < gotSubtasks[j].Ordinal - }) - for i := 0; i < len(gotSubtasks); i++ { - subtask := gotSubtasks[i] - require.Equal(t, []byte(fmt.Sprintf("%d", i)), subtask.Meta) - require.Equal(t, i+1, subtask.Ordinal) - require.Equal(t, 11, subtask.Concurrency) - require.Equal(t, ":4000", subtask.ExecID) - require.Equal(t, proto.TaskTypeExample, subtask.Type) - require.GreaterOrEqual(t, subtask.CreateTime, startTime) - } -} - -func TestSwitchTaskStep(t *testing.T) { - store, tm, ctx := testutil.InitTableTest(t) - tk := testkit.NewTestKit(t, store) - - require.NoError(t, tm.StartManager(ctx, ":4000", "")) - taskID, err := tm.CreateTask(ctx, "key1", "test", 4, []byte("test")) - require.NoError(t, err) - task, err := tm.GetTaskByID(ctx, taskID) - require.NoError(t, err) - checkTaskStateStep(t, task, proto.TaskStatePending, proto.StepInit) - subtasksStepOne := make([]*proto.Subtask, 3) - for i := 0; i < len(subtasksStepOne); i++ { - subtasksStepOne[i] = proto.NewSubtask(proto.StepOne, taskID, proto.TaskTypeExample, - ":4000", 11, []byte(fmt.Sprintf("%d", i)), i+1) - } - startTime := time.Unix(time.Now().Unix(), 0) - task.Meta = []byte("changed meta") - require.NoError(t, tm.SwitchTaskStep(ctx, task, proto.TaskStateRunning, proto.StepOne, subtasksStepOne)) - task, err = tm.GetTaskByID(ctx, taskID) - require.NoError(t, err) - require.Equal(t, []byte("changed meta"), task.Meta) - checkAfterSwitchStep(t, startTime, task, subtasksStepOne, proto.StepOne) - // switch step again, no effect - require.NoError(t, tm.SwitchTaskStep(ctx, task, proto.TaskStateRunning, proto.StepOne, subtasksStepOne)) - task, err = tm.GetTaskByID(ctx, taskID) - require.NoError(t, err) - checkAfterSwitchStep(t, startTime, task, subtasksStepOne, proto.StepOne) - // switch step to step two - time.Sleep(time.Second) - taskStartTime := task.StartTime - subtasksStepTwo := make([]*proto.Subtask, 3) - for i := 0; i < len(subtasksStepTwo); i++ { - subtasksStepTwo[i] = proto.NewSubtask(proto.StepTwo, taskID, proto.TaskTypeExample, - ":4000", 11, []byte(fmt.Sprintf("%d", i)), i+1) - } - require.NoError(t, tk.Session().GetSessionVars().SetSystemVar(variable.TiDBMemQuotaQuery, "1024")) - require.NoError(t, tm.SwitchTaskStep(ctx, task, proto.TaskStateRunning, proto.StepTwo, subtasksStepTwo)) - value, ok := tk.Session().GetSessionVars().GetSystemVar(variable.TiDBMemQuotaQuery) - require.True(t, ok) - require.Equal(t, "1024", value) - task, err = tm.GetTaskByID(ctx, taskID) - require.NoError(t, err) - // start time should not change - require.Equal(t, taskStartTime, task.StartTime) - checkAfterSwitchStep(t, startTime, task, subtasksStepTwo, proto.StepTwo) -} - -func TestSwitchTaskStepInBatch(t *testing.T) { - store, tm, ctx := testutil.InitTableTest(t) - tk := testkit.NewTestKit(t, store) - - require.NoError(t, tm.StartManager(ctx, ":4000", "")) - // normal flow - prepare := func(taskKey string) (*proto.Task, []*proto.Subtask) { - taskID, err := tm.CreateTask(ctx, taskKey, "test", 4, []byte("test")) - require.NoError(t, err) - task, err := tm.GetTaskByID(ctx, taskID) - require.NoError(t, err) - checkTaskStateStep(t, task, proto.TaskStatePending, proto.StepInit) - subtasks := make([]*proto.Subtask, 3) - for i := 0; i < len(subtasks); i++ { - subtasks[i] = proto.NewSubtask(proto.StepOne, taskID, proto.TaskTypeExample, - ":4000", 11, []byte(fmt.Sprintf("%d", i)), i+1) - } - return task, subtasks - } - startTime := time.Unix(time.Now().Unix(), 0) - task1, subtasks1 := prepare("key1") - require.NoError(t, tm.SwitchTaskStepInBatch(ctx, task1, proto.TaskStateRunning, proto.StepOne, subtasks1)) - task1, err := tm.GetTaskByID(ctx, task1.ID) - require.NoError(t, err) - checkAfterSwitchStep(t, startTime, task1, subtasks1, proto.StepOne) - - // mock another dispatcher inserted some subtasks - require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/pkg/disttask/framework/storage/waitBeforeInsertSubtasks", `1*return(true)`)) - t.Cleanup(func() { - require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/pkg/disttask/framework/storage/waitBeforeInsertSubtasks")) - }) - task2, subtasks2 := prepare("key2") - go func() { - storage.TestChannel <- struct{}{} - tk2 := testkit.NewTestKit(t, store) - subtask := subtasks2[0] - _, err = sqlexec.ExecSQL(ctx, tk2.Session(), ` - insert into mysql.tidb_background_subtask( - step, task_key, exec_id, meta, state, type, concurrency, ordinal, create_time, checkpoint, summary) - values (%?, %?, %?, %?, %?, %?, %?, %?, CURRENT_TIMESTAMP(), '{}', '{}')`, - subtask.Step, subtask.TaskID, subtask.ExecID, subtask.Meta, - proto.TaskStatePending, proto.Type2Int(subtask.Type), subtask.Concurrency, subtask.Ordinal) - require.NoError(t, err) - storage.TestChannel <- struct{}{} - }() - err = tm.SwitchTaskStepInBatch(ctx, task2, proto.TaskStateRunning, proto.StepOne, subtasks2) - require.ErrorIs(t, err, kv.ErrKeyExists) - task2, err = tm.GetTaskByID(ctx, task2.ID) - require.NoError(t, err) - checkTaskStateStep(t, task2, proto.TaskStatePending, proto.StepInit) - gotSubtasks, err := tm.GetSubtasksByStepAndState(ctx, task2.ID, proto.StepOne, proto.TaskStatePending) - require.NoError(t, err) - require.Len(t, gotSubtasks, 1) - // run again, should success - require.NoError(t, tm.SwitchTaskStepInBatch(ctx, task2, proto.TaskStateRunning, proto.StepOne, subtasks2)) - task2, err = tm.GetTaskByID(ctx, task2.ID) - require.NoError(t, err) - checkAfterSwitchStep(t, startTime, task2, subtasks2, proto.StepOne) - - // mock subtasks unstable - task3, subtasks3 := prepare("key3") - for i := 0; i < 2; i++ { - subtask := subtasks3[i] - _, err = sqlexec.ExecSQL(ctx, tk.Session(), ` - insert into mysql.tidb_background_subtask( - step, task_key, exec_id, meta, state, type, concurrency, ordinal, create_time, checkpoint, summary) - values (%?, %?, %?, %?, %?, %?, %?, %?, CURRENT_TIMESTAMP(), '{}', '{}')`, - subtask.Step, subtask.TaskID, subtask.ExecID, subtask.Meta, - proto.TaskStatePending, proto.Type2Int(subtask.Type), subtask.Concurrency, subtask.Ordinal) - require.NoError(t, err) - } - err = tm.SwitchTaskStepInBatch(ctx, task3, proto.TaskStateRunning, proto.StepOne, subtasks3[:1]) - require.ErrorIs(t, err, storage.ErrUnstableSubtasks) - require.ErrorContains(t, err, "expected 1, got 2") -} - -func TestGetTopUnfinishedTasks(t *testing.T) { - _, gm, ctx := testutil.InitTableTest(t) - - require.NoError(t, gm.StartManager(ctx, ":4000", "")) - taskStates := []proto.TaskState{ - proto.TaskStateSucceed, - proto.TaskStatePending, - proto.TaskStateRunning, - proto.TaskStateReverting, - proto.TaskStateCancelling, - proto.TaskStatePausing, - proto.TaskStateResuming, - proto.TaskStateFailed, - proto.TaskStatePending, - proto.TaskStatePending, - proto.TaskStatePending, - proto.TaskStatePending, - } - for i, state := range taskStates { - taskKey := fmt.Sprintf("key/%d", i) - _, err := gm.CreateTask(ctx, taskKey, "test", 4, []byte("test")) - require.NoError(t, err) - require.NoError(t, gm.WithNewSession(func(se sessionctx.Context) error { - _, err := se.(sqlexec.SQLExecutor).ExecuteInternal(ctx, ` - update mysql.tidb_global_task set state = %? where task_key = %?`, - state, taskKey) - return err - })) - } - // adjust task order - require.NoError(t, gm.WithNewSession(func(se sessionctx.Context) error { - _, err := se.(sqlexec.SQLExecutor).ExecuteInternal(ctx, ` - update mysql.tidb_global_task set create_time = current_timestamp`) - return err - })) - require.NoError(t, gm.WithNewSession(func(se sessionctx.Context) error { - _, err := se.(sqlexec.SQLExecutor).ExecuteInternal(ctx, ` - update mysql.tidb_global_task - set create_time = timestampadd(minute, -10, current_timestamp) - where task_key = 'key/5'`) - return err - })) - require.NoError(t, gm.WithNewSession(func(se sessionctx.Context) error { - _, err := se.(sqlexec.SQLExecutor).ExecuteInternal(ctx, ` - update mysql.tidb_global_task set priority = 100 where task_key = 'key/6'`) - return err - })) - require.NoError(t, gm.WithNewSession(func(se sessionctx.Context) error { - rs, err := sqlexec.ExecSQL(ctx, se, ` - select count(1) from mysql.tidb_global_task`) - require.Len(t, rs, 1) - require.Equal(t, int64(12), rs[0].GetInt64(0)) - return err - })) - tasks, err := gm.GetTopUnfinishedTasks(ctx) - require.NoError(t, err) - require.Len(t, tasks, 8) - taskKeys := make([]string, 0, len(tasks)) - for _, task := range tasks { - taskKeys = append(taskKeys, task.Key) - // not filled - require.Empty(t, task.Meta) - } - require.Equal(t, []string{"key/6", "key/5", "key/1", "key/2", "key/3", "key/4", "key/8", "key/9"}, taskKeys) -} - -func TestGetUsedSlotsOnNodes(t *testing.T) { - _, sm, ctx := testutil.InitTableTest(t) - - testutil.InsertSubtask(t, sm, 1, proto.StepOne, "tidb-1", []byte(""), proto.SubtaskStateRunning, "test", 12) - testutil.InsertSubtask(t, sm, 1, proto.StepOne, "tidb-2", []byte(""), proto.SubtaskStatePending, "test", 12) - testutil.InsertSubtask(t, sm, 2, proto.StepOne, "tidb-2", []byte(""), proto.SubtaskStatePending, "test", 8) - testutil.InsertSubtask(t, sm, 3, proto.StepOne, "tidb-3", []byte(""), proto.SubtaskStatePending, "test", 8) - testutil.InsertSubtask(t, sm, 4, proto.StepOne, "tidb-3", []byte(""), proto.SubtaskStateFailed, "test", 8) - slotsOnNodes, err := sm.GetUsedSlotsOnNodes(ctx) - require.NoError(t, err) - require.Equal(t, map[string]int{ - "tidb-1": 12, - "tidb-2": 20, - "tidb-3": 8, - }, slotsOnNodes) ->>>>>>> 99f0349bfb6 (disttask: fix failed step is taken as success (#49971)) } func TestSubTaskTable(t *testing.T) { @@ -435,21 +167,13 @@ func TestSubTaskTable(t *testing.T) { require.NoError(t, err) require.Len(t, ids, 0) -<<<<<<< HEAD cnt, err := sm.GetSubtaskInStatesCnt(ctx, 1, proto.TaskStatePending) -======= - cntByStates, err := sm.GetSubtaskCntGroupByStates(ctx, 1, proto.StepInit) ->>>>>>> 99f0349bfb6 (disttask: fix failed step is taken as success (#49971)) require.NoError(t, err) - require.Equal(t, int64(1), cntByStates[proto.SubtaskStatePending]) + require.Equal(t, int64(1), cnt) -<<<<<<< HEAD cnt, err = sm.GetSubtaskInStatesCnt(ctx, 1, proto.TaskStatePending, proto.TaskStateRevertPending) -======= - cntByStates, err = sm.GetSubtaskCntGroupByStates(ctx, 1, proto.StepInit) ->>>>>>> 99f0349bfb6 (disttask: fix failed step is taken as success (#49971)) require.NoError(t, err) - require.Equal(t, int64(1), cntByStates[proto.SubtaskStatePending]+cntByStates[proto.SubtaskStateRevertPending]) + require.Equal(t, int64(1), cnt) ok, err := sm.HasSubtasksInStates(ctx, "tidb1", 1, proto.StepInit, proto.TaskStatePending) require.NoError(t, err) @@ -481,13 +205,9 @@ func TestSubTaskTable(t *testing.T) { require.Equal(t, proto.TaskStateCancelling, subtask2.State) require.Greater(t, subtask2.UpdateTime, subtask.UpdateTime) -<<<<<<< HEAD cnt, err = sm.GetSubtaskInStatesCnt(ctx, 1, proto.TaskStatePending) -======= - cntByStates, err = sm.GetSubtaskCntGroupByStates(ctx, 1, proto.StepInit) ->>>>>>> 99f0349bfb6 (disttask: fix failed step is taken as success (#49971)) require.NoError(t, err) - require.Equal(t, int64(0), cntByStates[proto.SubtaskStatePending]) + require.Equal(t, int64(0), cnt) ok, err = sm.HasSubtasksInStates(ctx, "tidb1", 1, proto.StepInit, proto.TaskStatePending) require.NoError(t, err) @@ -503,13 +223,9 @@ func TestSubTaskTable(t *testing.T) { err = sm.AddNewSubTask(ctx, 2, proto.StepInit, "tidb1", []byte("test"), proto.TaskTypeExample, true) require.NoError(t, err) -<<<<<<< HEAD cnt, err = sm.GetSubtaskInStatesCnt(ctx, 2, proto.TaskStateRevertPending) -======= - cntByStates, err = sm.GetSubtaskCntGroupByStates(ctx, 2, proto.StepInit) ->>>>>>> 99f0349bfb6 (disttask: fix failed step is taken as success (#49971)) require.NoError(t, err) - require.Equal(t, int64(1), cntByStates[proto.SubtaskStateRevertPending]) + require.Equal(t, int64(1), cnt) subtasks, err := sm.GetSucceedSubtasksByStep(ctx, 2, proto.StepInit) require.NoError(t, err) @@ -638,14 +354,9 @@ func TestBothGlobalAndSubTaskTable(t *testing.T) { require.Equal(t, proto.TaskTypeExample, subtask2.Type) require.Equal(t, []byte("m2"), subtask2.Meta) -<<<<<<< HEAD cnt, err := sm.GetSubtaskInStatesCnt(ctx, 1, proto.TaskStatePending) -======= - cntByStates, err := sm.GetSubtaskCntGroupByStates(ctx, 1, proto.StepInit) ->>>>>>> 99f0349bfb6 (disttask: fix failed step is taken as success (#49971)) require.NoError(t, err) - require.Len(t, cntByStates, 1) - require.Equal(t, int64(2), cntByStates[proto.SubtaskStatePending]) + require.Equal(t, int64(2), cnt) // isSubTaskRevert: true prevState = task.State @@ -684,13 +395,9 @@ func TestBothGlobalAndSubTaskTable(t *testing.T) { require.Equal(t, proto.TaskTypeExample, subtask2.Type) require.Equal(t, []byte("m4"), subtask2.Meta) -<<<<<<< HEAD cnt, err = sm.GetSubtaskInStatesCnt(ctx, 1, proto.TaskStateRevertPending) -======= - cntByStates, err = sm.GetSubtaskCntGroupByStates(ctx, 1, proto.StepInit) ->>>>>>> 99f0349bfb6 (disttask: fix failed step is taken as success (#49971)) require.NoError(t, err) - require.Equal(t, int64(2), cntByStates[proto.SubtaskStateRevertPending]) + require.Equal(t, int64(2), cnt) // test transactional require.NoError(t, sm.DeleteSubtasksByTaskID(ctx, 1)) @@ -708,34 +415,9 @@ func TestBothGlobalAndSubTaskTable(t *testing.T) { require.NoError(t, err) require.Equal(t, proto.TaskStateReverting, task.State) -<<<<<<< HEAD cnt, err = sm.GetSubtaskInStatesCnt(ctx, 1, proto.TaskStateRevertPending) -======= - cntByStates, err = sm.GetSubtaskCntGroupByStates(ctx, 1, proto.StepInit) ->>>>>>> 99f0349bfb6 (disttask: fix failed step is taken as success (#49971)) require.NoError(t, err) - require.Equal(t, int64(0), cntByStates[proto.SubtaskStateRevertPending]) -} - -func TestGetSubtaskCntByStates(t *testing.T) { - _, sm, ctx := testutil.InitTableTest(t) - testutil.InsertSubtask(t, sm, 1, proto.StepOne, "tidb1", nil, proto.SubtaskStatePending, "test", 1) - testutil.InsertSubtask(t, sm, 1, proto.StepOne, "tidb1", nil, proto.SubtaskStatePending, "test", 1) - testutil.InsertSubtask(t, sm, 1, proto.StepOne, "tidb1", nil, proto.SubtaskStateRunning, "test", 1) - testutil.InsertSubtask(t, sm, 1, proto.StepOne, "tidb1", nil, proto.SubtaskStateSucceed, "test", 1) - testutil.InsertSubtask(t, sm, 1, proto.StepOne, "tidb1", nil, proto.SubtaskStateFailed, "test", 1) - testutil.InsertSubtask(t, sm, 1, proto.StepTwo, "tidb1", nil, proto.SubtaskStateFailed, "test", 1) - cntByStates, err := sm.GetSubtaskCntGroupByStates(ctx, 1, proto.StepOne) - require.NoError(t, err) - require.Len(t, cntByStates, 4) - require.Equal(t, int64(2), cntByStates[proto.SubtaskStatePending]) - require.Equal(t, int64(1), cntByStates[proto.SubtaskStateRunning]) - require.Equal(t, int64(1), cntByStates[proto.SubtaskStateSucceed]) - require.Equal(t, int64(1), cntByStates[proto.SubtaskStateFailed]) - cntByStates, err = sm.GetSubtaskCntGroupByStates(ctx, 1, proto.StepTwo) - require.NoError(t, err) - require.Len(t, cntByStates, 1) - require.Equal(t, int64(1), cntByStates[proto.SubtaskStateFailed]) + require.Equal(t, int64(0), cnt) } func TestDistFrameworkMeta(t *testing.T) { @@ -902,42 +584,26 @@ func TestPauseAndResume(t *testing.T) { require.NoError(t, sm.AddNewSubTask(ctx, 1, proto.StepInit, "tidb1", []byte("test"), proto.TaskTypeExample, false)) // 1.1 pause all subtasks. require.NoError(t, sm.PauseSubtasks(ctx, "tidb1", 1)) -<<<<<<< HEAD cnt, err := sm.GetSubtaskInStatesCnt(ctx, 1, proto.TaskStatePaused) -======= - cntByStates, err := sm.GetSubtaskCntGroupByStates(ctx, 1, proto.StepInit) ->>>>>>> 99f0349bfb6 (disttask: fix failed step is taken as success (#49971)) require.NoError(t, err) - require.Equal(t, int64(3), cntByStates[proto.SubtaskStatePaused]) + require.Equal(t, int64(3), cnt) // 1.2 resume all subtasks. require.NoError(t, sm.ResumeSubtasks(ctx, 1)) -<<<<<<< HEAD cnt, err = sm.GetSubtaskInStatesCnt(ctx, 1, proto.TaskStatePending) -======= - cntByStates, err = sm.GetSubtaskCntGroupByStates(ctx, 1, proto.StepInit) ->>>>>>> 99f0349bfb6 (disttask: fix failed step is taken as success (#49971)) require.NoError(t, err) - require.Equal(t, int64(3), cntByStates[proto.SubtaskStatePending]) + require.Equal(t, int64(3), cnt) // 2.1 pause 2 subtasks. sm.UpdateSubtaskStateAndError(ctx, "tidb1", 1, proto.TaskStateSucceed, nil) require.NoError(t, sm.PauseSubtasks(ctx, "tidb1", 1)) -<<<<<<< HEAD cnt, err = sm.GetSubtaskInStatesCnt(ctx, 1, proto.TaskStatePaused) -======= - cntByStates, err = sm.GetSubtaskCntGroupByStates(ctx, 1, proto.StepInit) ->>>>>>> 99f0349bfb6 (disttask: fix failed step is taken as success (#49971)) require.NoError(t, err) - require.Equal(t, int64(2), cntByStates[proto.SubtaskStatePaused]) + require.Equal(t, int64(2), cnt) // 2.2 resume 2 subtasks. require.NoError(t, sm.ResumeSubtasks(ctx, 1)) -<<<<<<< HEAD cnt, err = sm.GetSubtaskInStatesCnt(ctx, 1, proto.TaskStatePending) -======= - cntByStates, err = sm.GetSubtaskCntGroupByStates(ctx, 1, proto.StepInit) ->>>>>>> 99f0349bfb6 (disttask: fix failed step is taken as success (#49971)) require.NoError(t, err) - require.Equal(t, int64(2), cntByStates[proto.SubtaskStatePending]) + require.Equal(t, int64(2), cnt) } func TestCancelAndExecIdChanged(t *testing.T) { diff --git a/pkg/disttask/framework/storage/task_table.go b/pkg/disttask/framework/storage/task_table.go index cbfa0b1ec532a..d5205ffce2254 100644 --- a/pkg/disttask/framework/storage/task_table.go +++ b/pkg/disttask/framework/storage/task_table.go @@ -488,41 +488,24 @@ func (stm *TaskManager) UpdateSubtaskRowCount(ctx context.Context, subtaskID int return err } -<<<<<<< HEAD // GetSubtaskInStatesCnt gets the subtask count in the states. func (stm *TaskManager) GetSubtaskInStatesCnt(ctx context.Context, taskID int64, states ...interface{}) (int64, error) { args := []interface{}{taskID} args = append(args, states...) rs, err := stm.executeSQLWithNewSession(ctx, `select count(*) from mysql.tidb_background_subtask where task_key = %? and state in (`+strings.Repeat("%?,", len(states)-1)+"%?)", args...) -======= -// GetSubtaskCntGroupByStates gets the subtask count by states. -func (stm *TaskManager) GetSubtaskCntGroupByStates(ctx context.Context, taskID int64, step proto.Step) (map[proto.SubtaskState]int64, error) { - rs, err := stm.executeSQLWithNewSession(ctx, ` - select state, count(*) - from mysql.tidb_background_subtask - where task_key = %? and step = %? - group by state`, - taskID, step) ->>>>>>> 99f0349bfb6 (disttask: fix failed step is taken as success (#49971)) if err != nil { - return nil, err - } - - res := make(map[proto.SubtaskState]int64, len(rs)) - for _, r := range rs { - state := proto.SubtaskState(r.GetString(0)) - res[state] = r.GetInt64(1) + return 0, err } - return res, nil + return rs[0].GetInt64(0), nil } // CollectSubTaskError collects the subtask error. func (stm *TaskManager) CollectSubTaskError(ctx context.Context, taskID int64) ([]error, error) { rs, err := stm.executeSQLWithNewSession(ctx, `select error from mysql.tidb_background_subtask - where task_key = %? AND state in (%?, %?)`, taskID, proto.SubtaskStateFailed, proto.SubtaskStateCanceled) + where task_key = %? AND state in (%?, %?)`, taskID, proto.TaskStateFailed, proto.TaskStateCanceled) if err != nil { return nil, err } diff --git a/pkg/disttask/framework/taskexecutor/task_executor.go b/pkg/disttask/framework/taskexecutor/task_executor.go deleted file mode 100644 index fd585a259daef..0000000000000 --- a/pkg/disttask/framework/taskexecutor/task_executor.go +++ /dev/null @@ -1,653 +0,0 @@ -// Copyright 2023 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package taskexecutor - -import ( - "context" - "sync" - "time" - - "github.com/pingcap/errors" - "github.com/pingcap/failpoint" - "github.com/pingcap/tidb/br/pkg/lightning/common" - "github.com/pingcap/tidb/br/pkg/lightning/log" - "github.com/pingcap/tidb/pkg/disttask/framework/handle" - "github.com/pingcap/tidb/pkg/disttask/framework/proto" - "github.com/pingcap/tidb/pkg/disttask/framework/scheduler" - "github.com/pingcap/tidb/pkg/disttask/framework/storage" - "github.com/pingcap/tidb/pkg/disttask/framework/taskexecutor/execute" - "github.com/pingcap/tidb/pkg/domain/infosync" - "github.com/pingcap/tidb/pkg/metrics" - "github.com/pingcap/tidb/pkg/util/backoff" - "github.com/pingcap/tidb/pkg/util/gctuner" - "github.com/pingcap/tidb/pkg/util/logutil" - "github.com/pingcap/tidb/pkg/util/memory" - "go.uber.org/zap" -) - -const ( - // DefaultCheckSubtaskCanceledInterval is the default check interval for cancel cancelled subtasks. - DefaultCheckSubtaskCanceledInterval = 2 * time.Second -) - -var ( - // ErrCancelSubtask is the cancel cause when cancelling subtasks. - ErrCancelSubtask = errors.New("cancel subtasks") - // ErrFinishSubtask is the cancel cause when TaskExecutor successfully processed subtasks. - ErrFinishSubtask = errors.New("finish subtasks") - // ErrFinishRollback is the cancel cause when TaskExecutor rollback successfully. - ErrFinishRollback = errors.New("finish rollback") - - // TestSyncChan is used to sync the test. - TestSyncChan = make(chan struct{}) -) - -// BaseTaskExecutor is the base implementation of TaskExecutor. -type BaseTaskExecutor struct { - // id, it's the same as server id now, i.e. host:port. - id string - taskID int64 - taskTable TaskTable - logCtx context.Context - // ctx from manager - ctx context.Context - Extension - - mu struct { - sync.RWMutex - err error - // handled indicates whether the error has been updated to one of the subtask. - handled bool - // runtimeCancel is used to cancel the Run/Rollback when error occurs. - runtimeCancel context.CancelCauseFunc - } -} - -// NewBaseTaskExecutor creates a new BaseTaskExecutor. -func NewBaseTaskExecutor(ctx context.Context, id string, taskID int64, taskTable TaskTable) *BaseTaskExecutor { - taskExecutorImpl := &BaseTaskExecutor{ - id: id, - taskID: taskID, - taskTable: taskTable, - ctx: ctx, - logCtx: logutil.WithFields(context.Background(), zap.Int64("task-id", taskID)), - } - return taskExecutorImpl -} - -func (s *BaseTaskExecutor) startCancelCheck(ctx context.Context, wg *sync.WaitGroup, cancelFn context.CancelCauseFunc) { - wg.Add(1) - go func() { - defer wg.Done() - ticker := time.NewTicker(DefaultCheckSubtaskCanceledInterval) - defer ticker.Stop() - for { - select { - case <-ctx.Done(): - logutil.Logger(s.logCtx).Info("task executor exits") - return - case <-ticker.C: - canceled, err := s.taskTable.IsTaskExecutorCanceled(ctx, s.id, s.taskID) - if err != nil { - continue - } - if canceled { - logutil.Logger(s.logCtx).Info("taskExecutor canceled") - if cancelFn != nil { - // subtask transferred to other tidb, don't mark subtask as canceled. - // Should not change the subtask's state. - cancelFn(nil) - } - } - } - } - }() -} - -// Init implements the TaskExecutor interface. -func (*BaseTaskExecutor) Init(_ context.Context) error { - return nil -} - -// Run start to fetch and run all subtasks of the task on the node. -func (s *BaseTaskExecutor) Run(ctx context.Context, task *proto.Task) (err error) { - defer func() { - if r := recover(); r != nil { - logutil.Logger(s.logCtx).Error("BaseTaskExecutor panicked", zap.Any("recover", r), zap.Stack("stack")) - err4Panic := errors.Errorf("%v", r) - err1 := s.updateErrorToSubtask(ctx, task.ID, err4Panic) - if err == nil { - err = err1 - } - } - }() - err = s.run(ctx, task) - if s.mu.handled { - return err - } - if err == nil { - return nil - } - return s.updateErrorToSubtask(ctx, task.ID, err) -} - -func (s *BaseTaskExecutor) run(ctx context.Context, task *proto.Task) (resErr error) { - if ctx.Err() != nil { - s.onError(ctx.Err()) - return s.getError() - } - runCtx, runCancel := context.WithCancelCause(ctx) - defer runCancel(ErrFinishSubtask) - s.registerCancelFunc(runCancel) - s.resetError() - stepLogger := log.BeginTask(logutil.Logger(s.logCtx).With( - zap.Any("step", task.Step), - zap.Int("concurrency", task.Concurrency), - zap.Float64("mem-limit-percent", gctuner.GlobalMemoryLimitTuner.GetPercentage()), - zap.String("server-mem-limit", memory.ServerMemoryLimitOriginText.Load()), - ), "execute task") - // log as info level, subtask might be cancelled, let caller check it. - defer func() { - stepLogger.End(zap.InfoLevel, resErr) - }() - - summary, cleanup, err := runSummaryCollectLoop(ctx, task, s.taskTable) - if err != nil { - s.onError(err) - return s.getError() - } - defer cleanup() - subtaskExecutor, err := s.GetSubtaskExecutor(ctx, task, summary) - if err != nil { - s.onError(err) - return s.getError() - } - - failpoint.Inject("mockExecSubtaskInitEnvErr", func() { - failpoint.Return(errors.New("mockExecSubtaskInitEnvErr")) - }) - if err := subtaskExecutor.Init(runCtx); err != nil { - s.onError(err) - return s.getError() - } - - var wg sync.WaitGroup - cancelCtx, checkCancel := context.WithCancel(ctx) - s.startCancelCheck(cancelCtx, &wg, runCancel) - - defer func() { - err := subtaskExecutor.Cleanup(runCtx) - if err != nil { - logutil.Logger(s.logCtx).Error("cleanup subtask exec env failed", zap.Error(err)) - } - checkCancel() - wg.Wait() - }() - - subtasks, err := s.taskTable.GetSubtasksByStepAndStates( - runCtx, s.id, task.ID, task.Step, - proto.SubtaskStatePending, proto.SubtaskStateRunning) - if err != nil { - s.onError(err) - if common.IsRetryableError(err) { - logutil.Logger(s.logCtx).Warn("met retryable error", zap.Error(err)) - return nil - } - return s.getError() - } - for _, subtask := range subtasks { - metrics.IncDistTaskSubTaskCnt(subtask) - metrics.StartDistTaskSubTask(subtask) - } - - for { - // check if any error occurs. - if err := s.getError(); err != nil { - break - } - if runCtx.Err() != nil { - logutil.Logger(s.logCtx).Info("taskExecutor runSubtask loop exit") - break - } - - subtask, err := s.taskTable.GetFirstSubtaskInStates(runCtx, s.id, task.ID, task.Step, - proto.SubtaskStatePending, proto.SubtaskStateRunning) - if err != nil { - logutil.Logger(s.logCtx).Warn("GetFirstSubtaskInStates meets error", zap.Error(err)) - continue - } - if subtask == nil { - failpoint.Inject("breakInTaskExecutorUT", func() { - failpoint.Break() - }) - newTask, err := s.taskTable.GetTaskByID(runCtx, task.ID) - if err != nil { - logutil.Logger(s.logCtx).Warn("GetTaskByID meets error", zap.Error(err)) - continue - } - // When the task move to next step or task state changes, the TaskExecutor should exit. - if newTask.Step != task.Step || newTask.State != task.State { - break - } - time.Sleep(checkTime) - continue - } - - if subtask.State == proto.SubtaskStateRunning { - if !s.IsIdempotent(subtask) { - logutil.Logger(s.logCtx).Info("subtask in running state and is not idempotent, fail it", - zap.Int64("subtask-id", subtask.ID)) - subtaskErr := errors.New("subtask in running state and is not idempotent") - s.onError(subtaskErr) - s.updateSubtaskStateAndError(runCtx, subtask, proto.SubtaskStateFailed, subtaskErr) - s.markErrorHandled() - break - } - } else { - // subtask.State == proto.TaskStatePending - err := s.startSubtaskAndUpdateState(runCtx, subtask) - if err != nil { - logutil.Logger(s.logCtx).Warn("startSubtaskAndUpdateState meets error", zap.Error(err)) - // should ignore ErrSubtaskNotFound - // since the err only indicate that the subtask not owned by current task executor. - if err == storage.ErrSubtaskNotFound { - continue - } - s.onError(err) - continue - } - } - - failpoint.Inject("mockCleanExecutor", func() { - v, ok := testContexts.Load(s.id) - if ok { - if v.(*TestContext).mockDown.Load() { - failpoint.Break() - } - } - }) - - failpoint.Inject("cancelBeforeRunSubtask", func() { - runCancel(nil) - }) - - s.runSubtask(runCtx, subtaskExecutor, subtask) - } - return s.getError() -} - -func (s *BaseTaskExecutor) runSubtask(ctx context.Context, subtaskExecutor execute.SubtaskExecutor, subtask *proto.Subtask) { - err := subtaskExecutor.RunSubtask(ctx, subtask) - failpoint.Inject("MockRunSubtaskCancel", func(val failpoint.Value) { - if val.(bool) { - err = ErrCancelSubtask - } - }) - - failpoint.Inject("MockRunSubtaskContextCanceled", func(val failpoint.Value) { - if val.(bool) { - err = context.Canceled - } - }) - - if err != nil { - s.onError(err) - } - - finished := s.markSubTaskCanceledOrFailed(ctx, subtask) - if finished { - return - } - - failpoint.Inject("mockTiDBDown", func(val failpoint.Value) { - logutil.Logger(s.logCtx).Info("trigger mockTiDBDown") - if s.id == val.(string) || s.id == ":4001" || s.id == ":4002" { - v, ok := testContexts.Load(s.id) - if ok { - v.(*TestContext).TestSyncSubtaskRun <- struct{}{} - v.(*TestContext).mockDown.Store(true) - logutil.Logger(s.logCtx).Info("mockTiDBDown") - time.Sleep(2 * time.Second) - failpoint.Return() - } - } - }) - failpoint.Inject("mockTiDBDown2", func() { - if s.id == ":4003" && subtask.Step == proto.StepTwo { - v, ok := testContexts.Load(s.id) - if ok { - v.(*TestContext).TestSyncSubtaskRun <- struct{}{} - v.(*TestContext).mockDown.Store(true) - time.Sleep(2 * time.Second) - return - } - } - }) - - failpoint.Inject("mockTiDBPartitionThenResume", func(val failpoint.Value) { - if val.(bool) && (s.id == ":4000" || s.id == ":4001" || s.id == ":4002") { - infosync.MockGlobalServerInfoManagerEntry.DeleteByExecID(s.id) - time.Sleep(20 * time.Second) - } - }) - - failpoint.Inject("MockExecutorRunErr", func(val failpoint.Value) { - if val.(bool) { - s.onError(errors.New("MockExecutorRunErr")) - } - }) - failpoint.Inject("MockExecutorRunCancel", func(val failpoint.Value) { - if taskID, ok := val.(int); ok { - mgr, err := storage.GetTaskManager() - if err != nil { - logutil.BgLogger().Error("get task manager failed", zap.Error(err)) - } else { - err = mgr.CancelTask(ctx, int64(taskID)) - if err != nil { - logutil.BgLogger().Error("cancel task failed", zap.Error(err)) - } - } - } - }) - s.onSubtaskFinished(ctx, subtaskExecutor, subtask) -} - -func (s *BaseTaskExecutor) onSubtaskFinished(ctx context.Context, executor execute.SubtaskExecutor, subtask *proto.Subtask) { - if err := s.getError(); err == nil { - if err = executor.OnFinished(ctx, subtask); err != nil { - s.onError(err) - } - } - failpoint.Inject("MockSubtaskFinishedCancel", func(val failpoint.Value) { - if val.(bool) { - s.onError(ErrCancelSubtask) - } - }) - - finished := s.markSubTaskCanceledOrFailed(ctx, subtask) - if finished { - return - } - - s.finishSubtaskAndUpdateState(ctx, subtask) - - finished = s.markSubTaskCanceledOrFailed(ctx, subtask) - if finished { - return - } - - failpoint.Inject("syncAfterSubtaskFinish", func() { - TestSyncChan <- struct{}{} - <-TestSyncChan - }) -} - -// Rollback rollbacks the subtask. -func (s *BaseTaskExecutor) Rollback(ctx context.Context, task *proto.Task) error { - rollbackCtx, rollbackCancel := context.WithCancelCause(ctx) - defer rollbackCancel(ErrFinishRollback) - s.registerCancelFunc(rollbackCancel) - - s.resetError() - logutil.Logger(s.logCtx).Info("taskExecutor rollback a step", zap.Any("step", task.Step)) - - // We should cancel all subtasks before rolling back - for { - subtask, err := s.taskTable.GetFirstSubtaskInStates(ctx, s.id, task.ID, task.Step, - proto.SubtaskStatePending, proto.SubtaskStateRunning) - if err != nil { - s.onError(err) - return s.getError() - } - - if subtask == nil { - break - } - - s.updateSubtaskStateAndError(ctx, subtask, proto.SubtaskStateCanceled, nil) - if err = s.getError(); err != nil { - return err - } - } - - executor, err := s.GetSubtaskExecutor(ctx, task, nil) - if err != nil { - s.onError(err) - return s.getError() - } - subtask, err := s.taskTable.GetFirstSubtaskInStates(ctx, s.id, task.ID, task.Step, - proto.SubtaskStateRevertPending, proto.SubtaskStateReverting) - if err != nil { - s.onError(err) - return s.getError() - } - if subtask == nil { - logutil.BgLogger().Warn("taskExecutor rollback a step, but no subtask in revert_pending state", zap.Any("step", task.Step)) - return nil - } - if subtask.State == proto.SubtaskStateRevertPending { - s.updateSubtaskStateAndError(ctx, subtask, proto.SubtaskStateReverting, nil) - } - if err := s.getError(); err != nil { - return err - } - - // right now all impl of Rollback is empty, so we don't check idempotent here. - // will try to remove this rollback completely in the future. - err = executor.Rollback(rollbackCtx) - if err != nil { - s.updateSubtaskStateAndError(ctx, subtask, proto.SubtaskStateRevertFailed, nil) - s.onError(err) - } else { - s.updateSubtaskStateAndError(ctx, subtask, proto.SubtaskStateReverted, nil) - } - return s.getError() -} - -// Pause pause the TaskExecutor's subtasks. -func (s *BaseTaskExecutor) Pause(ctx context.Context, task *proto.Task) error { - logutil.Logger(s.logCtx).Info("taskExecutor pause subtasks") - // pause all running subtasks. - if err := s.taskTable.PauseSubtasks(ctx, s.id, task.ID); err != nil { - s.onError(err) - return s.getError() - } - return nil -} - -// Close closes the TaskExecutor when all the subtasks are complete. -func (*BaseTaskExecutor) Close() { -} - -func runSummaryCollectLoop( - ctx context.Context, - task *proto.Task, - taskTable TaskTable, -) (summary *execute.Summary, cleanup func(), err error) { - taskMgr, ok := taskTable.(*storage.TaskManager) - if !ok { - return nil, func() {}, nil - } - opt, ok := taskTypes[task.Type] - if !ok { - return nil, func() {}, errors.Errorf("taskExecutor option for type %s not found", task.Type) - } - if opt.Summary != nil { - go opt.Summary.UpdateRowCountLoop(ctx, taskMgr) - return opt.Summary, func() { - opt.Summary.PersistRowCount(ctx, taskMgr) - }, nil - } - return nil, func() {}, nil -} - -func (s *BaseTaskExecutor) registerCancelFunc(cancel context.CancelCauseFunc) { - s.mu.Lock() - defer s.mu.Unlock() - s.mu.runtimeCancel = cancel -} - -func (s *BaseTaskExecutor) onError(err error) { - if err == nil { - return - } - err = errors.Trace(err) - logutil.Logger(s.logCtx).Error("onError", zap.Error(err), zap.Stack("stack")) - s.mu.Lock() - defer s.mu.Unlock() - - if s.mu.err == nil { - s.mu.err = err - logutil.Logger(s.logCtx).Error("taskExecutor met first error", zap.Error(err)) - } - - if s.mu.runtimeCancel != nil { - s.mu.runtimeCancel(err) - } -} - -func (s *BaseTaskExecutor) markErrorHandled() { - s.mu.Lock() - defer s.mu.Unlock() - s.mu.handled = true -} - -func (s *BaseTaskExecutor) getError() error { - s.mu.RLock() - defer s.mu.RUnlock() - return s.mu.err -} - -func (s *BaseTaskExecutor) resetError() { - s.mu.Lock() - defer s.mu.Unlock() - s.mu.err = nil - s.mu.handled = false -} - -func (s *BaseTaskExecutor) startSubtaskAndUpdateState(ctx context.Context, subtask *proto.Subtask) error { - err := s.startSubtask(ctx, subtask.ID) - if err == nil { - metrics.DecDistTaskSubTaskCnt(subtask) - metrics.EndDistTaskSubTask(subtask) - subtask.State = proto.SubtaskStateRunning - metrics.IncDistTaskSubTaskCnt(subtask) - metrics.StartDistTaskSubTask(subtask) - } - return err -} - -func (s *BaseTaskExecutor) updateSubtaskStateAndErrorImpl(ctx context.Context, execID string, subtaskID int64, state proto.SubtaskState, subTaskErr error) { - // retry for 3+6+12+24+(30-4)*30 ~= 825s ~= 14 minutes - logger := logutil.Logger(s.logCtx) - backoffer := backoff.NewExponential(scheduler.RetrySQLInterval, 2, scheduler.RetrySQLMaxInterval) - err := handle.RunWithRetry(ctx, scheduler.RetrySQLTimes, backoffer, logger, - func(ctx context.Context) (bool, error) { - return true, s.taskTable.UpdateSubtaskStateAndError(ctx, execID, subtaskID, state, subTaskErr) - }, - ) - if err != nil { - s.onError(err) - } -} - -// startSubtask try to change the state of the subtask to running. -// If the subtask is not owned by the task executor, -// the update will fail and task executor should not run the subtask. -func (s *BaseTaskExecutor) startSubtask(ctx context.Context, subtaskID int64) error { - // retry for 3+6+12+24+(30-4)*30 ~= 825s ~= 14 minutes - logger := logutil.Logger(s.logCtx) - backoffer := backoff.NewExponential(scheduler.RetrySQLInterval, 2, scheduler.RetrySQLMaxInterval) - return handle.RunWithRetry(ctx, scheduler.RetrySQLTimes, backoffer, logger, - func(ctx context.Context) (bool, error) { - err := s.taskTable.StartSubtask(ctx, subtaskID, s.id) - if err == storage.ErrSubtaskNotFound { - // No need to retry. - return false, err - } - return true, err - }, - ) -} - -func (s *BaseTaskExecutor) finishSubtask(ctx context.Context, subtask *proto.Subtask) { - logger := logutil.Logger(s.logCtx) - backoffer := backoff.NewExponential(scheduler.RetrySQLInterval, 2, scheduler.RetrySQLMaxInterval) - err := handle.RunWithRetry(ctx, scheduler.RetrySQLTimes, backoffer, logger, - func(ctx context.Context) (bool, error) { - return true, s.taskTable.FinishSubtask(ctx, subtask.ExecID, subtask.ID, subtask.Meta) - }, - ) - if err != nil { - s.onError(err) - } -} - -func (s *BaseTaskExecutor) updateSubtaskStateAndError(ctx context.Context, subtask *proto.Subtask, state proto.SubtaskState, subTaskErr error) { - metrics.DecDistTaskSubTaskCnt(subtask) - metrics.EndDistTaskSubTask(subtask) - s.updateSubtaskStateAndErrorImpl(ctx, subtask.ExecID, subtask.ID, state, subTaskErr) - subtask.State = state - metrics.IncDistTaskSubTaskCnt(subtask) - if !subtask.IsDone() { - metrics.StartDistTaskSubTask(subtask) - } -} - -func (s *BaseTaskExecutor) finishSubtaskAndUpdateState(ctx context.Context, subtask *proto.Subtask) { - metrics.DecDistTaskSubTaskCnt(subtask) - metrics.EndDistTaskSubTask(subtask) - s.finishSubtask(ctx, subtask) - subtask.State = proto.SubtaskStateSucceed - metrics.IncDistTaskSubTaskCnt(subtask) -} - -// markSubTaskCanceledOrFailed check the error type and decide the subtasks' state. -// 1. Only cancel subtasks when meet ErrCancelSubtask. -// 2. Only fail subtasks when meet non retryable error. -// 3. When meet other errors, don't change subtasks' state. -func (s *BaseTaskExecutor) markSubTaskCanceledOrFailed(ctx context.Context, subtask *proto.Subtask) bool { - if err := s.getError(); err != nil { - err := errors.Cause(err) - if ctx.Err() != nil && context.Cause(ctx) == ErrCancelSubtask { - logutil.Logger(s.logCtx).Warn("subtask canceled", zap.Error(err)) - s.updateSubtaskStateAndError(s.ctx, subtask, proto.SubtaskStateCanceled, nil) - } else if s.IsRetryableError(err) { - logutil.Logger(s.logCtx).Warn("met retryable error", zap.Error(err)) - } else if common.IsContextCanceledError(err) { - logutil.Logger(s.logCtx).Info("met context canceled for gracefully shutdown", zap.Error(err)) - } else { - logutil.Logger(s.logCtx).Warn("subtask failed", zap.Error(err)) - s.updateSubtaskStateAndError(s.ctx, subtask, proto.SubtaskStateFailed, err) - } - s.markErrorHandled() - return true - } - return false -} - -func (s *BaseTaskExecutor) updateErrorToSubtask(ctx context.Context, taskID int64, err error) error { - logger := logutil.Logger(s.logCtx) - backoffer := backoff.NewExponential(scheduler.RetrySQLInterval, 2, scheduler.RetrySQLMaxInterval) - err1 := handle.RunWithRetry(s.logCtx, scheduler.RetrySQLTimes, backoffer, logger, - func(_ context.Context) (bool, error) { - return true, s.taskTable.UpdateErrorToSubtask(ctx, s.id, taskID, err) - }, - ) - if err1 == nil { - logger.Warn("update error to subtask success", zap.Error(err)) - } - return err1 -} diff --git a/pkg/disttask/framework/testutil/context.go b/pkg/disttask/framework/testutil/context.go deleted file mode 100644 index 0c7ac9c58ff81..0000000000000 --- a/pkg/disttask/framework/testutil/context.go +++ /dev/null @@ -1,100 +0,0 @@ -// Copyright 2023 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package testutil - -import ( - "context" - "fmt" - "sync" - "sync/atomic" - "testing" - "time" - - "github.com/pingcap/failpoint" - "github.com/pingcap/tidb/pkg/disttask/framework/proto" - "github.com/pingcap/tidb/pkg/disttask/framework/storage" - "github.com/pingcap/tidb/pkg/testkit" - "github.com/stretchr/testify/require" - "github.com/tikv/client-go/v2/util" - "go.uber.org/mock/gomock" -) - -// TestContext defines shared variables for disttask tests. -type TestContext struct { - sync.RWMutex - // taskID/step -> subtask map. - subtasksHasRun map[string]map[int64]struct{} - // for rollback tests. - RollbackCnt atomic.Int32 - // for plan err handling tests. - CallTime int -} - -// InitTestContext inits test context for disttask tests. -func InitTestContext(t *testing.T, nodeNum int) (context.Context, *gomock.Controller, *TestContext, *testkit.DistExecutionContext) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - ctx := context.Background() - ctx = util.WithInternalSourceType(ctx, "dispatcher") - require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/pkg/util/cpu/mockNumCpu", "return(8)")) - t.Cleanup(func() { - require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/pkg/util/cpu/mockNumCpu")) - }) - - executionContext := testkit.NewDistExecutionContext(t, nodeNum) - WaitNodeRegistered(ctx, t) - testCtx := &TestContext{ - subtasksHasRun: make(map[string]map[int64]struct{}), - } - return ctx, ctrl, testCtx, executionContext -} - -// CollectSubtask collects subtask info -func (c *TestContext) CollectSubtask(subtask *proto.Subtask) { - key := getTaskStepKey(subtask.TaskID, subtask.Step) - c.Lock() - defer c.Unlock() - m, ok := c.subtasksHasRun[key] - if !ok { - m = make(map[int64]struct{}) - c.subtasksHasRun[key] = m - } - m[subtask.ID] = struct{}{} -} - -// CollectedSubtaskCnt returns the collected subtask count. -func (c *TestContext) CollectedSubtaskCnt(taskID int64, step proto.Step) int { - key := getTaskStepKey(taskID, step) - c.RLock() - defer c.RUnlock() - return len(c.subtasksHasRun[key]) -} - -// getTaskStepKey returns the key of a task step. -func getTaskStepKey(id int64, step proto.Step) string { - return fmt.Sprintf("%d/%d", id, step) -} - -// WaitNodeRegistered waits until some node is registered. -func WaitNodeRegistered(ctx context.Context, t *testing.T) { - // wait until some node is registered. - require.Eventually(t, func() bool { - taskMgr, err := storage.GetTaskManager() - require.NoError(t, err) - nodes, err := taskMgr.GetAllNodes(ctx) - require.NoError(t, err) - return len(nodes) > 0 - }, 5*time.Second, 100*time.Millisecond) -} diff --git a/pkg/disttask/framework/testutil/task_util.go b/pkg/disttask/framework/testutil/task_util.go deleted file mode 100644 index 0f8d92deac04e..0000000000000 --- a/pkg/disttask/framework/testutil/task_util.go +++ /dev/null @@ -1,50 +0,0 @@ -// Copyright 2023 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package testutil - -import ( - "context" - "testing" - - "github.com/pingcap/tidb/pkg/disttask/framework/proto" - "github.com/pingcap/tidb/pkg/disttask/framework/storage" - "github.com/pingcap/tidb/pkg/sessionctx" - "github.com/pingcap/tidb/pkg/util/sqlexec" - "github.com/stretchr/testify/require" - "github.com/tikv/client-go/v2/util" -) - -// CreateSubTask adds a new task to subtask table. -// used for testing. -func CreateSubTask(t *testing.T, gm *storage.TaskManager, taskID int64, step proto.Step, execID string, meta []byte, tp proto.TaskType, concurrency int, isRevert bool) { - state := proto.SubtaskStatePending - if isRevert { - state = proto.SubtaskStateRevertPending - } - InsertSubtask(t, gm, taskID, step, execID, meta, state, tp, concurrency) -} - -// InsertSubtask adds a new subtask of any state to subtask table. -func InsertSubtask(t *testing.T, gm *storage.TaskManager, taskID int64, step proto.Step, execID string, meta []byte, state proto.SubtaskState, tp proto.TaskType, concurrency int) { - ctx := context.Background() - ctx = util.WithInternalSourceType(ctx, "table_test") - require.NoError(t, gm.WithNewSession(func(se sessionctx.Context) error { - _, err := sqlexec.ExecSQL(ctx, se, ` - insert into mysql.tidb_background_subtask(`+storage.InsertSubtaskColumns+`) values`+ - `(%?, %?, %?, %?, %?, %?, %?, NULL, CURRENT_TIMESTAMP(), '{}', '{}')`, - step, taskID, execID, meta, state, proto.Type2Int(tp), concurrency) - return err - })) -}