diff --git a/pkg/disttask/framework/mock/task_executor_mock.go b/pkg/disttask/framework/mock/task_executor_mock.go index a3ee3c8f26b8b..e3d576e274495 100644 --- a/pkg/disttask/framework/mock/task_executor_mock.go +++ b/pkg/disttask/framework/mock/task_executor_mock.go @@ -193,17 +193,17 @@ func (mr *MockTaskTableMockRecorder) StartManager(arg0, arg1, arg2 any) *gomock. } // StartSubtask mocks base method. -func (m *MockTaskTable) StartSubtask(arg0 context.Context, arg1 int64) error { +func (m *MockTaskTable) StartSubtask(arg0 context.Context, arg1 int64, arg2 string) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "StartSubtask", arg0, arg1) + ret := m.ctrl.Call(m, "StartSubtask", arg0, arg1, arg2) ret0, _ := ret[0].(error) return ret0 } // StartSubtask indicates an expected call of StartSubtask. -func (mr *MockTaskTableMockRecorder) StartSubtask(arg0, arg1 any) *gomock.Call { +func (mr *MockTaskTableMockRecorder) StartSubtask(arg0, arg1, arg2 any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StartSubtask", reflect.TypeOf((*MockTaskTable)(nil).StartSubtask), arg0, arg1) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StartSubtask", reflect.TypeOf((*MockTaskTable)(nil).StartSubtask), arg0, arg1, arg2) } // UpdateErrorToSubtask mocks base method. diff --git a/pkg/disttask/framework/scheduler/interface.go b/pkg/disttask/framework/scheduler/interface.go index ebe9e8d014be8..253b8e4ba1b2c 100644 --- a/pkg/disttask/framework/scheduler/interface.go +++ b/pkg/disttask/framework/scheduler/interface.go @@ -71,7 +71,7 @@ type TaskManager interface { GetManagedNodes(ctx context.Context) ([]string, error) GetTaskExecutorIDsByTaskID(ctx context.Context, taskID int64) ([]string, error) GetSubtasksByStepAndState(ctx context.Context, taskID int64, step proto.Step, state proto.TaskState) ([]*proto.Subtask, error) - GetSubtasksByExecIdsAndStepAndState(ctx context.Context, tidbIDs []string, taskID int64, step proto.Step, state proto.TaskState) ([]*proto.Subtask, error) + GetSubtasksByExecIdsAndStepAndState(ctx context.Context, execIDs []string, taskID int64, step proto.Step, state proto.TaskState) ([]*proto.Subtask, error) GetTaskExecutorIDsByTaskIDAndStep(ctx context.Context, taskID int64, step proto.Step) ([]string, error) WithNewSession(fn func(se sessionctx.Context) error) error diff --git a/pkg/disttask/framework/storage/table_test.go b/pkg/disttask/framework/storage/table_test.go index 7b0c6795c85a3..997c4516673e3 100644 --- a/pkg/disttask/framework/storage/table_test.go +++ b/pkg/disttask/framework/storage/table_test.go @@ -487,7 +487,11 @@ func TestSubTaskTable(t *testing.T) { ts := time.Now() time.Sleep(time.Second) - require.NoError(t, sm.StartSubtask(ctx, 1)) + err = sm.StartSubtask(ctx, 1, "tidb1") + require.NoError(t, err) + + err = sm.StartSubtask(ctx, 1, "tidb2") + require.Error(t, storage.ErrSubtaskNotFound, err) subtask, err = sm.GetFirstSubtaskInStates(ctx, "tidb1", 1, proto.StepInit, proto.TaskStatePending) require.NoError(t, err) @@ -565,7 +569,9 @@ func TestSubTaskTable(t *testing.T) { testutil.CreateSubTask(t, sm, 4, proto.StepInit, "for_test1", []byte("test"), proto.TaskTypeExample, 11, false) subtask, err = sm.GetFirstSubtaskInStates(ctx, "for_test1", 4, proto.StepInit, proto.TaskStatePending) require.NoError(t, err) - require.NoError(t, sm.StartSubtask(ctx, subtask.ID)) + err = sm.StartSubtask(ctx, subtask.ID, "for_test1") + require.NoError(t, err) + subtask, err = sm.GetFirstSubtaskInStates(ctx, "for_test1", 4, proto.StepInit, proto.TaskStateRunning) require.NoError(t, err) require.Greater(t, subtask.StartTime, ts) diff --git a/pkg/disttask/framework/storage/task_table.go b/pkg/disttask/framework/storage/task_table.go index 888d8747124d8..80e8515c34c9d 100644 --- a/pkg/disttask/framework/storage/task_table.go +++ b/pkg/disttask/framework/storage/task_table.go @@ -63,6 +63,9 @@ var ( // unstable, i.e. count, order and content of the subtasks are changed on // different call. ErrUnstableSubtasks = errors.New("unstable subtasks") + // ErrSubtaskNotFound is the error when can't find subtask by subtask_id and execId, + // i.e. scheduler change the subtask's execId when subtask need to balance to other nodes. + ErrSubtaskNotFound = errors.New("subtask not found") ) // SessionExecutor defines the interface for executing SQLs in a session. @@ -114,7 +117,6 @@ func SetTaskManager(is *TaskManager) { } // ExecSQL executes the sql and returns the result. -// TODO: consider retry. func ExecSQL(ctx context.Context, se sessionctx.Context, sql string, args ...interface{}) ([]chunk.Row, error) { rs, err := se.(sqlexec.SQLExecutor).ExecuteInternal(ctx, sql, args...) if err != nil { @@ -490,8 +492,8 @@ func row2SubTask(r chunk.Row) *proto.Subtask { } // GetSubtasksByStepAndStates gets all subtasks by given states. -func (stm *TaskManager) GetSubtasksByStepAndStates(ctx context.Context, tidbID string, taskID int64, step proto.Step, states ...interface{}) ([]*proto.Subtask, error) { - args := []interface{}{tidbID, taskID, step} +func (stm *TaskManager) GetSubtasksByStepAndStates(ctx context.Context, execID string, taskID int64, step proto.Step, states ...interface{}) ([]*proto.Subtask, error) { + args := []interface{}{execID, taskID, step} args = append(args, states...) rs, err := stm.executeSQLWithNewSession(ctx, `select `+subtaskColumns+` from mysql.tidb_background_subtask where exec_id = %? and task_key = %? and step = %? @@ -508,14 +510,14 @@ func (stm *TaskManager) GetSubtasksByStepAndStates(ctx context.Context, tidbID s } // GetSubtasksByExecIdsAndStepAndState gets all subtasks by given taskID, exec_id, step and state. -func (stm *TaskManager) GetSubtasksByExecIdsAndStepAndState(ctx context.Context, tidbIDs []string, taskID int64, step proto.Step, state proto.TaskState) ([]*proto.Subtask, error) { +func (stm *TaskManager) GetSubtasksByExecIdsAndStepAndState(ctx context.Context, execIDs []string, taskID int64, step proto.Step, state proto.TaskState) ([]*proto.Subtask, error) { args := []interface{}{taskID, step, state} - for _, tidbID := range tidbIDs { - args = append(args, tidbID) + for _, execID := range execIDs { + args = append(args, execID) } rs, err := stm.executeSQLWithNewSession(ctx, `select `+subtaskColumns+` from mysql.tidb_background_subtask where task_key = %? and step = %? and state = %? - and exec_id in (`+strings.Repeat("%?,", len(tidbIDs)-1)+"%?)", args...) + and exec_id in (`+strings.Repeat("%?,", len(execIDs)-1)+"%?)", args...) if err != nil { return nil, err } @@ -528,8 +530,8 @@ func (stm *TaskManager) GetSubtasksByExecIdsAndStepAndState(ctx context.Context, } // GetFirstSubtaskInStates gets the first subtask by given states. -func (stm *TaskManager) GetFirstSubtaskInStates(ctx context.Context, tidbID string, taskID int64, step proto.Step, states ...interface{}) (*proto.Subtask, error) { - args := []interface{}{tidbID, taskID, step} +func (stm *TaskManager) GetFirstSubtaskInStates(ctx context.Context, execID string, taskID int64, step proto.Step, states ...interface{}) (*proto.Subtask, error) { + args := []interface{}{execID, taskID, step} args = append(args, states...) rs, err := stm.executeSQLWithNewSession(ctx, `select `+subtaskColumns+` from mysql.tidb_background_subtask where exec_id = %? and task_key = %? and step = %? @@ -545,49 +547,34 @@ func (stm *TaskManager) GetFirstSubtaskInStates(ctx context.Context, tidbID stri } // UpdateSubtaskExecID updates the subtask's exec_id, used for testing now. -func (stm *TaskManager) UpdateSubtaskExecID(ctx context.Context, tidbID string, subtaskID int64) error { +func (stm *TaskManager) UpdateSubtaskExecID(ctx context.Context, execID string, subtaskID int64) error { _, err := stm.executeSQLWithNewSession(ctx, `update mysql.tidb_background_subtask set exec_id = %?, state_update_time = unix_timestamp() where id = %?`, - tidbID, subtaskID) + execID, subtaskID) return err } // UpdateErrorToSubtask updates the error to subtask. -func (stm *TaskManager) UpdateErrorToSubtask(ctx context.Context, tidbID string, taskID int64, err error) error { +func (stm *TaskManager) UpdateErrorToSubtask(ctx context.Context, execID string, taskID int64, err error) error { if err == nil { return nil } - _, err1 := stm.executeSQLWithNewSession(ctx, `update mysql.tidb_background_subtask + _, err1 := stm.executeSQLWithNewSession(ctx, + `update mysql.tidb_background_subtask set state = %?, error = %?, start_time = unix_timestamp(), state_update_time = unix_timestamp() - where exec_id = %? and task_key = %? and state in (%?, %?) limit 1;`, - proto.TaskStateFailed, serializeErr(err), tidbID, taskID, proto.TaskStatePending, proto.TaskStateRunning) + where exec_id = %? and + task_key = %? and + state in (%?, %?) + limit 1;`, + proto.TaskStateFailed, + serializeErr(err), + execID, + taskID, + proto.TaskStatePending, + proto.TaskStateRunning) return err1 } -// PrintSubtaskInfo log the subtask info by taskKey. Only used for UT. -func (stm *TaskManager) PrintSubtaskInfo(ctx context.Context, taskID int64) { - rs, _ := stm.executeSQLWithNewSession(ctx, - `select `+subtaskColumns+` from mysql.tidb_background_subtask_history where task_key = %?`, taskID) - rs2, _ := stm.executeSQLWithNewSession(ctx, - `select `+subtaskColumns+` from mysql.tidb_background_subtask where task_key = %?`, taskID) - rs = append(rs, rs2...) - - for _, r := range rs { - errBytes := r.GetBytes(13) - var err error - if len(errBytes) > 0 { - stdErr := errors.Normalize("") - err1 := stdErr.UnmarshalJSON(errBytes) - if err1 != nil { - err = err1 - } else { - err = stdErr - } - } - logutil.BgLogger().Info(fmt.Sprintf("subTask: %v\n", row2SubTask(r)), zap.Error(err)) - } -} - // GetSubtasksByStepAndState gets the subtask by step and state. func (stm *TaskManager) GetSubtasksByStepAndState(ctx context.Context, taskID int64, step proto.Step, state proto.TaskState) ([]*proto.Subtask, error) { rs, err := stm.executeSQLWithNewSession(ctx, `select `+subtaskColumns+` from mysql.tidb_background_subtask @@ -673,8 +660,8 @@ func (stm *TaskManager) CollectSubTaskError(ctx context.Context, taskID int64) ( } // HasSubtasksInStates checks if there are subtasks in the states. -func (stm *TaskManager) HasSubtasksInStates(ctx context.Context, tidbID string, taskID int64, step proto.Step, states ...interface{}) (bool, error) { - args := []interface{}{tidbID, taskID, step} +func (stm *TaskManager) HasSubtasksInStates(ctx context.Context, execID string, taskID int64, step proto.Step, states ...interface{}) (bool, error) { + args := []interface{}{execID, taskID, step} args = append(args, states...) rs, err := stm.executeSQLWithNewSession(ctx, `select 1 from mysql.tidb_background_subtask where exec_id = %? and task_key = %? and step = %? @@ -687,35 +674,49 @@ func (stm *TaskManager) HasSubtasksInStates(ctx context.Context, tidbID string, } // StartSubtask updates the subtask state to running. -func (stm *TaskManager) StartSubtask(ctx context.Context, subtaskID int64) error { - _, err := stm.executeSQLWithNewSession(ctx, `update mysql.tidb_background_subtask - set state = %?, start_time = unix_timestamp(), state_update_time = unix_timestamp() - where id = %?`, - proto.TaskStateRunning, subtaskID) +func (stm *TaskManager) StartSubtask(ctx context.Context, subtaskID int64, execID string) error { + err := stm.WithNewTxn(ctx, func(se sessionctx.Context) error { + vars := se.GetSessionVars() + _, err := ExecSQL(ctx, + se, + `update mysql.tidb_background_subtask + set state = %?, start_time = unix_timestamp(), state_update_time = unix_timestamp() + where id = %? and exec_id = %?`, + proto.TaskStateRunning, + subtaskID, + execID) + if err != nil { + return err + } + if vars.StmtCtx.AffectedRows() == 0 { + return ErrSubtaskNotFound + } + return nil + }) return err } // StartManager insert the manager information into dist_framework_meta. -func (stm *TaskManager) StartManager(ctx context.Context, tidbID string, role string) error { +func (stm *TaskManager) StartManager(ctx context.Context, execID string, role string) error { _, err := stm.executeSQLWithNewSession(ctx, `insert into mysql.dist_framework_meta(host, role, keyspace_id) SELECT %?, %?,-1 - WHERE NOT EXISTS (SELECT 1 FROM mysql.dist_framework_meta WHERE host = %?)`, tidbID, role, tidbID) + WHERE NOT EXISTS (SELECT 1 FROM mysql.dist_framework_meta WHERE host = %?)`, execID, role, execID) return err } // UpdateSubtaskStateAndError updates the subtask state. -func (stm *TaskManager) UpdateSubtaskStateAndError(ctx context.Context, tidbID string, id int64, state proto.TaskState, subTaskErr error) error { +func (stm *TaskManager) UpdateSubtaskStateAndError(ctx context.Context, execID string, id int64, state proto.TaskState, subTaskErr error) error { _, err := stm.executeSQLWithNewSession(ctx, `update mysql.tidb_background_subtask set state = %?, error = %?, state_update_time = unix_timestamp() where id = %? and exec_id = %?`, - state, serializeErr(subTaskErr), id, tidbID) + state, serializeErr(subTaskErr), id, execID) return err } // FinishSubtask updates the subtask meta and mark state to succeed. -func (stm *TaskManager) FinishSubtask(ctx context.Context, tidbID string, id int64, meta []byte) error { +func (stm *TaskManager) FinishSubtask(ctx context.Context, execID string, id int64, meta []byte) error { _, err := stm.executeSQLWithNewSession(ctx, `update mysql.tidb_background_subtask set meta = %?, state = %?, state_update_time = unix_timestamp() where id = %? and exec_id = %?`, - meta, proto.TaskStateSucceed, id, tidbID) + meta, proto.TaskStateSucceed, id, execID) return err } @@ -825,9 +826,9 @@ func (stm *TaskManager) DeleteDeadNodes(ctx context.Context, nodes []string) err } // PauseSubtasks update all running/pending subtasks to pasued state. -func (stm *TaskManager) PauseSubtasks(ctx context.Context, tidbID string, taskID int64) error { +func (stm *TaskManager) PauseSubtasks(ctx context.Context, execID string, taskID int64) error { _, err := stm.executeSQLWithNewSession(ctx, - `update mysql.tidb_background_subtask set state = "paused" where task_key = %? and state in ("running", "pending") and exec_id = %?`, taskID, tidbID) + `update mysql.tidb_background_subtask set state = "paused" where task_key = %? and state in ("running", "pending") and exec_id = %?`, taskID, execID) return err } @@ -864,7 +865,7 @@ func (stm *TaskManager) SwitchTaskStep( } if vars.StmtCtx.AffectedRows() == 0 { // on network partition or owner change, there might be multiple - // dispatchers for the same task, if other dispatcher has switched + // schedulers for the same task, if other scheduler has switched // the task to next step, skip the update process. // Or when there is no such task. return nil diff --git a/pkg/disttask/framework/storage/util.go b/pkg/disttask/framework/storage/util.go index 3a5bc97f9675b..6a1922c7624bb 100644 --- a/pkg/disttask/framework/storage/util.go +++ b/pkg/disttask/framework/storage/util.go @@ -16,8 +16,12 @@ package storage import ( "context" + "fmt" + "github.com/pingcap/errors" "github.com/pingcap/tidb/pkg/disttask/framework/proto" + "github.com/pingcap/tidb/pkg/util/logutil" + "go.uber.org/zap" ) // GetSubtasksFromHistoryForTest gets subtasks from history table for test. @@ -66,3 +70,27 @@ func GetTasksFromHistoryForTest(ctx context.Context, stm *TaskManager) (int, err } return len(rs), nil } + +// PrintSubtaskInfo log the subtask info by taskKey. Only used for UT. +func (stm *TaskManager) PrintSubtaskInfo(ctx context.Context, taskID int64) { + rs, _ := stm.executeSQLWithNewSession(ctx, + `select `+subtaskColumns+` from mysql.tidb_background_subtask_history where task_key = %?`, taskID) + rs2, _ := stm.executeSQLWithNewSession(ctx, + `select `+subtaskColumns+` from mysql.tidb_background_subtask where task_key = %?`, taskID) + rs = append(rs, rs2...) + + for _, r := range rs { + errBytes := r.GetBytes(13) + var err error + if len(errBytes) > 0 { + stdErr := errors.Normalize("") + err1 := stdErr.UnmarshalJSON(errBytes) + if err1 != nil { + err = err1 + } else { + err = stdErr + } + } + logutil.BgLogger().Info(fmt.Sprintf("subTask: %v\n", row2SubTask(r)), zap.Error(err)) + } +} diff --git a/pkg/disttask/framework/taskexecutor/interface.go b/pkg/disttask/framework/taskexecutor/interface.go index f969979734b28..c25ae29d9ab92 100644 --- a/pkg/disttask/framework/taskexecutor/interface.go +++ b/pkg/disttask/framework/taskexecutor/interface.go @@ -26,17 +26,19 @@ type TaskTable interface { GetTasksInStates(ctx context.Context, states ...interface{}) (task []*proto.Task, err error) GetTaskByID(ctx context.Context, taskID int64) (task *proto.Task, err error) - GetSubtasksByStepAndStates(ctx context.Context, tidbID string, taskID int64, step proto.Step, states ...interface{}) ([]*proto.Subtask, error) + GetSubtasksByStepAndStates(ctx context.Context, execID string, taskID int64, step proto.Step, states ...interface{}) ([]*proto.Subtask, error) GetFirstSubtaskInStates(ctx context.Context, instanceID string, taskID int64, step proto.Step, states ...interface{}) (*proto.Subtask, error) - StartManager(ctx context.Context, tidbID string, role string) error - StartSubtask(ctx context.Context, subtaskID int64) error - UpdateSubtaskStateAndError(ctx context.Context, tidbID string, subtaskID int64, state proto.TaskState, err error) error - FinishSubtask(ctx context.Context, tidbID string, subtaskID int64, meta []byte) error - - HasSubtasksInStates(ctx context.Context, tidbID string, taskID int64, step proto.Step, states ...interface{}) (bool, error) - UpdateErrorToSubtask(ctx context.Context, tidbID string, taskID int64, err error) error - IsTaskExecutorCanceled(ctx context.Context, tidbID string, taskID int64) (bool, error) - PauseSubtasks(ctx context.Context, tidbID string, taskID int64) error + StartManager(ctx context.Context, execID string, role string) error + // StartSubtask try to update the subtask's state to running if the subtask is owned by execID. + // If the update success, it means the execID's related task executor own the subtask. + StartSubtask(ctx context.Context, subtaskID int64, execID string) error + UpdateSubtaskStateAndError(ctx context.Context, execID string, subtaskID int64, state proto.TaskState, err error) error + FinishSubtask(ctx context.Context, execID string, subtaskID int64, meta []byte) error + + HasSubtasksInStates(ctx context.Context, execID string, taskID int64, step proto.Step, states ...interface{}) (bool, error) + UpdateErrorToSubtask(ctx context.Context, execID string, taskID int64, err error) error + IsTaskExecutorCanceled(ctx context.Context, execID string, taskID int64) (bool, error) + PauseSubtasks(ctx context.Context, execID string, taskID int64) error } // Pool defines the interface of a pool. diff --git a/pkg/disttask/framework/taskexecutor/task_executor.go b/pkg/disttask/framework/taskexecutor/task_executor.go index c870bcee3f2e6..01ffb0e9c90cb 100644 --- a/pkg/disttask/framework/taskexecutor/task_executor.go +++ b/pkg/disttask/framework/taskexecutor/task_executor.go @@ -256,9 +256,15 @@ func (s *BaseTaskExecutor) run(ctx context.Context, task *proto.Task) (resErr er } } else { // subtask.State == proto.TaskStatePending - s.startSubtaskAndUpdateState(runCtx, subtask) - if err := s.getError(); err != nil { + err := s.startSubtaskAndUpdateState(runCtx, subtask) + if err != nil { logutil.Logger(s.logCtx).Warn("startSubtaskAndUpdateState meets error", zap.Error(err)) + // should ignore ErrSubtaskNotFound + // since the err only indicate that the subtask not owned by current task executor. + if err == storage.ErrSubtaskNotFound { + continue + } + s.onError(err) continue } } @@ -530,22 +536,25 @@ func (s *BaseTaskExecutor) resetError() { s.mu.handled = false } -func (s *BaseTaskExecutor) startSubtaskAndUpdateState(ctx context.Context, subtask *proto.Subtask) { - metrics.DecDistTaskSubTaskCnt(subtask) - metrics.EndDistTaskSubTask(subtask) - s.startSubtask(ctx, subtask.ID) - subtask.State = proto.TaskStateRunning - metrics.IncDistTaskSubTaskCnt(subtask) - metrics.StartDistTaskSubTask(subtask) +func (s *BaseTaskExecutor) startSubtaskAndUpdateState(ctx context.Context, subtask *proto.Subtask) error { + err := s.startSubtask(ctx, subtask.ID) + if err == nil { + metrics.DecDistTaskSubTaskCnt(subtask) + metrics.EndDistTaskSubTask(subtask) + subtask.State = proto.TaskStateRunning + metrics.IncDistTaskSubTaskCnt(subtask) + metrics.StartDistTaskSubTask(subtask) + } + return err } -func (s *BaseTaskExecutor) updateSubtaskStateAndErrorImpl(ctx context.Context, tidbID string, subtaskID int64, state proto.TaskState, subTaskErr error) { +func (s *BaseTaskExecutor) updateSubtaskStateAndErrorImpl(ctx context.Context, execID string, subtaskID int64, state proto.TaskState, subTaskErr error) { // retry for 3+6+12+24+(30-4)*30 ~= 825s ~= 14 minutes logger := logutil.Logger(s.logCtx) backoffer := backoff.NewExponential(scheduler.RetrySQLInterval, 2, scheduler.RetrySQLMaxInterval) err := handle.RunWithRetry(ctx, scheduler.RetrySQLTimes, backoffer, logger, func(ctx context.Context) (bool, error) { - return true, s.taskTable.UpdateSubtaskStateAndError(ctx, tidbID, subtaskID, state, subTaskErr) + return true, s.taskTable.UpdateSubtaskStateAndError(ctx, execID, subtaskID, state, subTaskErr) }, ) if err != nil { @@ -553,18 +562,23 @@ func (s *BaseTaskExecutor) updateSubtaskStateAndErrorImpl(ctx context.Context, t } } -func (s *BaseTaskExecutor) startSubtask(ctx context.Context, subtaskID int64) { +// startSubtask try to change the state of the subtask to running. +// If the subtask is not owned by the task executor, +// the update will fail and task executor should not run the subtask. +func (s *BaseTaskExecutor) startSubtask(ctx context.Context, subtaskID int64) error { // retry for 3+6+12+24+(30-4)*30 ~= 825s ~= 14 minutes logger := logutil.Logger(s.logCtx) backoffer := backoff.NewExponential(scheduler.RetrySQLInterval, 2, scheduler.RetrySQLMaxInterval) - err := handle.RunWithRetry(ctx, scheduler.RetrySQLTimes, backoffer, logger, + return handle.RunWithRetry(ctx, scheduler.RetrySQLTimes, backoffer, logger, func(ctx context.Context) (bool, error) { - return true, s.taskTable.StartSubtask(ctx, subtaskID) + err := s.taskTable.StartSubtask(ctx, subtaskID, s.id) + if err == storage.ErrSubtaskNotFound { + // No need to retry. + return false, err + } + return true, err }, ) - if err != nil { - s.onError(err) - } } func (s *BaseTaskExecutor) finishSubtask(ctx context.Context, subtask *proto.Subtask) { diff --git a/pkg/disttask/framework/taskexecutor/task_executor_test.go b/pkg/disttask/framework/taskexecutor/task_executor_test.go index bbe2d37196667..13676cf03b1f8 100644 --- a/pkg/disttask/framework/taskexecutor/task_executor_test.go +++ b/pkg/disttask/framework/taskexecutor/task_executor_test.go @@ -22,6 +22,7 @@ import ( "github.com/pingcap/tidb/pkg/disttask/framework/mock" mockexecute "github.com/pingcap/tidb/pkg/disttask/framework/mock/execute" "github.com/pingcap/tidb/pkg/disttask/framework/proto" + "github.com/pingcap/tidb/pkg/disttask/framework/storage" "github.com/stretchr/testify/require" "go.uber.org/mock/gomock" "google.golang.org/grpc/codes" @@ -85,7 +86,7 @@ func TestTaskExecutorRun(t *testing.T) { mockSubtaskTable.EXPECT().GetFirstSubtaskInStates(gomock.Any(), "id", taskID, proto.StepOne, unfinishedNormalSubtaskStates...).Return(&proto.Subtask{ ID: 1, Type: tp, Step: proto.StepOne, State: proto.TaskStatePending, ExecID: "id"}, nil) - mockSubtaskTable.EXPECT().StartSubtask(gomock.Any(), taskID).Return(nil) + mockSubtaskTable.EXPECT().StartSubtask(gomock.Any(), taskID, "id").Return(nil) mockSubtaskExecutor.EXPECT().RunSubtask(gomock.Any(), gomock.Any()).Return(runSubtaskErr) mockSubtaskTable.EXPECT().UpdateSubtaskStateAndError(gomock.Any(), "id", taskID, proto.TaskStateFailed, gomock.Any()).Return(nil) mockSubtaskExecutor.EXPECT().Cleanup(gomock.Any()).Return(nil) @@ -101,7 +102,7 @@ func TestTaskExecutorRun(t *testing.T) { mockSubtaskTable.EXPECT().GetFirstSubtaskInStates(gomock.Any(), "id", taskID, proto.StepOne, unfinishedNormalSubtaskStates...).Return(&proto.Subtask{ ID: 1, Type: tp, Step: proto.StepOne, State: proto.TaskStatePending, ExecID: "id"}, nil) - mockSubtaskTable.EXPECT().StartSubtask(gomock.Any(), taskID).Return(nil) + mockSubtaskTable.EXPECT().StartSubtask(gomock.Any(), taskID, "id").Return(nil) mockSubtaskExecutor.EXPECT().RunSubtask(gomock.Any(), gomock.Any()).Return(nil) mockSubtaskExecutor.EXPECT().OnFinished(gomock.Any(), gomock.Any()).Return(nil) mockSubtaskTable.EXPECT().FinishSubtask(gomock.Any(), "id", int64(1), gomock.Any()).Return(nil) @@ -124,7 +125,7 @@ func TestTaskExecutorRun(t *testing.T) { mockSubtaskTable.EXPECT().GetFirstSubtaskInStates(gomock.Any(), "id", taskID, proto.StepOne, unfinishedNormalSubtaskStates...).Return(&proto.Subtask{ ID: 1, Type: tp, Step: proto.StepOne, State: proto.TaskStatePending, ExecID: "id"}, nil) - mockSubtaskTable.EXPECT().StartSubtask(gomock.Any(), int64(1)).Return(nil) + mockSubtaskTable.EXPECT().StartSubtask(gomock.Any(), int64(1), "id").Return(nil) mockSubtaskExecutor.EXPECT().RunSubtask(gomock.Any(), gomock.Any()).Return(nil) mockSubtaskExecutor.EXPECT().OnFinished(gomock.Any(), gomock.Any()).Return(nil) mockSubtaskTable.EXPECT().FinishSubtask(gomock.Any(), "id", int64(1), gomock.Any()).Return(nil) @@ -132,7 +133,7 @@ func TestTaskExecutorRun(t *testing.T) { mockSubtaskTable.EXPECT().GetFirstSubtaskInStates(gomock.Any(), "id", taskID, proto.StepOne, unfinishedNormalSubtaskStates...).Return(&proto.Subtask{ ID: 2, Type: tp, Step: proto.StepOne, State: proto.TaskStatePending, ExecID: "id"}, nil) - mockSubtaskTable.EXPECT().StartSubtask(gomock.Any(), int64(2)).Return(nil) + mockSubtaskTable.EXPECT().StartSubtask(gomock.Any(), int64(2), "id").Return(nil) mockSubtaskExecutor.EXPECT().RunSubtask(gomock.Any(), gomock.Any()).Return(nil) mockSubtaskExecutor.EXPECT().OnFinished(gomock.Any(), gomock.Any()).Return(nil) mockSubtaskTable.EXPECT().FinishSubtask(gomock.Any(), "id", int64(2), gomock.Any()).Return(nil) @@ -188,7 +189,7 @@ func TestTaskExecutorRun(t *testing.T) { mockSubtaskTable.EXPECT().GetFirstSubtaskInStates(gomock.Any(), "id", taskID, proto.StepOne, unfinishedNormalSubtaskStates...).Return(&proto.Subtask{ ID: 1, Type: tp, Step: proto.StepOne, State: proto.TaskStatePending, ExecID: "id"}, nil) - mockSubtaskTable.EXPECT().StartSubtask(gomock.Any(), taskID).Return(nil) + mockSubtaskTable.EXPECT().StartSubtask(gomock.Any(), taskID, "id").Return(nil) mockSubtaskExecutor.EXPECT().RunSubtask(gomock.Any(), gomock.Any()).Return(ErrCancelSubtask) mockSubtaskTable.EXPECT().UpdateSubtaskStateAndError(gomock.Any(), "id", taskID, proto.TaskStateCanceled, gomock.Any()).Return(nil) mockSubtaskExecutor.EXPECT().Cleanup(gomock.Any()).Return(nil) @@ -203,7 +204,7 @@ func TestTaskExecutorRun(t *testing.T) { mockSubtaskTable.EXPECT().GetFirstSubtaskInStates(gomock.Any(), "id", taskID, proto.StepOne, unfinishedNormalSubtaskStates...).Return(&proto.Subtask{ ID: 1, Type: tp, Step: proto.StepOne, State: proto.TaskStatePending, ExecID: "id"}, nil) - mockSubtaskTable.EXPECT().StartSubtask(gomock.Any(), taskID).Return(nil) + mockSubtaskTable.EXPECT().StartSubtask(gomock.Any(), taskID, "id").Return(nil) mockSubtaskExecutor.EXPECT().RunSubtask(gomock.Any(), gomock.Any()).Return(context.Canceled) mockSubtaskExecutor.EXPECT().Cleanup(gomock.Any()).Return(nil) err = taskExecutor.Run(runCtx, task) @@ -217,7 +218,7 @@ func TestTaskExecutorRun(t *testing.T) { mockSubtaskTable.EXPECT().GetFirstSubtaskInStates(gomock.Any(), "id", taskID, proto.StepOne, unfinishedNormalSubtaskStates...).Return(&proto.Subtask{ ID: 1, Type: tp, Step: proto.StepOne, State: proto.TaskStatePending, ExecID: "id"}, nil) - mockSubtaskTable.EXPECT().StartSubtask(gomock.Any(), taskID).Return(nil) + mockSubtaskTable.EXPECT().StartSubtask(gomock.Any(), taskID, "id").Return(nil) grpcErr := status.Error(codes.Canceled, "test cancel") mockSubtaskExecutor.EXPECT().RunSubtask(gomock.Any(), gomock.Any()).Return(grpcErr) mockSubtaskExecutor.EXPECT().Cleanup(gomock.Any()).Return(nil) @@ -232,7 +233,7 @@ func TestTaskExecutorRun(t *testing.T) { mockSubtaskTable.EXPECT().GetFirstSubtaskInStates(gomock.Any(), "id", taskID, proto.StepOne, unfinishedNormalSubtaskStates...).Return(&proto.Subtask{ ID: 1, Type: tp, Step: proto.StepOne, State: proto.TaskStatePending, ExecID: "id"}, nil) - mockSubtaskTable.EXPECT().StartSubtask(gomock.Any(), taskID).Return(nil) + mockSubtaskTable.EXPECT().StartSubtask(gomock.Any(), taskID, "id").Return(nil) grpcErr = status.Error(codes.Canceled, "test cancel") annotatedError := errors.Annotatef( grpcErr, @@ -244,6 +245,21 @@ func TestTaskExecutorRun(t *testing.T) { err = taskExecutor.Run(runCtx, task) require.EqualError(t, err, annotatedError.Error()) + // 10. subtask owned by other executor + mockSubtaskExecutor.EXPECT().Init(gomock.Any()).Return(nil) + mockSubtaskTable.EXPECT().GetSubtasksByStepAndStates(gomock.Any(), "id", taskID, proto.StepOne, + unfinishedNormalSubtaskStates...).Return([]*proto.Subtask{{ + ID: 1, Type: tp, Step: proto.StepOne, State: proto.TaskStatePending, ExecID: "id"}}, nil) + mockSubtaskTable.EXPECT().GetFirstSubtaskInStates(gomock.Any(), "id", taskID, proto.StepOne, + unfinishedNormalSubtaskStates...).Return(&proto.Subtask{ + ID: 1, Type: tp, Step: proto.StepOne, State: proto.TaskStatePending, ExecID: "id"}, nil) + mockSubtaskTable.EXPECT().GetFirstSubtaskInStates(gomock.Any(), "id", taskID, proto.StepOne, + unfinishedNormalSubtaskStates...).Return(nil, nil) + mockSubtaskTable.EXPECT().StartSubtask(gomock.Any(), taskID, "id").Return(storage.ErrSubtaskNotFound) + mockSubtaskTable.EXPECT().GetTaskByID(gomock.Any(), gomock.Any()).Return(&proto.Task{ID: taskID, Step: proto.StepTwo}, nil) + mockSubtaskExecutor.EXPECT().Cleanup(gomock.Any()).Return(nil) + err = taskExecutor.Run(runCtx, task) + require.NoError(t, err) runCancel() } @@ -383,7 +399,7 @@ func TestTaskExecutor(t *testing.T) { mockSubtaskTable.EXPECT().GetFirstSubtaskInStates(gomock.Any(), "id", taskID, proto.StepOne, unfinishedNormalSubtaskStates...).Return(&proto.Subtask{ ID: 1, Type: tp, Step: proto.StepOne, State: proto.TaskStatePending, ExecID: "id"}, nil) - mockSubtaskTable.EXPECT().StartSubtask(gomock.Any(), taskID).Return(nil) + mockSubtaskTable.EXPECT().StartSubtask(gomock.Any(), taskID, "id").Return(nil) mockSubtaskExecutor.EXPECT().RunSubtask(gomock.Any(), gomock.Any()).Return(runSubtaskErr) mockSubtaskExecutor.EXPECT().Cleanup(gomock.Any()).Return(nil) err := taskExecutor.run(runCtx, &proto.Task{Step: proto.StepOne, Type: tp, ID: taskID, Concurrency: concurrency})