Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

disttask: rollback subtasks without init taskexecutor #50987

Merged
merged 2 commits into from
Feb 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pkg/disttask/framework/taskexecutor/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ go_test(
],
embed = [":taskexecutor"],
flaky = True,
shard_count = 16,
shard_count = 15,
deps = [
"//pkg/disttask/framework/mock",
"//pkg/disttask/framework/mock/execute",
Expand Down
15 changes: 10 additions & 5 deletions pkg/disttask/framework/taskexecutor/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -204,18 +204,18 @@ func (m *Manager) handleTasks() {
executableTasks := make([]*storage.TaskExecInfo, 0, len(tasks))
for _, task := range tasks {
switch task.State {
case proto.TaskStateRunning, proto.TaskStateReverting:
if task.State == proto.TaskStateReverting {
m.cancelRunningSubtaskOf(task.ID)
}
// TaskStateReverting require executor to run rollback logic.
case proto.TaskStateRunning:
if !m.isExecutorStarted(task.ID) {
executableTasks = append(executableTasks, task)
}
case proto.TaskStatePausing:
if err := m.handlePausingTask(task.ID); err != nil {
m.logErr(err)
}
case proto.TaskStateReverting:
if err := m.handleRevertingTask(task.ID); err != nil {
m.logErr(err)
}
}
}

Expand Down Expand Up @@ -266,6 +266,11 @@ func (m *Manager) handlePausingTask(taskID int64) error {
return m.taskTable.PauseSubtasks(m.ctx, m.id, taskID)
}

func (m *Manager) handleRevertingTask(taskID int64) error {
m.cancelRunningSubtaskOf(taskID)
return m.taskTable.CancelSubtask(m.ctx, m.id, taskID)
}

// recoverMetaLoop recovers dist_framework_meta for the tidb node running the taskExecutor manager.
// This is necessary when the TiDB node experiences a prolonged network partition
// and the scheduler deletes `dist_framework_meta`.
Expand Down
14 changes: 10 additions & 4 deletions pkg/disttask/framework/taskexecutor/manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,14 @@ func TestManageTaskExecutor(t *testing.T) {
mockTaskTable.EXPECT().PauseSubtasks(m.ctx, "test", int64(1)).Return(errors.New("pause failed"))
require.ErrorContains(t, m.handlePausingTask(1), "pause failed")
require.True(t, ctrl.Satisfied())

// handle reverting
executor1.EXPECT().GetTask().Return(&proto.Task{ID: 1})
executor1.EXPECT().CancelRunningSubtask()
m.addTaskExecutor(executor1)
mockTaskTable.EXPECT().CancelSubtask(m.ctx, "test", int64(1)).Return(nil)
require.NoError(t, m.handleRevertingTask(1))
require.True(t, ctrl.Satisfied())
}

func TestHandleExecutableTasks(t *testing.T) {
Expand Down Expand Up @@ -191,10 +199,7 @@ func TestManager(t *testing.T) {
mockInternalExecutors[task1.ID].EXPECT().Run(gomock.Any())
mockInternalExecutors[task1.ID].EXPECT().Close()
// task2
mockInternalExecutors[task2.ID].EXPECT().GetTask().Return(task2).Times(2)
mockInternalExecutors[task2.ID].EXPECT().Init(gomock.Any()).Return(nil)
mockInternalExecutors[task2.ID].EXPECT().Run(gomock.Any())
mockInternalExecutors[task2.ID].EXPECT().Close()
mockTaskTable.EXPECT().CancelSubtask(m.ctx, m.id, task2.ID)
// task3
mockTaskTable.EXPECT().PauseSubtasks(m.ctx, id, task3.ID).Return(nil).AnyTimes()

Expand Down Expand Up @@ -265,6 +270,7 @@ func TestManagerHandleTasks(t *testing.T) {
mockTaskTable.EXPECT().GetTaskExecInfoByExecID(m.ctx, m.id).
Return([]*storage.TaskExecInfo{{Task: task1}}, nil)
mockInternalExecutor.EXPECT().CancelRunningSubtask()
mockTaskTable.EXPECT().CancelSubtask(m.ctx, m.id, task1.ID)
m.handleTasks()
require.True(t, ctrl.Satisfied())
require.True(t, m.isExecutorStarted(task1.ID))
Expand Down
41 changes: 2 additions & 39 deletions pkg/disttask/framework/taskexecutor/task_executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ func (e *BaseTaskExecutor) Run(resource *proto.StepResource) {
continue
}
task := e.task.Load()
if task.State != proto.TaskStateRunning && task.State != proto.TaskStateReverting {
if task.State != proto.TaskStateRunning {
return
}
if exist, err := e.taskTable.HasSubtasksInStates(e.ctx, e.id, task.ID, task.Step,
Expand All @@ -238,14 +238,7 @@ func (e *BaseTaskExecutor) Run(resource *proto.StepResource) {
}
// reset it when we get a subtask
checkInterval, noSubtaskCheckCnt = SubtaskCheckInterval, 0

switch task.State {
case proto.TaskStateRunning:
err = e.RunStep(resource)
case proto.TaskStateReverting:
// TODO: will remove it later, leave it now.
err = e.Rollback()
}
err = e.RunStep(resource)
if err != nil {
e.logger.Error("failed to handle task", zap.Error(err))
}
Expand Down Expand Up @@ -486,36 +479,6 @@ func (e *BaseTaskExecutor) onSubtaskFinished(ctx context.Context, executor execu
})
}

// Rollback rollbacks the subtask.
// TODO no need to start executor to do it, refactor it later.
func (e *BaseTaskExecutor) Rollback() error {
task := e.task.Load()
e.resetError()
e.logger.Info("task reverting, cancel unfinished subtasks")

// We should cancel all subtasks before rolling back
for {
// TODO we can update them using one sql, but requires change the metric
// gathering logic.
subtask, err := e.taskTable.GetFirstSubtaskInStates(e.ctx, e.id, task.ID, task.Step,
proto.SubtaskStatePending, proto.SubtaskStateRunning)
if err != nil {
e.onError(err)
return e.getError()
}

if subtask == nil {
break
}

e.updateSubtaskStateAndErrorImpl(e.ctx, subtask.ExecID, subtask.ID, proto.SubtaskStateCanceled, nil)
if err = e.getError(); err != nil {
return err
}
}
return e.getError()
}

// GetTask implements TaskExecutor.GetTask.
func (e *BaseTaskExecutor) GetTask() *proto.Task {
return e.task.Load()
Expand Down
70 changes: 1 addition & 69 deletions pkg/disttask/framework/taskexecutor/task_executor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -278,72 +278,11 @@ func TestTaskExecutorRun(t *testing.T) {
taskExecutor.Run(nil)
require.True(t, ctrl.Satisfied())

task1.State = proto.TaskStateReverting
mockSubtaskTable.EXPECT().GetTaskByID(gomock.Any(), task1.ID).Return(task1, nil)
mockSubtaskTable.EXPECT().HasSubtasksInStates(gomock.Any(), "id", task1.ID, task1.Step,
unfinishedNormalSubtaskStates...).Return(true, nil)
mockSubtaskTable.EXPECT().GetFirstSubtaskInStates(gomock.Any(), "id", task1.ID, proto.StepOne,
unfinishedNormalSubtaskStates...).Return(nil, nil)
mockSubtaskTable.EXPECT().GetTaskByID(gomock.Any(), task1.ID).Return(nil, storage.ErrTaskNotFound)
taskExecutor.Run(nil)
require.True(t, ctrl.Satisfied())

taskExecutor.Cancel()
taskExecutor.Run(nil)
require.True(t, ctrl.Satisfied())
}

func TestTaskExecutorRollback(t *testing.T) {
var tp proto.TaskType = "test_executor_rollback"
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
runCtx, runCancel := context.WithCancel(ctx)
defer runCancel()
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockSubtaskTable := mock.NewMockTaskTable(ctrl)
mockStepExecutor := mockexecute.NewMockStepExecutor(ctrl)
mockExtension := mock.NewMockExtension(ctrl)

// 1. no taskExecutor constructor
task1 := &proto.Task{Step: proto.StepOne, ID: 1, Type: tp}
taskExecutor := NewBaseTaskExecutor(ctx, "id", task1, mockSubtaskTable)
taskExecutor.Extension = mockExtension

mockExtension.EXPECT().GetStepExecutor(gomock.Any(), gomock.Any()).Return(mockStepExecutor, nil).AnyTimes()

// 2. get subtask failed
getSubtaskErr := errors.New("get subtask error")
mockSubtaskTable.EXPECT().GetFirstSubtaskInStates(runCtx, "id", task1.ID, proto.StepOne,
unfinishedNormalSubtaskStates...).Return(nil, getSubtaskErr)
err := taskExecutor.Rollback()
require.EqualError(t, err, getSubtaskErr.Error())

// 3. no subtask
mockSubtaskTable.EXPECT().GetFirstSubtaskInStates(runCtx, "id", task1.ID, proto.StepOne,
unfinishedNormalSubtaskStates...).Return(nil, nil)
err = taskExecutor.Rollback()
require.NoError(t, err)

// 5. rollback success
mockSubtaskTable.EXPECT().GetFirstSubtaskInStates(runCtx, "id", task1.ID, proto.StepOne,
unfinishedNormalSubtaskStates...).Return(&proto.Subtask{ID: 1, ExecID: "id"}, nil)
mockSubtaskTable.EXPECT().UpdateSubtaskStateAndError(runCtx, "id", int64(1), proto.SubtaskStateCanceled, nil).Return(nil)
mockSubtaskTable.EXPECT().GetFirstSubtaskInStates(runCtx, "id", task1.ID, proto.StepOne,
unfinishedNormalSubtaskStates...).Return(&proto.Subtask{ID: 2, ExecID: "id"}, nil)
mockSubtaskTable.EXPECT().UpdateSubtaskStateAndError(runCtx, "id", int64(2), proto.SubtaskStateCanceled, nil).Return(nil)
mockSubtaskTable.EXPECT().GetFirstSubtaskInStates(runCtx, "id", task1.ID, proto.StepOne,
unfinishedNormalSubtaskStates...).Return(nil, nil)
err = taskExecutor.Rollback()
require.NoError(t, err)

// rollback again for previous left subtask in TaskStateReverting state
mockSubtaskTable.EXPECT().GetFirstSubtaskInStates(runCtx, "id", task1.ID, proto.StepOne,
unfinishedNormalSubtaskStates...).Return(nil, nil)
err = taskExecutor.Rollback()
require.NoError(t, err)
}

func TestTaskExecutor(t *testing.T) {
var tp proto.TaskType = "test_task_executor"
var taskID int64 = 1
Expand Down Expand Up @@ -381,14 +320,7 @@ func TestTaskExecutor(t *testing.T) {
require.EqualError(t, err, runSubtaskErr.Error())
require.True(t, ctrl.Satisfied())

// 2. rollback success.
mockSubtaskTable.EXPECT().GetFirstSubtaskInStates(taskExecutor.ctx, "id", taskID, proto.StepOne,
unfinishedNormalSubtaskStates...).Return(nil, nil)
err = taskExecutor.Rollback()
require.NoError(t, err)
require.True(t, ctrl.Satisfied())

// 3. run one subtask, then task moved to history(ErrTaskNotFound).
// 2. run one subtask, then task moved to history(ErrTaskNotFound).
mockStepExecutor.EXPECT().Init(gomock.Any()).Return(nil)
mockSubtaskTable.EXPECT().GetFirstSubtaskInStates(gomock.Any(), "id", taskID, proto.StepOne,
unfinishedNormalSubtaskStates...).Return(subtasks[0], nil)
Expand Down