From 386bb43fea146b1cfc1bfb16cfb766ca99762e02 Mon Sep 17 00:00:00 2001 From: D3Hunter Date: Tue, 9 Jan 2024 17:06:24 +0800 Subject: [PATCH 1/3] disttask: consider node load when schedule/balance subtasks (#49758) ref pingcap/tidb#49008 --- pkg/ddl/backfilling_dist_scheduler.go | 18 +- pkg/ddl/ddl.go | 4 +- pkg/disttask/framework/mock/BUILD.bazel | 1 + pkg/disttask/framework/mock/scheduler_mock.go | 122 +++- pkg/disttask/framework/scheduler/BUILD.bazel | 6 +- pkg/disttask/framework/scheduler/balancer.go | 219 +++++++ .../framework/scheduler/balancer_test.go | 433 ++++++++++++++ pkg/disttask/framework/scheduler/interface.go | 19 +- pkg/disttask/framework/scheduler/main_test.go | 4 - .../framework/scheduler/mock/BUILD.bazel | 2 +- .../scheduler/mock/scheduler_mock.go | 6 +- pkg/disttask/framework/scheduler/nodes.go | 11 +- .../framework/scheduler/rebalance_test.go | 560 ------------------ pkg/disttask/framework/scheduler/scheduler.go | 378 +++++------- .../framework/scheduler/scheduler_manager.go | 64 +- .../scheduler/scheduler_manager_test.go | 6 +- .../scheduler/scheduler_nokit_test.go | 179 ++++++ .../framework/scheduler/scheduler_test.go | 122 +--- pkg/disttask/framework/scheduler/slots.go | 74 ++- .../framework/scheduler/slots_test.go | 54 +- pkg/disttask/framework/storage/BUILD.bazel | 2 +- pkg/disttask/framework/storage/table_test.go | 41 +- pkg/disttask/framework/storage/task_table.go | 95 +-- pkg/disttask/framework/taskexecutor/slot.go | 1 + .../framework/testutil/disttest_util.go | 4 +- .../framework/testutil/scheduler_util.go | 11 +- pkg/disttask/importinto/scheduler.go | 29 +- pkg/disttask/importinto/scheduler_test.go | 16 +- 28 files changed, 1413 insertions(+), 1068 deletions(-) create mode 100644 pkg/disttask/framework/scheduler/balancer.go create mode 100644 pkg/disttask/framework/scheduler/balancer_test.go delete mode 100644 pkg/disttask/framework/scheduler/rebalance_test.go diff --git a/pkg/ddl/backfilling_dist_scheduler.go b/pkg/ddl/backfilling_dist_scheduler.go index d7c71113d00ae..6ee22fdaa33c2 100644 --- a/pkg/ddl/backfilling_dist_scheduler.go +++ b/pkg/ddl/backfilling_dist_scheduler.go @@ -30,6 +30,7 @@ import ( "github.com/pingcap/tidb/pkg/ddl/ingest" "github.com/pingcap/tidb/pkg/disttask/framework/proto" "github.com/pingcap/tidb/pkg/disttask/framework/scheduler" + diststorage "github.com/pingcap/tidb/pkg/disttask/framework/storage" "github.com/pingcap/tidb/pkg/kv" "github.com/pingcap/tidb/pkg/meta" "github.com/pingcap/tidb/pkg/parser/model" @@ -67,7 +68,7 @@ func (*BackfillingSchedulerExt) OnTick(_ context.Context, _ *proto.Task) { // OnNextSubtasksBatch generate batch of next step's plan. func (sch *BackfillingSchedulerExt) OnNextSubtasksBatch( ctx context.Context, - taskHandle scheduler.TaskHandle, + taskHandle diststorage.TaskHandle, task *proto.Task, execIDs []string, nextStep proto.Step, @@ -175,7 +176,7 @@ func skipMergeSort(stats []external.MultipleFilesStat) bool { } // OnDone implements scheduler.Extension interface. -func (*BackfillingSchedulerExt) OnDone(_ context.Context, _ scheduler.TaskHandle, _ *proto.Task) error { +func (*BackfillingSchedulerExt) OnDone(_ context.Context, _ diststorage.TaskHandle, _ *proto.Task) error { return nil } @@ -195,11 +196,10 @@ type LitBackfillScheduler struct { d *ddl } -func newLitBackfillScheduler(ctx context.Context, d *ddl, taskMgr scheduler.TaskManager, - nodeMgr *scheduler.NodeManager, task *proto.Task) scheduler.Scheduler { +func newLitBackfillScheduler(ctx context.Context, d *ddl, task *proto.Task, param scheduler.Param) scheduler.Scheduler { sch := LitBackfillScheduler{ d: d, - BaseScheduler: scheduler.NewBaseScheduler(ctx, taskMgr, nodeMgr, task), + BaseScheduler: scheduler.NewBaseScheduler(ctx, task, param), } return &sch } @@ -207,7 +207,7 @@ func newLitBackfillScheduler(ctx context.Context, d *ddl, taskMgr scheduler.Task // Init implements BaseScheduler interface. func (sch *LitBackfillScheduler) Init() (err error) { taskMeta := &BackfillTaskMeta{} - if err = json.Unmarshal(sch.BaseScheduler.Task.Meta, taskMeta); err != nil { + if err = json.Unmarshal(sch.BaseScheduler.GetTask().Meta, taskMeta); err != nil { return errors.Annotate(err, "unmarshal task meta failed") } sch.BaseScheduler.Extension = &BackfillingSchedulerExt{ @@ -333,7 +333,7 @@ func calculateRegionBatch(totalRegionCnt int, instanceCnt int, useLocalDisk bool func generateGlobalSortIngestPlan( ctx context.Context, - taskHandle scheduler.TaskHandle, + taskHandle diststorage.TaskHandle, task *proto.Task, jobID int64, cloudStorageURI string, @@ -408,7 +408,7 @@ func generateGlobalSortIngestPlan( } func generateMergePlan( - taskHandle scheduler.TaskHandle, + taskHandle diststorage.TaskHandle, task *proto.Task, logger *zap.Logger, ) ([][]byte, error) { @@ -507,7 +507,7 @@ func getRangeSplitter( } func getSummaryFromLastStep( - taskHandle scheduler.TaskHandle, + taskHandle diststorage.TaskHandle, gTaskID int64, step proto.Step, ) (startKey, endKey kv.Key, totalKVSize uint64, multiFileStat []external.MultipleFilesStat, err error) { diff --git a/pkg/ddl/ddl.go b/pkg/ddl/ddl.go index 8ec723a8d0229..36156d9c27500 100644 --- a/pkg/ddl/ddl.go +++ b/pkg/ddl/ddl.go @@ -749,8 +749,8 @@ func newDDL(ctx context.Context, options ...Option) *ddl { ) scheduler.RegisterSchedulerFactory(proto.Backfill, - func(ctx context.Context, taskMgr scheduler.TaskManager, nodeMgr *scheduler.NodeManager, task *proto.Task) scheduler.Scheduler { - return newLitBackfillScheduler(ctx, d, taskMgr, nodeMgr, task) + func(ctx context.Context, task *proto.Task, param scheduler.Param) scheduler.Scheduler { + return newLitBackfillScheduler(ctx, d, task, param) }) scheduler.RegisterSchedulerCleanUpFactory(proto.Backfill, newBackfillCleanUpS3) // Register functions for enable/disable ddl when changing system variable `tidb_enable_ddl`. diff --git a/pkg/disttask/framework/mock/BUILD.bazel b/pkg/disttask/framework/mock/BUILD.bazel index fd3ab164569da..20b7f0379b257 100644 --- a/pkg/disttask/framework/mock/BUILD.bazel +++ b/pkg/disttask/framework/mock/BUILD.bazel @@ -12,6 +12,7 @@ go_library( deps = [ "//pkg/disttask/framework/planner", "//pkg/disttask/framework/proto", + "//pkg/disttask/framework/storage", "//pkg/disttask/framework/taskexecutor/execute", "//pkg/sessionctx", "@org_uber_go_mock//gomock", diff --git a/pkg/disttask/framework/mock/scheduler_mock.go b/pkg/disttask/framework/mock/scheduler_mock.go index f6c6ab30ae80e..9ef2f9f5dc955 100644 --- a/pkg/disttask/framework/mock/scheduler_mock.go +++ b/pkg/disttask/framework/mock/scheduler_mock.go @@ -13,6 +13,7 @@ import ( reflect "reflect" proto "github.com/pingcap/tidb/pkg/disttask/framework/proto" + storage "github.com/pingcap/tidb/pkg/disttask/framework/storage" sessionctx "github.com/pingcap/tidb/pkg/sessionctx" gomock "go.uber.org/mock/gomock" ) @@ -52,6 +53,49 @@ func (mr *MockSchedulerMockRecorder) Close() *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockScheduler)(nil).Close)) } +// GetEligibleInstances mocks base method. +func (m *MockScheduler) GetEligibleInstances(arg0 context.Context, arg1 *proto.Task) ([]string, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetEligibleInstances", arg0, arg1) + ret0, _ := ret[0].([]string) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetEligibleInstances indicates an expected call of GetEligibleInstances. +func (mr *MockSchedulerMockRecorder) GetEligibleInstances(arg0, arg1 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetEligibleInstances", reflect.TypeOf((*MockScheduler)(nil).GetEligibleInstances), arg0, arg1) +} + +// GetNextStep mocks base method. +func (m *MockScheduler) GetNextStep(arg0 *proto.Task) proto.Step { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetNextStep", arg0) + ret0, _ := ret[0].(proto.Step) + return ret0 +} + +// GetNextStep indicates an expected call of GetNextStep. +func (mr *MockSchedulerMockRecorder) GetNextStep(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetNextStep", reflect.TypeOf((*MockScheduler)(nil).GetNextStep), arg0) +} + +// GetTask mocks base method. +func (m *MockScheduler) GetTask() *proto.Task { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetTask") + ret0, _ := ret[0].(*proto.Task) + return ret0 +} + +// GetTask indicates an expected call of GetTask. +func (mr *MockSchedulerMockRecorder) GetTask() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTask", reflect.TypeOf((*MockScheduler)(nil).GetTask)) +} + // Init mocks base method. func (m *MockScheduler) Init() error { m.ctrl.T.Helper() @@ -66,6 +110,61 @@ func (mr *MockSchedulerMockRecorder) Init() *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Init", reflect.TypeOf((*MockScheduler)(nil).Init)) } +// IsRetryableErr mocks base method. +func (m *MockScheduler) IsRetryableErr(arg0 error) bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "IsRetryableErr", arg0) + ret0, _ := ret[0].(bool) + return ret0 +} + +// IsRetryableErr indicates an expected call of IsRetryableErr. +func (mr *MockSchedulerMockRecorder) IsRetryableErr(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsRetryableErr", reflect.TypeOf((*MockScheduler)(nil).IsRetryableErr), arg0) +} + +// OnDone mocks base method. +func (m *MockScheduler) OnDone(arg0 context.Context, arg1 storage.TaskHandle, arg2 *proto.Task) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "OnDone", arg0, arg1, arg2) + ret0, _ := ret[0].(error) + return ret0 +} + +// OnDone indicates an expected call of OnDone. +func (mr *MockSchedulerMockRecorder) OnDone(arg0, arg1, arg2 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnDone", reflect.TypeOf((*MockScheduler)(nil).OnDone), arg0, arg1, arg2) +} + +// OnNextSubtasksBatch mocks base method. +func (m *MockScheduler) OnNextSubtasksBatch(arg0 context.Context, arg1 storage.TaskHandle, arg2 *proto.Task, arg3 []string, arg4 proto.Step) ([][]byte, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "OnNextSubtasksBatch", arg0, arg1, arg2, arg3, arg4) + ret0, _ := ret[0].([][]byte) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// OnNextSubtasksBatch indicates an expected call of OnNextSubtasksBatch. +func (mr *MockSchedulerMockRecorder) OnNextSubtasksBatch(arg0, arg1, arg2, arg3, arg4 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnNextSubtasksBatch", reflect.TypeOf((*MockScheduler)(nil).OnNextSubtasksBatch), arg0, arg1, arg2, arg3, arg4) +} + +// OnTick mocks base method. +func (m *MockScheduler) OnTick(arg0 context.Context, arg1 *proto.Task) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "OnTick", arg0, arg1) +} + +// OnTick indicates an expected call of OnTick. +func (mr *MockSchedulerMockRecorder) OnTick(arg0, arg1 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnTick", reflect.TypeOf((*MockScheduler)(nil).OnTick), arg0, arg1) +} + // ScheduleTask mocks base method. func (m *MockScheduler) ScheduleTask() { m.ctrl.T.Helper() @@ -209,6 +308,21 @@ func (mr *MockTaskManagerMockRecorder) GCSubtasks(arg0 any) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GCSubtasks", reflect.TypeOf((*MockTaskManager)(nil).GCSubtasks), arg0) } +// GetActiveSubtasks mocks base method. +func (m *MockTaskManager) GetActiveSubtasks(arg0 context.Context, arg1 int64) ([]*proto.Subtask, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetActiveSubtasks", arg0, arg1) + ret0, _ := ret[0].([]*proto.Subtask) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetActiveSubtasks indicates an expected call of GetActiveSubtasks. +func (mr *MockTaskManagerMockRecorder) GetActiveSubtasks(arg0, arg1 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetActiveSubtasks", reflect.TypeOf((*MockTaskManager)(nil).GetActiveSubtasks), arg0, arg1) +} + // GetAllNodes mocks base method. func (m *MockTaskManager) GetAllNodes(arg0 context.Context) ([]proto.ManagedNode, error) { m.ctrl.T.Helper() @@ -507,17 +621,17 @@ func (mr *MockTaskManagerMockRecorder) TransferTasks2History(arg0, arg1 any) *go } // UpdateSubtasksExecIDs mocks base method. -func (m *MockTaskManager) UpdateSubtasksExecIDs(arg0 context.Context, arg1 int64, arg2 []*proto.Subtask) error { +func (m *MockTaskManager) UpdateSubtasksExecIDs(arg0 context.Context, arg1 []*proto.Subtask) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "UpdateSubtasksExecIDs", arg0, arg1, arg2) + ret := m.ctrl.Call(m, "UpdateSubtasksExecIDs", arg0, arg1) ret0, _ := ret[0].(error) return ret0 } // UpdateSubtasksExecIDs indicates an expected call of UpdateSubtasksExecIDs. -func (mr *MockTaskManagerMockRecorder) UpdateSubtasksExecIDs(arg0, arg1, arg2 any) *gomock.Call { +func (mr *MockTaskManagerMockRecorder) UpdateSubtasksExecIDs(arg0, arg1 any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateSubtasksExecIDs", reflect.TypeOf((*MockTaskManager)(nil).UpdateSubtasksExecIDs), arg0, arg1, arg2) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateSubtasksExecIDs", reflect.TypeOf((*MockTaskManager)(nil).UpdateSubtasksExecIDs), arg0, arg1) } // UpdateTaskAndAddSubTasks mocks base method. diff --git a/pkg/disttask/framework/scheduler/BUILD.bazel b/pkg/disttask/framework/scheduler/BUILD.bazel index af08b49314aec..6dd3e40fbeffb 100644 --- a/pkg/disttask/framework/scheduler/BUILD.bazel +++ b/pkg/disttask/framework/scheduler/BUILD.bazel @@ -3,6 +3,7 @@ load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") go_library( name = "scheduler", srcs = [ + "balancer.go", "interface.go", "nodes.go", "scheduler.go", @@ -40,9 +41,9 @@ go_test( name = "scheduler_test", timeout = "short", srcs = [ + "balancer_test.go", "main_test.go", "nodes_test.go", - "rebalance_test.go", "scheduler_manager_test.go", "scheduler_nokit_test.go", "scheduler_test.go", @@ -51,7 +52,7 @@ go_test( embed = [":scheduler"], flaky = True, race = "off", - shard_count = 26, + shard_count = 29, deps = [ "//pkg/config", "//pkg/disttask/framework/mock", @@ -71,6 +72,7 @@ go_test( "@com_github_ngaut_pools//:pools", "@com_github_pingcap_errors//:errors", "@com_github_pingcap_failpoint//:failpoint", + "@com_github_stretchr_testify//assert", "@com_github_stretchr_testify//require", "@com_github_tikv_client_go_v2//util", "@org_uber_go_goleak//:goleak", diff --git a/pkg/disttask/framework/scheduler/balancer.go b/pkg/disttask/framework/scheduler/balancer.go new file mode 100644 index 0000000000000..fcba934d14943 --- /dev/null +++ b/pkg/disttask/framework/scheduler/balancer.go @@ -0,0 +1,219 @@ +// Copyright 2023 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package scheduler + +import ( + "context" + "time" + + "github.com/pingcap/errors" + "github.com/pingcap/tidb/br/pkg/lightning/log" + "github.com/pingcap/tidb/pkg/disttask/framework/proto" + "github.com/pingcap/tidb/pkg/util/logutil" + "go.uber.org/zap" +) + +var ( + // balanceCheckInterval is the interval to check whether we need to balance the subtasks. + balanceCheckInterval = 3 * checkTaskFinishedInterval +) + +// balancer is used to balance subtasks on managed nodes +// it handles 2 cases: +// - managed node scale in/out. +// - nodes might run subtasks in different speed, the amount of data processed by subtasks varies, cause the subtasks are not balanced. +// +// we will try balance in task order, subtasks will be scheduled to the node with +// enough slots to run them, if there is no such node, we will skip balance for +// the task and try next one. +type balancer struct { + Param + + // a helper temporary map to record the used slots of each node during balance + // to avoid passing it around. + currUsedSlots map[string]int +} + +func newBalancer(param Param) *balancer { + return &balancer{ + Param: param, + currUsedSlots: make(map[string]int), + } +} + +func (b *balancer) balanceLoop(ctx context.Context, sm *Manager) { + for { + select { + case <-ctx.Done(): + return + case <-time.After(balanceCheckInterval): + } + b.balance(ctx, sm) + } +} + +func (b *balancer) balance(ctx context.Context, sm *Manager) { + // we will use currUsedSlots to calculate adjusted eligible nodes during balance, + // it's initial value depends on the managed nodes, to have a consistent view, + // DO NOT call getManagedNodes twice during 1 balance. + managedNodes := b.nodeMgr.getManagedNodes() + b.currUsedSlots = make(map[string]int, len(managedNodes)) + for _, n := range managedNodes { + b.currUsedSlots[n] = 0 + } + + schedulers := sm.getSchedulers() + for _, sch := range schedulers { + if err := b.balanceSubtasks(ctx, sch, managedNodes); err != nil { + logutil.Logger(ctx).Warn("failed to balance subtasks", + zap.Int64("task-id", sch.GetTask().ID), log.ShortError(err)) + return + } + } +} + +func (b *balancer) balanceSubtasks(ctx context.Context, sch Scheduler, managedNodes []string) error { + task := sch.GetTask() + eligibleNodes, err := getEligibleNodes(ctx, sch, managedNodes) + if err != nil { + return err + } + if len(eligibleNodes) == 0 { + return errors.New("no eligible nodes to balance subtasks") + } + return b.doBalanceSubtasks(ctx, task.ID, eligibleNodes) +} + +func (b *balancer) doBalanceSubtasks(ctx context.Context, taskID int64, eligibleNodes []string) (err error) { + subtasks, err := b.taskMgr.GetActiveSubtasks(ctx, taskID) + if err != nil { + return err + } + if len(subtasks) == 0 { + return nil + } + + // balance subtasks only to nodes with enough slots, from the view of all + // managed nodes, subtasks of task might not be balanced. + adjustedNodes := filterNodesWithEnoughSlots(b.currUsedSlots, b.slotMgr.getCapacity(), + eligibleNodes, subtasks[0].Concurrency) + if len(adjustedNodes) == 0 { + // no node has enough slots to run the subtasks, skip balance and skip + // update used slots. + return nil + } + adjustedNodeMap := make(map[string]struct{}, len(adjustedNodes)) + for _, n := range adjustedNodes { + adjustedNodeMap[n] = struct{}{} + } + + defer func() { + if err == nil { + b.updateUsedNodes(subtasks) + } + }() + + averageSubtaskCnt := len(subtasks) / len(adjustedNodes) + averageSubtaskRemainder := len(subtasks) - averageSubtaskCnt*len(adjustedNodes) + executorSubtasks := make(map[string][]*proto.Subtask, len(adjustedNodes)) + executorPendingCnts := make(map[string]int, len(adjustedNodes)) + for _, node := range adjustedNodes { + executorSubtasks[node] = make([]*proto.Subtask, 0, averageSubtaskCnt+1) + } + for _, subtask := range subtasks { + // put running subtask in the front of slice. + // if subtask fail-over, it's possible that there are multiple running + // subtasks for one task executor. + if subtask.State == proto.SubtaskStateRunning { + executorSubtasks[subtask.ExecID] = append([]*proto.Subtask{subtask}, executorSubtasks[subtask.ExecID]...) + } else { + executorSubtasks[subtask.ExecID] = append(executorSubtasks[subtask.ExecID], subtask) + executorPendingCnts[subtask.ExecID]++ + } + } + + subtasksNeedSchedule := make([]*proto.Subtask, 0) + remainder := averageSubtaskRemainder + executorWithOneMoreSubtask := make(map[string]struct{}, remainder) + for node, sts := range executorSubtasks { + if _, ok := adjustedNodeMap[node]; !ok { + // dead node or not have enough slots + subtasksNeedSchedule = append(subtasksNeedSchedule, sts...) + delete(executorSubtasks, node) + continue + } + if remainder > 0 { + // first remainder nodes will get 1 more subtask. + if len(sts) >= averageSubtaskCnt+1 { + needScheduleCnt := len(sts) - (averageSubtaskCnt + 1) + // running subtasks are never balanced. + needScheduleCnt = min(executorPendingCnts[node], needScheduleCnt) + subtasksNeedSchedule = append(subtasksNeedSchedule, sts[len(sts)-needScheduleCnt:]...) + executorSubtasks[node] = sts[:len(sts)-needScheduleCnt] + + executorWithOneMoreSubtask[node] = struct{}{} + remainder-- + } + } else if len(sts) > averageSubtaskCnt { + // running subtasks are never balanced. + cnt := min(executorPendingCnts[node], len(sts)-averageSubtaskCnt) + subtasksNeedSchedule = append(subtasksNeedSchedule, sts[len(sts)-cnt:]...) + executorSubtasks[node] = sts[:len(sts)-cnt] + } + } + if len(subtasksNeedSchedule) == 0 { + return nil + } + + for i := 0; i < len(adjustedNodes) && remainder > 0; i++ { + if _, ok := executorWithOneMoreSubtask[adjustedNodes[i]]; !ok { + executorWithOneMoreSubtask[adjustedNodes[i]] = struct{}{} + remainder-- + } + } + + fillIdx := 0 + for _, node := range adjustedNodes { + sts := executorSubtasks[node] + targetSubtaskCnt := averageSubtaskCnt + if _, ok := executorWithOneMoreSubtask[node]; ok { + targetSubtaskCnt = averageSubtaskCnt + 1 + } + for i := len(sts); i < targetSubtaskCnt && fillIdx < len(subtasksNeedSchedule); i++ { + subtasksNeedSchedule[fillIdx].ExecID = node + fillIdx++ + } + } + + if err = b.taskMgr.UpdateSubtasksExecIDs(ctx, subtasksNeedSchedule); err != nil { + return err + } + logutil.BgLogger().Info("balance subtasks", zap.Stringers("subtasks", subtasksNeedSchedule)) + return nil +} + +func (b *balancer) updateUsedNodes(subtasks []*proto.Subtask) { + used := make(map[string]int, len(b.currUsedSlots)) + // see slotManager.alloc in task executor. + for _, st := range subtasks { + if _, ok := used[st.ExecID]; !ok { + used[st.ExecID] = st.Concurrency + } + } + + for node, slots := range used { + b.currUsedSlots[node] += slots + } +} diff --git a/pkg/disttask/framework/scheduler/balancer_test.go b/pkg/disttask/framework/scheduler/balancer_test.go new file mode 100644 index 0000000000000..06f87c223ef73 --- /dev/null +++ b/pkg/disttask/framework/scheduler/balancer_test.go @@ -0,0 +1,433 @@ +// Copyright 2023 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package scheduler + +import ( + "context" + "fmt" + "testing" + + "github.com/pingcap/errors" + "github.com/pingcap/tidb/pkg/disttask/framework/mock" + "github.com/pingcap/tidb/pkg/disttask/framework/proto" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.uber.org/mock/gomock" +) + +type balanceTestCase struct { + subtasks []*proto.Subtask + eligibleNodes []string + initUsedSlots map[string]int + expectedSubtasks []*proto.Subtask + expectedUsedSlots map[string]int +} + +func TestBalanceOneTask(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + testCases := []balanceTestCase{ + // no subtasks to balance, no need to do anything. + { + subtasks: []*proto.Subtask{}, + eligibleNodes: []string{"tidb1"}, + initUsedSlots: map[string]int{"tidb1": 0}, + expectedSubtasks: []*proto.Subtask{}, + expectedUsedSlots: map[string]int{"tidb1": 0}, + }, + // balanced, no need to do anything. + { + subtasks: []*proto.Subtask{ + {ID: 1, ExecID: "tidb1", Concurrency: 16, State: proto.SubtaskStateRunning}, + {ID: 2, ExecID: "tidb2", Concurrency: 16, State: proto.SubtaskStateRunning}, + {ID: 3, ExecID: "tidb2", Concurrency: 16, State: proto.SubtaskStatePending}, + }, + eligibleNodes: []string{"tidb1", "tidb2"}, + initUsedSlots: map[string]int{"tidb1": 0, "tidb2": 0}, + expectedSubtasks: []*proto.Subtask{ + {ID: 1, ExecID: "tidb1", Concurrency: 16, State: proto.SubtaskStateRunning}, + {ID: 2, ExecID: "tidb2", Concurrency: 16, State: proto.SubtaskStateRunning}, + {ID: 3, ExecID: "tidb2", Concurrency: 16, State: proto.SubtaskStatePending}, + }, + expectedUsedSlots: map[string]int{"tidb1": 16, "tidb2": 16}, + }, + // balanced case 2, make sure the remainder calculate part is right, so we don't + // balance subtasks to 2:2:0 + { + subtasks: []*proto.Subtask{ + {ID: 1, ExecID: "tidb2", Concurrency: 16, State: proto.SubtaskStateRunning}, + {ID: 2, ExecID: "tidb2", Concurrency: 16, State: proto.SubtaskStatePending}, + {ID: 3, ExecID: "tidb1", Concurrency: 16, State: proto.SubtaskStatePending}, + {ID: 4, ExecID: "tidb3", Concurrency: 16, State: proto.SubtaskStatePending}, + }, + eligibleNodes: []string{"tidb1", "tidb2", "tidb3"}, + initUsedSlots: map[string]int{"tidb1": 0, "tidb2": 0, "tidb3": 0}, + expectedSubtasks: []*proto.Subtask{ + {ID: 1, ExecID: "tidb2", Concurrency: 16, State: proto.SubtaskStateRunning}, + {ID: 2, ExecID: "tidb2", Concurrency: 16, State: proto.SubtaskStatePending}, + {ID: 3, ExecID: "tidb1", Concurrency: 16, State: proto.SubtaskStatePending}, + {ID: 4, ExecID: "tidb3", Concurrency: 16, State: proto.SubtaskStatePending}, + }, + expectedUsedSlots: map[string]int{"tidb1": 16, "tidb2": 16, "tidb3": 16}, + }, + // no eligible nodes to run those subtasks, leave it unbalanced. + // used slots will not be changed. + { + subtasks: []*proto.Subtask{ + {ID: 1, ExecID: "tidb1", Concurrency: 16, State: proto.SubtaskStateRunning}, + {ID: 2, ExecID: "tidb1", Concurrency: 16, State: proto.SubtaskStatePending}, + }, + eligibleNodes: []string{"tidb1", "tidb2"}, + initUsedSlots: map[string]int{"tidb1": 8, "tidb2": 8}, + expectedSubtasks: []*proto.Subtask{ + {ID: 1, ExecID: "tidb1", Concurrency: 16, State: proto.SubtaskStateRunning}, + {ID: 2, ExecID: "tidb1", Concurrency: 16, State: proto.SubtaskStatePending}, + }, + expectedUsedSlots: map[string]int{"tidb1": 8, "tidb2": 8}, + }, + // balance subtasks to eligible nodes, tidb1 has 8 used slots cannot run target subtasks. + // all subtasks will be balanced to tidb2. + { + subtasks: []*proto.Subtask{ + {ID: 1, ExecID: "tidb1", Concurrency: 16, State: proto.SubtaskStateRunning}, + {ID: 2, ExecID: "tidb1", Concurrency: 16, State: proto.SubtaskStatePending}, + }, + eligibleNodes: []string{"tidb1", "tidb2"}, + initUsedSlots: map[string]int{"tidb1": 8, "tidb2": 0}, + expectedSubtasks: []*proto.Subtask{ + {ID: 1, ExecID: "tidb2", Concurrency: 16, State: proto.SubtaskStateRunning}, + {ID: 2, ExecID: "tidb2", Concurrency: 16, State: proto.SubtaskStatePending}, + }, + expectedUsedSlots: map[string]int{"tidb1": 8, "tidb2": 16}, + }, + // running subtasks are not re-scheduled if the node is eligible, we leave it un-balanced. + // task executor should mark those subtasks as pending, then we can balance them. + { + subtasks: []*proto.Subtask{ + {ID: 1, ExecID: "tidb1", Concurrency: 16, State: proto.SubtaskStateRunning}, + {ID: 2, ExecID: "tidb1", Concurrency: 16, State: proto.SubtaskStateRunning}, + {ID: 3, ExecID: "tidb1", Concurrency: 16, State: proto.SubtaskStateRunning}, + }, + eligibleNodes: []string{"tidb1", "tidb2"}, + initUsedSlots: map[string]int{"tidb1": 0, "tidb2": 0}, + expectedSubtasks: []*proto.Subtask{ + {ID: 1, ExecID: "tidb1", Concurrency: 16, State: proto.SubtaskStateRunning}, + {ID: 2, ExecID: "tidb1", Concurrency: 16, State: proto.SubtaskStateRunning}, + {ID: 3, ExecID: "tidb1", Concurrency: 16, State: proto.SubtaskStateRunning}, + }, + expectedUsedSlots: map[string]int{"tidb1": 16, "tidb2": 0}, + }, + // balance from 1:4 to 2:3 + { + subtasks: []*proto.Subtask{ + {ID: 1, ExecID: "tidb1", Concurrency: 16, State: proto.SubtaskStateRunning}, + {ID: 2, ExecID: "tidb2", Concurrency: 16, State: proto.SubtaskStateRunning}, + {ID: 3, ExecID: "tidb2", Concurrency: 16, State: proto.SubtaskStatePending}, + {ID: 4, ExecID: "tidb2", Concurrency: 16, State: proto.SubtaskStatePending}, + {ID: 5, ExecID: "tidb2", Concurrency: 16, State: proto.SubtaskStatePending}, + }, + eligibleNodes: []string{"tidb1", "tidb2"}, + initUsedSlots: map[string]int{"tidb1": 0, "tidb2": 0}, + expectedSubtasks: []*proto.Subtask{ + {ID: 1, ExecID: "tidb1", Concurrency: 16, State: proto.SubtaskStateRunning}, + {ID: 2, ExecID: "tidb2", Concurrency: 16, State: proto.SubtaskStateRunning}, + {ID: 3, ExecID: "tidb2", Concurrency: 16, State: proto.SubtaskStatePending}, + {ID: 4, ExecID: "tidb2", Concurrency: 16, State: proto.SubtaskStatePending}, + {ID: 5, ExecID: "tidb1", Concurrency: 16, State: proto.SubtaskStatePending}, + }, + expectedUsedSlots: map[string]int{"tidb1": 16, "tidb2": 16}, + }, + // scale out, balance from 5 to 2:2:1 + { + subtasks: []*proto.Subtask{ + {ID: 1, ExecID: "tidb1", Concurrency: 16, State: proto.SubtaskStateRunning}, + {ID: 2, ExecID: "tidb1", Concurrency: 16, State: proto.SubtaskStatePending}, + {ID: 3, ExecID: "tidb1", Concurrency: 16, State: proto.SubtaskStatePending}, + {ID: 4, ExecID: "tidb1", Concurrency: 16, State: proto.SubtaskStatePending}, + {ID: 5, ExecID: "tidb1", Concurrency: 16, State: proto.SubtaskStatePending}, + }, + eligibleNodes: []string{"tidb1", "tidb2", "tidb3"}, + initUsedSlots: map[string]int{"tidb1": 0, "tidb2": 0, "tidb3": 0}, + expectedSubtasks: []*proto.Subtask{ + {ID: 1, ExecID: "tidb1", Concurrency: 16, State: proto.SubtaskStateRunning}, + {ID: 2, ExecID: "tidb1", Concurrency: 16, State: proto.SubtaskStatePending}, + {ID: 3, ExecID: "tidb2", Concurrency: 16, State: proto.SubtaskStatePending}, + {ID: 4, ExecID: "tidb2", Concurrency: 16, State: proto.SubtaskStatePending}, + {ID: 5, ExecID: "tidb3", Concurrency: 16, State: proto.SubtaskStatePending}, + }, + expectedUsedSlots: map[string]int{"tidb1": 16, "tidb2": 16, "tidb3": 16}, + }, + // scale out case 2: balance from 4 to 2:1:1 + // this case checks the remainder part is right, so we don't balance it as 2:2:0. + { + subtasks: []*proto.Subtask{ + {ID: 1, ExecID: "tidb2", Concurrency: 16, State: proto.SubtaskStateRunning}, + {ID: 2, ExecID: "tidb2", Concurrency: 16, State: proto.SubtaskStatePending}, + {ID: 3, ExecID: "tidb2", Concurrency: 16, State: proto.SubtaskStatePending}, + {ID: 4, ExecID: "tidb2", Concurrency: 16, State: proto.SubtaskStatePending}, + }, + eligibleNodes: []string{"tidb1", "tidb2", "tidb3"}, + initUsedSlots: map[string]int{"tidb1": 0, "tidb2": 0, "tidb3": 0}, + expectedSubtasks: []*proto.Subtask{ + {ID: 1, ExecID: "tidb2", Concurrency: 16, State: proto.SubtaskStateRunning}, + {ID: 2, ExecID: "tidb2", Concurrency: 16, State: proto.SubtaskStatePending}, + {ID: 3, ExecID: "tidb1", Concurrency: 16, State: proto.SubtaskStatePending}, + {ID: 4, ExecID: "tidb3", Concurrency: 16, State: proto.SubtaskStatePending}, + }, + expectedUsedSlots: map[string]int{"tidb1": 16, "tidb2": 16, "tidb3": 16}, + }, + // scale in, balance from 1:3:1 to 3:2 + { + subtasks: []*proto.Subtask{ + {ID: 1, ExecID: "tidb1", Concurrency: 16, State: proto.SubtaskStateRunning}, + {ID: 2, ExecID: "tidb2", Concurrency: 16, State: proto.SubtaskStateRunning}, + {ID: 3, ExecID: "tidb2", Concurrency: 16, State: proto.SubtaskStatePending}, + {ID: 4, ExecID: "tidb2", Concurrency: 16, State: proto.SubtaskStatePending}, + {ID: 5, ExecID: "tidb3", Concurrency: 16, State: proto.SubtaskStateRunning}, + }, + eligibleNodes: []string{"tidb1", "tidb3"}, + initUsedSlots: map[string]int{"tidb1": 0, "tidb3": 0}, + expectedSubtasks: []*proto.Subtask{ + {ID: 1, ExecID: "tidb1", Concurrency: 16, State: proto.SubtaskStateRunning}, + {ID: 2, ExecID: "tidb1", Concurrency: 16, State: proto.SubtaskStateRunning}, + {ID: 3, ExecID: "tidb1", Concurrency: 16, State: proto.SubtaskStatePending}, + {ID: 4, ExecID: "tidb3", Concurrency: 16, State: proto.SubtaskStatePending}, + {ID: 5, ExecID: "tidb3", Concurrency: 16, State: proto.SubtaskStateRunning}, + }, + expectedUsedSlots: map[string]int{"tidb1": 16, "tidb3": 16}, + }, + // scale in and out at the same time, balance from 2:1 to 2:1 + { + subtasks: []*proto.Subtask{ + {ID: 1, ExecID: "tidb1", Concurrency: 16, State: proto.SubtaskStateRunning}, + {ID: 2, ExecID: "tidb1", Concurrency: 16, State: proto.SubtaskStatePending}, + {ID: 3, ExecID: "tidb2", Concurrency: 16, State: proto.SubtaskStatePending}, + }, + eligibleNodes: []string{"tidb2", "tidb3"}, + initUsedSlots: map[string]int{"tidb2": 0, "tidb3": 0}, + expectedSubtasks: []*proto.Subtask{ + {ID: 1, ExecID: "tidb2", Concurrency: 16, State: proto.SubtaskStateRunning}, + {ID: 2, ExecID: "tidb3", Concurrency: 16, State: proto.SubtaskStatePending}, + {ID: 3, ExecID: "tidb2", Concurrency: 16, State: proto.SubtaskStatePending}, + }, + expectedUsedSlots: map[string]int{"tidb2": 16, "tidb3": 16}, + }, + } + + ctx := context.Background() + for i, c := range testCases { + t.Run(fmt.Sprintf("case %d", i), func(t *testing.T) { + mockTaskMgr := mock.NewMockTaskManager(ctrl) + mockTaskMgr.EXPECT().GetActiveSubtasks(gomock.Any(), gomock.Any()).Return(c.subtasks, nil) + if !assert.ObjectsAreEqual(c.subtasks, c.expectedSubtasks) { + mockTaskMgr.EXPECT().UpdateSubtasksExecIDs(gomock.Any(), gomock.Any()).Return(nil) + } + mockScheduler := mock.NewMockScheduler(ctrl) + mockScheduler.EXPECT().GetTask().Return(&proto.Task{ID: 1}).Times(2) + mockScheduler.EXPECT().GetEligibleInstances(gomock.Any(), gomock.Any()).Return(nil, nil) + + slotMgr := newSlotManager() + slotMgr.updateCapacity(16) + b := newBalancer(Param{ + taskMgr: mockTaskMgr, + nodeMgr: newNodeManager(), + slotMgr: slotMgr, + }) + b.currUsedSlots = c.initUsedSlots + require.NoError(t, b.balanceSubtasks(ctx, mockScheduler, c.eligibleNodes)) + require.Equal(t, c.expectedUsedSlots, b.currUsedSlots) + // c.subtasks is updated in-place + require.Equal(t, c.expectedSubtasks, c.subtasks) + require.True(t, ctrl.Satisfied()) + }) + } + + t.Run("scheduler err or no instance", func(t *testing.T) { + mockTaskMgr := mock.NewMockTaskManager(ctrl) + mockScheduler := mock.NewMockScheduler(ctrl) + mockScheduler.EXPECT().GetTask().Return(&proto.Task{ID: 1}).Times(2) + mockScheduler.EXPECT().GetEligibleInstances(gomock.Any(), gomock.Any()).Return(nil, errors.New("mock error")) + slotMgr := newSlotManager() + slotMgr.updateCapacity(16) + b := newBalancer(Param{ + taskMgr: mockTaskMgr, + nodeMgr: newNodeManager(), + slotMgr: slotMgr, + }) + require.ErrorContains(t, b.balanceSubtasks(ctx, mockScheduler, []string{"tidb1"}), "mock error") + require.True(t, ctrl.Satisfied()) + + mockScheduler.EXPECT().GetTask().Return(&proto.Task{ID: 1}).Times(2) + mockScheduler.EXPECT().GetEligibleInstances(gomock.Any(), gomock.Any()).Return(nil, nil) + require.ErrorContains(t, b.balanceSubtasks(ctx, mockScheduler, nil), "no eligible nodes to balance subtasks") + require.True(t, ctrl.Satisfied()) + }) + + t.Run("task mgr failed", func(t *testing.T) { + mockTaskMgr := mock.NewMockTaskManager(ctrl) + mockTaskMgr.EXPECT().GetActiveSubtasks(gomock.Any(), gomock.Any()).Return(nil, errors.New("mock error")) + mockScheduler := mock.NewMockScheduler(ctrl) + mockScheduler.EXPECT().GetTask().Return(&proto.Task{ID: 1}).Times(2) + mockScheduler.EXPECT().GetEligibleInstances(gomock.Any(), gomock.Any()).Return([]string{"tidb1"}, nil) + + slotMgr := newSlotManager() + slotMgr.updateCapacity(16) + b := newBalancer(Param{ + taskMgr: mockTaskMgr, + nodeMgr: newNodeManager(), + slotMgr: slotMgr, + }) + require.ErrorContains(t, b.balanceSubtasks(ctx, mockScheduler, []string{"tidb1"}), "mock error") + require.True(t, ctrl.Satisfied()) + + b.currUsedSlots = map[string]int{"tidb1": 0, "tidb2": 0} + mockTaskMgr.EXPECT().GetActiveSubtasks(gomock.Any(), gomock.Any()).Return( + []*proto.Subtask{ + {ID: 1, ExecID: "tidb1", Concurrency: 16, State: proto.SubtaskStateRunning}, + {ID: 2, ExecID: "tidb1", Concurrency: 16, State: proto.SubtaskStatePending}, + }, nil) + mockTaskMgr.EXPECT().UpdateSubtasksExecIDs(gomock.Any(), gomock.Any()).Return(errors.New("mock error2")) + mockScheduler.EXPECT().GetTask().Return(&proto.Task{ID: 1}).Times(2) + mockScheduler.EXPECT().GetEligibleInstances(gomock.Any(), gomock.Any()).Return(nil, nil) + require.ErrorContains(t, b.balanceSubtasks(ctx, mockScheduler, []string{"tidb1", "tidb2"}), "mock error2") + // not updated + require.Equal(t, map[string]int{"tidb1": 0, "tidb2": 0}, b.currUsedSlots) + require.True(t, ctrl.Satisfied()) + }) +} + +func TestBalanceMultipleTasks(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + mockTaskMgr := mock.NewMockTaskManager(ctrl) + + taskCases := []struct { + subtasks, expectedSubtasks []*proto.Subtask + }{ + // task 1 is balanced + // used slots will be {tidb1: 8, tidb2: 8, tidb3: 0} + { + subtasks: []*proto.Subtask{ + {ID: 1, ExecID: "tidb1", Concurrency: 8, State: proto.SubtaskStateRunning}, + {ID: 2, ExecID: "tidb2", Concurrency: 8, State: proto.SubtaskStateRunning}, + }, + expectedSubtasks: []*proto.Subtask{ + {ID: 1, ExecID: "tidb1", Concurrency: 8, State: proto.SubtaskStateRunning}, + {ID: 2, ExecID: "tidb2", Concurrency: 8, State: proto.SubtaskStateRunning}, + }, + }, + // task 2 require balance + // used slots will be {tidb1: 16, tidb2: 16, tidb3: 8} + { + subtasks: []*proto.Subtask{ + {ID: 3, ExecID: "tidb1", Concurrency: 8, State: proto.SubtaskStateRunning}, + {ID: 4, ExecID: "tidb4", Concurrency: 8, State: proto.SubtaskStateRunning}, + {ID: 5, ExecID: "tidb4", Concurrency: 8, State: proto.SubtaskStatePending}, + }, + expectedSubtasks: []*proto.Subtask{ + {ID: 3, ExecID: "tidb1", Concurrency: 8, State: proto.SubtaskStateRunning}, + {ID: 4, ExecID: "tidb2", Concurrency: 8, State: proto.SubtaskStateRunning}, + {ID: 5, ExecID: "tidb3", Concurrency: 8, State: proto.SubtaskStatePending}, + }, + }, + // task 3 require balance, but no eligible node, so it's not balanced, and + // used slots are not updated + // used slots will be {tidb1: 16, tidb2: 16, tidb3: 8} + { + subtasks: []*proto.Subtask{ + {ID: 6, ExecID: "tidb4", Concurrency: 16, State: proto.SubtaskStatePending}, + {ID: 7, ExecID: "tidb4", Concurrency: 16, State: proto.SubtaskStatePending}, + }, + expectedSubtasks: []*proto.Subtask{ + {ID: 6, ExecID: "tidb4", Concurrency: 16, State: proto.SubtaskStatePending}, + {ID: 7, ExecID: "tidb4", Concurrency: 16, State: proto.SubtaskStatePending}, + }, + }, + // task 4 require balance + // used slots will be {tidb1: 16, tidb2: 16, tidb3: 16} + { + subtasks: []*proto.Subtask{ + {ID: 8, ExecID: "tidb1", Concurrency: 8, State: proto.SubtaskStatePending}, + }, + expectedSubtasks: []*proto.Subtask{ + {ID: 8, ExecID: "tidb3", Concurrency: 8, State: proto.SubtaskStatePending}, + }, + }, + } + ctx := context.Background() + + manager, err := NewManager(ctx, mockTaskMgr, "1") + require.NoError(t, err) + manager.slotMgr.updateCapacity(16) + manager.nodeMgr.managedNodes.Store(&[]string{"tidb1", "tidb2", "tidb3"}) + b := newBalancer(Param{ + taskMgr: manager.taskMgr, + nodeMgr: manager.nodeMgr, + slotMgr: manager.slotMgr, + }) + for i := range taskCases { + taskID := int64(i + 1) + sch := mock.NewMockScheduler(ctrl) + sch.EXPECT().GetTask().Return(&proto.Task{ID: taskID}).AnyTimes() + manager.addScheduler(taskID, sch) + } + require.Len(t, manager.getSchedulers(), 4) + + // fail fast if balance failed on some task + manager.getSchedulers()[0].(*mock.MockScheduler).EXPECT(). + GetEligibleInstances(gomock.Any(), gomock.Any()).Return(nil, errors.New("mock error")) + b.balance(ctx, manager) + require.True(t, ctrl.Satisfied()) + + // balance multiple tasks + for i, c := range taskCases { + taskID := int64(i + 1) + if !assert.ObjectsAreEqual(c.subtasks, c.expectedSubtasks) { + gomock.InOrder( + manager.getSchedulers()[i].(*mock.MockScheduler).EXPECT(). + GetEligibleInstances(gomock.Any(), gomock.Any()).Return(nil, nil), + mockTaskMgr.EXPECT().GetActiveSubtasks(gomock.Any(), taskID).Return(c.subtasks, nil), + mockTaskMgr.EXPECT().UpdateSubtasksExecIDs(gomock.Any(), gomock.Any()).Return(nil), + ) + } else { + gomock.InOrder( + manager.getSchedulers()[i].(*mock.MockScheduler).EXPECT(). + GetEligibleInstances(gomock.Any(), gomock.Any()).Return(nil, nil), + mockTaskMgr.EXPECT().GetActiveSubtasks(gomock.Any(), taskID).Return(c.subtasks, nil), + ) + } + } + b.balance(ctx, manager) + require.Equal(t, map[string]int{"tidb1": 16, "tidb2": 16, "tidb3": 16}, b.currUsedSlots) + require.True(t, ctrl.Satisfied()) + for _, c := range taskCases { + require.Equal(t, c.expectedSubtasks, c.subtasks) + } +} + +func TestBalancerUpdateUsedNodes(t *testing.T) { + b := newBalancer(Param{}) + b.updateUsedNodes([]*proto.Subtask{ + {ID: 1, ExecID: "tidb1", Concurrency: 16, State: proto.SubtaskStateRunning}, + {ID: 2, ExecID: "tidb1", Concurrency: 16, State: proto.SubtaskStatePending}, + }) + require.Equal(t, map[string]int{"tidb1": 16}, b.currUsedSlots) + b.updateUsedNodes([]*proto.Subtask{ + {ID: 3, ExecID: "tidb1", Concurrency: 4, State: proto.SubtaskStateRunning}, + {ID: 4, ExecID: "tidb2", Concurrency: 8, State: proto.SubtaskStatePending}, + {ID: 5, ExecID: "tidb3", Concurrency: 12, State: proto.SubtaskStatePending}, + }) + require.Equal(t, map[string]int{"tidb1": 20, "tidb2": 8, "tidb3": 12}, b.currUsedSlots) +} diff --git a/pkg/disttask/framework/scheduler/interface.go b/pkg/disttask/framework/scheduler/interface.go index 77c949dd2c661..0cca913fa1479 100644 --- a/pkg/disttask/framework/scheduler/interface.go +++ b/pkg/disttask/framework/scheduler/interface.go @@ -18,6 +18,7 @@ import ( "context" "github.com/pingcap/tidb/pkg/disttask/framework/proto" + "github.com/pingcap/tidb/pkg/disttask/framework/storage" "github.com/pingcap/tidb/pkg/sessionctx" "github.com/pingcap/tidb/pkg/util/syncutil" ) @@ -65,12 +66,15 @@ type TaskManager interface { // we only consider pending/running subtasks, subtasks related to revert are // not considered. GetUsedSlotsOnNodes(ctx context.Context) (map[string]int, error) + // GetActiveSubtasks returns subtasks of the task that are in pending/running state. + // the returned subtasks only contains some fields, see row2SubtaskBasic. + GetActiveSubtasks(ctx context.Context, taskID int64) ([]*proto.Subtask, error) // GetSubtaskCntGroupByStates returns the count of subtasks of some step group by state. GetSubtaskCntGroupByStates(ctx context.Context, taskID int64, step proto.Step) (map[proto.SubtaskState]int64, error) ResumeSubtasks(ctx context.Context, taskID int64) error CollectSubTaskError(ctx context.Context, taskID int64) ([]error, error) TransferSubTasks2History(ctx context.Context, taskID int64) error - UpdateSubtasksExecIDs(ctx context.Context, taskID int64, subtasks []*proto.Subtask) error + UpdateSubtasksExecIDs(ctx context.Context, subtasks []*proto.Subtask) error // GetManagedNodes returns the nodes managed by dist framework and can be used // to execute tasks. If there are any nodes with background role, we use them, // else we use nodes without role. @@ -101,13 +105,13 @@ type Extension interface { // 1. task is pending and entering it's first step. // 2. subtasks scheduled has all finished with no error. // when next step is StepDone, it should return nil, nil. - OnNextSubtasksBatch(ctx context.Context, h TaskHandle, task *proto.Task, execIDs []string, step proto.Step) (subtaskMetas [][]byte, err error) + OnNextSubtasksBatch(ctx context.Context, h storage.TaskHandle, task *proto.Task, execIDs []string, step proto.Step) (subtaskMetas [][]byte, err error) // OnDone is called when task is done, either finished successfully or failed // with error. // if the task is failed when initializing scheduler, or it's an unknown task, // we don't call this function. - OnDone(ctx context.Context, h TaskHandle, task *proto.Task) error + OnDone(ctx context.Context, h storage.TaskHandle, task *proto.Task) error // GetEligibleInstances is used to get the eligible instances for the task. // on certain condition we may want to use some instances to do the task, such as instances with more disk. @@ -125,8 +129,15 @@ type Extension interface { GetNextStep(task *proto.Task) proto.Step } +// Param is used to pass parameters when creating scheduler. +type Param struct { + taskMgr TaskManager + nodeMgr *NodeManager + slotMgr *SlotManager +} + // schedulerFactoryFn is used to create a scheduler. -type schedulerFactoryFn func(ctx context.Context, taskMgr TaskManager, nodeMgr *NodeManager, task *proto.Task) Scheduler +type schedulerFactoryFn func(ctx context.Context, task *proto.Task, param Param) Scheduler var schedulerFactoryMap = struct { syncutil.RWMutex diff --git a/pkg/disttask/framework/scheduler/main_test.go b/pkg/disttask/framework/scheduler/main_test.go index 8a16e29e849a7..fadbe0d4adbef 100644 --- a/pkg/disttask/framework/scheduler/main_test.go +++ b/pkg/disttask/framework/scheduler/main_test.go @@ -40,10 +40,6 @@ func (s *BaseScheduler) Switch2NextStep() (err error) { return s.switch2NextStep() } -func (s *BaseScheduler) DoBalanceSubtasks(eligibleNodes []string) error { - return s.doBalanceSubtasks(eligibleNodes) -} - func NewNodeManager() *NodeManager { return newNodeManager() } diff --git a/pkg/disttask/framework/scheduler/mock/BUILD.bazel b/pkg/disttask/framework/scheduler/mock/BUILD.bazel index 541b2013c3bdd..76642e205e655 100644 --- a/pkg/disttask/framework/scheduler/mock/BUILD.bazel +++ b/pkg/disttask/framework/scheduler/mock/BUILD.bazel @@ -7,7 +7,7 @@ go_library( visibility = ["//visibility:public"], deps = [ "//pkg/disttask/framework/proto", - "//pkg/disttask/framework/scheduler", + "//pkg/disttask/framework/storage", "@org_uber_go_mock//gomock", ], ) diff --git a/pkg/disttask/framework/scheduler/mock/scheduler_mock.go b/pkg/disttask/framework/scheduler/mock/scheduler_mock.go index c65047b74aeda..a6b241273a43e 100644 --- a/pkg/disttask/framework/scheduler/mock/scheduler_mock.go +++ b/pkg/disttask/framework/scheduler/mock/scheduler_mock.go @@ -13,7 +13,7 @@ import ( reflect "reflect" proto "github.com/pingcap/tidb/pkg/disttask/framework/proto" - scheduler "github.com/pingcap/tidb/pkg/disttask/framework/scheduler" + storage "github.com/pingcap/tidb/pkg/disttask/framework/storage" gomock "go.uber.org/mock/gomock" ) @@ -84,7 +84,7 @@ func (mr *MockExtensionMockRecorder) IsRetryableErr(arg0 any) *gomock.Call { } // OnDone mocks base method. -func (m *MockExtension) OnDone(arg0 context.Context, arg1 scheduler.TaskHandle, arg2 *proto.Task) error { +func (m *MockExtension) OnDone(arg0 context.Context, arg1 storage.TaskHandle, arg2 *proto.Task) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "OnDone", arg0, arg1, arg2) ret0, _ := ret[0].(error) @@ -98,7 +98,7 @@ func (mr *MockExtensionMockRecorder) OnDone(arg0, arg1, arg2 any) *gomock.Call { } // OnNextSubtasksBatch mocks base method. -func (m *MockExtension) OnNextSubtasksBatch(arg0 context.Context, arg1 scheduler.TaskHandle, arg2 *proto.Task, arg3 []string, arg4 proto.Step) ([][]byte, error) { +func (m *MockExtension) OnNextSubtasksBatch(arg0 context.Context, arg1 storage.TaskHandle, arg2 *proto.Task, arg3 []string, arg4 proto.Step) ([][]byte, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "OnNextSubtasksBatch", arg0, arg1, arg2, arg3, arg4) ret0, _ := ret[0].([][]byte) diff --git a/pkg/disttask/framework/scheduler/nodes.go b/pkg/disttask/framework/scheduler/nodes.go index ee968c67bccd6..7f9c2a8eb77af 100644 --- a/pkg/disttask/framework/scheduler/nodes.go +++ b/pkg/disttask/framework/scheduler/nodes.go @@ -111,7 +111,7 @@ func (nm *NodeManager) maintainLiveNodes(ctx context.Context, taskMgr TaskManage nm.prevLiveNodes = currLiveNodes } -func (nm *NodeManager) refreshManagedNodesLoop(ctx context.Context, taskMgr TaskManager, slotMgr *slotManager) { +func (nm *NodeManager) refreshManagedNodesLoop(ctx context.Context, taskMgr TaskManager, slotMgr *SlotManager) { ticker := time.NewTicker(nodesCheckInterval) defer ticker.Stop() for { @@ -128,7 +128,7 @@ func (nm *NodeManager) refreshManagedNodesLoop(ctx context.Context, taskMgr Task var TestRefreshedChan = make(chan struct{}) // refreshManagedNodes maintains the nodes managed by the framework. -func (nm *NodeManager) refreshManagedNodes(ctx context.Context, taskMgr TaskManager, slotMgr *slotManager) { +func (nm *NodeManager) refreshManagedNodes(ctx context.Context, taskMgr TaskManager, slotMgr *SlotManager) { newNodes, err := taskMgr.GetManagedNodes(ctx) if err != nil { logutil.BgLogger().Warn("get managed nodes met error", log.ShortError(err)) @@ -151,7 +151,10 @@ func (nm *NodeManager) refreshManagedNodes(ctx context.Context, taskMgr TaskMana } // GetManagedNodes returns the nodes managed by the framework. -// The returned map is read-only, don't write to it. +// return a copy of the managed nodes. func (nm *NodeManager) getManagedNodes() []string { - return *nm.managedNodes.Load() + nodes := *nm.managedNodes.Load() + res := make([]string, len(nodes)) + copy(res, nodes) + return res } diff --git a/pkg/disttask/framework/scheduler/rebalance_test.go b/pkg/disttask/framework/scheduler/rebalance_test.go deleted file mode 100644 index 61387d5b0ce3d..0000000000000 --- a/pkg/disttask/framework/scheduler/rebalance_test.go +++ /dev/null @@ -1,560 +0,0 @@ -// Copyright 2023 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package scheduler_test - -import ( - "context" - "slices" - "strings" - "testing" - "time" - - "github.com/ngaut/pools" - "github.com/pingcap/tidb/pkg/disttask/framework/mock" - "github.com/pingcap/tidb/pkg/disttask/framework/proto" - "github.com/pingcap/tidb/pkg/disttask/framework/scheduler" - "github.com/pingcap/tidb/pkg/testkit" - "github.com/stretchr/testify/require" - "go.uber.org/mock/gomock" -) - -type scaleTestCase struct { - subtasks []*proto.Subtask - liveNodes []string - taskNodes []string - cleanedNodes []string - expectedTaskNodes []string - expectedSubtasks []*proto.Subtask -} - -type balanceTestCase struct { - subtasks []*proto.Subtask - liveNodes []string - taskNodes []string - expectedSubtasks []*proto.Subtask -} - -func scaleTest(t *testing.T, - mockTaskMgr *mock.MockTaskManager, - testCase scaleTestCase, - id int) { - ctx := context.Background() - mockTaskMgr.EXPECT().GetSubtasksByStepAndState(ctx, int64(id), proto.StepInit, proto.TaskStatePending).Return( - testCase.subtasks, - nil) - mockTaskMgr.EXPECT().UpdateSubtasksExecIDs(ctx, int64(id), testCase.subtasks).Return(nil).AnyTimes() - mockTaskMgr.EXPECT().DeleteDeadNodes(ctx, testCase.cleanedNodes).Return(nil).AnyTimes() - if len(testCase.cleanedNodes) > 0 { - mockTaskMgr.EXPECT().GetSubtasksByExecIdsAndStepAndState(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, nil) - } - nodeMgr := scheduler.NewNodeManager() - sch := scheduler.NewBaseScheduler(ctx, mockTaskMgr, nodeMgr, &proto.Task{Step: proto.StepInit, ID: int64(id)}) - sch.TaskNodes = testCase.taskNodes - require.NoError(t, sch.DoBalanceSubtasks(testCase.liveNodes)) - slices.SortFunc(sch.TaskNodes, func(i, j string) int { - return strings.Compare(i, j) - }) - slices.SortFunc(testCase.subtasks, func(i, j *proto.Subtask) int { - return strings.Compare(i.ExecID, j.ExecID) - }) - require.Equal(t, testCase.expectedTaskNodes, sch.TaskNodes) - require.Equal(t, testCase.expectedSubtasks, testCase.subtasks) -} - -func balanceTest(t *testing.T, - mockTaskMgr *mock.MockTaskManager, - testCase balanceTestCase, - id int) { - ctx := context.Background() - mockTaskMgr.EXPECT().GetSubtasksByStepAndState(ctx, int64(id), proto.StepInit, proto.TaskStatePending).Return( - testCase.subtasks, - nil) - mockTaskMgr.EXPECT().DeleteDeadNodes(ctx, gomock.Any()).Return(nil).AnyTimes() - - nodeMgr := scheduler.NewNodeManager() - mockTaskMgr.EXPECT().UpdateSubtasksExecIDs(ctx, int64(id), testCase.subtasks).Return(nil).AnyTimes() - sch := scheduler.NewBaseScheduler(ctx, mockTaskMgr, nodeMgr, &proto.Task{Step: proto.StepInit, ID: int64(id)}) - sch.TaskNodes = testCase.taskNodes - require.NoError(t, sch.DoBalanceSubtasks(testCase.liveNodes)) - slices.SortFunc(sch.TaskNodes, func(i, j string) int { - return strings.Compare(i, j) - }) - slices.SortFunc(testCase.subtasks, func(i, j *proto.Subtask) int { - return strings.Compare(i.ExecID, j.ExecID) - }) - require.Equal(t, testCase.expectedSubtasks, testCase.subtasks) -} - -func TestScaleOutNodes(t *testing.T) { - store := testkit.CreateMockStore(t) - gtk := testkit.NewTestKit(t, store) - pool := pools.NewResourcePool(func() (pools.Resource, error) { - return gtk.Session(), nil - }, 1, 1, time.Second) - defer pool.Close() - ctrl := gomock.NewController(t) - defer ctrl.Finish() - mockTaskMgr := mock.NewMockTaskManager(ctrl) - testCases := []scaleTestCase{ - // 1. scale out from 1 node to 2 nodes. 4 subtasks. - { - []*proto.Subtask{ - {ExecID: "1.1.1.1:4000"}, - {ExecID: "1.1.1.1:4000"}, - {ExecID: "1.1.1.1:4000"}, - {ExecID: "1.1.1.1:4000"}}, - []string{"1.1.1.1:4000", "1.1.1.2:4000"}, - []string{"1.1.1.1:4000"}, - []string{}, - []string{"1.1.1.1:4000", "1.1.1.2:4000"}, - []*proto.Subtask{ - {ExecID: "1.1.1.1:4000"}, - {ExecID: "1.1.1.1:4000"}, - {ExecID: "1.1.1.2:4000"}, - {ExecID: "1.1.1.2:4000"}}, - }, - // 2. scale out from 1 node to 2 nodes. 3 subtasks. - { - []*proto.Subtask{ - {ExecID: "1.1.1.1:4000"}, - {ExecID: "1.1.1.1:4000"}, - {ExecID: "1.1.1.1:4000"}}, - []string{"1.1.1.1:4000", "1.1.1.2:4000"}, - []string{"1.1.1.1:4000"}, - []string{}, - []string{"1.1.1.1:4000", "1.1.1.2:4000"}, - []*proto.Subtask{ - {ExecID: "1.1.1.1:4000"}, - {ExecID: "1.1.1.1:4000"}, - {ExecID: "1.1.1.2:4000"}}, - }, - // 3. scale out from 2 nodes to 4 nodes. 4 subtasks. - { - []*proto.Subtask{ - {ExecID: "1.1.1.2:4000"}, - {ExecID: "1.1.1.2:4000"}, - {ExecID: "1.1.1.1:4000"}, - {ExecID: "1.1.1.1:4000"}}, - []string{"1.1.1.1:4000", "1.1.1.2:4000", "1.1.1.3:4000", "1.1.1.4:4000"}, - []string{"1.1.1.1:4000", "1.1.1.2:4000"}, - []string{}, - []string{"1.1.1.1:4000", "1.1.1.2:4000", "1.1.1.3:4000", "1.1.1.4:4000"}, - []*proto.Subtask{ - {ExecID: "1.1.1.1:4000"}, - {ExecID: "1.1.1.2:4000"}, - {ExecID: "1.1.1.3:4000"}, - {ExecID: "1.1.1.4:4000"}}, - }, - // 4. scale out from 2 nodes to 4 nodes. 9 subtasks. - { - []*proto.Subtask{ - {ExecID: "1.1.1.2:4000"}, - {ExecID: "1.1.1.2:4000"}, - {ExecID: "1.1.1.2:4000"}, - {ExecID: "1.1.1.2:4000"}, - {ExecID: "1.1.1.2:4000"}, - {ExecID: "1.1.1.1:4000"}, - {ExecID: "1.1.1.1:4000"}, - {ExecID: "1.1.1.1:4000"}, - {ExecID: "1.1.1.1:4000"}}, - []string{"1.1.1.1:4000", "1.1.1.2:4000", "1.1.1.3:4000", "1.1.1.4:4000"}, - []string{"1.1.1.1:4000", "1.1.1.2:4000"}, - []string{}, - []string{"1.1.1.1:4000", "1.1.1.2:4000", "1.1.1.3:4000", "1.1.1.4:4000"}, - []*proto.Subtask{ - {ExecID: "1.1.1.1:4000"}, - {ExecID: "1.1.1.1:4000"}, - {ExecID: "1.1.1.1:4000"}, - {ExecID: "1.1.1.2:4000"}, - {ExecID: "1.1.1.2:4000"}, - {ExecID: "1.1.1.3:4000"}, - {ExecID: "1.1.1.3:4000"}, - {ExecID: "1.1.1.4:4000"}, - {ExecID: "1.1.1.4:4000"}}, - }, - // 5. scale out from 2 nodes to 3 nodes. - { - []*proto.Subtask{ - {ExecID: "1.1.1.2:4000"}, - {ExecID: "1.1.1.2:4000"}, - {ExecID: "1.1.1.1:4000"}, - {ExecID: "1.1.1.1:4000"}}, - []string{"1.1.1.1:4000", "1.1.1.2:4000", "1.1.1.3:4000"}, - []string{"1.1.1.1:4000", "1.1.1.2:4000"}, - []string{}, - []string{"1.1.1.1:4000", "1.1.1.2:4000", "1.1.1.3:4000"}, - []*proto.Subtask{ - {ExecID: "1.1.1.1:4000"}, - {ExecID: "1.1.1.1:4000"}, - {ExecID: "1.1.1.2:4000"}, - {ExecID: "1.1.1.3:4000"}}, - }, - // 6. scale out from 1 node to another 2 node. - { - []*proto.Subtask{ - {ExecID: "1.1.1.1:4000"}, - {ExecID: "1.1.1.1:4000"}, - {ExecID: "1.1.1.1:4000"}, - {ExecID: "1.1.1.1:4000"}}, - []string{"1.1.1.2:4000", "1.1.1.3:4000"}, - []string{"1.1.1.1:4000"}, - []string{"1.1.1.1:4000"}, - []string{"1.1.1.2:4000", "1.1.1.3:4000"}, - []*proto.Subtask{ - {ExecID: "1.1.1.2:4000"}, - {ExecID: "1.1.1.2:4000"}, - {ExecID: "1.1.1.3:4000"}, - {ExecID: "1.1.1.3:4000"}}, - }, - // 7. scale out from tidb1, tidb2 to tidb2, tidb3. - { - []*proto.Subtask{ - {ExecID: "1.1.1.1:4000"}, - {ExecID: "1.1.1.1:4000"}, - {ExecID: "1.1.1.2:4000"}, - {ExecID: "1.1.1.2:4000"}}, - []string{"1.1.1.2:4000", "1.1.1.3:4000"}, - []string{"1.1.1.1:4000", "1.1.1.2:4000"}, - []string{"1.1.1.1:4000"}, - []string{"1.1.1.2:4000", "1.1.1.3:4000"}, - []*proto.Subtask{ - {ExecID: "1.1.1.2:4000"}, - {ExecID: "1.1.1.2:4000"}, - {ExecID: "1.1.1.3:4000"}, - {ExecID: "1.1.1.3:4000"}}, - }, - // 8. scale from tidb1, tidb2 to tidb3, tidb4. - { - []*proto.Subtask{ - {ExecID: "1.1.1.1:4000"}, - {ExecID: "1.1.1.1:4000"}, - {ExecID: "1.1.1.2:4000"}, - {ExecID: "1.1.1.2:4000"}}, - []string{"1.1.1.3:4000", "1.1.1.4:4000"}, - []string{"1.1.1.1:4000", "1.1.1.2:4000"}, - []string{"1.1.1.1:4000", "1.1.1.2:4000"}, - []string{"1.1.1.3:4000", "1.1.1.4:4000"}, - []*proto.Subtask{ - {ExecID: "1.1.1.3:4000"}, - {ExecID: "1.1.1.3:4000"}, - {ExecID: "1.1.1.4:4000"}, - {ExecID: "1.1.1.4:4000"}}, - }, - // 9. scale from tidb1, tidb2 to tidb2, tidb3, tidb4. - { - []*proto.Subtask{ - {ExecID: "1.1.1.1:4000"}, - {ExecID: "1.1.1.1:4000"}, - {ExecID: "1.1.1.2:4000"}, - {ExecID: "1.1.1.2:4000"}}, - []string{"1.1.1.2:4000", "1.1.1.3:4000", "1.1.1.4:4000"}, - []string{"1.1.1.1:4000", "1.1.1.2:4000"}, - []string{"1.1.1.1:4000"}, - []string{"1.1.1.2:4000", "1.1.1.3:4000", "1.1.1.4:4000"}, - []*proto.Subtask{ - {ExecID: "1.1.1.2:4000"}, - {ExecID: "1.1.1.2:4000"}, - {ExecID: "1.1.1.3:4000"}, - {ExecID: "1.1.1.4:4000"}}, - }, - // 10. scale form tidb1, tidb2 to tidb2, tidb3, tidb4. - { - - []*proto.Subtask{ - {ExecID: "1.1.1.1:4000"}, - {ExecID: "1.1.1.1:4000"}, - {ExecID: "1.1.1.2:4000"}, - {ExecID: "1.1.1.2:4000"}, - {ExecID: "1.1.1.1:4000"}}, - []string{"1.1.1.2:4000", "1.1.1.3:4000", "1.1.1.4:4000"}, - []string{"1.1.1.1:4000", "1.1.1.2:4000"}, - []string{"1.1.1.1:4000"}, - []string{"1.1.1.2:4000", "1.1.1.3:4000", "1.1.1.4:4000"}, - []*proto.Subtask{ - {ExecID: "1.1.1.2:4000"}, - {ExecID: "1.1.1.2:4000"}, - {ExecID: "1.1.1.3:4000"}, - {ExecID: "1.1.1.3:4000"}, - {ExecID: "1.1.1.4:4000"}}, - }, - // 11. scale from tidb1(2 subtasks), tidb2(3 subtasks), tidb3(0 subtasks) to tidb1, tidb3, tidb4, tidb5, tidb6. - { - []*proto.Subtask{ - {ExecID: "1.1.1.1:4000"}, - {ExecID: "1.1.1.1:4000"}, - {ExecID: "1.1.1.2:4000"}, - {ExecID: "1.1.1.2:4000"}, - {ExecID: "1.1.1.2:4000"}}, - []string{"1.1.1.1:4000", "1.1.1.3:4000", "1.1.1.4:4000", "1.1.1.5:4000", "1.1.1.6:4000"}, - []string{"1.1.1.1:4000", "1.1.1.2:4000"}, - []string{"1.1.1.2:4000"}, - []string{"1.1.1.1:4000", "1.1.1.3:4000", "1.1.1.4:4000", "1.1.1.5:4000", "1.1.1.6:4000"}, - []*proto.Subtask{ - {ExecID: "1.1.1.1:4000"}, - {ExecID: "1.1.1.3:4000"}, - {ExecID: "1.1.1.4:4000"}, - {ExecID: "1.1.1.5:4000"}, - {ExecID: "1.1.1.6:4000"}}, - }, - } - for i, testCase := range testCases { - scaleTest(t, mockTaskMgr, testCase, i+1) - } -} - -func TestScaleInNodes(t *testing.T) { - store := testkit.CreateMockStore(t) - gtk := testkit.NewTestKit(t, store) - pool := pools.NewResourcePool(func() (pools.Resource, error) { - return gtk.Session(), nil - }, 1, 1, time.Second) - defer pool.Close() - ctrl := gomock.NewController(t) - defer ctrl.Finish() - mockTaskMgr := mock.NewMockTaskManager(ctrl) - testCases := []scaleTestCase{ - // 1. scale in from tidb1, tidb2 to tidb1. - { - []*proto.Subtask{ - {ExecID: "1.1.1.1:4000"}, - {ExecID: "1.1.1.1:4000"}, - {ExecID: "1.1.1.2:4000"}, - {ExecID: "1.1.1.2:4000"}}, - []string{"1.1.1.1:4000"}, - []string{"1.1.1.1:4000", "1.1.1.2:4000"}, - []string{"1.1.1.2:4000"}, - []string{"1.1.1.1:4000"}, - []*proto.Subtask{ - {ExecID: "1.1.1.1:4000"}, - {ExecID: "1.1.1.1:4000"}, - {ExecID: "1.1.1.1:4000"}, - {ExecID: "1.1.1.1:4000"}}, - }, - // 2. scale in from tidb1, tidb2 to tidb3. - { - []*proto.Subtask{ - {ExecID: "1.1.1.1:4000"}, - {ExecID: "1.1.1.1:4000"}, - {ExecID: "1.1.1.2:4000"}, - {ExecID: "1.1.1.2:4000"}}, - []string{"1.1.1.3:4000"}, - []string{"1.1.1.1:4000", "1.1.1.2:4000"}, - []string{"1.1.1.1:4000", "1.1.1.2:4000"}, - []string{"1.1.1.3:4000"}, - []*proto.Subtask{ - {ExecID: "1.1.1.3:4000"}, - {ExecID: "1.1.1.3:4000"}, - {ExecID: "1.1.1.3:4000"}, - {ExecID: "1.1.1.3:4000"}}, - }, - // 5. scale in from 10 nodes to 2 nodes. - { - []*proto.Subtask{ - {ExecID: "1.1.1.1:4000"}, - {ExecID: "1.1.1.2:4000"}, - {ExecID: "1.1.1.3:4000"}, - {ExecID: "1.1.1.4:4000"}, - {ExecID: "1.1.1.5:4000"}, - {ExecID: "1.1.1.6:4000"}, - {ExecID: "1.1.1.7:4000"}, - {ExecID: "1.1.1.8:4000"}, - {ExecID: "1.1.1.9:4000"}, - {ExecID: "1.1.1.10:4000"}}, - []string{"1.1.1.2:4000", "1.1.1.3:4000"}, - []string{ - "1.1.1.1:4000", - "1.1.1.2:4000", - "1.1.1.3:4000", - "1.1.1.4:4000", - "1.1.1.5:4000", - "1.1.1.6:4000", - "1.1.1.7:4000", - "1.1.1.8:4000", - "1.1.1.9:4000", - "1.1.1.10:4000"}, - []string{ - "1.1.1.1:4000", - "1.1.1.4:4000", - "1.1.1.5:4000", - "1.1.1.6:4000", - "1.1.1.7:4000", - "1.1.1.8:4000", - "1.1.1.9:4000", - "1.1.1.10:4000"}, - []string{"1.1.1.2:4000", "1.1.1.3:4000"}, - []*proto.Subtask{ - {ExecID: "1.1.1.2:4000"}, - {ExecID: "1.1.1.2:4000"}, - {ExecID: "1.1.1.2:4000"}, - {ExecID: "1.1.1.2:4000"}, - {ExecID: "1.1.1.2:4000"}, - {ExecID: "1.1.1.3:4000"}, - {ExecID: "1.1.1.3:4000"}, - {ExecID: "1.1.1.3:4000"}, - {ExecID: "1.1.1.3:4000"}, - {ExecID: "1.1.1.3:4000"}, - }, - }, - // 6. scale in from 1 node with 10 subtasks, 1 node with 1 subtasks to 1 node. - { - []*proto.Subtask{ - {ExecID: "1.1.1.1:4000"}, - {ExecID: "1.1.1.2:4000"}, - {ExecID: "1.1.1.2:4000"}, - {ExecID: "1.1.1.2:4000"}, - {ExecID: "1.1.1.2:4000"}, - {ExecID: "1.1.1.2:4000"}, - {ExecID: "1.1.1.2:4000"}, - {ExecID: "1.1.1.2:4000"}, - {ExecID: "1.1.1.2:4000"}, - {ExecID: "1.1.1.2:4000"}}, - []string{"1.1.1.1:4000"}, - []string{"1.1.1.1:4000", "1.1.1.2:4000"}, - []string{"1.1.1.2:4000"}, - []string{"1.1.1.1:4000"}, - []*proto.Subtask{ - {ExecID: "1.1.1.1:4000"}, - {ExecID: "1.1.1.1:4000"}, - {ExecID: "1.1.1.1:4000"}, - {ExecID: "1.1.1.1:4000"}, - {ExecID: "1.1.1.1:4000"}, - {ExecID: "1.1.1.1:4000"}, - {ExecID: "1.1.1.1:4000"}, - {ExecID: "1.1.1.1:4000"}, - {ExecID: "1.1.1.1:4000"}, - {ExecID: "1.1.1.1:4000"}}, - }, - } - for i, testCase := range testCases { - scaleTest(t, mockTaskMgr, testCase, i+1) - } -} - -func TestBalanceWithoutScale(t *testing.T) { - store := testkit.CreateMockStore(t) - gtk := testkit.NewTestKit(t, store) - pool := pools.NewResourcePool(func() (pools.Resource, error) { - return gtk.Session(), nil - }, 1, 1, time.Second) - defer pool.Close() - ctrl := gomock.NewController(t) - defer ctrl.Finish() - mockTaskMgr := mock.NewMockTaskManager(ctrl) - testCases := []balanceTestCase{ - // 1. from tidb1:1, tidb2:3 to tidb1:2, tidb2:2 - { - []*proto.Subtask{ - {ExecID: "1.1.1.1:4000"}, - {ExecID: "1.1.1.2:4000"}, - {ExecID: "1.1.1.2:4000"}, - {ExecID: "1.1.1.2:4000"}}, - []string{"1.1.1.1:4000", "1.1.1.2:4000"}, - []string{"1.1.1.1:4000", "1.1.1.2:4000"}, - []*proto.Subtask{ - {ExecID: "1.1.1.1:4000"}, - {ExecID: "1.1.1.1:4000"}, - {ExecID: "1.1.1.2:4000"}, - {ExecID: "1.1.1.2:4000"}}, - }, - // 2. from tidb1:3, tidb2:2, tidb3:1 to tidb1:2, tidb2:2, tidb3:2 - { - []*proto.Subtask{ - {ExecID: "1.1.1.1:4000"}, - {ExecID: "1.1.1.1:4000"}, - {ExecID: "1.1.1.1:4000"}, - {ExecID: "1.1.1.2:4000"}, - {ExecID: "1.1.1.2:4000"}, - {ExecID: "1.1.1.3:4000"}}, - []string{"1.1.1.1:4000", "1.1.1.2:4000", "1.1.1.3:4000"}, - []string{"1.1.1.1:4000", "1.1.1.2:4000", "1.1.1.3:4000"}, - []*proto.Subtask{ - {ExecID: "1.1.1.1:4000"}, - {ExecID: "1.1.1.1:4000"}, - {ExecID: "1.1.1.2:4000"}, - {ExecID: "1.1.1.2:4000"}, - {ExecID: "1.1.1.3:4000"}, - {ExecID: "1.1.1.3:4000"}}, - }, - // 3. from tidb1: 0, tidb2: 5 to tidb1: 3, tidb2: 2 - { - []*proto.Subtask{ - {ExecID: "1.1.1.2:4000"}, - {ExecID: "1.1.1.2:4000"}, - {ExecID: "1.1.1.2:4000"}, - {ExecID: "1.1.1.2:4000"}, - {ExecID: "1.1.1.2:4000"}}, - []string{"1.1.1.1:4000", "1.1.1.2:4000"}, - []string{"1.1.1.1:4000", "1.1.1.2:4000"}, - []*proto.Subtask{ - {ExecID: "1.1.1.1:4000"}, - {ExecID: "1.1.1.1:4000"}, - {ExecID: "1.1.1.1:4000"}, - {ExecID: "1.1.1.2:4000"}, - {ExecID: "1.1.1.2:4000"}}, - }, - // 4. from tidb1:5, tidb2:0, tidb3:0, tidb4:0, tidb5:0, tidb6:0 to 1,1,1,1,1,0. - { - []*proto.Subtask{ - {ExecID: "1.1.1.1:4000"}, - {ExecID: "1.1.1.1:4000"}, - {ExecID: "1.1.1.1:4000"}, - {ExecID: "1.1.1.1:4000"}, - {ExecID: "1.1.1.1:4000"}}, - []string{"1.1.1.1:4000", "1.1.1.2:4000", "1.1.1.3:4000", "1.1.1.4:4000", "1.1.1.5:4000", "1.1.1.6:4000"}, - []string{"1.1.1.1:4000", "1.1.1.2:4000", "1.1.1.3:4000", "1.1.1.4:4000", "1.1.1.5:4000", "1.1.1.6:4000"}, - []*proto.Subtask{ - {ExecID: "1.1.1.1:4000"}, - {ExecID: "1.1.1.2:4000"}, - {ExecID: "1.1.1.3:4000"}, - {ExecID: "1.1.1.4:4000"}, - {ExecID: "1.1.1.5:4000"}}, - }, - // 5. no balance needed. tidb1:2, tidb2:3 - { - []*proto.Subtask{ - {ExecID: "1.1.1.1:4000"}, - {ExecID: "1.1.1.1:4000"}, - {ExecID: "1.1.1.2:4000"}, - {ExecID: "1.1.1.2:4000"}, - {ExecID: "1.1.1.2:4000"}}, - []string{"1.1.1.1:4000", "1.1.1.2:4000"}, - []string{"1.1.1.1:4000", "1.1.1.2:4000"}, - []*proto.Subtask{ - {ExecID: "1.1.1.1:4000"}, - {ExecID: "1.1.1.1:4000"}, - {ExecID: "1.1.1.2:4000"}, - {ExecID: "1.1.1.2:4000"}, - {ExecID: "1.1.1.2:4000"}}, - }, - // 6. no balance needed. tidb1:2, tidb2:2 - { - []*proto.Subtask{ - {ExecID: "1.1.1.1:4000"}, - {ExecID: "1.1.1.1:4000"}, - {ExecID: "1.1.1.2:4000"}, - {ExecID: "1.1.1.2:4000"}}, - []string{"1.1.1.1:4000", "1.1.1.2:4000"}, - []string{"1.1.1.1:4000", "1.1.1.2:4000"}, - []*proto.Subtask{ - {ExecID: "1.1.1.1:4000"}, - {ExecID: "1.1.1.1:4000"}, - {ExecID: "1.1.1.2:4000"}, - {ExecID: "1.1.1.2:4000"}}, - }, - } - for i, testCase := range testCases { - balanceTest(t, mockTaskMgr, testCase, i+1) - } -} diff --git a/pkg/disttask/framework/scheduler/scheduler.go b/pkg/disttask/framework/scheduler/scheduler.go index a034b6a733746..cba85f9ab6979 100644 --- a/pkg/disttask/framework/scheduler/scheduler.go +++ b/pkg/disttask/framework/scheduler/scheduler.go @@ -18,6 +18,7 @@ import ( "context" "math/rand" "strings" + "sync/atomic" "time" "github.com/pingcap/errors" @@ -60,14 +61,6 @@ var ( RetrySQLMaxInterval = 30 * time.Second ) -// TaskHandle provides the interface for operations needed by Scheduler. -// Then we can use scheduler's function in Scheduler interface. -type TaskHandle interface { - // GetPreviousSubtaskMetas gets previous subtask metas. - GetPreviousSubtaskMetas(taskID int64, step proto.Step) ([][]byte, error) - storage.SessionExecutor -} - // Scheduler manages the lifetime of a task // including submitting subtasks and updating the status of a task. type Scheduler interface { @@ -79,22 +72,22 @@ type Scheduler interface { ScheduleTask() // Close closes the scheduler, should be called if Init returns nil. Close() + // GetTask returns the task that the scheduler is managing. + GetTask() *proto.Task + Extension } // BaseScheduler is the base struct for Scheduler. // each task type embed this struct and implement the Extension interface. type BaseScheduler struct { - ctx context.Context - taskMgr TaskManager - nodeMgr *NodeManager - Task *proto.Task - logCtx context.Context + ctx context.Context + Param + task atomic.Pointer[proto.Task] + logCtx context.Context // when RegisterSchedulerFactory, the factory MUST initialize this fields. Extension balanceSubtaskTick int - // TaskNodes stores the exec id of current task executor nodes. - TaskNodes []string // rand is for generating random selection of nodes. rand *rand.Rand } @@ -103,18 +96,17 @@ type BaseScheduler struct { var MockOwnerChange func() // NewBaseScheduler creates a new BaseScheduler. -func NewBaseScheduler(ctx context.Context, taskMgr TaskManager, nodeMgr *NodeManager, task *proto.Task) *BaseScheduler { +func NewBaseScheduler(ctx context.Context, task *proto.Task, param Param) *BaseScheduler { logCtx := logutil.WithFields(context.Background(), zap.Int64("task-id", task.ID), zap.Stringer("task-type", task.Type)) - return &BaseScheduler{ - ctx: ctx, - taskMgr: taskMgr, - nodeMgr: nodeMgr, - Task: task, - logCtx: logCtx, - TaskNodes: nil, - rand: rand.New(rand.NewSource(time.Now().UnixNano())), + s := &BaseScheduler{ + ctx: ctx, + Param: param, + logCtx: logCtx, + rand: rand.New(rand.NewSource(time.Now().UnixNano())), } + s.task.Store(task) + return s } // Init implements the Scheduler interface. @@ -124,8 +116,9 @@ func (*BaseScheduler) Init() error { // ScheduleTask implements the Scheduler interface. func (s *BaseScheduler) ScheduleTask() { + task := s.GetTask() logutil.Logger(s.logCtx).Info("schedule task", - zap.Stringer("state", s.Task.State), zap.Int("concurrency", s.Task.Concurrency)) + zap.Stringer("state", task.State), zap.Int("concurrency", task.Concurrency)) s.scheduleTask() } @@ -133,14 +126,25 @@ func (s *BaseScheduler) ScheduleTask() { func (*BaseScheduler) Close() { } +// GetTask implements the Scheduler interface. +func (s *BaseScheduler) GetTask() *proto.Task { + // Note: be careful when accessing state/step/meta/error of the task, they + // will be changed in scheduler, so there might be data race if they're accessed + // in other goroutines. Also clone them won't work, as we need read first. + // balancer is the only goroutine that accesses the task except scheduler now, + // and it only uses ID field, so it's safe. + return s.task.Load() +} + // refreshTask fetch task state from tidb_global_task table. func (s *BaseScheduler) refreshTask() error { - newTask, err := s.taskMgr.GetTaskByID(s.ctx, s.Task.ID) + task := s.GetTask() + newTask, err := s.taskMgr.GetTaskByID(s.ctx, task.ID) if err != nil { logutil.Logger(s.logCtx).Error("refresh task failed", zap.Error(err)) return err } - s.Task = newTask + s.task.Store(newTask) return nil } @@ -158,9 +162,10 @@ func (s *BaseScheduler) scheduleTask() { if err != nil { continue } + task := s.GetTask() failpoint.Inject("cancelTaskAfterRefreshTask", func(val failpoint.Value) { - if val.(bool) && s.Task.State == proto.TaskStateRunning { - err := s.taskMgr.CancelTask(s.ctx, s.Task.ID) + if val.(bool) && task.State == proto.TaskStateRunning { + err := s.taskMgr.CancelTask(s.ctx, task.ID) if err != nil { logutil.Logger(s.logCtx).Error("cancel task failed", zap.Error(err)) } @@ -168,26 +173,26 @@ func (s *BaseScheduler) scheduleTask() { }) failpoint.Inject("pausePendingTask", func(val failpoint.Value) { - if val.(bool) && s.Task.State == proto.TaskStatePending { - _, err := s.taskMgr.PauseTask(s.ctx, s.Task.Key) + if val.(bool) && task.State == proto.TaskStatePending { + _, err := s.taskMgr.PauseTask(s.ctx, task.Key) if err != nil { logutil.Logger(s.logCtx).Error("pause task failed", zap.Error(err)) } - s.Task.State = proto.TaskStatePausing + task.State = proto.TaskStatePausing } }) failpoint.Inject("pauseTaskAfterRefreshTask", func(val failpoint.Value) { - if val.(bool) && s.Task.State == proto.TaskStateRunning { - _, err := s.taskMgr.PauseTask(s.ctx, s.Task.Key) + if val.(bool) && task.State == proto.TaskStateRunning { + _, err := s.taskMgr.PauseTask(s.ctx, task.Key) if err != nil { logutil.Logger(s.logCtx).Error("pause task failed", zap.Error(err)) } - s.Task.State = proto.TaskStatePausing + task.State = proto.TaskStatePausing } }) - switch s.Task.State { + switch task.State { case proto.TaskStateCancelling: err = s.onCancelling() case proto.TaskStatePausing: @@ -208,7 +213,7 @@ func (s *BaseScheduler) scheduleTask() { err = s.onRunning() case proto.TaskStateSucceed, proto.TaskStateReverted, proto.TaskStateFailed: if err := s.onFinished(); err != nil { - logutil.Logger(s.logCtx).Error("schedule task meet error", zap.Stringer("state", s.Task.State), zap.Error(err)) + logutil.Logger(s.logCtx).Error("schedule task meet error", zap.Stringer("state", task.State), zap.Error(err)) } return } @@ -229,15 +234,17 @@ func (s *BaseScheduler) scheduleTask() { // handle task in cancelling state, schedule revert subtasks. func (s *BaseScheduler) onCancelling() error { - logutil.Logger(s.logCtx).Info("on cancelling state", zap.Stringer("state", s.Task.State), zap.Int64("step", int64(s.Task.Step))) + task := s.GetTask() + logutil.Logger(s.logCtx).Info("on cancelling state", zap.Stringer("state", task.State), zap.Int64("step", int64(task.Step))) errs := []error{errors.New(taskCancelMsg)} return s.onErrHandlingStage(errs) } // handle task in pausing state, cancel all running subtasks. func (s *BaseScheduler) onPausing() error { - logutil.Logger(s.logCtx).Info("on pausing state", zap.Stringer("state", s.Task.State), zap.Int64("step", int64(s.Task.Step))) - cntByStates, err := s.taskMgr.GetSubtaskCntGroupByStates(s.ctx, s.Task.ID, s.Task.Step) + task := s.GetTask() + logutil.Logger(s.logCtx).Info("on pausing state", zap.Stringer("state", task.State), zap.Int64("step", int64(task.Step))) + cntByStates, err := s.taskMgr.GetSubtaskCntGroupByStates(s.ctx, task.ID, task.Step) if err != nil { logutil.Logger(s.logCtx).Warn("check task failed", zap.Error(err)) return err @@ -245,9 +252,9 @@ func (s *BaseScheduler) onPausing() error { runningPendingCnt := cntByStates[proto.SubtaskStateRunning] + cntByStates[proto.SubtaskStatePending] if runningPendingCnt == 0 { logutil.Logger(s.logCtx).Info("all running subtasks paused, update the task to paused state") - return s.taskMgr.PausedTask(s.ctx, s.Task.ID) + return s.taskMgr.PausedTask(s.ctx, task.ID) } - logutil.Logger(s.logCtx).Debug("on pausing state, this task keeps current state", zap.Stringer("state", s.Task.State)) + logutil.Logger(s.logCtx).Debug("on pausing state, this task keeps current state", zap.Stringer("state", task.State)) return nil } @@ -256,10 +263,11 @@ var MockDMLExecutionOnPausedState func(task *proto.Task) // handle task in paused state. func (s *BaseScheduler) onPaused() error { - logutil.Logger(s.logCtx).Info("on paused state", zap.Stringer("state", s.Task.State), zap.Int64("step", int64(s.Task.Step))) + task := s.GetTask() + logutil.Logger(s.logCtx).Info("on paused state", zap.Stringer("state", task.State), zap.Int64("step", int64(task.Step))) failpoint.Inject("mockDMLExecutionOnPausedState", func(val failpoint.Value) { if val.(bool) { - MockDMLExecutionOnPausedState(s.Task) + MockDMLExecutionOnPausedState(task) } }) return nil @@ -270,8 +278,9 @@ var TestSyncChan = make(chan struct{}) // handle task in resuming state. func (s *BaseScheduler) onResuming() error { - logutil.Logger(s.logCtx).Info("on resuming state", zap.Stringer("state", s.Task.State), zap.Int64("step", int64(s.Task.Step))) - cntByStates, err := s.taskMgr.GetSubtaskCntGroupByStates(s.ctx, s.Task.ID, s.Task.Step) + task := s.GetTask() + logutil.Logger(s.logCtx).Info("on resuming state", zap.Stringer("state", task.State), zap.Int64("step", int64(task.Step))) + cntByStates, err := s.taskMgr.GetSubtaskCntGroupByStates(s.ctx, task.ID, task.Step) if err != nil { logutil.Logger(s.logCtx).Warn("check task failed", zap.Error(err)) return err @@ -286,50 +295,53 @@ func (s *BaseScheduler) onResuming() error { return err } - return s.taskMgr.ResumeSubtasks(s.ctx, s.Task.ID) + return s.taskMgr.ResumeSubtasks(s.ctx, task.ID) } // handle task in reverting state, check all revert subtasks finishes. func (s *BaseScheduler) onReverting() error { - logutil.Logger(s.logCtx).Debug("on reverting state", zap.Stringer("state", s.Task.State), zap.Int64("step", int64(s.Task.Step))) - cntByStates, err := s.taskMgr.GetSubtaskCntGroupByStates(s.ctx, s.Task.ID, s.Task.Step) + task := s.GetTask() + logutil.Logger(s.logCtx).Debug("on reverting state", zap.Stringer("state", task.State), zap.Int64("step", int64(task.Step))) + cntByStates, err := s.taskMgr.GetSubtaskCntGroupByStates(s.ctx, task.ID, task.Step) if err != nil { logutil.Logger(s.logCtx).Warn("check task failed", zap.Error(err)) return err } activeRevertCnt := cntByStates[proto.SubtaskStateRevertPending] + cntByStates[proto.SubtaskStateReverting] if activeRevertCnt == 0 { - if err = s.OnDone(s.ctx, s, s.Task); err != nil { + if err = s.OnDone(s.ctx, s, task); err != nil { return errors.Trace(err) } - return s.taskMgr.RevertedTask(s.ctx, s.Task.ID) + return s.taskMgr.RevertedTask(s.ctx, task.ID) } // Wait all subtasks in this step finishes. - s.OnTick(s.ctx, s.Task) - logutil.Logger(s.logCtx).Debug("on reverting state, this task keeps current state", zap.Stringer("state", s.Task.State)) + s.OnTick(s.ctx, task) + logutil.Logger(s.logCtx).Debug("on reverting state, this task keeps current state", zap.Stringer("state", task.State)) return nil } // handle task in pending state, schedule subtasks. func (s *BaseScheduler) onPending() error { - logutil.Logger(s.logCtx).Debug("on pending state", zap.Stringer("state", s.Task.State), zap.Int64("step", int64(s.Task.Step))) + task := s.GetTask() + logutil.Logger(s.logCtx).Debug("on pending state", zap.Stringer("state", task.State), zap.Int64("step", int64(task.Step))) return s.switch2NextStep() } // handle task in running state, check all running subtasks finishes. // If subtasks finished, run into the next step. func (s *BaseScheduler) onRunning() error { + task := s.GetTask() logutil.Logger(s.logCtx).Debug("on running state", - zap.Stringer("state", s.Task.State), - zap.Int64("step", int64(s.Task.Step))) + zap.Stringer("state", task.State), + zap.Int64("step", int64(task.Step))) // check current step finishes. - cntByStates, err := s.taskMgr.GetSubtaskCntGroupByStates(s.ctx, s.Task.ID, s.Task.Step) + cntByStates, err := s.taskMgr.GetSubtaskCntGroupByStates(s.ctx, task.ID, task.Step) if err != nil { logutil.Logger(s.logCtx).Warn("check task failed", zap.Error(err)) return err } if cntByStates[proto.SubtaskStateFailed] > 0 || cntByStates[proto.SubtaskStateCanceled] > 0 { - subTaskErrs, err := s.taskMgr.CollectSubTaskError(s.ctx, s.Task.ID) + subTaskErrs, err := s.taskMgr.CollectSubTaskError(s.ctx, task.ID) if err != nil { logutil.Logger(s.logCtx).Warn("collect subtask error failed", zap.Error(err)) return err @@ -342,148 +354,24 @@ func (s *BaseScheduler) onRunning() error { return s.switch2NextStep() } - if err := s.balanceSubtasks(); err != nil { - return err - } // Wait all subtasks in this step finishes. - s.OnTick(s.ctx, s.Task) - logutil.Logger(s.logCtx).Debug("on running state, this task keeps current state", zap.Stringer("state", s.Task.State)) + s.OnTick(s.ctx, task) + logutil.Logger(s.logCtx).Debug("on running state, this task keeps current state", zap.Stringer("state", task.State)) return nil } func (s *BaseScheduler) onFinished() error { - metrics.UpdateMetricsForFinishTask(s.Task) - logutil.Logger(s.logCtx).Debug("schedule task, task is finished", zap.Stringer("state", s.Task.State)) - return s.taskMgr.TransferSubTasks2History(s.ctx, s.Task.ID) -} - -// balanceSubtasks check the liveNode num every liveNodeFetchInterval then rebalance subtasks. -func (s *BaseScheduler) balanceSubtasks() error { - if len(s.TaskNodes) == 0 { - var err error - s.TaskNodes, err = s.taskMgr.GetTaskExecutorIDsByTaskIDAndStep(s.ctx, s.Task.ID, s.Task.Step) - if err != nil { - return err - } - } - s.balanceSubtaskTick++ - if s.balanceSubtaskTick == defaultBalanceSubtaskTicks { - s.balanceSubtaskTick = 0 - eligibleNodes, err := s.getEligibleNodes() - if err != nil { - return err - } - if len(eligibleNodes) > 0 { - return s.doBalanceSubtasks(eligibleNodes) - } - } - return nil -} - -// DoBalanceSubtasks make count of subtasks on each liveNodes balanced and clean up subtasks on dead nodes. -// TODO(ywqzzy): refine to make it easier for testing. -func (s *BaseScheduler) doBalanceSubtasks(eligibleNodes []string) error { - eligibleNodeMap := make(map[string]struct{}, len(eligibleNodes)) - for _, n := range eligibleNodes { - eligibleNodeMap[n] = struct{}{} - } - // 1. find out nodes need to clean subtasks. - deadNodes := make([]string, 0) - deadNodesMap := make(map[string]bool, 0) - for _, node := range s.TaskNodes { - if _, ok := eligibleNodeMap[node]; !ok { - deadNodes = append(deadNodes, node) - deadNodesMap[node] = true - } - } - // 2. get subtasks for each node before rebalance. - subtasks, err := s.taskMgr.GetSubtasksByStepAndState(s.ctx, s.Task.ID, s.Task.Step, proto.TaskStatePending) - if err != nil { - return err - } - if len(deadNodes) != 0 { - /// get subtask from deadNodes, since there might be some running subtasks on deadNodes. - /// In this case, all subtasks on deadNodes are in running/pending state. - subtasksOnDeadNodes, err := s.taskMgr.GetSubtasksByExecIdsAndStepAndState( - s.ctx, - deadNodes, - s.Task.ID, - s.Task.Step, - proto.SubtaskStateRunning) - if err != nil { - return err - } - subtasks = append(subtasks, subtasksOnDeadNodes...) - } - // 3. group subtasks for each task executor. - subtasksOnTaskExecutor := make(map[string][]*proto.Subtask, len(eligibleNodes)+len(deadNodes)) - for _, node := range eligibleNodes { - subtasksOnTaskExecutor[node] = make([]*proto.Subtask, 0) - } - for _, subtask := range subtasks { - subtasksOnTaskExecutor[subtask.ExecID] = append( - subtasksOnTaskExecutor[subtask.ExecID], - subtask) - } - // 4. prepare subtasks that need to rebalance to other nodes. - averageSubtaskCnt := len(subtasks) / len(eligibleNodes) - rebalanceSubtasks := make([]*proto.Subtask, 0) - for k, v := range subtasksOnTaskExecutor { - if ok := deadNodesMap[k]; ok { - rebalanceSubtasks = append(rebalanceSubtasks, v...) - continue - } - // When no tidb scale-in/out and averageSubtaskCnt*len(eligibleNodes) < len(subtasks), - // no need to send subtask to other nodes. - // eg: tidb1 with 3 subtasks, tidb2 with 2 subtasks, subtasks are balanced now. - if averageSubtaskCnt*len(eligibleNodes) < len(subtasks) && len(s.TaskNodes) == len(eligibleNodes) { - if len(v) > averageSubtaskCnt+1 { - rebalanceSubtasks = append(rebalanceSubtasks, v[0:len(v)-averageSubtaskCnt]...) - } - continue - } - if len(v) > averageSubtaskCnt { - rebalanceSubtasks = append(rebalanceSubtasks, v[0:len(v)-averageSubtaskCnt]...) - } - } - // 5. skip rebalance. - if len(rebalanceSubtasks) == 0 { - return nil - } - // 6.rebalance subtasks to other nodes. - rebalanceIdx := 0 - for k, v := range subtasksOnTaskExecutor { - if ok := deadNodesMap[k]; !ok { - if len(v) < averageSubtaskCnt { - for i := 0; i < averageSubtaskCnt-len(v) && rebalanceIdx < len(rebalanceSubtasks); i++ { - rebalanceSubtasks[rebalanceIdx].ExecID = k - rebalanceIdx++ - } - } - } - } - // 7. rebalance rest subtasks evenly to liveNodes. - liveNodeIdx := 0 - for rebalanceIdx < len(rebalanceSubtasks) { - rebalanceSubtasks[rebalanceIdx].ExecID = eligibleNodes[liveNodeIdx] - rebalanceIdx++ - liveNodeIdx++ - } - - // 8. update subtasks and do clean up logic. - if err = s.taskMgr.UpdateSubtasksExecIDs(s.ctx, s.Task.ID, subtasks); err != nil { - return err - } - logutil.Logger(s.logCtx).Info("balance subtasks", - zap.Stringers("subtasks-rebalanced", subtasks)) - s.TaskNodes = append([]string{}, eligibleNodes...) - return nil + task := s.GetTask() + metrics.UpdateMetricsForFinishTask(task) + logutil.Logger(s.logCtx).Debug("schedule task, task is finished", zap.Stringer("state", task.State)) + return s.taskMgr.TransferSubTasks2History(s.ctx, task.ID) } // updateTask update the task in tidb_global_task table. func (s *BaseScheduler) updateTask(taskState proto.TaskState, newSubTasks []*proto.Subtask, retryTimes int) (err error) { - prevState := s.Task.State - s.Task.State = taskState + task := s.GetTask() + prevState := task.State + task.State = taskState logutil.BgLogger().Info("task state transform", zap.Stringer("from", prevState), zap.Stringer("to", taskState)) if !VerifyTaskStateTransform(prevState, taskState) { return errors.Errorf("invalid task state transform, from %s to %s", prevState, taskState) @@ -491,7 +379,7 @@ func (s *BaseScheduler) updateTask(taskState proto.TaskState, newSubTasks []*pro var retryable bool for i := 0; i < retryTimes; i++ { - retryable, err = s.taskMgr.UpdateTaskAndAddSubTasks(s.ctx, s.Task, newSubTasks, prevState) + retryable, err = s.taskMgr.UpdateTaskAndAddSubTasks(s.ctx, task, newSubTasks, prevState) if err == nil || !retryable { break } @@ -499,26 +387,27 @@ func (s *BaseScheduler) updateTask(taskState proto.TaskState, newSubTasks []*pro return err1 } if i%10 == 0 { - logutil.Logger(s.logCtx).Warn("updateTask first failed", zap.Stringer("from", prevState), zap.Stringer("to", s.Task.State), + logutil.Logger(s.logCtx).Warn("updateTask first failed", zap.Stringer("from", prevState), zap.Stringer("to", task.State), zap.Int("retry times", i), zap.Error(err)) } time.Sleep(RetrySQLInterval) } if err != nil && retryTimes != nonRetrySQLTime { logutil.Logger(s.logCtx).Warn("updateTask failed", - zap.Stringer("from", prevState), zap.Stringer("to", s.Task.State), zap.Int("retry times", retryTimes), zap.Error(err)) + zap.Stringer("from", prevState), zap.Stringer("to", task.State), zap.Int("retry times", retryTimes), zap.Error(err)) } return err } func (s *BaseScheduler) onErrHandlingStage(receiveErrs []error) error { + task := s.GetTask() // we only store the first error. - s.Task.Error = receiveErrs[0] + task.Error = receiveErrs[0] var subTasks []*proto.Subtask // when step of task is `StepInit`, no need to do revert - if s.Task.Step != proto.StepInit { - instanceIDs, err := s.GetAllTaskExecutorIDs(s.ctx, s.Task) + if task.Step != proto.StepInit { + instanceIDs, err := s.GetAllTaskExecutorIDs(s.ctx, task) if err != nil { logutil.Logger(s.logCtx).Warn("get task's all instances failed", zap.Error(err)) return err @@ -528,86 +417,80 @@ func (s *BaseScheduler) onErrHandlingStage(receiveErrs []error) error { for _, id := range instanceIDs { // reverting subtasks belong to the same step as current active step. subTasks = append(subTasks, proto.NewSubtask( - s.Task.Step, s.Task.ID, s.Task.Type, id, - s.Task.Concurrency, proto.EmptyMeta, 0)) + task.Step, task.ID, task.Type, id, + task.Concurrency, proto.EmptyMeta, 0)) } } return s.updateTask(proto.TaskStateReverting, subTasks, RetrySQLTimes) } func (s *BaseScheduler) switch2NextStep() (err error) { - nextStep := s.GetNextStep(s.Task) + task := s.GetTask() + nextStep := s.GetNextStep(task) logutil.Logger(s.logCtx).Info("on next step", - zap.Int64("current-step", int64(s.Task.Step)), + zap.Int64("current-step", int64(task.Step)), zap.Int64("next-step", int64(nextStep))) if nextStep == proto.StepDone { - s.Task.Step = nextStep - s.Task.StateUpdateTime = time.Now().UTC() - if err = s.OnDone(s.ctx, s, s.Task); err != nil { + task.Step = nextStep + task.StateUpdateTime = time.Now().UTC() + if err = s.OnDone(s.ctx, s, task); err != nil { return errors.Trace(err) } - return s.taskMgr.SucceedTask(s.ctx, s.Task.ID) + return s.taskMgr.SucceedTask(s.ctx, task.ID) } - serverNodes, err := s.getEligibleNodes() + eligibleNodes, err := getEligibleNodes(s.ctx, s, s.nodeMgr.getManagedNodes()) if err != nil { return err } - logutil.Logger(s.logCtx).Info("eligible instances", zap.Int("num", len(serverNodes))) - if len(serverNodes) == 0 { + logutil.Logger(s.logCtx).Info("eligible instances", zap.Int("num", len(eligibleNodes))) + if len(eligibleNodes) == 0 { return errors.New("no available TiDB node to dispatch subtasks") } - metas, err := s.OnNextSubtasksBatch(s.ctx, s, s.Task, serverNodes, nextStep) + metas, err := s.OnNextSubtasksBatch(s.ctx, s, task, eligibleNodes, nextStep) if err != nil { logutil.Logger(s.logCtx).Warn("generate part of subtasks failed", zap.Error(err)) return s.handlePlanErr(err) } - return s.scheduleSubTask(nextStep, metas, serverNodes) -} - -// getEligibleNodes returns the eligible(live) nodes for the task. -// if the task can only be scheduled to some specific nodes, return them directly, -// we don't care liveliness of them. -func (s *BaseScheduler) getEligibleNodes() ([]string, error) { - serverNodes, err := s.GetEligibleInstances(s.ctx, s.Task) - if err != nil { - return nil, err - } - logutil.Logger(s.logCtx).Debug("eligible instances", zap.Int("num", len(serverNodes))) - if len(serverNodes) == 0 { - serverNodes = append([]string{}, s.nodeMgr.getManagedNodes()...) - } - return serverNodes, nil + return s.scheduleSubTask(nextStep, metas, eligibleNodes) } func (s *BaseScheduler) scheduleSubTask( subtaskStep proto.Step, metas [][]byte, - serverNodes []string) error { + eligibleNodes []string) error { + task := s.GetTask() logutil.Logger(s.logCtx).Info("schedule subtasks", - zap.Stringer("state", s.Task.State), - zap.Int64("step", int64(s.Task.Step)), - zap.Int("concurrency", s.Task.Concurrency), + zap.Stringer("state", task.State), + zap.Int64("step", int64(task.Step)), + zap.Int("concurrency", task.Concurrency), zap.Int("subtasks", len(metas))) - s.TaskNodes = serverNodes + + // the scheduled node of the subtask might not be optimal, as we run all + // scheduler in parallel, and update might be called too many times when + // multiple tasks are switching to next step. + if err := s.slotMgr.update(s.ctx, s.nodeMgr, s.taskMgr); err != nil { + return err + } + adjustedEligibleNodes := s.slotMgr.adjustEligibleNodes(eligibleNodes, task.Concurrency) var size uint64 subTasks := make([]*proto.Subtask, 0, len(metas)) for i, meta := range metas { // we assign the subtask to the instance in a round-robin way. // TODO: assign the subtask to the instance according to the system load of each nodes - pos := i % len(serverNodes) - instanceID := serverNodes[pos] + pos := i % len(adjustedEligibleNodes) + instanceID := adjustedEligibleNodes[pos] logutil.Logger(s.logCtx).Debug("create subtasks", zap.String("instanceID", instanceID)) subTasks = append(subTasks, proto.NewSubtask( - subtaskStep, s.Task.ID, s.Task.Type, instanceID, s.Task.Concurrency, meta, i+1)) + subtaskStep, task.ID, task.Type, instanceID, task.Concurrency, meta, i+1)) size += uint64(len(meta)) } failpoint.Inject("cancelBeforeUpdateTask", func() { - _ = s.taskMgr.CancelTask(s.ctx, s.Task.ID) + _ = s.taskMgr.CancelTask(s.ctx, task.ID) }) // as other fields and generated key and index KV takes space too, we limit @@ -626,7 +509,7 @@ func (s *BaseScheduler) scheduleSubTask( backoffer := backoff.NewExponential(RetrySQLInterval, 2, RetrySQLMaxInterval) return handle.RunWithRetry(s.ctx, RetrySQLTimes, backoffer, logutil.Logger(s.logCtx), func(ctx context.Context) (bool, error) { - err := fn(s.ctx, s.Task, proto.TaskStateRunning, subtaskStep, subTasks) + err := fn(s.ctx, task, proto.TaskStateRunning, subtaskStep, subTasks) if errors.Cause(err) == storage.ErrUnstableSubtasks { return false, err } @@ -636,16 +519,18 @@ func (s *BaseScheduler) scheduleSubTask( } func (s *BaseScheduler) handlePlanErr(err error) error { - logutil.Logger(s.logCtx).Warn("generate plan failed", zap.Error(err), zap.Stringer("state", s.Task.State)) + task := s.GetTask() + logutil.Logger(s.logCtx).Warn("generate plan failed", zap.Error(err), zap.Stringer("state", task.State)) if s.IsRetryableErr(err) { return err } - s.Task.Error = err - if err = s.OnDone(s.ctx, s, s.Task); err != nil { + task.Error = err + + if err = s.OnDone(s.ctx, s, task); err != nil { return errors.Trace(err) } - return s.taskMgr.FailTask(s.ctx, s.Task.ID, s.Task.State, s.Task.Error) + return s.taskMgr.FailTask(s.ctx, task.ID, task.State, task.Error) } // MockServerInfo exported for scheduler_test.go @@ -735,3 +620,18 @@ func (*BaseScheduler) isStepSucceed(cntByStates map[proto.SubtaskState]int64) bo func IsCancelledErr(err error) bool { return strings.Contains(err.Error(), taskCancelMsg) } + +// getEligibleNodes returns the eligible(live) nodes for the task. +// if the task can only be scheduled to some specific nodes, return them directly, +// we don't care liveliness of them. +func getEligibleNodes(ctx context.Context, sch Scheduler, managedNodes []string) ([]string, error) { + serverNodes, err := sch.GetEligibleInstances(ctx, sch.GetTask()) + if err != nil { + return nil, err + } + logutil.BgLogger().Debug("eligible instances", zap.Int("num", len(serverNodes))) + if len(serverNodes) == 0 { + serverNodes = managedNodes + } + return serverNodes, nil +} diff --git a/pkg/disttask/framework/scheduler/scheduler_manager.go b/pkg/disttask/framework/scheduler/scheduler_manager.go index b95292895eabf..4bdba0f0282b2 100644 --- a/pkg/disttask/framework/scheduler/scheduler_manager.go +++ b/pkg/disttask/framework/scheduler/scheduler_manager.go @@ -16,6 +16,7 @@ package scheduler import ( "context" + "slices" "time" "github.com/pingcap/errors" @@ -46,32 +47,52 @@ var WaitTaskFinished = make(chan struct{}) func (sm *Manager) getSchedulerCount() int { sm.mu.RLock() defer sm.mu.RUnlock() - return len(sm.mu.schedulers) + return len(sm.mu.schedulerMap) } func (sm *Manager) addScheduler(taskID int64, scheduler Scheduler) { sm.mu.Lock() defer sm.mu.Unlock() - sm.mu.schedulers[taskID] = scheduler + sm.mu.schedulerMap[taskID] = scheduler + sm.mu.schedulers = append(sm.mu.schedulers, scheduler) + slices.SortFunc(sm.mu.schedulers, func(i, j Scheduler) int { + return i.GetTask().Compare(j.GetTask()) + }) } func (sm *Manager) hasScheduler(taskID int64) bool { sm.mu.Lock() defer sm.mu.Unlock() - _, ok := sm.mu.schedulers[taskID] + _, ok := sm.mu.schedulerMap[taskID] return ok } func (sm *Manager) delScheduler(taskID int64) { sm.mu.Lock() defer sm.mu.Unlock() - delete(sm.mu.schedulers, taskID) + delete(sm.mu.schedulerMap, taskID) + for i, scheduler := range sm.mu.schedulers { + if scheduler.GetTask().ID == taskID { + sm.mu.schedulers = append(sm.mu.schedulers[:i], sm.mu.schedulers[i+1:]...) + break + } + } } func (sm *Manager) clearSchedulers() { sm.mu.Lock() defer sm.mu.Unlock() - sm.mu.schedulers = make(map[int64]Scheduler) + sm.mu.schedulerMap = make(map[int64]Scheduler) + sm.mu.schedulers = sm.mu.schedulers[:0] +} + +// getSchedulers returns a copy of schedulers. +func (sm *Manager) getSchedulers() []Scheduler { + sm.mu.RLock() + defer sm.mu.RUnlock() + res := make([]Scheduler, len(sm.mu.schedulers)) + copy(res, sm.mu.schedulers) + return res } // Manager manage a bunch of schedulers. @@ -83,8 +104,9 @@ type Manager struct { taskMgr TaskManager wg tidbutil.WaitGroupWrapper gPool *spool.Pool - slotMgr *slotManager + slotMgr *SlotManager nodeMgr *NodeManager + balancer *balancer initialized bool // serverID, it's value is ip:port now. serverID string @@ -93,7 +115,9 @@ type Manager struct { mu struct { syncutil.RWMutex - schedulers map[int64]Scheduler + schedulerMap map[int64]Scheduler + // in task order + schedulers []Scheduler } } @@ -111,8 +135,13 @@ func NewManager(ctx context.Context, taskMgr TaskManager, serverID string) (*Man } schedulerManager.gPool = gPool schedulerManager.ctx, schedulerManager.cancel = context.WithCancel(ctx) - schedulerManager.mu.schedulers = make(map[int64]Scheduler) + schedulerManager.mu.schedulerMap = make(map[int64]Scheduler) schedulerManager.finishCh = make(chan struct{}, proto.MaxConcurrentTask) + schedulerManager.balancer = newBalancer(Param{ + taskMgr: taskMgr, + nodeMgr: schedulerManager.nodeMgr, + slotMgr: schedulerManager.slotMgr, + }) return schedulerManager, nil } @@ -134,6 +163,9 @@ func (sm *Manager) Start() { sm.wg.Run(func() { sm.nodeMgr.refreshManagedNodesLoop(sm.ctx, sm.taskMgr, sm.slotMgr) }) + sm.wg.Run(func() { + sm.balancer.balanceLoop(sm.ctx, sm) + }) sm.initialized = true } @@ -152,7 +184,7 @@ func (sm *Manager) Initialized() bool { return sm.initialized } -// scheduleTaskLoop schedulees the tasks. +// scheduleTaskLoop schedules the tasks. func (sm *Manager) scheduleTaskLoop() { logutil.BgLogger().Info("schedule task loop start") ticker := time.NewTicker(checkTaskRunningInterval) @@ -200,7 +232,7 @@ func (sm *Manager) scheduleTaskLoop() { continue } - if err = sm.slotMgr.update(sm.ctx, sm.taskMgr); err != nil { + if err = sm.slotMgr.update(sm.ctx, sm.nodeMgr, sm.taskMgr); err != nil { logutil.BgLogger().Warn("update used slot failed", zap.Error(err)) continue } @@ -265,7 +297,11 @@ func (sm *Manager) startScheduler(basicTask *proto.Task, reservedExecID string) } schedulerFactory := getSchedulerFactory(task.Type) - scheduler := schedulerFactory(sm.ctx, sm.taskMgr, sm.nodeMgr, task) + scheduler := schedulerFactory(sm.ctx, task, Param{ + taskMgr: sm.taskMgr, + nodeMgr: sm.nodeMgr, + slotMgr: sm.slotMgr, + }) if err = scheduler.Init(); err != nil { logutil.BgLogger().Error("init scheduler failed", zap.Error(err)) sm.failTask(task.ID, task.State, err) @@ -365,5 +401,9 @@ func (sm *Manager) cleanUpFinishedTasks(tasks []*proto.Task) error { // MockScheduler mock one scheduler for one task, only used for tests. func (sm *Manager) MockScheduler(task *proto.Task) *BaseScheduler { - return NewBaseScheduler(sm.ctx, sm.taskMgr, sm.nodeMgr, task) + return NewBaseScheduler(sm.ctx, task, Param{ + taskMgr: sm.taskMgr, + nodeMgr: sm.nodeMgr, + slotMgr: sm.slotMgr, + }) } diff --git a/pkg/disttask/framework/scheduler/scheduler_manager_test.go b/pkg/disttask/framework/scheduler/scheduler_manager_test.go index c54b2bd462f07..1ebc2ec541c54 100644 --- a/pkg/disttask/framework/scheduler/scheduler_manager_test.go +++ b/pkg/disttask/framework/scheduler/scheduler_manager_test.go @@ -55,7 +55,7 @@ func TestCleanUpRoutine(t *testing.T) { tasks, err = mgr.GetTasksInStates(ctx, proto.TaskStateRunning) require.NoError(t, err) return len(tasks) == 1 - }, time.Second, 50*time.Millisecond) + }, 5*time.Second, 50*time.Millisecond) return tasks } @@ -64,7 +64,7 @@ func TestCleanUpRoutine(t *testing.T) { cntByStates, err := mgr.GetSubtaskCntGroupByStates(ctx, taskID, proto.StepOne) require.NoError(t, err) return int64(subtaskCnt) == cntByStates[proto.SubtaskStatePending] - }, time.Second, 50*time.Millisecond) + }, 5*time.Second, 50*time.Millisecond) } tasks := checkTaskRunningCnt() @@ -78,5 +78,5 @@ func TestCleanUpRoutine(t *testing.T) { tasks, err := mgr.GetTasksFromHistoryInStates(ctx, proto.TaskStateSucceed) require.NoError(t, err) return len(tasks) != 0 - }, time.Second*10, time.Millisecond*300) + }, 5*time.Second*10, time.Millisecond*300) } diff --git a/pkg/disttask/framework/scheduler/scheduler_nokit_test.go b/pkg/disttask/framework/scheduler/scheduler_nokit_test.go index e6320ded9ceec..534babbaafe7d 100644 --- a/pkg/disttask/framework/scheduler/scheduler_nokit_test.go +++ b/pkg/disttask/framework/scheduler/scheduler_nokit_test.go @@ -15,12 +15,191 @@ package scheduler import ( + "context" "testing" + "time" + "github.com/pingcap/errors" + "github.com/pingcap/tidb/pkg/config" + "github.com/pingcap/tidb/pkg/disttask/framework/mock" "github.com/pingcap/tidb/pkg/disttask/framework/proto" + mockDispatch "github.com/pingcap/tidb/pkg/disttask/framework/scheduler/mock" + "github.com/pingcap/tidb/pkg/disttask/framework/storage" + "github.com/pingcap/tidb/pkg/kv" "github.com/stretchr/testify/require" + "github.com/tikv/client-go/v2/util" + "go.uber.org/mock/gomock" ) +func TestDispatcherOnNextStage(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + taskMgr := mock.NewMockTaskManager(ctrl) + schExt := mockDispatch.NewMockExtension(ctrl) + + ctx := context.Background() + ctx = util.WithInternalSourceType(ctx, "dispatcher") + task := proto.Task{ + ID: 1, + State: proto.TaskStatePending, + Step: proto.StepInit, + } + cloneTask := task + nodeMgr := NewNodeManager() + sch := NewBaseScheduler(ctx, &cloneTask, Param{ + taskMgr: taskMgr, + nodeMgr: nodeMgr, + slotMgr: newSlotManager(), + }) + sch.Extension = schExt + + // test next step is done + schExt.EXPECT().GetNextStep(gomock.Any()).Return(proto.StepDone) + schExt.EXPECT().OnDone(gomock.Any(), gomock.Any(), gomock.Any()).Return(errors.New("done err")) + require.ErrorContains(t, sch.Switch2NextStep(), "done err") + require.True(t, ctrl.Satisfied()) + // we update task step before OnDone + require.Equal(t, proto.StepDone, sch.GetTask().Step) + taskClone2 := task + sch.task.Store(&taskClone2) + schExt.EXPECT().GetNextStep(gomock.Any()).Return(proto.StepDone) + schExt.EXPECT().OnDone(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil) + taskMgr.EXPECT().SucceedTask(gomock.Any(), gomock.Any()).Return(nil) + require.NoError(t, sch.Switch2NextStep()) + require.True(t, ctrl.Satisfied()) + + // GetEligibleInstances err + schExt.EXPECT().GetNextStep(gomock.Any()).Return(proto.StepOne) + schExt.EXPECT().GetEligibleInstances(gomock.Any(), gomock.Any()).Return(nil, errors.New("GetEligibleInstances err")) + require.ErrorContains(t, sch.Switch2NextStep(), "GetEligibleInstances err") + require.True(t, ctrl.Satisfied()) + // GetEligibleInstances no instance + schExt.EXPECT().GetNextStep(gomock.Any()).Return(proto.StepOne) + schExt.EXPECT().GetEligibleInstances(gomock.Any(), gomock.Any()).Return(nil, nil) + require.ErrorContains(t, sch.Switch2NextStep(), "no available TiDB node to dispatch subtasks") + require.True(t, ctrl.Satisfied()) + + serverNodes := []string{":4000"} + // OnNextSubtasksBatch err + schExt.EXPECT().GetNextStep(gomock.Any()).Return(proto.StepOne) + schExt.EXPECT().GetEligibleInstances(gomock.Any(), gomock.Any()).Return(serverNodes, nil) + schExt.EXPECT().OnNextSubtasksBatch(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + Return(nil, errors.New("OnNextSubtasksBatch err")) + schExt.EXPECT().IsRetryableErr(gomock.Any()).Return(true) + require.ErrorContains(t, sch.Switch2NextStep(), "OnNextSubtasksBatch err") + require.True(t, ctrl.Satisfied()) + + bak := kv.TxnTotalSizeLimit.Load() + t.Cleanup(func() { + kv.TxnTotalSizeLimit.Store(bak) + }) + + // dispatch in batch + subtaskMetas := [][]byte{ + []byte(`{"xx": "1"}`), + []byte(`{"xx": "2"}`), + } + schExt.EXPECT().GetNextStep(gomock.Any()).Return(proto.StepOne) + schExt.EXPECT().GetEligibleInstances(gomock.Any(), gomock.Any()).Return(serverNodes, nil) + schExt.EXPECT().OnNextSubtasksBatch(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + Return(subtaskMetas, nil) + taskMgr.EXPECT().GetUsedSlotsOnNodes(gomock.Any()).Return(nil, nil) + taskMgr.EXPECT().SwitchTaskStepInBatch(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil) + kv.TxnTotalSizeLimit.Store(1) + require.NoError(t, sch.Switch2NextStep()) + require.True(t, ctrl.Satisfied()) + // met unstable subtasks + schExt.EXPECT().GetNextStep(gomock.Any()).Return(proto.StepOne) + schExt.EXPECT().GetEligibleInstances(gomock.Any(), gomock.Any()).Return(serverNodes, nil) + schExt.EXPECT().OnNextSubtasksBatch(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + Return(subtaskMetas, nil) + taskMgr.EXPECT().GetUsedSlotsOnNodes(gomock.Any()).Return(nil, nil) + taskMgr.EXPECT().SwitchTaskStepInBatch(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + Return(errors.Annotatef(storage.ErrUnstableSubtasks, "expected %d, got %d", + 2, 100)) + kv.TxnTotalSizeLimit.Store(1) + startTime := time.Now() + err := sch.Switch2NextStep() + require.ErrorIs(t, err, storage.ErrUnstableSubtasks) + require.ErrorContains(t, err, "expected 2, got 100") + require.WithinDuration(t, startTime, time.Now(), 10*time.Second) + require.True(t, ctrl.Satisfied()) + + // dispatch in one txn + schExt.EXPECT().GetNextStep(gomock.Any()).Return(proto.StepOne) + schExt.EXPECT().GetEligibleInstances(gomock.Any(), gomock.Any()).Return(serverNodes, nil) + schExt.EXPECT().OnNextSubtasksBatch(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + Return(subtaskMetas, nil) + taskMgr.EXPECT().GetUsedSlotsOnNodes(gomock.Any()).Return(nil, nil) + taskMgr.EXPECT().SwitchTaskStep(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil) + kv.TxnTotalSizeLimit.Store(config.DefTxnTotalSizeLimit) + require.NoError(t, sch.Switch2NextStep()) + require.True(t, ctrl.Satisfied()) +} + +func TestManagerSchedulersOrdered(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + mgr, err := NewManager(context.Background(), nil, "1") + require.NoError(t, err) + for i := 1; i <= 5; i++ { + task := &proto.Task{ + ID: int64(i * 10), + } + mockScheduler := mock.NewMockScheduler(ctrl) + mockScheduler.EXPECT().GetTask().Return(task).AnyTimes() + mgr.addScheduler(task.ID, mockScheduler) + } + ordered := func(schedulers []Scheduler) bool { + for i := 1; i < len(schedulers); i++ { + if schedulers[i-1].GetTask().Compare(schedulers[i].GetTask()) >= 0 { + return false + } + } + return true + } + require.Len(t, mgr.getSchedulers(), 5) + require.True(t, ordered(mgr.getSchedulers())) + + task35 := &proto.Task{ + ID: int64(35), + } + mockScheduler35 := mock.NewMockScheduler(ctrl) + mockScheduler35.EXPECT().GetTask().Return(task35).AnyTimes() + + mgr.delScheduler(30) + require.False(t, mgr.hasScheduler(30)) + mgr.addScheduler(task35.ID, mockScheduler35) + require.True(t, mgr.hasScheduler(35)) + require.Len(t, mgr.getSchedulers(), 5) + require.True(t, ordered(mgr.getSchedulers())) +} + +func TestGetEligibleNodes(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + ctx := context.Background() + mockSch := mock.NewMockScheduler(ctrl) + mockSch.EXPECT().GetTask().Return(&proto.Task{ID: 1}).AnyTimes() + + mockSch.EXPECT().GetEligibleInstances(gomock.Any(), gomock.Any()).Return(nil, errors.New("mock err")) + _, err := getEligibleNodes(ctx, mockSch, []string{":4000"}) + require.ErrorContains(t, err, "mock err") + require.True(t, ctrl.Satisfied()) + + mockSch.EXPECT().GetEligibleInstances(gomock.Any(), gomock.Any()).Return([]string{":4000"}, nil) + nodes, err := getEligibleNodes(ctx, mockSch, []string{":4000", ":4001"}) + require.NoError(t, err) + require.Equal(t, []string{":4000"}, nodes) + require.True(t, ctrl.Satisfied()) + + mockSch.EXPECT().GetEligibleInstances(gomock.Any(), gomock.Any()).Return(nil, nil) + nodes, err = getEligibleNodes(ctx, mockSch, []string{":4000", ":4001"}) + require.NoError(t, err) + require.Equal(t, []string{":4000", ":4001"}, nodes) + require.True(t, ctrl.Satisfied()) +} + func TestSchedulerIsStepSucceed(t *testing.T) { s := &BaseScheduler{} require.True(t, s.isStepSucceed(nil)) diff --git a/pkg/disttask/framework/scheduler/scheduler_test.go b/pkg/disttask/framework/scheduler/scheduler_test.go index 15ca02692ce3b..14749d92f1634 100644 --- a/pkg/disttask/framework/scheduler/scheduler_test.go +++ b/pkg/disttask/framework/scheduler/scheduler_test.go @@ -26,7 +26,6 @@ import ( "github.com/ngaut/pools" "github.com/pingcap/errors" "github.com/pingcap/failpoint" - "github.com/pingcap/tidb/pkg/config" "github.com/pingcap/tidb/pkg/disttask/framework/mock" "github.com/pingcap/tidb/pkg/disttask/framework/proto" "github.com/pingcap/tidb/pkg/disttask/framework/scheduler" @@ -66,7 +65,7 @@ func getTestSchedulerExt(ctrl *gomock.Controller) scheduler.Extension { }, ).AnyTimes() mockScheduler.EXPECT().OnNextSubtasksBatch(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn( - func(_ context.Context, _ scheduler.TaskHandle, _ *proto.Task, _ []string, _ proto.Step) (metas [][]byte, err error) { + func(_ context.Context, _ storage.TaskHandle, _ *proto.Task, _ []string, _ proto.Step) (metas [][]byte, err error) { return nil, nil }, ).AnyTimes() @@ -95,7 +94,7 @@ func getNumberExampleSchedulerExt(ctrl *gomock.Controller) scheduler.Extension { }, ).AnyTimes() mockScheduler.EXPECT().OnNextSubtasksBatch(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn( - func(_ context.Context, _ scheduler.TaskHandle, task *proto.Task, _ []string, _ proto.Step) (metas [][]byte, err error) { + func(_ context.Context, _ storage.TaskHandle, task *proto.Task, _ []string, _ proto.Step) (metas [][]byte, err error) { switch task.Step { case proto.StepInit: for i := 0; i < subtaskCnt; i++ { @@ -123,7 +122,7 @@ func MockSchedulerManager(t *testing.T, ctrl *gomock.Controller, pool *pools.Res sch, err := scheduler.NewManager(util.WithInternalSourceType(ctx, "scheduler"), mgr, "host:port") require.NoError(t, err) scheduler.RegisterSchedulerFactory(proto.TaskTypeExample, - func(ctx context.Context, taskMgr scheduler.TaskManager, nodeMgr *scheduler.NodeManager, task *proto.Task) scheduler.Scheduler { + func(ctx context.Context, task *proto.Task, param scheduler.Param) scheduler.Scheduler { mockScheduler := sch.MockScheduler(task) mockScheduler.Extension = ext return mockScheduler @@ -222,7 +221,7 @@ func TestTaskFailInManager(t *testing.T) { mockScheduler.EXPECT().Init().Return(errors.New("mock scheduler init error")) schManager, mgr := MockSchedulerManager(t, ctrl, pool, getTestSchedulerExt(ctrl), nil) scheduler.RegisterSchedulerFactory(proto.TaskTypeExample, - func(ctx context.Context, taskMgr scheduler.TaskManager, nodeMgr *scheduler.NodeManager, task *proto.Task) scheduler.Scheduler { + func(ctx context.Context, task *proto.Task, param scheduler.Param) scheduler.Scheduler { return mockScheduler }) schManager.Start() @@ -275,6 +274,7 @@ func checkDispatch(t *testing.T, taskCnt int, isSucc, isCancel, isSubtaskCancel, ctx = util.WithInternalSourceType(ctx, "scheduler") sch, mgr := MockSchedulerManager(t, ctrl, pool, getNumberExampleSchedulerExt(ctrl), nil) + require.NoError(t, mgr.InitMeta(ctx, ":4000", "background")) sch.Start() defer func() { sch.Stop() @@ -284,14 +284,12 @@ func checkDispatch(t *testing.T, taskCnt int, isSucc, isCancel, isSubtaskCancel, } }() - require.NoError(t, mgr.InitMeta(ctx, ":4000", "background")) - // 3s cnt := 60 checkGetRunningTaskCnt := func(expected int) { require.Eventually(t, func() bool { return sch.GetRunningTaskCnt() == expected - }, time.Second, 50*time.Millisecond) + }, 5*time.Second, 50*time.Millisecond) } checkTaskRunningCnt := func() []*proto.Task { @@ -301,7 +299,7 @@ func checkDispatch(t *testing.T, taskCnt int, isSucc, isCancel, isSubtaskCancel, tasks, err = mgr.GetTasksInStates(ctx, proto.TaskStateRunning) require.NoError(t, err) return len(tasks) == taskCnt - }, time.Second, 50*time.Millisecond) + }, 5*time.Second, 50*time.Millisecond) return tasks } @@ -312,7 +310,7 @@ func checkDispatch(t *testing.T, taskCnt int, isSucc, isCancel, isSubtaskCancel, cntByStates, err := mgr.GetSubtaskCntGroupByStates(ctx, taskID, proto.StepOne) require.NoError(t, err) return int64(subtaskCnt) == cntByStates[proto.SubtaskStatePending] - }, time.Second, 50*time.Millisecond) + }, 5*time.Second, 50*time.Millisecond) } } @@ -501,104 +499,6 @@ func TestIsCancelledErr(t *testing.T) { require.True(t, scheduler.IsCancelledErr(errors.New("cancelled by user"))) } -func TestDispatcherOnNextStage(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - taskMgr := mock.NewMockTaskManager(ctrl) - schExt := mockDispatch.NewMockExtension(ctrl) - - ctx := context.Background() - ctx = util.WithInternalSourceType(ctx, "dispatcher") - task := proto.Task{ - ID: 1, - State: proto.TaskStatePending, - Step: proto.StepInit, - } - cloneTask := task - nodeMgr := scheduler.NewNodeManager() - sch := scheduler.NewBaseScheduler(ctx, taskMgr, nodeMgr, &cloneTask) - sch.Extension = schExt - - // test next step is done - schExt.EXPECT().GetNextStep(gomock.Any()).Return(proto.StepDone) - schExt.EXPECT().OnDone(gomock.Any(), gomock.Any(), gomock.Any()).Return(errors.New("done err")) - require.ErrorContains(t, sch.Switch2NextStep(), "done err") - require.True(t, ctrl.Satisfied()) - // we update task step before OnDone - require.Equal(t, proto.StepDone, sch.Task.Step) - *sch.Task = task - schExt.EXPECT().GetNextStep(gomock.Any()).Return(proto.StepDone) - schExt.EXPECT().OnDone(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil) - taskMgr.EXPECT().SucceedTask(gomock.Any(), gomock.Any()).Return(nil) - require.NoError(t, sch.Switch2NextStep()) - require.True(t, ctrl.Satisfied()) - - // GetEligibleInstances err - schExt.EXPECT().GetNextStep(gomock.Any()).Return(proto.StepOne) - schExt.EXPECT().GetEligibleInstances(gomock.Any(), gomock.Any()).Return(nil, errors.New("GetEligibleInstances err")) - require.ErrorContains(t, sch.Switch2NextStep(), "GetEligibleInstances err") - require.True(t, ctrl.Satisfied()) - // GetEligibleInstances no instance - schExt.EXPECT().GetNextStep(gomock.Any()).Return(proto.StepOne) - schExt.EXPECT().GetEligibleInstances(gomock.Any(), gomock.Any()).Return(nil, nil) - require.ErrorContains(t, sch.Switch2NextStep(), "no available TiDB node to dispatch subtasks") - require.True(t, ctrl.Satisfied()) - - serverNodes := []string{":4000"} - // OnNextSubtasksBatch err - schExt.EXPECT().GetNextStep(gomock.Any()).Return(proto.StepOne) - schExt.EXPECT().GetEligibleInstances(gomock.Any(), gomock.Any()).Return(serverNodes, nil) - schExt.EXPECT().OnNextSubtasksBatch(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). - Return(nil, errors.New("OnNextSubtasksBatch err")) - schExt.EXPECT().IsRetryableErr(gomock.Any()).Return(true) - require.ErrorContains(t, sch.Switch2NextStep(), "OnNextSubtasksBatch err") - require.True(t, ctrl.Satisfied()) - - bak := kv.TxnTotalSizeLimit.Load() - t.Cleanup(func() { - kv.TxnTotalSizeLimit.Store(bak) - }) - - // dispatch in batch - subtaskMetas := [][]byte{ - []byte(`{"xx": "1"}`), - []byte(`{"xx": "2"}`), - } - schExt.EXPECT().GetNextStep(gomock.Any()).Return(proto.StepOne) - schExt.EXPECT().GetEligibleInstances(gomock.Any(), gomock.Any()).Return(serverNodes, nil) - schExt.EXPECT().OnNextSubtasksBatch(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). - Return(subtaskMetas, nil) - taskMgr.EXPECT().SwitchTaskStepInBatch(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil) - kv.TxnTotalSizeLimit.Store(1) - require.NoError(t, sch.Switch2NextStep()) - require.True(t, ctrl.Satisfied()) - // met unstable subtasks - schExt.EXPECT().GetNextStep(gomock.Any()).Return(proto.StepOne) - schExt.EXPECT().GetEligibleInstances(gomock.Any(), gomock.Any()).Return(serverNodes, nil) - schExt.EXPECT().OnNextSubtasksBatch(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). - Return(subtaskMetas, nil) - taskMgr.EXPECT().SwitchTaskStepInBatch(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). - Return(errors.Annotatef(storage.ErrUnstableSubtasks, "expected %d, got %d", - 2, 100)) - kv.TxnTotalSizeLimit.Store(1) - startTime := time.Now() - err := sch.Switch2NextStep() - require.ErrorIs(t, err, storage.ErrUnstableSubtasks) - require.ErrorContains(t, err, "expected 2, got 100") - require.WithinDuration(t, startTime, time.Now(), 10*time.Second) - require.True(t, ctrl.Satisfied()) - - // dispatch in one txn - schExt.EXPECT().GetNextStep(gomock.Any()).Return(proto.StepOne) - schExt.EXPECT().GetEligibleInstances(gomock.Any(), gomock.Any()).Return(serverNodes, nil) - schExt.EXPECT().OnNextSubtasksBatch(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). - Return(subtaskMetas, nil) - taskMgr.EXPECT().SwitchTaskStep(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil) - kv.TxnTotalSizeLimit.Store(config.DefTxnTotalSizeLimit) - require.NoError(t, sch.Switch2NextStep()) - require.True(t, ctrl.Satisfied()) -} - func TestManagerDispatchLoop(t *testing.T) { // Mock 16 cpu node. require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/pkg/util/cpu/mockNumCpu", "return(16)")) @@ -636,9 +536,13 @@ func TestManagerDispatchLoop(t *testing.T) { } var counter atomic.Int32 scheduler.RegisterSchedulerFactory(proto.TaskTypeExample, - func(ctx context.Context, taskMgr scheduler.TaskManager, nodeMgr *scheduler.NodeManager, task *proto.Task) scheduler.Scheduler { + func(ctx context.Context, task *proto.Task, param scheduler.Param) scheduler.Scheduler { idx := counter.Load() mockScheduler = mock.NewMockScheduler(ctrl) + // below 2 are for balancer loop, it's async, cannot determine how + // many times it will be called. + mockScheduler.EXPECT().GetTask().Return(task).AnyTimes() + mockScheduler.EXPECT().GetEligibleInstances(gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes() mockScheduler.EXPECT().Init().Return(nil) mockScheduler.EXPECT().ScheduleTask().Do(func() { require.NoError(t, taskMgr.WithNewSession(func(se sessionctx.Context) error { diff --git a/pkg/disttask/framework/scheduler/slots.go b/pkg/disttask/framework/scheduler/slots.go index 528dd45ad48d8..c572c54eb5333 100644 --- a/pkg/disttask/framework/scheduler/slots.go +++ b/pkg/disttask/framework/scheduler/slots.go @@ -31,7 +31,7 @@ type taskStripes struct { stripes int } -// slotManager is used to manage the resource slots and stripes. +// SlotManager is used to manage the resource slots and stripes. // // Slot is the resource unit of dist framework on each node, each slot represents // 1 cpu core, 1/total-core of memory, 1/total-core of disk, etc. @@ -48,7 +48,7 @@ type taskStripes struct { // // Dist framework will try to allocate resource by slots and stripes, and give // quota to subtask, but subtask can determine what to conform. -type slotManager struct { +type SlotManager struct { // Capacity is the total number of slots and stripes. capacity atomic.Int32 @@ -70,16 +70,17 @@ type slotManager struct { // to schedule lower priority task, but next step of A has many subtasks. // once initialized, the length of usedSlots should be equal to number of nodes // managed by dist framework. - usedSlots map[string]int + usedSlots atomic.Pointer[map[string]int] } -// newSlotManager creates a new slotManager. -func newSlotManager() *slotManager { - s := &slotManager{ +// newSlotManager creates a new SlotManager. +func newSlotManager() *SlotManager { + usedSlots := make(map[string]int) + s := &SlotManager{ task2Index: make(map[int64]int), reservedSlots: make(map[string]int), - usedSlots: make(map[string]int), } + s.usedSlots.Store(&usedSlots) // this node might not be the managed node of the framework, but we initialize // capacity with the cpu count of this node, it will be updated when node // manager starts. @@ -89,22 +90,18 @@ func newSlotManager() *slotManager { // Update updates the used slots on each node. // TODO: on concurrent call, update once. -func (sm *slotManager) update(ctx context.Context, taskMgr TaskManager) error { - nodes, err := taskMgr.GetManagedNodes(ctx) - if err != nil { - return err - } +func (sm *SlotManager) update(ctx context.Context, nodeMgr *NodeManager, taskMgr TaskManager) error { + nodes := nodeMgr.getManagedNodes() slotsOnNodes, err := taskMgr.GetUsedSlotsOnNodes(ctx) if err != nil { return err } newUsedSlots := make(map[string]int, len(nodes)) for _, node := range nodes { - newUsedSlots[node.ID] = slotsOnNodes[node.ID] + newUsedSlots[node] = slotsOnNodes[node] } - sm.mu.Lock() - defer sm.mu.Unlock() - sm.usedSlots = newUsedSlots + + sm.usedSlots.Store(&newUsedSlots) return nil } @@ -114,11 +111,12 @@ func (sm *slotManager) update(ctx context.Context, taskMgr TaskManager) error { // as usedSlots is updated asynchronously, it might return false even if there // are enough resources, or return true on resource shortage when some task // scheduled subtasks. -func (sm *slotManager) canReserve(task *proto.Task) (execID string, ok bool) { +func (sm *SlotManager) canReserve(task *proto.Task) (execID string, ok bool) { + usedSlots := *sm.usedSlots.Load() capacity := int(sm.capacity.Load()) sm.mu.RLock() defer sm.mu.RUnlock() - if len(sm.usedSlots) == 0 { + if len(usedSlots) == 0 { // no node managed by dist framework return "", false } @@ -134,7 +132,7 @@ func (sm *slotManager) canReserve(task *proto.Task) (execID string, ok bool) { return "", true } - for id, count := range sm.usedSlots { + for id, count := range usedSlots { if count+sm.reservedSlots[id]+task.Concurrency <= capacity { return id, true } @@ -144,7 +142,7 @@ func (sm *slotManager) canReserve(task *proto.Task) (execID string, ok bool) { // Reserve reserves resources for a task. // Reserve and UnReserve should be called in pair with same parameters. -func (sm *slotManager) reserve(task *proto.Task, execID string) { +func (sm *SlotManager) reserve(task *proto.Task, execID string) { taskClone := *task sm.mu.Lock() @@ -163,7 +161,7 @@ func (sm *slotManager) reserve(task *proto.Task, execID string) { } // UnReserve un-reserve resources for a task. -func (sm *slotManager) unReserve(task *proto.Task, execID string) { +func (sm *SlotManager) unReserve(task *proto.Task, execID string) { sm.mu.Lock() defer sm.mu.Unlock() idx, ok := sm.task2Index[task.ID] @@ -184,7 +182,22 @@ func (sm *slotManager) unReserve(task *proto.Task, execID string) { } } -func (sm *slotManager) updateCapacity(cpuCount int) { +func (sm *SlotManager) getCapacity() int { + return int(sm.capacity.Load()) +} + +// we schedule subtasks to the nodes with enough slots first, if no such nodes, +// schedule to all nodes. +func (sm *SlotManager) adjustEligibleNodes(eligibleNodes []string, concurrency int) []string { + usedSlots := *sm.usedSlots.Load() + nodes := filterNodesWithEnoughSlots(usedSlots, sm.getCapacity(), eligibleNodes, concurrency) + if len(nodes) == 0 { + nodes = eligibleNodes + } + return nodes +} + +func (sm *SlotManager) updateCapacity(cpuCount int) { old := sm.capacity.Load() if cpuCount > 0 && cpuCount != int(old) { sm.capacity.Store(int32(cpuCount)) @@ -196,3 +209,20 @@ func (sm *slotManager) updateCapacity(cpuCount int) { } } } + +func filterNodesWithEnoughSlots(usedSlots map[string]int, capacity int, eligibleNodes []string, concurrency int) []string { + nodesOfEnoughSlots := make(map[string]struct{}, len(usedSlots)) + for node, slots := range usedSlots { + if slots+concurrency <= capacity { + nodesOfEnoughSlots[node] = struct{}{} + } + } + + result := make([]string, 0, len(eligibleNodes)) + for _, node := range eligibleNodes { + if _, ok := nodesOfEnoughSlots[node]; ok { + result = append(result, node) + } + } + return result +} diff --git a/pkg/disttask/framework/scheduler/slots_test.go b/pkg/disttask/framework/scheduler/slots_test.go index 04b72c6fdac9b..135a233b5ce07 100644 --- a/pkg/disttask/framework/scheduler/slots_test.go +++ b/pkg/disttask/framework/scheduler/slots_test.go @@ -34,9 +34,9 @@ func TestSlotManagerReserve(t *testing.T) { require.False(t, ok) // reserve by stripes - sm.usedSlots = map[string]int{ + sm.usedSlots.Store(&map[string]int{ "tidb-1": 16, - } + }) task := proto.Task{ Priority: proto.NormalPriority, Concurrency: 16, @@ -86,10 +86,10 @@ func TestSlotManagerReserve(t *testing.T) { require.True(t, ok) // reserve by slots - sm.usedSlots = map[string]int{ + sm.usedSlots.Store(&map[string]int{ "tidb-1": 12, "tidb-2": 8, - } + }) task40 := task task40.ID = 40 task40.Concurrency = 16 @@ -179,49 +179,63 @@ func TestSlotManagerReserve(t *testing.T) { func TestSlotManagerUpdate(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() + ctx := context.Background() + nodeMgr := newNodeManager() taskMgr := mock.NewMockTaskManager(ctrl) - taskMgr.EXPECT().GetManagedNodes(gomock.Any()).Return([]proto.ManagedNode{{ID: "tidb-1"}, {ID: "tidb-2"}, {ID: "tidb-3"}}, nil) + nodeMgr.managedNodes.Store(&[]string{"tidb-1", "tidb-2", "tidb-3"}) taskMgr.EXPECT().GetUsedSlotsOnNodes(gomock.Any()).Return(map[string]int{ "tidb-1": 12, "tidb-2": 8, }, nil) sm := newSlotManager() sm.updateCapacity(16) - require.Empty(t, sm.usedSlots) + require.Empty(t, sm.usedSlots.Load()) require.Empty(t, sm.reservedSlots) - require.NoError(t, sm.update(context.Background(), taskMgr)) + require.NoError(t, sm.update(ctx, nodeMgr, taskMgr)) require.Empty(t, sm.reservedSlots) require.Equal(t, map[string]int{ "tidb-1": 12, "tidb-2": 8, "tidb-3": 0, - }, sm.usedSlots) + }, *sm.usedSlots.Load()) + require.True(t, ctrl.Satisfied()) + // some node scaled in, should be reflected - taskMgr.EXPECT().GetManagedNodes(gomock.Any()).Return([]proto.ManagedNode{{ID: "tidb-1"}}, nil) + nodeMgr.managedNodes.Store(&[]string{"tidb-1"}) taskMgr.EXPECT().GetUsedSlotsOnNodes(gomock.Any()).Return(map[string]int{ "tidb-1": 12, "tidb-2": 8, }, nil) - require.NoError(t, sm.update(context.Background(), taskMgr)) + require.NoError(t, sm.update(ctx, nodeMgr, taskMgr)) require.Empty(t, sm.reservedSlots) require.Equal(t, map[string]int{ "tidb-1": 12, - }, sm.usedSlots) + }, *sm.usedSlots.Load()) + require.True(t, ctrl.Satisfied()) // on error, the usedSlots should not be changed - taskMgr.EXPECT().GetManagedNodes(gomock.Any()).Return(nil, errors.New("mock err")) - require.ErrorContains(t, sm.update(context.Background(), taskMgr), "mock err") - require.Empty(t, sm.reservedSlots) - require.Equal(t, map[string]int{ - "tidb-1": 12, - }, sm.usedSlots) - taskMgr.EXPECT().GetManagedNodes(gomock.Any()).Return([]proto.ManagedNode{{ID: "tidb-1"}}, nil) taskMgr.EXPECT().GetUsedSlotsOnNodes(gomock.Any()).Return(nil, errors.New("mock err")) - require.ErrorContains(t, sm.update(context.Background(), taskMgr), "mock err") + require.ErrorContains(t, sm.update(ctx, nodeMgr, taskMgr), "mock err") require.Empty(t, sm.reservedSlots) require.Equal(t, map[string]int{ "tidb-1": 12, - }, sm.usedSlots) + }, *sm.usedSlots.Load()) +} + +func TestSchedulerAdjustEligibleNodes(t *testing.T) { + slotMgr := newSlotManager() + slotMgr.updateCapacity(16) + + allNodes := []string{":4000", ":4001", ":4002"} + require.Equal(t, allNodes, slotMgr.adjustEligibleNodes(allNodes, 10)) + + usedSlots := map[string]int{ + ":4000": 12, + ":4001": 4, + ":4003": 0, // stale node + } + slotMgr.usedSlots.Store(&usedSlots) + require.Equal(t, []string{":4001"}, slotMgr.adjustEligibleNodes(allNodes, 10)) } func TestSlotManagerUpdateCapacity(t *testing.T) { diff --git a/pkg/disttask/framework/storage/BUILD.bazel b/pkg/disttask/framework/storage/BUILD.bazel index a00fc6baeb98c..ee7f49eb4d793 100644 --- a/pkg/disttask/framework/storage/BUILD.bazel +++ b/pkg/disttask/framework/storage/BUILD.bazel @@ -40,7 +40,7 @@ go_test( embed = [":storage"], flaky = True, race = "on", - shard_count = 17, + shard_count = 18, deps = [ "//pkg/config", "//pkg/disttask/framework/proto", diff --git a/pkg/disttask/framework/storage/table_test.go b/pkg/disttask/framework/storage/table_test.go index 78b97f690d69b..a39026b92ff2c 100644 --- a/pkg/disttask/framework/storage/table_test.go +++ b/pkg/disttask/framework/storage/table_test.go @@ -17,6 +17,7 @@ package storage_test import ( "context" "fmt" + "slices" "sort" "testing" "time" @@ -437,6 +438,37 @@ func TestGetUsedSlotsOnNodes(t *testing.T) { }, slotsOnNodes) } +func TestGetActiveSubtasks(t *testing.T) { + _, tm, ctx := testutil.InitTableTest(t) + require.NoError(t, tm.InitMeta(ctx, ":4000", "")) + id, err := tm.CreateTask(ctx, "key1", "test", 4, []byte("test")) + require.NoError(t, err) + require.Equal(t, int64(1), id) + task, err := tm.GetTaskByID(ctx, id) + require.NoError(t, err) + + subtasks := make([]*proto.Subtask, 0, 3) + for i := 0; i < 3; i++ { + subtasks = append(subtasks, + proto.NewSubtask(proto.StepOne, id, "test", fmt.Sprintf("tidb%d", i), 8, []byte("{}}"), i+1), + ) + } + require.NoError(t, tm.SwitchTaskStep(ctx, task, proto.TaskStateRunning, proto.StepOne, subtasks)) + require.NoError(t, tm.FinishSubtask(ctx, "tidb0", 1, []byte("{}}"))) + require.NoError(t, tm.StartSubtask(ctx, 2, "tidb1")) + + activeSubtasks, err := tm.GetActiveSubtasks(ctx, task.ID) + require.NoError(t, err) + require.Len(t, activeSubtasks, 2) + slices.SortFunc(activeSubtasks, func(i, j *proto.Subtask) int { + return int(i.ID - j.ID) + }) + require.Equal(t, int64(2), activeSubtasks[0].ID) + require.Equal(t, proto.SubtaskStateRunning, activeSubtasks[0].State) + require.Equal(t, int64(3), activeSubtasks[1].ID) + require.Equal(t, proto.SubtaskStatePending, activeSubtasks[1].State) +} + func TestSubTaskTable(t *testing.T) { _, sm, ctx := testutil.InitTableTest(t) timeBeforeCreate := time.Unix(time.Now().Unix(), 0) @@ -457,6 +489,7 @@ func TestSubTaskTable(t *testing.T) { Concurrency: 11, ExecID: "tidb1", Meta: []byte("test"), + Ordinal: 1, }, }, proto.TaskStatePending, ) @@ -469,7 +502,6 @@ func TestSubTaskTable(t *testing.T) { subtask, err := sm.GetFirstSubtaskInStates(ctx, "tidb1", 1, proto.StepInit, proto.SubtaskStatePending) require.NoError(t, err) require.Equal(t, proto.StepInit, subtask.Step) - require.Equal(t, int64(1), subtask.TaskID) require.Equal(t, proto.TaskTypeExample, subtask.Type) require.Equal(t, int64(1), subtask.TaskID) require.Equal(t, proto.SubtaskStatePending, subtask.State) @@ -477,6 +509,7 @@ func TestSubTaskTable(t *testing.T) { require.Equal(t, []byte("test"), subtask.Meta) require.Equal(t, 11, subtask.Concurrency) require.GreaterOrEqual(t, subtask.CreateTime, timeBeforeCreate) + require.Equal(t, 0, subtask.Ordinal) require.Zero(t, subtask.StartTime) require.Zero(t, subtask.UpdateTime) require.Equal(t, "{}", subtask.Summary) @@ -629,7 +662,7 @@ func TestSubTaskTable(t *testing.T) { subtasks, err = sm.GetSubtasksByStepAndState(ctx, 5, proto.StepInit, proto.TaskStatePending) require.NoError(t, err) subtasks[0].ExecID = "tidb2" - require.NoError(t, sm.UpdateSubtasksExecIDs(ctx, 5, subtasks)) + require.NoError(t, sm.UpdateSubtasksExecIDs(ctx, subtasks)) subtasks, err = sm.GetSubtasksByStepAndState(ctx, 5, proto.StepInit, proto.TaskStatePending) require.NoError(t, err) require.Equal(t, "tidb2", subtasks[0].ExecID) @@ -638,7 +671,7 @@ func TestSubTaskTable(t *testing.T) { subtasks, err = sm.GetSubtasksByStepAndState(ctx, 5, proto.StepInit, proto.TaskStatePending) require.NoError(t, err) subtasks[0].ExecID = "tidb3" - require.NoError(t, sm.UpdateSubtasksExecIDs(ctx, 5, subtasks)) + require.NoError(t, sm.UpdateSubtasksExecIDs(ctx, subtasks)) subtasks, err = sm.GetSubtasksByStepAndState(ctx, 5, proto.StepInit, proto.TaskStatePending) require.NoError(t, err) require.Equal(t, "tidb3", subtasks[0].ExecID) @@ -650,7 +683,7 @@ func TestSubTaskTable(t *testing.T) { require.Equal(t, "tidb3", subtasks[0].ExecID) subtasks[0].ExecID = "tidb2" // update success - require.NoError(t, sm.UpdateSubtasksExecIDs(ctx, 5, subtasks)) + require.NoError(t, sm.UpdateSubtasksExecIDs(ctx, subtasks)) subtasks, err = sm.GetSubtasksByStepAndState(ctx, 5, proto.StepInit, proto.TaskStatePending) require.NoError(t, err) require.Equal(t, "tidb2", subtasks[0].ExecID) diff --git a/pkg/disttask/framework/storage/task_table.go b/pkg/disttask/framework/storage/task_table.go index 310e248ee80c5..b913e5426016a 100644 --- a/pkg/disttask/framework/storage/task_table.go +++ b/pkg/disttask/framework/storage/task_table.go @@ -49,8 +49,8 @@ const ( // InsertTaskColumns is the columns used in insert task. InsertTaskColumns = `task_key, type, state, priority, concurrency, step, meta, create_time` - subtaskColumns = `id, step, task_key, type, exec_id, state, concurrency, create_time, - start_time, state_update_time, meta, summary, ordinal` + basicSubtaskColumns = `id, step, task_key, type, exec_id, state, concurrency, create_time, ordinal` + subtaskColumns = basicSubtaskColumns + `, start_time, state_update_time, meta, summary` // InsertSubtaskColumns is the columns used in insert subtask. InsertSubtaskColumns = `step, task_key, exec_id, meta, state, type, concurrency, ordinal, create_time, checkpoint, summary` ) @@ -84,6 +84,14 @@ type SessionExecutor interface { WithNewTxn(ctx context.Context, fn func(se sessionctx.Context) error) error } +// TaskHandle provides the interface for operations needed by Scheduler. +// Then we can use scheduler's function in Scheduler interface. +type TaskHandle interface { + // GetPreviousSubtaskMetas gets previous subtask metas. + GetPreviousSubtaskMetas(taskID int64, step proto.Step) ([][]byte, error) + SessionExecutor +} + // TaskManager is the manager of task and subtask. type TaskManager struct { sePool sessionPool @@ -415,44 +423,50 @@ func (mgr *TaskManager) GetUsedSlotsOnNodes(ctx context.Context) (map[string]int return slots, nil } -// row2SubTask converts a row to a subtask. -func row2SubTask(r chunk.Row) *proto.Subtask { - // subtask defines start/update time as bigint, to ensure backward compatible, - // we keep it that way, and we convert it here. - createTime, _ := r.GetTime(7).GoTime(time.Local) - var startTime, updateTime time.Time - if !r.IsNull(8) { - ts := r.GetInt64(8) - startTime = time.Unix(ts, 0) - } - if !r.IsNull(9) { - ts := r.GetInt64(9) - updateTime = time.Unix(ts, 0) +// row2BasicSubTask converts a row to a subtask with basic info +func row2BasicSubTask(r chunk.Row) *proto.Subtask { + taskIDStr := r.GetString(2) + tid, err := strconv.Atoi(taskIDStr) + if err != nil { + logutil.BgLogger().Warn("unexpected subtask id", zap.String("subtask-id", taskIDStr)) } + createTime, _ := r.GetTime(7).GoTime(time.Local) var ordinal int - if !r.IsNull(12) { - ordinal = int(r.GetInt64(12)) + if !r.IsNull(8) { + ordinal = int(r.GetInt64(8)) } subtask := &proto.Subtask{ ID: r.GetInt64(0), Step: proto.Step(r.GetInt64(1)), + TaskID: int64(tid), Type: proto.Int2Type(int(r.GetInt64(3))), ExecID: r.GetString(4), State: proto.SubtaskState(r.GetString(5)), Concurrency: int(r.GetInt64(6)), CreateTime: createTime, - StartTime: startTime, - UpdateTime: updateTime, - Meta: r.GetBytes(10), - Summary: r.GetJSON(11).String(), Ordinal: ordinal, } - taskIDStr := r.GetString(2) - tid, err := strconv.Atoi(taskIDStr) - if err != nil { - logutil.BgLogger().Warn("unexpected subtask id", zap.String("subtask-id", taskIDStr)) + return subtask +} + +// row2SubTask converts a row to a subtask. +func row2SubTask(r chunk.Row) *proto.Subtask { + subtask := row2BasicSubTask(r) + // subtask defines start/update time as bigint, to ensure backward compatible, + // we keep it that way, and we convert it here. + var startTime, updateTime time.Time + if !r.IsNull(9) { + ts := r.GetInt64(9) + startTime = time.Unix(ts, 0) + } + if !r.IsNull(10) { + ts := r.GetInt64(10) + updateTime = time.Unix(ts, 0) } - subtask.TaskID = int64(tid) + subtask.StartTime = startTime + subtask.UpdateTime = updateTime + subtask.Meta = r.GetBytes(11) + subtask.Summary = r.GetJSON(12).String() return subtask } @@ -550,6 +564,22 @@ func (mgr *TaskManager) UpdateErrorToSubtask(ctx context.Context, execID string, return err1 } +// GetActiveSubtasks implements TaskManager.GetActiveSubtasks. +func (mgr *TaskManager) GetActiveSubtasks(ctx context.Context, taskID int64) ([]*proto.Subtask, error) { + rs, err := mgr.executeSQLWithNewSession(ctx, ` + select `+basicSubtaskColumns+` from mysql.tidb_background_subtask + where task_key = %? and state in (%?, %?)`, + taskID, proto.TaskStatePending, proto.TaskStateRunning) + if err != nil { + return nil, err + } + subtasks := make([]*proto.Subtask, 0, len(rs)) + for _, r := range rs { + subtasks = append(subtasks, row2BasicSubTask(r)) + } + return subtasks, nil +} + // GetSubtasksByStepAndState gets the subtask by step and state. func (mgr *TaskManager) GetSubtasksByStepAndState(ctx context.Context, taskID int64, step proto.Step, state proto.TaskState) ([]*proto.Subtask, error) { rs, err := mgr.executeSQLWithNewSession(ctx, `select `+subtaskColumns+` from mysql.tidb_background_subtask @@ -798,19 +828,18 @@ func (mgr *TaskManager) IsTaskExecutorCanceled(ctx context.Context, execID strin } // UpdateSubtasksExecIDs update subtasks' execID. -func (mgr *TaskManager) UpdateSubtasksExecIDs(ctx context.Context, taskID int64, subtasks []*proto.Subtask) error { +func (mgr *TaskManager) UpdateSubtasksExecIDs(ctx context.Context, subtasks []*proto.Subtask) error { // skip the update process. if len(subtasks) == 0 { return nil } err := mgr.WithNewTxn(ctx, func(se sessionctx.Context) error { for _, subtask := range subtasks { - _, err := sqlexec.ExecSQL(ctx, se, - "update mysql.tidb_background_subtask set exec_id = %? where id = %? and state = %? and task_key = %?", - subtask.ExecID, - subtask.ID, - subtask.State, - taskID) + _, err := sqlexec.ExecSQL(ctx, se, ` + update mysql.tidb_background_subtask + set exec_id = %? + where id = %? and state = %?`, + subtask.ExecID, subtask.ID, subtask.State) if err != nil { return err } diff --git a/pkg/disttask/framework/taskexecutor/slot.go b/pkg/disttask/framework/taskexecutor/slot.go index 6f78f6e0d144b..219ffce50851a 100644 --- a/pkg/disttask/framework/taskexecutor/slot.go +++ b/pkg/disttask/framework/taskexecutor/slot.go @@ -35,6 +35,7 @@ type slotManager struct { available int } +// subtasks inside a task will be run in serial, so they takes task.Concurrency slots. func (sm *slotManager) alloc(task *proto.Task) { sm.Lock() defer sm.Unlock() diff --git a/pkg/disttask/framework/testutil/disttest_util.go b/pkg/disttask/framework/testutil/disttest_util.go index ff37547df23bc..f75c68dc416eb 100644 --- a/pkg/disttask/framework/testutil/disttest_util.go +++ b/pkg/disttask/framework/testutil/disttest_util.go @@ -62,8 +62,8 @@ func registerTaskMetaInner(t *testing.T, taskType proto.TaskType, mockExtension taskexecutor.ClearTaskExecutors() }) scheduler.RegisterSchedulerFactory(taskType, - func(ctx context.Context, taskMgr scheduler.TaskManager, nodeMgr *scheduler.NodeManager, task *proto.Task) scheduler.Scheduler { - baseScheduler := scheduler.NewBaseScheduler(ctx, taskMgr, nodeMgr, task) + func(ctx context.Context, task *proto.Task, param scheduler.Param) scheduler.Scheduler { + baseScheduler := scheduler.NewBaseScheduler(ctx, task, param) baseScheduler.Extension = schedulerHandle return baseScheduler }) diff --git a/pkg/disttask/framework/testutil/scheduler_util.go b/pkg/disttask/framework/testutil/scheduler_util.go index 99f5e2e0af987..aab32f6eb7440 100644 --- a/pkg/disttask/framework/testutil/scheduler_util.go +++ b/pkg/disttask/framework/testutil/scheduler_util.go @@ -21,6 +21,7 @@ import ( "github.com/pingcap/tidb/pkg/disttask/framework/proto" "github.com/pingcap/tidb/pkg/disttask/framework/scheduler" mockDispatch "github.com/pingcap/tidb/pkg/disttask/framework/scheduler/mock" + "github.com/pingcap/tidb/pkg/disttask/framework/storage" "go.uber.org/mock/gomock" ) @@ -47,7 +48,7 @@ func GetMockBasicSchedulerExt(ctrl *gomock.Controller) scheduler.Extension { }, ).AnyTimes() mockScheduler.EXPECT().OnNextSubtasksBatch(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn( - func(_ context.Context, _ scheduler.TaskHandle, task *proto.Task, _ []string, _ proto.Step) (metas [][]byte, err error) { + func(_ context.Context, _ storage.TaskHandle, task *proto.Task, _ []string, _ proto.Step) (metas [][]byte, err error) { if task.Step == proto.StepInit { return [][]byte{ []byte("task1"), @@ -91,7 +92,7 @@ func GetMockHATestSchedulerExt(ctrl *gomock.Controller) scheduler.Extension { }, ).AnyTimes() mockScheduler.EXPECT().OnNextSubtasksBatch(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn( - func(_ context.Context, _ scheduler.TaskHandle, task *proto.Task, _ []string, _ proto.Step) (metas [][]byte, err error) { + func(_ context.Context, _ storage.TaskHandle, task *proto.Task, _ []string, _ proto.Step) (metas [][]byte, err error) { if task.Step == proto.StepInit { return [][]byte{ []byte("task1"), @@ -143,7 +144,7 @@ func GetPlanNotRetryableErrSchedulerExt(ctrl *gomock.Controller) scheduler.Exten }, ).AnyTimes() mockScheduler.EXPECT().OnNextSubtasksBatch(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn( - func(_ context.Context, _ scheduler.TaskHandle, _ *proto.Task, _ []string, _ proto.Step) (metas [][]byte, err error) { + func(_ context.Context, _ storage.TaskHandle, _ *proto.Task, _ []string, _ proto.Step) (metas [][]byte, err error) { return nil, errors.New("not retryable err") }, ).AnyTimes() @@ -175,7 +176,7 @@ func GetPlanErrSchedulerExt(ctrl *gomock.Controller, testContext *TestContext) s }, ).AnyTimes() mockScheduler.EXPECT().OnNextSubtasksBatch(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn( - func(_ context.Context, _ scheduler.TaskHandle, task *proto.Task, _ []string, _ proto.Step) (metas [][]byte, err error) { + func(_ context.Context, _ storage.TaskHandle, task *proto.Task, _ []string, _ proto.Step) (metas [][]byte, err error) { if task.Step == proto.StepInit { if testContext.CallTime == 0 { testContext.CallTime++ @@ -224,7 +225,7 @@ func GetMockRollbackSchedulerExt(ctrl *gomock.Controller) scheduler.Extension { }, ).AnyTimes() mockScheduler.EXPECT().OnNextSubtasksBatch(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn( - func(_ context.Context, _ scheduler.TaskHandle, task *proto.Task, _ []string, _ proto.Step) (metas [][]byte, err error) { + func(_ context.Context, _ storage.TaskHandle, task *proto.Task, _ []string, _ proto.Step) (metas [][]byte, err error) { if task.Step == proto.StepInit { return [][]byte{ []byte("task1"), diff --git a/pkg/disttask/importinto/scheduler.go b/pkg/disttask/importinto/scheduler.go index d7a1997b68762..0dff829893565 100644 --- a/pkg/disttask/importinto/scheduler.go +++ b/pkg/disttask/importinto/scheduler.go @@ -199,7 +199,7 @@ func (sch *ImportSchedulerExt) unregisterTask(ctx context.Context, task *proto.T // OnNextSubtasksBatch generate batch of next stage's plan. func (sch *ImportSchedulerExt) OnNextSubtasksBatch( ctx context.Context, - taskHandle scheduler.TaskHandle, + taskHandle storage.TaskHandle, task *proto.Task, execIDs []string, nextStep proto.Step, @@ -309,7 +309,7 @@ func (sch *ImportSchedulerExt) OnNextSubtasksBatch( } // OnDone implements scheduler.Extension interface. -func (sch *ImportSchedulerExt) OnDone(ctx context.Context, handle scheduler.TaskHandle, task *proto.Task) error { +func (sch *ImportSchedulerExt) OnDone(ctx context.Context, handle storage.TaskHandle, task *proto.Task) error { logger := logutil.BgLogger().With( zap.Stringer("type", task.Type), zap.Int64("task-id", task.ID), @@ -405,12 +405,11 @@ type importScheduler struct { *scheduler.BaseScheduler } -func newImportScheduler(ctx context.Context, taskMgr scheduler.TaskManager, - nodeMgr *scheduler.NodeManager, task *proto.Task) scheduler.Scheduler { +func newImportScheduler(ctx context.Context, task *proto.Task, param scheduler.Param) scheduler.Scheduler { metrics := metricsManager.getOrCreateMetrics(task.ID) subCtx := metric.WithCommonMetric(ctx, metrics) sch := importScheduler{ - BaseScheduler: scheduler.NewBaseScheduler(subCtx, taskMgr, nodeMgr, task), + BaseScheduler: scheduler.NewBaseScheduler(subCtx, task, param), } return &sch } @@ -419,11 +418,11 @@ func (sch *importScheduler) Init() (err error) { defer func() { if err != nil { // if init failed, close is not called, so we need to unregister here. - metricsManager.unregister(sch.Task.ID) + metricsManager.unregister(sch.GetTask().ID) } }() taskMeta := &TaskMeta{} - if err = json.Unmarshal(sch.BaseScheduler.Task.Meta, taskMeta); err != nil { + if err = json.Unmarshal(sch.BaseScheduler.GetTask().Meta, taskMeta); err != nil { return errors.Annotate(err, "unmarshal task meta failed") } @@ -434,12 +433,12 @@ func (sch *importScheduler) Init() (err error) { } func (sch *importScheduler) Close() { - metricsManager.unregister(sch.Task.ID) + metricsManager.unregister(sch.GetTask().ID) sch.BaseScheduler.Close() } // nolint:deadcode -func dropTableIndexes(ctx context.Context, handle scheduler.TaskHandle, taskMeta *TaskMeta, logger *zap.Logger) error { +func dropTableIndexes(ctx context.Context, handle storage.TaskHandle, taskMeta *TaskMeta, logger *zap.Logger) error { tblInfo := taskMeta.Plan.TableInfo tableName := common.UniqueTable(taskMeta.Plan.DBName, tblInfo.Name.L) @@ -536,7 +535,7 @@ func getStepOfEncode(globalSort bool) proto.Step { } // we will update taskMeta in place and make task.Meta point to the new taskMeta. -func updateResult(handle scheduler.TaskHandle, task *proto.Task, taskMeta *TaskMeta, globalSort bool) error { +func updateResult(handle storage.TaskHandle, task *proto.Task, taskMeta *TaskMeta, globalSort bool) error { stepOfEncode := getStepOfEncode(globalSort) metas, err := handle.GetPreviousSubtaskMetas(task.ID, stepOfEncode) if err != nil { @@ -570,7 +569,7 @@ func updateResult(handle scheduler.TaskHandle, task *proto.Task, taskMeta *TaskM return updateMeta(task, taskMeta) } -func getLoadedRowCountOnGlobalSort(handle scheduler.TaskHandle, task *proto.Task) (uint64, error) { +func getLoadedRowCountOnGlobalSort(handle storage.TaskHandle, task *proto.Task) (uint64, error) { metas, err := handle.GetPreviousSubtaskMetas(task.ID, StepWriteAndIngest) if err != nil { return 0, err @@ -587,7 +586,7 @@ func getLoadedRowCountOnGlobalSort(handle scheduler.TaskHandle, task *proto.Task return loadedRowCount, nil } -func startJob(ctx context.Context, logger *zap.Logger, taskHandle scheduler.TaskHandle, taskMeta *TaskMeta, jobStep string) error { +func startJob(ctx context.Context, logger *zap.Logger, taskHandle storage.TaskHandle, taskMeta *TaskMeta, jobStep string) error { failpoint.Inject("syncBeforeJobStarted", func() { TestSyncChan <- struct{}{} <-TestSyncChan @@ -631,7 +630,7 @@ func job2Step(ctx context.Context, logger *zap.Logger, taskMeta *TaskMeta, step } func (sch *ImportSchedulerExt) finishJob(ctx context.Context, logger *zap.Logger, - taskHandle scheduler.TaskHandle, task *proto.Task, taskMeta *TaskMeta) error { + taskHandle storage.TaskHandle, task *proto.Task, taskMeta *TaskMeta) error { // we have already switch import-mode when switch to post-process step. sch.unregisterTask(ctx, task) summary := &importer.JobSummary{ImportedRows: taskMeta.Result.LoadedRowCnt} @@ -647,7 +646,7 @@ func (sch *ImportSchedulerExt) finishJob(ctx context.Context, logger *zap.Logger ) } -func (sch *ImportSchedulerExt) failJob(ctx context.Context, taskHandle scheduler.TaskHandle, task *proto.Task, +func (sch *ImportSchedulerExt) failJob(ctx context.Context, taskHandle storage.TaskHandle, task *proto.Task, taskMeta *TaskMeta, logger *zap.Logger, errorMsg string) error { sch.switchTiKV2NormalMode(ctx, task, logger) sch.unregisterTask(ctx, task) @@ -663,7 +662,7 @@ func (sch *ImportSchedulerExt) failJob(ctx context.Context, taskHandle scheduler ) } -func (sch *ImportSchedulerExt) cancelJob(ctx context.Context, taskHandle scheduler.TaskHandle, task *proto.Task, +func (sch *ImportSchedulerExt) cancelJob(ctx context.Context, taskHandle storage.TaskHandle, task *proto.Task, meta *TaskMeta, logger *zap.Logger) error { sch.switchTiKV2NormalMode(ctx, task, logger) sch.unregisterTask(ctx, task) diff --git a/pkg/disttask/importinto/scheduler_test.go b/pkg/disttask/importinto/scheduler_test.go index b0835bbb853bf..00a082d21287a 100644 --- a/pkg/disttask/importinto/scheduler_test.go +++ b/pkg/disttask/importinto/scheduler_test.go @@ -94,11 +94,9 @@ func (s *importIntoSuite) TestSchedulerInit() { bytes, err := json.Marshal(meta) s.NoError(err) sch := importScheduler{ - BaseScheduler: &scheduler.BaseScheduler{ - Task: &proto.Task{ - Meta: bytes, - }, - }, + BaseScheduler: scheduler.NewBaseScheduler(context.Background(), &proto.Task{ + Meta: bytes, + }, scheduler.Param{}), } s.NoError(sch.Init()) s.False(sch.Extension.(*ImportSchedulerExt).GlobalSort) @@ -107,11 +105,9 @@ func (s *importIntoSuite) TestSchedulerInit() { bytes, err = json.Marshal(meta) s.NoError(err) sch = importScheduler{ - BaseScheduler: &scheduler.BaseScheduler{ - Task: &proto.Task{ - Meta: bytes, - }, - }, + BaseScheduler: scheduler.NewBaseScheduler(context.Background(), &proto.Task{ + Meta: bytes, + }, scheduler.Param{}), } s.NoError(sch.Init()) s.True(sch.Extension.(*ImportSchedulerExt).GlobalSort) From 136ae4649eaadc072e265d1a58455c25fe6a0876 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8E=8B=E8=B6=85?= Date: Tue, 9 Jan 2024 17:57:23 +0800 Subject: [PATCH 2/3] expression: refine code in `handleInvalidTimeError` and `handleAllowedPacketOverflowed` (#50180) close pingcap/tidb#50178 --- pkg/expression/bench_test.go | 11 +++++++++-- pkg/expression/builtin_cast_vec_test.go | 2 +- pkg/expression/errors.go | 23 ++++++----------------- 3 files changed, 16 insertions(+), 20 deletions(-) diff --git a/pkg/expression/bench_test.go b/pkg/expression/bench_test.go index f08453e84295e..f9a7665796719 100644 --- a/pkg/expression/bench_test.go +++ b/pkg/expression/bench_test.go @@ -1300,7 +1300,7 @@ func genVecExprBenchCase(ctx sessionctx.Context, funcName string, testCase vecEx // testVectorizedEvalOneVec is used to verify that the vectorized // expression is evaluated correctly during projection func testVectorizedEvalOneVec(t *testing.T, vecExprCases vecExprBenchCases) { - ctx := mock.NewContext() + ctx := createContext(t) for funcName, testCases := range vecExprCases { for _, testCase := range testCases { expr, fts, input, output := genVecExprBenchCase(ctx, funcName, testCase) @@ -1509,7 +1509,7 @@ func testVectorizedBuiltinFunc(t *testing.T, vecExprCases vecExprBenchCases) { } for funcName, testCases := range vecExprCases { for _, testCase := range testCases { - ctx := mock.NewContext() + ctx := createContext(t) if testCase.aesModes == "" { testCase.aesModes = "aes-128-ecb" } @@ -1675,6 +1675,13 @@ func testVectorizedBuiltinFunc(t *testing.T, vecExprCases vecExprBenchCases) { // check warnings totalWarns := ctx.GetSessionVars().StmtCtx.WarningCount() require.Equal(t, totalWarns, 2*vecWarnCnt) + + if _, ok := baseFunc.(*builtinAddSubDateAsStringSig); ok { + // skip check warnings for `builtinAddSubDateAsStringSig` for issue https://github.com/pingcap/tidb/issues/50197 + // TODO: fix this issue + continue + } + warns := ctx.GetSessionVars().StmtCtx.GetWarnings() for i := 0; i < int(vecWarnCnt); i++ { require.True(t, terror.ErrorEqual(warns[i].Err, warns[i+int(vecWarnCnt)].Err)) diff --git a/pkg/expression/builtin_cast_vec_test.go b/pkg/expression/builtin_cast_vec_test.go index ca61ad421616f..aeecc19b11e45 100644 --- a/pkg/expression/builtin_cast_vec_test.go +++ b/pkg/expression/builtin_cast_vec_test.go @@ -158,7 +158,7 @@ func TestVectorizedBuiltinCastFunc(t *testing.T) { func TestVectorizedCastRealAsTime(t *testing.T) { col := &Column{RetType: types.NewFieldType(mysql.TypeDouble), Index: 0} - ctx := mock.NewContext() + ctx := createContext(t) baseFunc, err := newBaseBuiltinFunc(ctx, "", []Expression{col}, types.NewFieldType(mysql.TypeDatetime)) if err != nil { panic(err) diff --git a/pkg/expression/errors.go b/pkg/expression/errors.go index e4888391df99b..9aa5c0b132253 100644 --- a/pkg/expression/errors.go +++ b/pkg/expression/errors.go @@ -76,12 +76,8 @@ func handleInvalidTimeError(ctx EvalContext, err error) error { types.ErrDatetimeFunctionOverflow.Equal(err) || types.ErrIncorrectDatetimeValue.Equal(err)) { return err } - sc := ctx.GetSessionVars().StmtCtx - err = sc.HandleTruncate(err) - if ctx.GetSessionVars().StrictSQLMode && (sc.InInsertStmt || sc.InUpdateStmt || sc.InDeleteStmt) { - return err - } - return nil + ec := ctx.GetSessionVars().StmtCtx.ErrCtx() + return ec.HandleError(err) } // handleDivisionByZeroError reports error or warning depend on the context. @@ -93,17 +89,10 @@ func handleDivisionByZeroError(ctx EvalContext) error { // handleAllowedPacketOverflowed reports error or warning depend on the context. func handleAllowedPacketOverflowed(ctx EvalContext, exprName string, maxAllowedPacketSize uint64) error { err := errWarnAllowedPacketOverflowed.FastGenByArgs(exprName, maxAllowedPacketSize) - sc := ctx.GetSessionVars().StmtCtx - - // insert|update|delete ignore ... - if sc.TypeFlags().TruncateAsWarning() { - sc.AppendWarning(err) + tc := ctx.GetSessionVars().StmtCtx.TypeCtx() + if f := tc.Flags(); f.TruncateAsWarning() || f.IgnoreTruncateErr() { + tc.AppendWarning(err) return nil } - - if ctx.GetSessionVars().StrictSQLMode && (sc.InInsertStmt || sc.InUpdateStmt || sc.InDeleteStmt) { - return errors.Trace(err) - } - sc.AppendWarning(err) - return nil + return errors.Trace(err) } From 56c619f1f5ea8e3f9970fe664729074b5123683a Mon Sep 17 00:00:00 2001 From: EasonBall <592838129@qq.com> Date: Tue, 9 Jan 2024 18:49:53 +0800 Subject: [PATCH 3/3] disttask: remove useless task state (#50071) ref pingcap/tidb#48795 --- pkg/disttask/framework/framework_test.go | 2 +- pkg/disttask/framework/handle/handle_test.go | 8 +-- pkg/disttask/framework/proto/BUILD.bazel | 4 +- pkg/disttask/framework/proto/subtask.go | 2 +- pkg/disttask/framework/proto/task.go | 57 ++++--------------- pkg/disttask/framework/proto/task_test.go | 1 - pkg/disttask/framework/proto/type.go | 52 +++++++++++++++++ pkg/disttask/framework/proto/type_test.go | 40 +++++++++++++ pkg/disttask/framework/scheduler/main_test.go | 10 ++-- .../framework/scheduler/scheduler_test.go | 3 - .../framework/scheduler/state_transform.go | 6 +- pkg/disttask/framework/storage/BUILD.bazel | 2 +- pkg/disttask/framework/storage/table_test.go | 20 ++++++- pkg/disttask/framework/storage/task_table.go | 4 +- .../framework/taskexecutor/manager_test.go | 2 +- 15 files changed, 138 insertions(+), 75 deletions(-) create mode 100644 pkg/disttask/framework/proto/type.go create mode 100644 pkg/disttask/framework/proto/type_test.go diff --git a/pkg/disttask/framework/framework_test.go b/pkg/disttask/framework/framework_test.go index 7dba395cd87eb..f937907f8ad68 100644 --- a/pkg/disttask/framework/framework_test.go +++ b/pkg/disttask/framework/framework_test.go @@ -140,7 +140,7 @@ func TestFrameworkWithQuery(t *testing.T) { distContext.Close() } -func TestFrameworkCancelGTask(t *testing.T) { +func TestFrameworkCancelTask(t *testing.T) { ctx, ctrl, testContext, distContext := testutil.InitTestContext(t, 2) defer ctrl.Finish() diff --git a/pkg/disttask/framework/handle/handle_test.go b/pkg/disttask/framework/handle/handle_test.go index e8792a9156db6..7a51944056a04 100644 --- a/pkg/disttask/framework/handle/handle_test.go +++ b/pkg/disttask/framework/handle/handle_test.go @@ -113,9 +113,7 @@ func TestRunWithRetry(t *testing.T) { ) require.Error(t, err) }() - require.Eventually(t, func() bool { - return end.Load() - }, 5*time.Second, 100*time.Millisecond) + require.Eventually(t, end.Load, 5*time.Second, 100*time.Millisecond) // fail with retryable error once, then success end.Store(false) @@ -134,9 +132,7 @@ func TestRunWithRetry(t *testing.T) { ) require.NoError(t, err) }() - require.Eventually(t, func() bool { - return end.Load() - }, 5*time.Second, 100*time.Millisecond) + require.Eventually(t, end.Load, 5*time.Second, 100*time.Millisecond) // context done subctx, cancel := context.WithCancel(ctx) diff --git a/pkg/disttask/framework/proto/BUILD.bazel b/pkg/disttask/framework/proto/BUILD.bazel index 8359f0783ea8d..caad86e2c2091 100644 --- a/pkg/disttask/framework/proto/BUILD.bazel +++ b/pkg/disttask/framework/proto/BUILD.bazel @@ -6,6 +6,7 @@ go_library( "node.go", "subtask.go", "task.go", + "type.go", ], importpath = "github.com/pingcap/tidb/pkg/disttask/framework/proto", visibility = ["//visibility:public"], @@ -17,9 +18,10 @@ go_test( srcs = [ "subtask_test.go", "task_test.go", + "type_test.go", ], embed = [":proto"], flaky = True, - shard_count = 4, + shard_count = 5, deps = ["@com_github_stretchr_testify//require"], ) diff --git a/pkg/disttask/framework/proto/subtask.go b/pkg/disttask/framework/proto/subtask.go index a33c44902e08b..9deea5d818d3d 100644 --- a/pkg/disttask/framework/proto/subtask.go +++ b/pkg/disttask/framework/proto/subtask.go @@ -54,10 +54,10 @@ const ( SubtaskStateFailed SubtaskState = "failed" SubtaskStateCanceled SubtaskState = "canceled" SubtaskStatePaused SubtaskState = "paused" + SubtaskStateRevertPending SubtaskState = "revert_pending" SubtaskStateReverting SubtaskState = "reverting" SubtaskStateReverted SubtaskState = "reverted" SubtaskStateRevertFailed SubtaskState = "revert_failed" - SubtaskStateRevertPending SubtaskState = "revert_pending" ) type ( diff --git a/pkg/disttask/framework/proto/task.go b/pkg/disttask/framework/proto/task.go index 2a8e3da716fda..2aab2059a1450 100644 --- a/pkg/disttask/framework/proto/task.go +++ b/pkg/disttask/framework/proto/task.go @@ -43,19 +43,17 @@ import ( // // TODO: we don't have revert_failed task for now. const ( - TaskStatePending TaskState = "pending" - TaskStateRunning TaskState = "running" - TaskStateSucceed TaskState = "succeed" - TaskStateReverting TaskState = "reverting" - TaskStateFailed TaskState = "failed" - TaskStateRevertFailed TaskState = "revert_failed" - TaskStateCancelling TaskState = "cancelling" - TaskStateCanceled TaskState = "canceled" - TaskStatePausing TaskState = "pausing" - TaskStatePaused TaskState = "paused" - TaskStateResuming TaskState = "resuming" - TaskStateRevertPending TaskState = "revert_pending" - TaskStateReverted TaskState = "reverted" + TaskStatePending TaskState = "pending" + TaskStateRunning TaskState = "running" + TaskStateSucceed TaskState = "succeed" + TaskStateFailed TaskState = "failed" + TaskStateReverting TaskState = "reverting" + TaskStateReverted TaskState = "reverted" + TaskStateRevertFailed TaskState = "revert_failed" + TaskStateCancelling TaskState = "cancelling" + TaskStatePausing TaskState = "pausing" + TaskStatePaused TaskState = "paused" + TaskStateResuming TaskState = "resuming" ) type ( @@ -148,36 +146,3 @@ func (t *Task) Compare(other *Task) int { } return int(t.ID - other.ID) } - -const ( - // TaskTypeExample is TaskType of Example. - TaskTypeExample TaskType = "Example" - // ImportInto is TaskType of ImportInto. - ImportInto TaskType = "ImportInto" - // Backfill is TaskType of add index Backfilling process. - Backfill TaskType = "backfill" -) - -// Type2Int converts task type to int. -func Type2Int(t TaskType) int { - switch t { - case TaskTypeExample: - return 1 - case ImportInto: - return 2 - default: - return 0 - } -} - -// Int2Type converts int to task type. -func Int2Type(i int) TaskType { - switch i { - case 1: - return TaskTypeExample - case 2: - return ImportInto - default: - return "" - } -} diff --git a/pkg/disttask/framework/proto/task_test.go b/pkg/disttask/framework/proto/task_test.go index f8b6861afa07a..a59981da333dd 100644 --- a/pkg/disttask/framework/proto/task_test.go +++ b/pkg/disttask/framework/proto/task_test.go @@ -39,7 +39,6 @@ func TestTaskIsDone(t *testing.T) { {TaskStateFailed, true}, {TaskStateRevertFailed, false}, {TaskStateCancelling, false}, - {TaskStateCanceled, false}, {TaskStatePausing, false}, {TaskStatePaused, false}, {TaskStateReverted, true}, diff --git a/pkg/disttask/framework/proto/type.go b/pkg/disttask/framework/proto/type.go new file mode 100644 index 0000000000000..b3283dc74947b --- /dev/null +++ b/pkg/disttask/framework/proto/type.go @@ -0,0 +1,52 @@ +// Copyright 2024 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package proto + +const ( + // TaskTypeExample is TaskType of Example. + TaskTypeExample TaskType = "Example" + // ImportInto is TaskType of ImportInto. + ImportInto TaskType = "ImportInto" + // Backfill is TaskType of add index Backfilling process. + Backfill TaskType = "backfill" +) + +// Type2Int converts task type to int. +func Type2Int(t TaskType) int { + switch t { + case TaskTypeExample: + return 1 + case ImportInto: + return 2 + case Backfill: + return 3 + default: + return 0 + } +} + +// Int2Type converts int to task type. +func Int2Type(i int) TaskType { + switch i { + case 1: + return TaskTypeExample + case 2: + return ImportInto + case 3: + return Backfill + default: + return "" + } +} diff --git a/pkg/disttask/framework/proto/type_test.go b/pkg/disttask/framework/proto/type_test.go new file mode 100644 index 0000000000000..5354e10e0d54f --- /dev/null +++ b/pkg/disttask/framework/proto/type_test.go @@ -0,0 +1,40 @@ +// Copyright 2024 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package proto + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestTaskType(t *testing.T) { + cases := []struct { + tp TaskType + val int + }{ + {TaskTypeExample, 1}, + {ImportInto, 2}, + {Backfill, 3}, + {"", 0}, + } + for _, c := range cases { + require.Equal(t, c.val, Type2Int(c.tp)) + } + + for _, c := range cases { + require.Equal(t, c.tp, Int2Type(c.val)) + } +} diff --git a/pkg/disttask/framework/scheduler/main_test.go b/pkg/disttask/framework/scheduler/main_test.go index fadbe0d4adbef..20e1f438ad003 100644 --- a/pkg/disttask/framework/scheduler/main_test.go +++ b/pkg/disttask/framework/scheduler/main_test.go @@ -21,12 +21,12 @@ import ( "go.uber.org/goleak" ) -// GetRunningGTaskCnt implements Scheduler.GetRunningGTaskCnt interface. +// GetRunningTaskCnt implements Scheduler.GetRunningTaskCnt interface. func (sm *Manager) GetRunningTaskCnt() int { return sm.getSchedulerCount() } -// DelRunningGTask implements Scheduler.DelRunningGTask interface. +// DelRunningTask implements Scheduler.DelRunningTask interface. func (sm *Manager) DelRunningTask(id int64) { sm.delScheduler(id) } @@ -48,9 +48,9 @@ func TestMain(m *testing.M) { testsetup.SetupForCommonTest() // Make test more fast. - checkTaskRunningInterval = checkTaskRunningInterval / 10 - checkTaskFinishedInterval = checkTaskFinishedInterval / 10 - RetrySQLInterval = RetrySQLInterval / 20 + checkTaskRunningInterval /= 10 + checkTaskFinishedInterval /= 10 + RetrySQLInterval /= 20 opts := []goleak.Option{ goleak.IgnoreTopFunction("github.com/golang/glog.(*fileSink).flushDaemon"), diff --git a/pkg/disttask/framework/scheduler/scheduler_test.go b/pkg/disttask/framework/scheduler/scheduler_test.go index 14749d92f1634..606fd3dafa44f 100644 --- a/pkg/disttask/framework/scheduler/scheduler_test.go +++ b/pkg/disttask/framework/scheduler/scheduler_test.go @@ -48,8 +48,6 @@ const ( subtaskCnt = 3 ) -var mockedAllServerInfos = []*infosync.ServerInfo{} - func getTestSchedulerExt(ctrl *gomock.Controller) scheduler.Extension { mockScheduler := mockDispatch.NewMockExtension(ctrl) mockScheduler.EXPECT().OnTick(gomock.Any(), gomock.Any()).Return().AnyTimes() @@ -486,7 +484,6 @@ func TestVerifyTaskStateTransform(t *testing.T) { {proto.TaskStateRunning, proto.TaskStatePausing, true}, {proto.TaskStateRunning, proto.TaskStateResuming, false}, {proto.TaskStateCancelling, proto.TaskStateRunning, false}, - {proto.TaskStateCanceled, proto.TaskStateRunning, false}, } for _, tc := range testCases { require.Equal(t, tc.expect, scheduler.VerifyTaskStateTransform(tc.oldState, tc.newState)) diff --git a/pkg/disttask/framework/scheduler/state_transform.go b/pkg/disttask/framework/scheduler/state_transform.go index 1e3a39a2614b5..1faaea344d396 100644 --- a/pkg/disttask/framework/scheduler/state_transform.go +++ b/pkg/disttask/framework/scheduler/state_transform.go @@ -45,10 +45,7 @@ func VerifyTaskStateTransform(from, to proto.TaskState) bool { proto.TaskStateRevertFailed: {}, proto.TaskStateCancelling: { proto.TaskStateReverting, - // no canceled now - // proto.TaskStateCanceled, }, - proto.TaskStateCanceled: {}, proto.TaskStatePausing: { proto.TaskStatePaused, }, @@ -58,8 +55,7 @@ func VerifyTaskStateTransform(from, to proto.TaskState) bool { proto.TaskStateResuming: { proto.TaskStateRunning, }, - proto.TaskStateRevertPending: {}, - proto.TaskStateReverted: {}, + proto.TaskStateReverted: {}, } if from == to { diff --git a/pkg/disttask/framework/storage/BUILD.bazel b/pkg/disttask/framework/storage/BUILD.bazel index ee7f49eb4d793..0af0ea5551dd9 100644 --- a/pkg/disttask/framework/storage/BUILD.bazel +++ b/pkg/disttask/framework/storage/BUILD.bazel @@ -40,7 +40,7 @@ go_test( embed = [":storage"], flaky = True, race = "on", - shard_count = 18, + shard_count = 19, deps = [ "//pkg/config", "//pkg/disttask/framework/proto", diff --git a/pkg/disttask/framework/storage/table_test.go b/pkg/disttask/framework/storage/table_test.go index a39026b92ff2c..8434a7ea7826b 100644 --- a/pkg/disttask/framework/storage/table_test.go +++ b/pkg/disttask/framework/storage/table_test.go @@ -816,7 +816,7 @@ func TestBothTaskAndSubTaskTable(t *testing.T) { // test transactional require.NoError(t, sm.DeleteSubtasksByTaskID(ctx, 1)) - failpoint.Enable("github.com/pingcap/tidb/pkg/disttask/framework/storage/MockUpdateTaskErr", "1*return(true)") + require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/pkg/disttask/framework/storage/MockUpdateTaskErr", "1*return(true)")) defer func() { require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/pkg/disttask/framework/storage/MockUpdateTaskErr")) }() @@ -987,7 +987,7 @@ func TestSubtaskHistoryTable(t *testing.T) { require.Len(t, subTasks, 3) // test GC history table. - failpoint.Enable("github.com/pingcap/tidb/pkg/disttask/framework/storage/subtaskHistoryKeepSeconds", "return(1)") + require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/pkg/disttask/framework/storage/subtaskHistoryKeepSeconds", "return(1)")) defer func() { require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/pkg/disttask/framework/storage/subtaskHistoryKeepSeconds")) }() @@ -1148,3 +1148,19 @@ func TestInitMeta(t *testing.T) { tk.MustExec(`set global tidb_service_scope="background"`) tk.MustQuery("select @@global.tidb_service_scope").Check(testkit.Rows("background")) } + +func TestSubtaskType(t *testing.T) { + _, sm, ctx := testutil.InitTableTest(t) + cases := []proto.TaskType{ + proto.TaskTypeExample, + proto.ImportInto, + proto.Backfill, + "", + } + for i, c := range cases { + testutil.InsertSubtask(t, sm, int64(i+1), proto.StepOne, "tidb-1", []byte(""), proto.SubtaskStateRunning, c, 12) + subtask, err := sm.GetFirstSubtaskInStates(ctx, "tidb-1", int64(i+1), proto.StepOne, proto.SubtaskStateRunning) + require.NoError(t, err) + require.Equal(t, c, subtask.Type) + } +} diff --git a/pkg/disttask/framework/storage/task_table.go b/pkg/disttask/framework/storage/task_table.go index b913e5426016a..e73278f4043f4 100644 --- a/pkg/disttask/framework/storage/task_table.go +++ b/pkg/disttask/framework/storage/task_table.go @@ -1064,9 +1064,9 @@ func (mgr *TaskManager) UpdateTaskAndAddSubTasks(ctx context.Context, task *prot } }) if len(subtasks) > 0 { - subtaskState := proto.TaskStatePending + subtaskState := proto.SubtaskStatePending if task.State == proto.TaskStateReverting { - subtaskState = proto.TaskStateRevertPending + subtaskState = proto.SubtaskStateRevertPending } sql := new(strings.Builder) diff --git a/pkg/disttask/framework/taskexecutor/manager_test.go b/pkg/disttask/framework/taskexecutor/manager_test.go index efafb5ced13bd..0fa7c246edcf2 100644 --- a/pkg/disttask/framework/taskexecutor/manager_test.go +++ b/pkg/disttask/framework/taskexecutor/manager_test.go @@ -94,7 +94,7 @@ func TestManageTask(t *testing.T) { ctx4, cancel4 := context.WithCancelCause(context.Background()) m.registerCancelFunc(1, cancel4) mockTaskTable.EXPECT().PauseSubtasks(m.ctx, "test", int64(1)).Return(nil) - m.onPausingTasks([]*proto.Task{{ID: 1}}) + require.NoError(t, m.onPausingTasks([]*proto.Task{{ID: 1}})) require.Equal(t, context.Canceled, ctx4.Err()) }