From 99f0349bfb684e108f73e07a1751ff718080fd95 Mon Sep 17 00:00:00 2001 From: D3Hunter Date: Tue, 2 Jan 2024 18:57:32 +0800 Subject: [PATCH] disttask: fix failed step is taken as success (#49971) close pingcap/tidb#49950 --- .../framework_pause_and_resume_test.go | 4 +- pkg/disttask/framework/handle/BUILD.bazel | 1 + pkg/disttask/framework/handle/handle_test.go | 3 + pkg/disttask/framework/mock/scheduler_mock.go | 19 ++--- pkg/disttask/framework/planner/BUILD.bazel | 1 + .../framework/planner/planner_test.go | 3 +- pkg/disttask/framework/scheduler/BUILD.bazel | 3 +- pkg/disttask/framework/scheduler/interface.go | 3 +- pkg/disttask/framework/scheduler/scheduler.go | 43 ++++++----- .../scheduler/scheduler_manager_test.go | 7 +- .../scheduler/scheduler_nokit_test.go | 40 ++++++++++ .../framework/scheduler/scheduler_test.go | 8 +- pkg/disttask/framework/storage/BUILD.bazel | 2 +- pkg/disttask/framework/storage/table_test.go | 76 ++++++++++++------- pkg/disttask/framework/storage/task_table.go | 28 ++++--- .../framework/taskexecutor/task_executor.go | 2 +- pkg/disttask/framework/testutil/context.go | 21 +++-- pkg/disttask/framework/testutil/task_util.go | 6 +- 18 files changed, 178 insertions(+), 92 deletions(-) create mode 100644 pkg/disttask/framework/scheduler/scheduler_nokit_test.go diff --git a/pkg/disttask/framework/framework_pause_and_resume_test.go b/pkg/disttask/framework/framework_pause_and_resume_test.go index bc20428db243b..b864857215782 100644 --- a/pkg/disttask/framework/framework_pause_and_resume_test.go +++ b/pkg/disttask/framework/framework_pause_and_resume_test.go @@ -31,11 +31,11 @@ func CheckSubtasksState(ctx context.Context, t *testing.T, taskID int64, state p mgr, err := storage.GetTaskManager() require.NoError(t, err) mgr.PrintSubtaskInfo(ctx, taskID) - cnt, err := mgr.GetSubtaskInStatesCnt(ctx, taskID, state) + cntByStates, err := mgr.GetSubtaskCntGroupByStates(ctx, taskID, proto.StepTwo) require.NoError(t, err) historySubTasksCnt, err := storage.GetSubtasksFromHistoryByTaskIDForTest(ctx, mgr, taskID) require.NoError(t, err) - require.Equal(t, expectedCnt, cnt+int64(historySubTasksCnt)) + require.Equal(t, expectedCnt, cntByStates[state]+int64(historySubTasksCnt)) } func TestFrameworkPauseAndResume(t *testing.T) { diff --git a/pkg/disttask/framework/handle/BUILD.bazel b/pkg/disttask/framework/handle/BUILD.bazel index 0d1b0f8355d9c..fa58fabc37f05 100644 --- a/pkg/disttask/framework/handle/BUILD.bazel +++ b/pkg/disttask/framework/handle/BUILD.bazel @@ -25,6 +25,7 @@ go_test( ":handle", "//pkg/disttask/framework/proto", "//pkg/disttask/framework/storage", + "//pkg/disttask/framework/testutil", "//pkg/testkit", "//pkg/util/backoff", "@com_github_ngaut_pools//:pools", diff --git a/pkg/disttask/framework/handle/handle_test.go b/pkg/disttask/framework/handle/handle_test.go index f22c57dbb687d..1e82b1639c1c1 100644 --- a/pkg/disttask/framework/handle/handle_test.go +++ b/pkg/disttask/framework/handle/handle_test.go @@ -28,6 +28,7 @@ import ( "github.com/pingcap/tidb/pkg/disttask/framework/handle" "github.com/pingcap/tidb/pkg/disttask/framework/proto" "github.com/pingcap/tidb/pkg/disttask/framework/storage" + "github.com/pingcap/tidb/pkg/disttask/framework/testutil" "github.com/pingcap/tidb/pkg/testkit" "github.com/pingcap/tidb/pkg/util/backoff" "github.com/stretchr/testify/require" @@ -52,6 +53,8 @@ func TestHandle(t *testing.T) { mgr := storage.NewTaskManager(pool) storage.SetTaskManager(mgr) + testutil.WaitNodeRegistered(ctx, t) + // no scheduler registered task, err := handle.SubmitTask(ctx, "1", proto.TaskTypeExample, 2, []byte("byte")) require.NoError(t, err) diff --git a/pkg/disttask/framework/mock/scheduler_mock.go b/pkg/disttask/framework/mock/scheduler_mock.go index 5c24654bcec26..933212d10557b 100644 --- a/pkg/disttask/framework/mock/scheduler_mock.go +++ b/pkg/disttask/framework/mock/scheduler_mock.go @@ -239,24 +239,19 @@ func (mr *MockTaskManagerMockRecorder) GetManagedNodes(arg0 any) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetManagedNodes", reflect.TypeOf((*MockTaskManager)(nil).GetManagedNodes), arg0) } -// GetSubtaskInStatesCnt mocks base method. -func (m *MockTaskManager) GetSubtaskInStatesCnt(arg0 context.Context, arg1 int64, arg2 ...proto.SubtaskState) (int64, error) { +// GetSubtaskCntGroupByStates mocks base method. +func (m *MockTaskManager) GetSubtaskCntGroupByStates(arg0 context.Context, arg1 int64, arg2 proto.Step) (map[proto.SubtaskState]int64, error) { m.ctrl.T.Helper() - varargs := []any{arg0, arg1} - for _, a := range arg2 { - varargs = append(varargs, a) - } - ret := m.ctrl.Call(m, "GetSubtaskInStatesCnt", varargs...) - ret0, _ := ret[0].(int64) + ret := m.ctrl.Call(m, "GetSubtaskCntGroupByStates", arg0, arg1, arg2) + ret0, _ := ret[0].(map[proto.SubtaskState]int64) ret1, _ := ret[1].(error) return ret0, ret1 } -// GetSubtaskInStatesCnt indicates an expected call of GetSubtaskInStatesCnt. -func (mr *MockTaskManagerMockRecorder) GetSubtaskInStatesCnt(arg0, arg1 any, arg2 ...any) *gomock.Call { +// GetSubtaskCntGroupByStates indicates an expected call of GetSubtaskCntGroupByStates. +func (mr *MockTaskManagerMockRecorder) GetSubtaskCntGroupByStates(arg0, arg1, arg2 any) *gomock.Call { mr.mock.ctrl.T.Helper() - varargs := append([]any{arg0, arg1}, arg2...) - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSubtaskInStatesCnt", reflect.TypeOf((*MockTaskManager)(nil).GetSubtaskInStatesCnt), varargs...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSubtaskCntGroupByStates", reflect.TypeOf((*MockTaskManager)(nil).GetSubtaskCntGroupByStates), arg0, arg1, arg2) } // GetSubtasksByExecIdsAndStepAndState mocks base method. diff --git a/pkg/disttask/framework/planner/BUILD.bazel b/pkg/disttask/framework/planner/BUILD.bazel index 3f28cede80197..8ea3e7c63e3b7 100644 --- a/pkg/disttask/framework/planner/BUILD.bazel +++ b/pkg/disttask/framework/planner/BUILD.bazel @@ -27,6 +27,7 @@ go_test( ":planner", "//pkg/disttask/framework/mock", "//pkg/disttask/framework/storage", + "//pkg/disttask/framework/testutil", "//pkg/kv", "//pkg/testkit", "@com_github_ngaut_pools//:pools", diff --git a/pkg/disttask/framework/planner/planner_test.go b/pkg/disttask/framework/planner/planner_test.go index e515c3e0e266f..0801ce6ecdb7a 100644 --- a/pkg/disttask/framework/planner/planner_test.go +++ b/pkg/disttask/framework/planner/planner_test.go @@ -23,6 +23,7 @@ import ( "github.com/pingcap/tidb/pkg/disttask/framework/mock" "github.com/pingcap/tidb/pkg/disttask/framework/planner" "github.com/pingcap/tidb/pkg/disttask/framework/storage" + "github.com/pingcap/tidb/pkg/disttask/framework/testutil" "github.com/pingcap/tidb/pkg/kv" "github.com/pingcap/tidb/pkg/testkit" "github.com/stretchr/testify/require" @@ -45,7 +46,7 @@ func TestPlanner(t *testing.T) { defer pool.Close() mgr := storage.NewTaskManager(pool) storage.SetTaskManager(mgr) - + testutil.WaitNodeRegistered(ctx, t) p := &planner.Planner{} pCtx := planner.PlanCtx{ Ctx: ctx, diff --git a/pkg/disttask/framework/scheduler/BUILD.bazel b/pkg/disttask/framework/scheduler/BUILD.bazel index dd50741076b86..af08b49314aec 100644 --- a/pkg/disttask/framework/scheduler/BUILD.bazel +++ b/pkg/disttask/framework/scheduler/BUILD.bazel @@ -44,13 +44,14 @@ go_test( "nodes_test.go", "rebalance_test.go", "scheduler_manager_test.go", + "scheduler_nokit_test.go", "scheduler_test.go", "slots_test.go", ], embed = [":scheduler"], flaky = True, race = "off", - shard_count = 25, + shard_count = 26, deps = [ "//pkg/config", "//pkg/disttask/framework/mock", diff --git a/pkg/disttask/framework/scheduler/interface.go b/pkg/disttask/framework/scheduler/interface.go index d0c792d1a10dc..f2508db179fca 100644 --- a/pkg/disttask/framework/scheduler/interface.go +++ b/pkg/disttask/framework/scheduler/interface.go @@ -59,7 +59,8 @@ type TaskManager interface { // we only consider pending/running subtasks, subtasks related to revert are // not considered. GetUsedSlotsOnNodes(ctx context.Context) (map[string]int, error) - GetSubtaskInStatesCnt(ctx context.Context, taskID int64, states ...proto.SubtaskState) (int64, error) + // GetSubtaskCntGroupByStates returns the count of subtasks of some step group by state. + GetSubtaskCntGroupByStates(ctx context.Context, taskID int64, step proto.Step) (map[proto.SubtaskState]int64, error) ResumeSubtasks(ctx context.Context, taskID int64) error CollectSubTaskError(ctx context.Context, taskID int64) ([]error, error) TransferSubTasks2History(ctx context.Context, taskID int64) error diff --git a/pkg/disttask/framework/scheduler/scheduler.go b/pkg/disttask/framework/scheduler/scheduler.go index 89c053ae2669d..a9531f62f82c7 100644 --- a/pkg/disttask/framework/scheduler/scheduler.go +++ b/pkg/disttask/framework/scheduler/scheduler.go @@ -240,12 +240,13 @@ func (s *BaseScheduler) onCancelling() error { // handle task in pausing state, cancel all running subtasks. func (s *BaseScheduler) onPausing() error { logutil.Logger(s.logCtx).Info("on pausing state", zap.Stringer("state", s.Task.State), zap.Int64("step", int64(s.Task.Step))) - cnt, err := s.taskMgr.GetSubtaskInStatesCnt(s.ctx, s.Task.ID, proto.SubtaskStateRunning, proto.SubtaskStatePending) + cntByStates, err := s.taskMgr.GetSubtaskCntGroupByStates(s.ctx, s.Task.ID, s.Task.Step) if err != nil { logutil.Logger(s.logCtx).Warn("check task failed", zap.Error(err)) return err } - if cnt == 0 { + runningPendingCnt := cntByStates[proto.SubtaskStateRunning] + cntByStates[proto.SubtaskStatePending] + if runningPendingCnt == 0 { logutil.Logger(s.logCtx).Info("all running subtasks paused, update the task to paused state") return s.updateTask(proto.TaskStatePaused, nil, RetrySQLTimes) } @@ -273,12 +274,12 @@ var TestSyncChan = make(chan struct{}) // handle task in resuming state. func (s *BaseScheduler) onResuming() error { logutil.Logger(s.logCtx).Info("on resuming state", zap.Stringer("state", s.Task.State), zap.Int64("step", int64(s.Task.Step))) - cnt, err := s.taskMgr.GetSubtaskInStatesCnt(s.ctx, s.Task.ID, proto.SubtaskStatePaused) + cntByStates, err := s.taskMgr.GetSubtaskCntGroupByStates(s.ctx, s.Task.ID, s.Task.Step) if err != nil { logutil.Logger(s.logCtx).Warn("check task failed", zap.Error(err)) return err } - if cnt == 0 { + if cntByStates[proto.SubtaskStatePaused] == 0 { // Finish the resuming process. logutil.Logger(s.logCtx).Info("all paused tasks converted to pending state, update the task to running state") err := s.updateTask(proto.TaskStateRunning, nil, RetrySQLTimes) @@ -294,12 +295,13 @@ func (s *BaseScheduler) onResuming() error { // handle task in reverting state, check all revert subtasks finishes. func (s *BaseScheduler) onReverting() error { logutil.Logger(s.logCtx).Debug("on reverting state", zap.Stringer("state", s.Task.State), zap.Int64("step", int64(s.Task.Step))) - cnt, err := s.taskMgr.GetSubtaskInStatesCnt(s.ctx, s.Task.ID, proto.SubtaskStateRevertPending, proto.SubtaskStateReverting) + cntByStates, err := s.taskMgr.GetSubtaskCntGroupByStates(s.ctx, s.Task.ID, s.Task.Step) if err != nil { logutil.Logger(s.logCtx).Warn("check task failed", zap.Error(err)) return err } - if cnt == 0 { + activeRevertCnt := cntByStates[proto.SubtaskStateRevertPending] + cntByStates[proto.SubtaskStateReverting] + if activeRevertCnt == 0 { if err = s.OnDone(s.ctx, s, s.Task); err != nil { return errors.Trace(err) } @@ -323,23 +325,23 @@ func (s *BaseScheduler) onRunning() error { logutil.Logger(s.logCtx).Debug("on running state", zap.Stringer("state", s.Task.State), zap.Int64("step", int64(s.Task.Step))) - subTaskErrs, err := s.taskMgr.CollectSubTaskError(s.ctx, s.Task.ID) - if err != nil { - logutil.Logger(s.logCtx).Warn("collect subtask error failed", zap.Error(err)) - return err - } - if len(subTaskErrs) > 0 { - logutil.Logger(s.logCtx).Warn("subtasks encounter errors") - return s.onErrHandlingStage(subTaskErrs) - } // check current step finishes. - cnt, err := s.taskMgr.GetSubtaskInStatesCnt(s.ctx, s.Task.ID, proto.SubtaskStatePending, proto.SubtaskStateRunning) + cntByStates, err := s.taskMgr.GetSubtaskCntGroupByStates(s.ctx, s.Task.ID, s.Task.Step) if err != nil { logutil.Logger(s.logCtx).Warn("check task failed", zap.Error(err)) return err } - - if cnt == 0 { + if cntByStates[proto.SubtaskStateFailed] > 0 || cntByStates[proto.SubtaskStateCanceled] > 0 { + subTaskErrs, err := s.taskMgr.CollectSubTaskError(s.ctx, s.Task.ID) + if err != nil { + logutil.Logger(s.logCtx).Warn("collect subtask error failed", zap.Error(err)) + return err + } + if len(subTaskErrs) > 0 { + logutil.Logger(s.logCtx).Warn("subtasks encounter errors") + return s.onErrHandlingStage(subTaskErrs) + } + } else if s.isStepSucceed(cntByStates) { return s.switch2NextStep() } @@ -727,6 +729,11 @@ func (s *BaseScheduler) WithNewTxn(ctx context.Context, fn func(se sessionctx.Co return s.taskMgr.WithNewTxn(ctx, fn) } +func (*BaseScheduler) isStepSucceed(cntByStates map[proto.SubtaskState]int64) bool { + _, ok := cntByStates[proto.SubtaskStateSucceed] + return len(cntByStates) == 0 || (len(cntByStates) == 1 && ok) +} + // IsCancelledErr checks if the error is a cancelled error. func IsCancelledErr(err error) bool { return strings.Contains(err.Error(), taskCancelMsg) diff --git a/pkg/disttask/framework/scheduler/scheduler_manager_test.go b/pkg/disttask/framework/scheduler/scheduler_manager_test.go index f35a32cd8056a..3fb7804244683 100644 --- a/pkg/disttask/framework/scheduler/scheduler_manager_test.go +++ b/pkg/disttask/framework/scheduler/scheduler_manager_test.go @@ -22,6 +22,7 @@ import ( "github.com/ngaut/pools" "github.com/pingcap/tidb/pkg/disttask/framework/mock" "github.com/pingcap/tidb/pkg/disttask/framework/proto" + "github.com/pingcap/tidb/pkg/disttask/framework/testutil" "github.com/pingcap/tidb/pkg/testkit" "github.com/stretchr/testify/require" "github.com/tikv/client-go/v2/util" @@ -43,9 +44,9 @@ func TestCleanUpRoutine(t *testing.T) { sch, mgr := MockSchedulerManager(t, ctrl, pool, getNumberExampleSchedulerExt(ctrl), mockCleanupRoutine) mockCleanupRoutine.EXPECT().CleanUp(gomock.Any(), gomock.Any()).Return(nil).AnyTimes() - require.NoError(t, mgr.StartManager(ctx, ":4000", "")) sch.Start() defer sch.Stop() + testutil.WaitNodeRegistered(ctx, t) taskID, err := mgr.CreateTask(ctx, "test", proto.TaskTypeExample, 1, nil) require.NoError(t, err) @@ -62,9 +63,9 @@ func TestCleanUpRoutine(t *testing.T) { checkSubtaskCnt := func(tasks []*proto.Task, taskID int64) { require.Eventually(t, func() bool { - cnt, err := mgr.GetSubtaskInStatesCnt(ctx, taskID, proto.SubtaskStatePending) + cntByStates, err := mgr.GetSubtaskCntGroupByStates(ctx, taskID, proto.StepOne) require.NoError(t, err) - return int64(subtaskCnt) == cnt + return int64(subtaskCnt) == cntByStates[proto.SubtaskStatePending] }, time.Second, 50*time.Millisecond) } diff --git a/pkg/disttask/framework/scheduler/scheduler_nokit_test.go b/pkg/disttask/framework/scheduler/scheduler_nokit_test.go new file mode 100644 index 0000000000000..e6320ded9ceec --- /dev/null +++ b/pkg/disttask/framework/scheduler/scheduler_nokit_test.go @@ -0,0 +1,40 @@ +// Copyright 2024 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package scheduler + +import ( + "testing" + + "github.com/pingcap/tidb/pkg/disttask/framework/proto" + "github.com/stretchr/testify/require" +) + +func TestSchedulerIsStepSucceed(t *testing.T) { + s := &BaseScheduler{} + require.True(t, s.isStepSucceed(nil)) + require.True(t, s.isStepSucceed(map[proto.SubtaskState]int64{})) + require.True(t, s.isStepSucceed(map[proto.SubtaskState]int64{ + proto.SubtaskStateSucceed: 1, + })) + for _, state := range []proto.SubtaskState{ + proto.SubtaskStateCanceled, + proto.SubtaskStateFailed, + proto.SubtaskStateReverting, + } { + require.False(t, s.isStepSucceed(map[proto.SubtaskState]int64{ + state: 1, + })) + } +} diff --git a/pkg/disttask/framework/scheduler/scheduler_test.go b/pkg/disttask/framework/scheduler/scheduler_test.go index 5400cb7d8c83f..d0eb57fc781e2 100644 --- a/pkg/disttask/framework/scheduler/scheduler_test.go +++ b/pkg/disttask/framework/scheduler/scheduler_test.go @@ -228,6 +228,8 @@ func TestTaskFailInManager(t *testing.T) { schManager.Start() defer schManager.Stop() + testutil.WaitNodeRegistered(ctx, t) + // unknown task type taskID, err := mgr.CreateTask(ctx, "test", "test-type", 1, nil) require.NoError(t, err) @@ -309,9 +311,9 @@ func checkDispatch(t *testing.T, taskCnt int, isSucc, isCancel, isSubtaskCancel, for i, taskID := range taskIDs { require.Equal(t, int64(i+1), tasks[i].ID) require.Eventually(t, func() bool { - cnt, err := mgr.GetSubtaskInStatesCnt(ctx, taskID, proto.SubtaskStatePending) + cntByStates, err := mgr.GetSubtaskCntGroupByStates(ctx, taskID, proto.StepOne) require.NoError(t, err) - return int64(subtaskCnt) == cnt + return int64(subtaskCnt) == cntByStates[proto.SubtaskStatePending] }, time.Second, 50*time.Millisecond) } } @@ -627,7 +629,7 @@ func TestManagerDispatchLoop(t *testing.T) { require.NoError(t, err) for _, s := range serverInfos { execID := disttaskutil.GenerateExecID(s) - testutil.InsertSubtask(t, taskMgr, 1000000, proto.StepOne, execID, []byte(""), proto.TaskStatePending, proto.TaskTypeExample, 16) + testutil.InsertSubtask(t, taskMgr, 1000000, proto.StepOne, execID, []byte(""), proto.SubtaskStatePending, proto.TaskTypeExample, 16) } concurrencies := []int{4, 6, 16, 2, 4, 4} waitChannels := make([]chan struct{}, len(concurrencies)) diff --git a/pkg/disttask/framework/storage/BUILD.bazel b/pkg/disttask/framework/storage/BUILD.bazel index 2e7b182b5200e..385ea4d06a807 100644 --- a/pkg/disttask/framework/storage/BUILD.bazel +++ b/pkg/disttask/framework/storage/BUILD.bazel @@ -38,7 +38,7 @@ go_test( embed = [":storage"], flaky = True, race = "on", - shard_count = 13, + shard_count = 14, deps = [ "//pkg/config", "//pkg/disttask/framework/proto", diff --git a/pkg/disttask/framework/storage/table_test.go b/pkg/disttask/framework/storage/table_test.go index 5f0d56d3b45f2..a83929f7ca571 100644 --- a/pkg/disttask/framework/storage/table_test.go +++ b/pkg/disttask/framework/storage/table_test.go @@ -377,11 +377,11 @@ func TestGetTopUnfinishedTasks(t *testing.T) { func TestGetUsedSlotsOnNodes(t *testing.T) { _, sm, ctx := testutil.InitTableTest(t) - testutil.InsertSubtask(t, sm, 1, proto.StepOne, "tidb-1", []byte(""), proto.TaskStateRunning, "test", 12) - testutil.InsertSubtask(t, sm, 1, proto.StepOne, "tidb-2", []byte(""), proto.TaskStatePending, "test", 12) - testutil.InsertSubtask(t, sm, 2, proto.StepOne, "tidb-2", []byte(""), proto.TaskStatePending, "test", 8) - testutil.InsertSubtask(t, sm, 3, proto.StepOne, "tidb-3", []byte(""), proto.TaskStatePending, "test", 8) - testutil.InsertSubtask(t, sm, 4, proto.StepOne, "tidb-3", []byte(""), proto.TaskStateFailed, "test", 8) + testutil.InsertSubtask(t, sm, 1, proto.StepOne, "tidb-1", []byte(""), proto.SubtaskStateRunning, "test", 12) + testutil.InsertSubtask(t, sm, 1, proto.StepOne, "tidb-2", []byte(""), proto.SubtaskStatePending, "test", 12) + testutil.InsertSubtask(t, sm, 2, proto.StepOne, "tidb-2", []byte(""), proto.SubtaskStatePending, "test", 8) + testutil.InsertSubtask(t, sm, 3, proto.StepOne, "tidb-3", []byte(""), proto.SubtaskStatePending, "test", 8) + testutil.InsertSubtask(t, sm, 4, proto.StepOne, "tidb-3", []byte(""), proto.SubtaskStateFailed, "test", 8) slotsOnNodes, err := sm.GetUsedSlotsOnNodes(ctx) require.NoError(t, err) require.Equal(t, map[string]int{ @@ -448,13 +448,13 @@ func TestSubTaskTable(t *testing.T) { require.NoError(t, err) require.Len(t, ids, 0) - cnt, err := sm.GetSubtaskInStatesCnt(ctx, 1, proto.SubtaskStatePending) + cntByStates, err := sm.GetSubtaskCntGroupByStates(ctx, 1, proto.StepInit) require.NoError(t, err) - require.Equal(t, int64(1), cnt) + require.Equal(t, int64(1), cntByStates[proto.SubtaskStatePending]) - cnt, err = sm.GetSubtaskInStatesCnt(ctx, 1, proto.SubtaskStatePending, proto.SubtaskStateRevertPending) + cntByStates, err = sm.GetSubtaskCntGroupByStates(ctx, 1, proto.StepInit) require.NoError(t, err) - require.Equal(t, int64(1), cnt) + require.Equal(t, int64(1), cntByStates[proto.SubtaskStatePending]+cntByStates[proto.SubtaskStateRevertPending]) ok, err := sm.HasSubtasksInStates(ctx, "tidb1", 1, proto.StepInit, proto.SubtaskStatePending) require.NoError(t, err) @@ -490,9 +490,9 @@ func TestSubTaskTable(t *testing.T) { require.Equal(t, proto.SubtaskStateReverting, subtask2.State) require.Greater(t, subtask2.UpdateTime, subtask.UpdateTime) - cnt, err = sm.GetSubtaskInStatesCnt(ctx, 1, proto.SubtaskStatePending) + cntByStates, err = sm.GetSubtaskCntGroupByStates(ctx, 1, proto.StepInit) require.NoError(t, err) - require.Equal(t, int64(0), cnt) + require.Equal(t, int64(0), cntByStates[proto.SubtaskStatePending]) ok, err = sm.HasSubtasksInStates(ctx, "tidb1", 1, proto.StepInit, proto.SubtaskStatePending) require.NoError(t, err) @@ -507,9 +507,9 @@ func TestSubTaskTable(t *testing.T) { testutil.CreateSubTask(t, sm, 2, proto.StepInit, "tidb1", []byte("test"), proto.TaskTypeExample, 11, true) - cnt, err = sm.GetSubtaskInStatesCnt(ctx, 2, proto.SubtaskStateRevertPending) + cntByStates, err = sm.GetSubtaskCntGroupByStates(ctx, 2, proto.StepInit) require.NoError(t, err) - require.Equal(t, int64(1), cnt) + require.Equal(t, int64(1), cntByStates[proto.SubtaskStateRevertPending]) subtasks, err := sm.GetSubtasksByStepAndState(ctx, 2, proto.StepInit, proto.TaskStateSucceed) require.NoError(t, err) @@ -670,9 +670,10 @@ func TestBothTaskAndSubTaskTable(t *testing.T) { require.Equal(t, proto.TaskTypeExample, subtask2.Type) require.Equal(t, []byte("m2"), subtask2.Meta) - cnt, err := sm.GetSubtaskInStatesCnt(ctx, 1, proto.SubtaskStatePending) + cntByStates, err := sm.GetSubtaskCntGroupByStates(ctx, 1, proto.StepInit) require.NoError(t, err) - require.Equal(t, int64(2), cnt) + require.Len(t, cntByStates, 1) + require.Equal(t, int64(2), cntByStates[proto.SubtaskStatePending]) // isSubTaskRevert: true prevState = task.State @@ -711,9 +712,9 @@ func TestBothTaskAndSubTaskTable(t *testing.T) { require.Equal(t, proto.TaskTypeExample, subtask2.Type) require.Equal(t, []byte("m4"), subtask2.Meta) - cnt, err = sm.GetSubtaskInStatesCnt(ctx, 1, proto.SubtaskStateRevertPending) + cntByStates, err = sm.GetSubtaskCntGroupByStates(ctx, 1, proto.StepInit) require.NoError(t, err) - require.Equal(t, int64(2), cnt) + require.Equal(t, int64(2), cntByStates[proto.SubtaskStateRevertPending]) // test transactional require.NoError(t, sm.DeleteSubtasksByTaskID(ctx, 1)) @@ -731,9 +732,30 @@ func TestBothTaskAndSubTaskTable(t *testing.T) { require.NoError(t, err) require.Equal(t, proto.TaskStateReverting, task.State) - cnt, err = sm.GetSubtaskInStatesCnt(ctx, 1, proto.SubtaskStateRevertPending) + cntByStates, err = sm.GetSubtaskCntGroupByStates(ctx, 1, proto.StepInit) require.NoError(t, err) - require.Equal(t, int64(0), cnt) + require.Equal(t, int64(0), cntByStates[proto.SubtaskStateRevertPending]) +} + +func TestGetSubtaskCntByStates(t *testing.T) { + _, sm, ctx := testutil.InitTableTest(t) + testutil.InsertSubtask(t, sm, 1, proto.StepOne, "tidb1", nil, proto.SubtaskStatePending, "test", 1) + testutil.InsertSubtask(t, sm, 1, proto.StepOne, "tidb1", nil, proto.SubtaskStatePending, "test", 1) + testutil.InsertSubtask(t, sm, 1, proto.StepOne, "tidb1", nil, proto.SubtaskStateRunning, "test", 1) + testutil.InsertSubtask(t, sm, 1, proto.StepOne, "tidb1", nil, proto.SubtaskStateSucceed, "test", 1) + testutil.InsertSubtask(t, sm, 1, proto.StepOne, "tidb1", nil, proto.SubtaskStateFailed, "test", 1) + testutil.InsertSubtask(t, sm, 1, proto.StepTwo, "tidb1", nil, proto.SubtaskStateFailed, "test", 1) + cntByStates, err := sm.GetSubtaskCntGroupByStates(ctx, 1, proto.StepOne) + require.NoError(t, err) + require.Len(t, cntByStates, 4) + require.Equal(t, int64(2), cntByStates[proto.SubtaskStatePending]) + require.Equal(t, int64(1), cntByStates[proto.SubtaskStateRunning]) + require.Equal(t, int64(1), cntByStates[proto.SubtaskStateSucceed]) + require.Equal(t, int64(1), cntByStates[proto.SubtaskStateFailed]) + cntByStates, err = sm.GetSubtaskCntGroupByStates(ctx, 1, proto.StepTwo) + require.NoError(t, err) + require.Len(t, cntByStates, 1) + require.Equal(t, int64(1), cntByStates[proto.SubtaskStateFailed]) } func TestDistFrameworkMeta(t *testing.T) { @@ -920,26 +942,26 @@ func TestPauseAndResume(t *testing.T) { testutil.CreateSubTask(t, sm, 1, proto.StepInit, "tidb1", []byte("test"), proto.TaskTypeExample, 11, false) // 1.1 pause all subtasks. require.NoError(t, sm.PauseSubtasks(ctx, "tidb1", 1)) - cnt, err := sm.GetSubtaskInStatesCnt(ctx, 1, proto.SubtaskStatePaused) + cntByStates, err := sm.GetSubtaskCntGroupByStates(ctx, 1, proto.StepInit) require.NoError(t, err) - require.Equal(t, int64(3), cnt) + require.Equal(t, int64(3), cntByStates[proto.SubtaskStatePaused]) // 1.2 resume all subtasks. require.NoError(t, sm.ResumeSubtasks(ctx, 1)) - cnt, err = sm.GetSubtaskInStatesCnt(ctx, 1, proto.SubtaskStatePending) + cntByStates, err = sm.GetSubtaskCntGroupByStates(ctx, 1, proto.StepInit) require.NoError(t, err) - require.Equal(t, int64(3), cnt) + require.Equal(t, int64(3), cntByStates[proto.SubtaskStatePending]) // 2.1 pause 2 subtasks. require.NoError(t, sm.UpdateSubtaskStateAndError(ctx, "tidb1", 1, proto.SubtaskStateSucceed, nil)) require.NoError(t, sm.PauseSubtasks(ctx, "tidb1", 1)) - cnt, err = sm.GetSubtaskInStatesCnt(ctx, 1, proto.SubtaskStatePaused) + cntByStates, err = sm.GetSubtaskCntGroupByStates(ctx, 1, proto.StepInit) require.NoError(t, err) - require.Equal(t, int64(2), cnt) + require.Equal(t, int64(2), cntByStates[proto.SubtaskStatePaused]) // 2.2 resume 2 subtasks. require.NoError(t, sm.ResumeSubtasks(ctx, 1)) - cnt, err = sm.GetSubtaskInStatesCnt(ctx, 1, proto.SubtaskStatePending) + cntByStates, err = sm.GetSubtaskCntGroupByStates(ctx, 1, proto.StepInit) require.NoError(t, err) - require.Equal(t, int64(2), cnt) + require.Equal(t, int64(2), cntByStates[proto.SubtaskStatePending]) } func TestCancelAndExecIdChanged(t *testing.T) { diff --git a/pkg/disttask/framework/storage/task_table.go b/pkg/disttask/framework/storage/task_table.go index f9b7fa141c082..772f18ad86d10 100644 --- a/pkg/disttask/framework/storage/task_table.go +++ b/pkg/disttask/framework/storage/task_table.go @@ -605,26 +605,32 @@ func (stm *TaskManager) UpdateSubtaskRowCount(ctx context.Context, subtaskID int return err } -// GetSubtaskInStatesCnt gets the subtask count in the states. -func (stm *TaskManager) GetSubtaskInStatesCnt(ctx context.Context, taskID int64, states ...proto.SubtaskState) (int64, error) { - args := []interface{}{taskID} - for _, state := range states { - args = append(args, state) - } - rs, err := stm.executeSQLWithNewSession(ctx, `select count(*) from mysql.tidb_background_subtask - where task_key = %? and state in (`+strings.Repeat("%?,", len(states)-1)+"%?)", args...) +// GetSubtaskCntGroupByStates gets the subtask count by states. +func (stm *TaskManager) GetSubtaskCntGroupByStates(ctx context.Context, taskID int64, step proto.Step) (map[proto.SubtaskState]int64, error) { + rs, err := stm.executeSQLWithNewSession(ctx, ` + select state, count(*) + from mysql.tidb_background_subtask + where task_key = %? and step = %? + group by state`, + taskID, step) if err != nil { - return 0, err + return nil, err } - return rs[0].GetInt64(0), nil + res := make(map[proto.SubtaskState]int64, len(rs)) + for _, r := range rs { + state := proto.SubtaskState(r.GetString(0)) + res[state] = r.GetInt64(1) + } + + return res, nil } // CollectSubTaskError collects the subtask error. func (stm *TaskManager) CollectSubTaskError(ctx context.Context, taskID int64) ([]error, error) { rs, err := stm.executeSQLWithNewSession(ctx, `select error from mysql.tidb_background_subtask - where task_key = %? AND state in (%?, %?)`, taskID, proto.TaskStateFailed, proto.TaskStateCanceled) + where task_key = %? AND state in (%?, %?)`, taskID, proto.SubtaskStateFailed, proto.SubtaskStateCanceled) if err != nil { return nil, err } diff --git a/pkg/disttask/framework/taskexecutor/task_executor.go b/pkg/disttask/framework/taskexecutor/task_executor.go index 8e4619d85d0e9..fd585a259daef 100644 --- a/pkg/disttask/framework/taskexecutor/task_executor.go +++ b/pkg/disttask/framework/taskexecutor/task_executor.go @@ -96,7 +96,7 @@ func (s *BaseTaskExecutor) startCancelCheck(ctx context.Context, wg *sync.WaitGr for { select { case <-ctx.Done(): - logutil.Logger(s.logCtx).Info("taskExecutor exits", zap.Error(ctx.Err())) + logutil.Logger(s.logCtx).Info("task executor exits") return case <-ticker.C: canceled, err := s.taskTable.IsTaskExecutorCanceled(ctx, s.id, s.taskID) diff --git a/pkg/disttask/framework/testutil/context.go b/pkg/disttask/framework/testutil/context.go index 34b38ba27c80c..0c7ac9c58ff81 100644 --- a/pkg/disttask/framework/testutil/context.go +++ b/pkg/disttask/framework/testutil/context.go @@ -54,14 +54,7 @@ func InitTestContext(t *testing.T, nodeNum int) (context.Context, *gomock.Contro }) executionContext := testkit.NewDistExecutionContext(t, nodeNum) - // wait until some node is registered. - require.Eventually(t, func() bool { - taskMgr, err := storage.GetTaskManager() - require.NoError(t, err) - nodes, err := taskMgr.GetAllNodes(ctx) - require.NoError(t, err) - return len(nodes) > 0 - }, 5*time.Second, 100*time.Millisecond) + WaitNodeRegistered(ctx, t) testCtx := &TestContext{ subtasksHasRun: make(map[string]map[int64]struct{}), } @@ -93,3 +86,15 @@ func (c *TestContext) CollectedSubtaskCnt(taskID int64, step proto.Step) int { func getTaskStepKey(id int64, step proto.Step) string { return fmt.Sprintf("%d/%d", id, step) } + +// WaitNodeRegistered waits until some node is registered. +func WaitNodeRegistered(ctx context.Context, t *testing.T) { + // wait until some node is registered. + require.Eventually(t, func() bool { + taskMgr, err := storage.GetTaskManager() + require.NoError(t, err) + nodes, err := taskMgr.GetAllNodes(ctx) + require.NoError(t, err) + return len(nodes) > 0 + }, 5*time.Second, 100*time.Millisecond) +} diff --git a/pkg/disttask/framework/testutil/task_util.go b/pkg/disttask/framework/testutil/task_util.go index eaf4d31ecad1f..0f8d92deac04e 100644 --- a/pkg/disttask/framework/testutil/task_util.go +++ b/pkg/disttask/framework/testutil/task_util.go @@ -29,15 +29,15 @@ import ( // CreateSubTask adds a new task to subtask table. // used for testing. func CreateSubTask(t *testing.T, gm *storage.TaskManager, taskID int64, step proto.Step, execID string, meta []byte, tp proto.TaskType, concurrency int, isRevert bool) { - state := proto.TaskStatePending + state := proto.SubtaskStatePending if isRevert { - state = proto.TaskStateRevertPending + state = proto.SubtaskStateRevertPending } InsertSubtask(t, gm, taskID, step, execID, meta, state, tp, concurrency) } // InsertSubtask adds a new subtask of any state to subtask table. -func InsertSubtask(t *testing.T, gm *storage.TaskManager, taskID int64, step proto.Step, execID string, meta []byte, state proto.TaskState, tp proto.TaskType, concurrency int) { +func InsertSubtask(t *testing.T, gm *storage.TaskManager, taskID int64, step proto.Step, execID string, meta []byte, state proto.SubtaskState, tp proto.TaskType, concurrency int) { ctx := context.Background() ctx = util.WithInternalSourceType(ctx, "table_test") require.NoError(t, gm.WithNewSession(func(se sessionctx.Context) error {