diff --git a/pkg/ddl/backfilling_dist_scheduler.go b/pkg/ddl/backfilling_dist_scheduler.go index 06eac3461f443..abd0423e2eaa1 100644 --- a/pkg/ddl/backfilling_dist_scheduler.go +++ b/pkg/ddl/backfilling_dist_scheduler.go @@ -30,7 +30,6 @@ import ( "github.com/pingcap/tidb/pkg/ddl/ingest" "github.com/pingcap/tidb/pkg/disttask/framework/proto" "github.com/pingcap/tidb/pkg/disttask/framework/scheduler" - "github.com/pingcap/tidb/pkg/domain/infosync" "github.com/pingcap/tidb/pkg/kv" "github.com/pingcap/tidb/pkg/meta" "github.com/pingcap/tidb/pkg/parser/model" @@ -70,7 +69,7 @@ func (sch *BackfillingSchedulerExt) OnNextSubtasksBatch( ctx context.Context, taskHandle scheduler.TaskHandle, task *proto.Task, - serverInfo []*infosync.ServerInfo, + execIDs []string, nextStep proto.Step, ) (taskMeta [][]byte, err error) { logger := logutil.BgLogger().With( @@ -96,7 +95,7 @@ func (sch *BackfillingSchedulerExt) OnNextSubtasksBatch( if tblInfo.Partition != nil { return generatePartitionPlan(tblInfo) } - return generateNonPartitionPlan(sch.d, tblInfo, job, sch.GlobalSort, len(serverInfo)) + return generateNonPartitionPlan(sch.d, tblInfo, job, sch.GlobalSort, len(execIDs)) case StepMergeSort: res, err := generateMergePlan(taskHandle, task, logger) if err != nil { @@ -181,12 +180,8 @@ func (*BackfillingSchedulerExt) OnDone(_ context.Context, _ scheduler.TaskHandle } // GetEligibleInstances implements scheduler.Extension interface. -func (*BackfillingSchedulerExt) GetEligibleInstances(ctx context.Context, _ *proto.Task) ([]*infosync.ServerInfo, bool, error) { - serverInfos, err := scheduler.GenerateTaskExecutorNodes(ctx) - if err != nil { - return nil, true, err - } - return serverInfos, true, nil +func (*BackfillingSchedulerExt) GetEligibleInstances(_ context.Context, _ *proto.Task) ([]string, error) { + return nil, nil } // IsRetryableErr implements scheduler.Extension.IsRetryableErr interface. @@ -201,10 +196,10 @@ type LitBackfillScheduler struct { } func newLitBackfillScheduler(ctx context.Context, d *ddl, taskMgr scheduler.TaskManager, - serverID string, task *proto.Task) scheduler.Scheduler { + nodeMgr *scheduler.NodeManager, task *proto.Task) scheduler.Scheduler { sch := LitBackfillScheduler{ d: d, - BaseScheduler: scheduler.NewBaseScheduler(ctx, taskMgr, serverID, task), + BaseScheduler: scheduler.NewBaseScheduler(ctx, taskMgr, nodeMgr, task), } return &sch } diff --git a/pkg/ddl/backfilling_dist_scheduler_test.go b/pkg/ddl/backfilling_dist_scheduler_test.go index 9d16f0b7f2a0b..3542300983aad 100644 --- a/pkg/ddl/backfilling_dist_scheduler_test.go +++ b/pkg/ddl/backfilling_dist_scheduler_test.go @@ -69,9 +69,8 @@ func TestBackfillingSchedulerLocalMode(t *testing.T) { // 1.1 OnNextSubtasksBatch task.Step = sch.GetNextStep(task) require.Equal(t, ddl.StepReadIndex, task.Step) - serverInfos, _, err := sch.GetEligibleInstances(context.Background(), task) - require.NoError(t, err) - metas, err := sch.OnNextSubtasksBatch(context.Background(), nil, task, serverInfos, task.Step) + execIDs := []string{":4000"} + metas, err := sch.OnNextSubtasksBatch(context.Background(), nil, task, execIDs, task.Step) require.NoError(t, err) require.Equal(t, len(tblInfo.Partition.Definitions), len(metas)) for i, par := range tblInfo.Partition.Definitions { @@ -84,7 +83,7 @@ func TestBackfillingSchedulerLocalMode(t *testing.T) { task.State = proto.TaskStateRunning task.Step = sch.GetNextStep(task) require.Equal(t, proto.StepDone, task.Step) - metas, err = sch.OnNextSubtasksBatch(context.Background(), nil, task, serverInfos, task.Step) + metas, err = sch.OnNextSubtasksBatch(context.Background(), nil, task, execIDs, task.Step) require.NoError(t, err) require.Len(t, metas, 0) @@ -96,7 +95,7 @@ func TestBackfillingSchedulerLocalMode(t *testing.T) { // 2.1 empty table tk.MustExec("create table t1(id int primary key, v int)") task = createAddIndexTask(t, dom, "test", "t1", proto.Backfill, false) - metas, err = sch.OnNextSubtasksBatch(context.Background(), nil, task, serverInfos, task.Step) + metas, err = sch.OnNextSubtasksBatch(context.Background(), nil, task, execIDs, task.Step) require.NoError(t, err) require.Equal(t, 0, len(metas)) // 2.2 non empty table. @@ -108,7 +107,7 @@ func TestBackfillingSchedulerLocalMode(t *testing.T) { task = createAddIndexTask(t, dom, "test", "t2", proto.Backfill, false) // 2.2.1 stepInit task.Step = sch.GetNextStep(task) - metas, err = sch.OnNextSubtasksBatch(context.Background(), nil, task, serverInfos, task.Step) + metas, err = sch.OnNextSubtasksBatch(context.Background(), nil, task, execIDs, task.Step) require.NoError(t, err) require.Equal(t, 1, len(metas)) require.Equal(t, ddl.StepReadIndex, task.Step) @@ -116,7 +115,7 @@ func TestBackfillingSchedulerLocalMode(t *testing.T) { task.State = proto.TaskStateRunning task.Step = sch.GetNextStep(task) require.Equal(t, proto.StepDone, task.Step) - metas, err = sch.OnNextSubtasksBatch(context.Background(), nil, task, serverInfos, task.Step) + metas, err = sch.OnNextSubtasksBatch(context.Background(), nil, task, execIDs, task.Step) require.NoError(t, err) require.Equal(t, 0, len(metas)) } @@ -173,11 +172,10 @@ func TestBackfillingSchedulerGlobalSortMode(t *testing.T) { taskID, err := mgr.CreateTask(ctx, task.Key, proto.Backfill, 1, task.Meta) require.NoError(t, err) task.ID = taskID - serverInfos, _, err := sch.GetEligibleInstances(context.Background(), task) - require.NoError(t, err) + execIDs := []string{":4000"} // 1. to read-index stage - subtaskMetas, err := sch.OnNextSubtasksBatch(ctx, sch, task, serverInfos, sch.GetNextStep(task)) + subtaskMetas, err := sch.OnNextSubtasksBatch(ctx, sch, task, execIDs, sch.GetNextStep(task)) require.NoError(t, err) require.Len(t, subtaskMetas, 1) task.Step = ext.GetNextStep(task) @@ -217,7 +215,7 @@ func TestBackfillingSchedulerGlobalSortMode(t *testing.T) { t.Cleanup(func() { require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/pkg/ddl/forceMergeSort")) }) - subtaskMetas, err = ext.OnNextSubtasksBatch(ctx, sch, task, serverInfos, ext.GetNextStep(task)) + subtaskMetas, err = ext.OnNextSubtasksBatch(ctx, sch, task, execIDs, ext.GetNextStep(task)) require.NoError(t, err) require.Len(t, subtaskMetas, 1) task.Step = ext.GetNextStep(task) @@ -256,13 +254,13 @@ func TestBackfillingSchedulerGlobalSortMode(t *testing.T) { t.Cleanup(func() { require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/pkg/ddl/mockWriteIngest")) }) - subtaskMetas, err = ext.OnNextSubtasksBatch(ctx, sch, task, serverInfos, ext.GetNextStep(task)) + subtaskMetas, err = ext.OnNextSubtasksBatch(ctx, sch, task, execIDs, ext.GetNextStep(task)) require.NoError(t, err) require.Len(t, subtaskMetas, 1) task.Step = ext.GetNextStep(task) require.Equal(t, ddl.StepWriteAndIngest, task.Step) // 4. to done stage. - subtaskMetas, err = ext.OnNextSubtasksBatch(ctx, sch, task, serverInfos, ext.GetNextStep(task)) + subtaskMetas, err = ext.OnNextSubtasksBatch(ctx, sch, task, execIDs, ext.GetNextStep(task)) require.NoError(t, err) require.Len(t, subtaskMetas, 0) task.Step = ext.GetNextStep(task) diff --git a/pkg/ddl/ddl.go b/pkg/ddl/ddl.go index 556dc5641405b..c0be4a0c6509c 100644 --- a/pkg/ddl/ddl.go +++ b/pkg/ddl/ddl.go @@ -684,8 +684,8 @@ func newDDL(ctx context.Context, options ...Option) *ddl { ) scheduler.RegisterSchedulerFactory(proto.Backfill, - func(ctx context.Context, taskMgr scheduler.TaskManager, serverID string, task *proto.Task) scheduler.Scheduler { - return newLitBackfillScheduler(ctx, d, taskMgr, serverID, task) + func(ctx context.Context, taskMgr scheduler.TaskManager, nodeMgr *scheduler.NodeManager, task *proto.Task) scheduler.Scheduler { + return newLitBackfillScheduler(ctx, d, taskMgr, nodeMgr, task) }) scheduler.RegisterSchedulerCleanUpFactory(proto.Backfill, newBackfillCleanUpS3) // Register functions for enable/disable ddl when changing system variable `tidb_enable_ddl`. diff --git a/pkg/disttask/framework/mock/scheduler_mock.go b/pkg/disttask/framework/mock/scheduler_mock.go index 796b4551fcbbf..8dc25566712ec 100644 --- a/pkg/disttask/framework/mock/scheduler_mock.go +++ b/pkg/disttask/framework/mock/scheduler_mock.go @@ -152,20 +152,6 @@ func (mr *MockTaskManagerMockRecorder) CancelTask(arg0, arg1 any) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CancelTask", reflect.TypeOf((*MockTaskManager)(nil).CancelTask), arg0, arg1) } -// CleanUpMeta mocks base method. -func (m *MockTaskManager) CleanUpMeta(arg0 context.Context, arg1 []string) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "CleanUpMeta", arg0, arg1) - ret0, _ := ret[0].(error) - return ret0 -} - -// CleanUpMeta indicates an expected call of CleanUpMeta. -func (mr *MockTaskManagerMockRecorder) CleanUpMeta(arg0, arg1 any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CleanUpMeta", reflect.TypeOf((*MockTaskManager)(nil).CleanUpMeta), arg0, arg1) -} - // CollectSubTaskError mocks base method. func (m *MockTaskManager) CollectSubTaskError(arg0 context.Context, arg1 int64) ([]error, error) { m.ctrl.T.Helper() @@ -181,6 +167,20 @@ func (mr *MockTaskManagerMockRecorder) CollectSubTaskError(arg0, arg1 any) *gomo return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CollectSubTaskError", reflect.TypeOf((*MockTaskManager)(nil).CollectSubTaskError), arg0, arg1) } +// DeleteDeadNodes mocks base method. +func (m *MockTaskManager) DeleteDeadNodes(arg0 context.Context, arg1 []string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DeleteDeadNodes", arg0, arg1) + ret0, _ := ret[0].(error) + return ret0 +} + +// DeleteDeadNodes indicates an expected call of DeleteDeadNodes. +func (mr *MockTaskManagerMockRecorder) DeleteDeadNodes(arg0, arg1 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteDeadNodes", reflect.TypeOf((*MockTaskManager)(nil).DeleteDeadNodes), arg0, arg1) +} + // FailTask mocks base method. func (m *MockTaskManager) FailTask(arg0 context.Context, arg1 int64, arg2 proto.TaskState, arg3 error) error { m.ctrl.T.Helper() diff --git a/pkg/disttask/framework/scheduler/BUILD.bazel b/pkg/disttask/framework/scheduler/BUILD.bazel index 90a0e4aca1394..2e3c7dac4caf5 100644 --- a/pkg/disttask/framework/scheduler/BUILD.bazel +++ b/pkg/disttask/framework/scheduler/BUILD.bazel @@ -4,6 +4,7 @@ go_library( name = "scheduler", srcs = [ "interface.go", + "nodes.go", "scheduler.go", "scheduler_manager.go", "slots.go", @@ -12,6 +13,7 @@ go_library( importpath = "github.com/pingcap/tidb/pkg/disttask/framework/scheduler", visibility = ["//visibility:public"], deps = [ + "//br/pkg/lightning/log", "//pkg/disttask/framework/handle", "//pkg/disttask/framework/proto", "//pkg/disttask/framework/storage", @@ -39,6 +41,7 @@ go_test( timeout = "short", srcs = [ "main_test.go", + "nodes_test.go", "rebalance_test.go", "scheduler_manager_test.go", "scheduler_test.go", @@ -47,7 +50,7 @@ go_test( embed = [":scheduler"], flaky = True, race = "off", - shard_count = 23, + shard_count = 24, deps = [ "//pkg/config", "//pkg/disttask/framework/mock", diff --git a/pkg/disttask/framework/scheduler/interface.go b/pkg/disttask/framework/scheduler/interface.go index ea1e65472dc0b..ebe9e8d014be8 100644 --- a/pkg/disttask/framework/scheduler/interface.go +++ b/pkg/disttask/framework/scheduler/interface.go @@ -18,7 +18,6 @@ import ( "context" "github.com/pingcap/tidb/pkg/disttask/framework/proto" - "github.com/pingcap/tidb/pkg/domain/infosync" "github.com/pingcap/tidb/pkg/sessionctx" "github.com/pingcap/tidb/pkg/util/syncutil" ) @@ -35,7 +34,7 @@ type TaskManager interface { UpdateTaskAndAddSubTasks(ctx context.Context, task *proto.Task, subtasks []*proto.Subtask, prevState proto.TaskState) (bool, error) GCSubtasks(ctx context.Context) error GetAllNodes(ctx context.Context) ([]string, error) - CleanUpMeta(ctx context.Context, nodes []string) error + DeleteDeadNodes(ctx context.Context, nodes []string) error TransferTasks2History(ctx context.Context, tasks []*proto.Task) error CancelTask(ctx context.Context, taskID int64) error // FailTask updates task state to Failed and updates task error. @@ -95,7 +94,7 @@ type Extension interface { // 1. task is pending and entering it's first step. // 2. subtasks scheduled has all finished with no error. // when next step is StepDone, it should return nil, nil. - OnNextSubtasksBatch(ctx context.Context, h TaskHandle, task *proto.Task, serverInfo []*infosync.ServerInfo, step proto.Step) (subtaskMetas [][]byte, err error) + OnNextSubtasksBatch(ctx context.Context, h TaskHandle, task *proto.Task, execIDs []string, step proto.Step) (subtaskMetas [][]byte, err error) // OnDone is called when task is done, either finished successfully or failed // with error. @@ -105,8 +104,10 @@ type Extension interface { // GetEligibleInstances is used to get the eligible instances for the task. // on certain condition we may want to use some instances to do the task, such as instances with more disk. - // The bool return value indicates whether filter instances by role. - GetEligibleInstances(ctx context.Context, task *proto.Task) ([]*infosync.ServerInfo, bool, error) + // if returned instances is empty, it means all instances are eligible. + // TODO: run import from server disk using framework makes this logic complicated, + // the instance might not be managed by framework. + GetEligibleInstances(ctx context.Context, task *proto.Task) ([]string, error) // IsRetryableErr is used to check whether the error occurred in scheduler is retryable. IsRetryableErr(err error) bool @@ -118,7 +119,7 @@ type Extension interface { } // schedulerFactoryFn is used to create a scheduler. -type schedulerFactoryFn func(ctx context.Context, taskMgr TaskManager, serverID string, task *proto.Task) Scheduler +type schedulerFactoryFn func(ctx context.Context, taskMgr TaskManager, nodeMgr *NodeManager, task *proto.Task) Scheduler var schedulerFactoryMap = struct { syncutil.RWMutex diff --git a/pkg/disttask/framework/scheduler/main_test.go b/pkg/disttask/framework/scheduler/main_test.go index effa0f18fc927..7dde19a1bd6ec 100644 --- a/pkg/disttask/framework/scheduler/main_test.go +++ b/pkg/disttask/framework/scheduler/main_test.go @@ -21,32 +21,33 @@ import ( "go.uber.org/goleak" ) -// SchedulerForTest exports for testing. -type SchedulerManagerForTest interface { - GetRunningTaskCnt() int - DelRunningTask(id int64) - DoCleanUpRoutine() -} - // GetRunningGTaskCnt implements Scheduler.GetRunningGTaskCnt interface. -func (dm *Manager) GetRunningTaskCnt() int { - return dm.getSchedulerCount() +func (sm *Manager) GetRunningTaskCnt() int { + return sm.getSchedulerCount() } // DelRunningGTask implements Scheduler.DelRunningGTask interface. -func (dm *Manager) DelRunningTask(id int64) { - dm.delScheduler(id) +func (sm *Manager) DelRunningTask(id int64) { + sm.delScheduler(id) } // DoCleanUpRoutine implements Scheduler.DoCleanUpRoutine interface. -func (dm *Manager) DoCleanUpRoutine() { - dm.doCleanUpRoutine() +func (sm *Manager) DoCleanUpRoutine() { + sm.doCleanupTask() } func (s *BaseScheduler) OnNextStage() (err error) { return s.onNextStage() } +func (s *BaseScheduler) DoBalanceSubtasks(eligibleNodes []string) error { + return s.doBalanceSubtasks(eligibleNodes) +} + +func NewNodeManager() *NodeManager { + return newNodeManager() +} + func TestMain(m *testing.M) { testsetup.SetupForCommonTest() diff --git a/pkg/disttask/framework/scheduler/mock/BUILD.bazel b/pkg/disttask/framework/scheduler/mock/BUILD.bazel index 890488b52257b..541b2013c3bdd 100644 --- a/pkg/disttask/framework/scheduler/mock/BUILD.bazel +++ b/pkg/disttask/framework/scheduler/mock/BUILD.bazel @@ -8,7 +8,6 @@ go_library( deps = [ "//pkg/disttask/framework/proto", "//pkg/disttask/framework/scheduler", - "//pkg/domain/infosync", "@org_uber_go_mock//gomock", ], ) diff --git a/pkg/disttask/framework/scheduler/mock/scheduler_mock.go b/pkg/disttask/framework/scheduler/mock/scheduler_mock.go index f7ea143c00c78..c65047b74aeda 100644 --- a/pkg/disttask/framework/scheduler/mock/scheduler_mock.go +++ b/pkg/disttask/framework/scheduler/mock/scheduler_mock.go @@ -14,7 +14,6 @@ import ( proto "github.com/pingcap/tidb/pkg/disttask/framework/proto" scheduler "github.com/pingcap/tidb/pkg/disttask/framework/scheduler" - infosync "github.com/pingcap/tidb/pkg/domain/infosync" gomock "go.uber.org/mock/gomock" ) @@ -42,13 +41,12 @@ func (m *MockExtension) EXPECT() *MockExtensionMockRecorder { } // GetEligibleInstances mocks base method. -func (m *MockExtension) GetEligibleInstances(arg0 context.Context, arg1 *proto.Task) ([]*infosync.ServerInfo, bool, error) { +func (m *MockExtension) GetEligibleInstances(arg0 context.Context, arg1 *proto.Task) ([]string, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "GetEligibleInstances", arg0, arg1) - ret0, _ := ret[0].([]*infosync.ServerInfo) - ret1, _ := ret[1].(bool) - ret2, _ := ret[2].(error) - return ret0, ret1, ret2 + ret0, _ := ret[0].([]string) + ret1, _ := ret[1].(error) + return ret0, ret1 } // GetEligibleInstances indicates an expected call of GetEligibleInstances. @@ -100,7 +98,7 @@ func (mr *MockExtensionMockRecorder) OnDone(arg0, arg1, arg2 any) *gomock.Call { } // OnNextSubtasksBatch mocks base method. -func (m *MockExtension) OnNextSubtasksBatch(arg0 context.Context, arg1 scheduler.TaskHandle, arg2 *proto.Task, arg3 []*infosync.ServerInfo, arg4 proto.Step) ([][]byte, error) { +func (m *MockExtension) OnNextSubtasksBatch(arg0 context.Context, arg1 scheduler.TaskHandle, arg2 *proto.Task, arg3 []string, arg4 proto.Step) ([][]byte, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "OnNextSubtasksBatch", arg0, arg1, arg2, arg3, arg4) ret0, _ := ret[0].([][]byte) diff --git a/pkg/disttask/framework/scheduler/nodes.go b/pkg/disttask/framework/scheduler/nodes.go new file mode 100644 index 0000000000000..47bea22172e52 --- /dev/null +++ b/pkg/disttask/framework/scheduler/nodes.go @@ -0,0 +1,143 @@ +// Copyright 2023 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package scheduler + +import ( + "context" + "sync/atomic" + "time" + + "github.com/pingcap/tidb/br/pkg/lightning/log" + disttaskutil "github.com/pingcap/tidb/pkg/util/disttask" + "github.com/pingcap/tidb/pkg/util/logutil" + "go.uber.org/zap" +) + +var ( + // liveNodesCheckInterval is the tick interval of fetching all server infos from etcs. + nodesCheckInterval = 2 * checkTaskFinishedInterval +) + +// NodeManager maintains live TiDB nodes in the cluster, and maintains the nodes +// managed by the framework. +type NodeManager struct { + // prevLiveNodes is used to record the live nodes in last checking. + prevLiveNodes map[string]struct{} + // managedNodes is the cached nodes managed by the framework. + // see TaskManager.GetManagedNodes for more details. + managedNodes atomic.Pointer[[]string] +} + +func newNodeManager() *NodeManager { + nm := &NodeManager{ + prevLiveNodes: make(map[string]struct{}), + } + managedNodes := make([]string, 0, 10) + nm.managedNodes.Store(&managedNodes) + return nm +} + +func (nm *NodeManager) maintainLiveNodesLoop(ctx context.Context, taskMgr TaskManager) { + ticker := time.NewTicker(nodesCheckInterval) + defer ticker.Stop() + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + nm.maintainLiveNodes(ctx, taskMgr) + } + } +} + +// maintainLiveNodes manages live node info in dist_framework_meta table +// see recoverMetaLoop in task executor for when node is inserted into dist_framework_meta. +func (nm *NodeManager) maintainLiveNodes(ctx context.Context, taskMgr TaskManager) { + // Safe to discard errors since this function can be called at regular intervals. + serverInfos, err := GenerateTaskExecutorNodes(ctx) + if err != nil { + logutil.BgLogger().Warn("generate task executor nodes met error", log.ShortError(err)) + return + } + nodeChanged := len(serverInfos) != len(nm.prevLiveNodes) + currLiveNodes := make(map[string]struct{}, len(serverInfos)) + for _, info := range serverInfos { + execID := disttaskutil.GenerateExecID(info) + if _, ok := nm.prevLiveNodes[execID]; !ok { + nodeChanged = true + } + currLiveNodes[execID] = struct{}{} + } + if !nodeChanged { + return + } + + oldNodes, err := taskMgr.GetAllNodes(ctx) + if err != nil { + logutil.BgLogger().Warn("get all nodes met error", log.ShortError(err)) + return + } + + deadNodes := make([]string, 0) + for _, nodeID := range oldNodes { + if _, ok := currLiveNodes[nodeID]; !ok { + deadNodes = append(deadNodes, nodeID) + } + } + if len(deadNodes) == 0 { + nm.prevLiveNodes = currLiveNodes + return + } + logutil.BgLogger().Info("delete dead nodes from dist_framework_meta", + zap.Int("dead-nodes", len(deadNodes))) + err = taskMgr.DeleteDeadNodes(ctx, deadNodes) + if err != nil { + logutil.BgLogger().Warn("delete dead nodes from dist_framework_meta failed", log.ShortError(err)) + return + } + nm.prevLiveNodes = currLiveNodes +} + +func (nm *NodeManager) refreshManagedNodesLoop(ctx context.Context, taskMgr TaskManager) { + ticker := time.NewTicker(nodesCheckInterval) + defer ticker.Stop() + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + nm.refreshManagedNodes(ctx, taskMgr) + } + } +} + +// refreshManagedNodes maintains the nodes managed by the framework. +func (nm *NodeManager) refreshManagedNodes(ctx context.Context, taskMgr TaskManager) { + newNodes, err := taskMgr.GetManagedNodes(ctx) + if err != nil { + logutil.BgLogger().Warn("get managed nodes met error", log.ShortError(err)) + return + } + if newNodes == nil { + newNodes = []string{} + } + nm.managedNodes.Store(&newNodes) +} + +// GetManagedNodes returns the nodes managed by the framework. +// The returned map is read-only, don't write to it. +func (nm *NodeManager) getManagedNodes() []string { + return *nm.managedNodes.Load() +} diff --git a/pkg/disttask/framework/scheduler/nodes_test.go b/pkg/disttask/framework/scheduler/nodes_test.go new file mode 100644 index 0000000000000..e6ff715773de3 --- /dev/null +++ b/pkg/disttask/framework/scheduler/nodes_test.go @@ -0,0 +1,119 @@ +// Copyright 2023 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package scheduler + +import ( + "context" + "testing" + + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/tidb/pkg/disttask/framework/mock" + "github.com/pingcap/tidb/pkg/domain/infosync" + "github.com/stretchr/testify/require" + "go.uber.org/mock/gomock" +) + +func TestMaintainLiveNodes(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + mockTaskMgr := mock.NewMockTaskManager(ctrl) + + require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/pkg/disttask/framework/scheduler/mockTaskExecutorNodes", "return()")) + t.Cleanup(func() { + require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/pkg/disttask/framework/scheduler/mockTaskExecutorNodes")) + }) + + MockServerInfo = []*infosync.ServerInfo{ + {Port: 4000}, + } + + nodeMgr := newNodeManager() + ctx := context.Background() + mockTaskMgr.EXPECT().GetAllNodes(gomock.Any()).Return(nil, errors.New("mock error")) + nodeMgr.maintainLiveNodes(ctx, mockTaskMgr) + require.Empty(t, nodeMgr.prevLiveNodes) + require.True(t, ctrl.Satisfied()) + // no change + mockTaskMgr.EXPECT().GetAllNodes(gomock.Any()).Return([]string{":4000"}, nil) + nodeMgr.maintainLiveNodes(ctx, mockTaskMgr) + require.Equal(t, map[string]struct{}{":4000": {}}, nodeMgr.prevLiveNodes) + require.True(t, ctrl.Satisfied()) + // run again, return fast + nodeMgr.maintainLiveNodes(ctx, mockTaskMgr) + require.Equal(t, map[string]struct{}{":4000": {}}, nodeMgr.prevLiveNodes) + require.True(t, ctrl.Satisfied()) + + // scale out 1 node + MockServerInfo = []*infosync.ServerInfo{ + {Port: 4000}, + {Port: 4001}, + } + + // fail on clean + mockTaskMgr.EXPECT().GetAllNodes(gomock.Any()).Return([]string{":4000", ":4001", ":4002"}, nil) + mockTaskMgr.EXPECT().DeleteDeadNodes(gomock.Any(), gomock.Any()).Return(errors.New("mock error")) + nodeMgr.maintainLiveNodes(ctx, mockTaskMgr) + require.Equal(t, map[string]struct{}{":4000": {}}, nodeMgr.prevLiveNodes) + require.True(t, ctrl.Satisfied()) + // remove 1 node + mockTaskMgr.EXPECT().GetAllNodes(gomock.Any()).Return([]string{":4000", ":4001", ":4002"}, nil) + mockTaskMgr.EXPECT().DeleteDeadNodes(gomock.Any(), gomock.Any()).Return(nil) + nodeMgr.maintainLiveNodes(ctx, mockTaskMgr) + require.Equal(t, map[string]struct{}{":4000": {}, ":4001": {}}, nodeMgr.prevLiveNodes) + require.True(t, ctrl.Satisfied()) + // run again, return fast + nodeMgr.maintainLiveNodes(ctx, mockTaskMgr) + require.Equal(t, map[string]struct{}{":4000": {}, ":4001": {}}, nodeMgr.prevLiveNodes) + require.True(t, ctrl.Satisfied()) + + // scale in 1 node + MockServerInfo = []*infosync.ServerInfo{ + {Port: 4000}, + } + + mockTaskMgr.EXPECT().GetAllNodes(gomock.Any()).Return([]string{":4000", ":4001", ":4002"}, nil) + mockTaskMgr.EXPECT().DeleteDeadNodes(gomock.Any(), gomock.Any()).Return(nil) + nodeMgr.maintainLiveNodes(ctx, mockTaskMgr) + require.Equal(t, map[string]struct{}{":4000": {}}, nodeMgr.prevLiveNodes) + require.True(t, ctrl.Satisfied()) + // run again, return fast + nodeMgr.maintainLiveNodes(ctx, mockTaskMgr) + require.Equal(t, map[string]struct{}{":4000": {}}, nodeMgr.prevLiveNodes) + require.True(t, ctrl.Satisfied()) +} + +func TestMaintainManagedNodes(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + ctx := context.Background() + mockTaskMgr := mock.NewMockTaskManager(ctrl) + nodeMgr := newNodeManager() + + mockTaskMgr.EXPECT().GetManagedNodes(gomock.Any()).Return(nil, errors.New("mock error")) + nodeMgr.refreshManagedNodes(ctx, mockTaskMgr) + require.Empty(t, nodeMgr.getManagedNodes()) + require.True(t, ctrl.Satisfied()) + + mockTaskMgr.EXPECT().GetManagedNodes(gomock.Any()).Return([]string{":4000", ":4001"}, nil) + nodeMgr.refreshManagedNodes(ctx, mockTaskMgr) + require.Equal(t, []string{":4000", ":4001"}, nodeMgr.getManagedNodes()) + require.True(t, ctrl.Satisfied()) + mockTaskMgr.EXPECT().GetManagedNodes(gomock.Any()).Return(nil, nil) + nodeMgr.refreshManagedNodes(ctx, mockTaskMgr) + require.NotNil(t, nodeMgr.getManagedNodes()) + require.Empty(t, nodeMgr.getManagedNodes()) + require.True(t, ctrl.Satisfied()) +} diff --git a/pkg/disttask/framework/scheduler/rebalance_test.go b/pkg/disttask/framework/scheduler/rebalance_test.go index f11efc8393c21..61387d5b0ce3d 100644 --- a/pkg/disttask/framework/scheduler/rebalance_test.go +++ b/pkg/disttask/framework/scheduler/rebalance_test.go @@ -25,7 +25,6 @@ import ( "github.com/pingcap/tidb/pkg/disttask/framework/mock" "github.com/pingcap/tidb/pkg/disttask/framework/proto" "github.com/pingcap/tidb/pkg/disttask/framework/scheduler" - "github.com/pingcap/tidb/pkg/domain/infosync" "github.com/pingcap/tidb/pkg/testkit" "github.com/stretchr/testify/require" "go.uber.org/mock/gomock" @@ -33,7 +32,7 @@ import ( type scaleTestCase struct { subtasks []*proto.Subtask - liveNodes []*infosync.ServerInfo + liveNodes []string taskNodes []string cleanedNodes []string expectedTaskNodes []string @@ -42,7 +41,7 @@ type scaleTestCase struct { type balanceTestCase struct { subtasks []*proto.Subtask - liveNodes []*infosync.ServerInfo + liveNodes []string taskNodes []string expectedSubtasks []*proto.Subtask } @@ -56,14 +55,14 @@ func scaleTest(t *testing.T, testCase.subtasks, nil) mockTaskMgr.EXPECT().UpdateSubtasksExecIDs(ctx, int64(id), testCase.subtasks).Return(nil).AnyTimes() - mockTaskMgr.EXPECT().CleanUpMeta(ctx, testCase.cleanedNodes).Return(nil).AnyTimes() + mockTaskMgr.EXPECT().DeleteDeadNodes(ctx, testCase.cleanedNodes).Return(nil).AnyTimes() if len(testCase.cleanedNodes) > 0 { mockTaskMgr.EXPECT().GetSubtasksByExecIdsAndStepAndState(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, nil) } - sch := scheduler.NewBaseScheduler(ctx, mockTaskMgr, "server", &proto.Task{Step: proto.StepInit, ID: int64(id)}) - sch.LiveNodes = testCase.liveNodes + nodeMgr := scheduler.NewNodeManager() + sch := scheduler.NewBaseScheduler(ctx, mockTaskMgr, nodeMgr, &proto.Task{Step: proto.StepInit, ID: int64(id)}) sch.TaskNodes = testCase.taskNodes - require.NoError(t, sch.ReDispatchSubtasks()) + require.NoError(t, sch.DoBalanceSubtasks(testCase.liveNodes)) slices.SortFunc(sch.TaskNodes, func(i, j string) int { return strings.Compare(i, j) }) @@ -82,13 +81,13 @@ func balanceTest(t *testing.T, mockTaskMgr.EXPECT().GetSubtasksByStepAndState(ctx, int64(id), proto.StepInit, proto.TaskStatePending).Return( testCase.subtasks, nil) - mockTaskMgr.EXPECT().CleanUpMeta(ctx, gomock.Any()).Return(nil).AnyTimes() + mockTaskMgr.EXPECT().DeleteDeadNodes(ctx, gomock.Any()).Return(nil).AnyTimes() + nodeMgr := scheduler.NewNodeManager() mockTaskMgr.EXPECT().UpdateSubtasksExecIDs(ctx, int64(id), testCase.subtasks).Return(nil).AnyTimes() - sch := scheduler.NewBaseScheduler(ctx, mockTaskMgr, "server", &proto.Task{Step: proto.StepInit, ID: int64(id)}) - sch.LiveNodes = testCase.liveNodes + sch := scheduler.NewBaseScheduler(ctx, mockTaskMgr, nodeMgr, &proto.Task{Step: proto.StepInit, ID: int64(id)}) sch.TaskNodes = testCase.taskNodes - require.NoError(t, sch.ReDispatchSubtasks()) + require.NoError(t, sch.DoBalanceSubtasks(testCase.liveNodes)) slices.SortFunc(sch.TaskNodes, func(i, j string) int { return strings.Compare(i, j) }) @@ -116,7 +115,7 @@ func TestScaleOutNodes(t *testing.T) { {ExecID: "1.1.1.1:4000"}, {ExecID: "1.1.1.1:4000"}, {ExecID: "1.1.1.1:4000"}}, - []*infosync.ServerInfo{{IP: "1.1.1.1", Port: 4000}, {IP: "1.1.1.2", Port: 4000}}, + []string{"1.1.1.1:4000", "1.1.1.2:4000"}, []string{"1.1.1.1:4000"}, []string{}, []string{"1.1.1.1:4000", "1.1.1.2:4000"}, @@ -132,7 +131,7 @@ func TestScaleOutNodes(t *testing.T) { {ExecID: "1.1.1.1:4000"}, {ExecID: "1.1.1.1:4000"}, {ExecID: "1.1.1.1:4000"}}, - []*infosync.ServerInfo{{IP: "1.1.1.1", Port: 4000}, {IP: "1.1.1.2", Port: 4000}}, + []string{"1.1.1.1:4000", "1.1.1.2:4000"}, []string{"1.1.1.1:4000"}, []string{}, []string{"1.1.1.1:4000", "1.1.1.2:4000"}, @@ -148,7 +147,7 @@ func TestScaleOutNodes(t *testing.T) { {ExecID: "1.1.1.2:4000"}, {ExecID: "1.1.1.1:4000"}, {ExecID: "1.1.1.1:4000"}}, - []*infosync.ServerInfo{{IP: "1.1.1.1", Port: 4000}, {IP: "1.1.1.2", Port: 4000}, {IP: "1.1.1.3", Port: 4000}, {IP: "1.1.1.4", Port: 4000}}, + []string{"1.1.1.1:4000", "1.1.1.2:4000", "1.1.1.3:4000", "1.1.1.4:4000"}, []string{"1.1.1.1:4000", "1.1.1.2:4000"}, []string{}, []string{"1.1.1.1:4000", "1.1.1.2:4000", "1.1.1.3:4000", "1.1.1.4:4000"}, @@ -170,7 +169,7 @@ func TestScaleOutNodes(t *testing.T) { {ExecID: "1.1.1.1:4000"}, {ExecID: "1.1.1.1:4000"}, {ExecID: "1.1.1.1:4000"}}, - []*infosync.ServerInfo{{IP: "1.1.1.1", Port: 4000}, {IP: "1.1.1.2", Port: 4000}, {IP: "1.1.1.3", Port: 4000}, {IP: "1.1.1.4", Port: 4000}}, + []string{"1.1.1.1:4000", "1.1.1.2:4000", "1.1.1.3:4000", "1.1.1.4:4000"}, []string{"1.1.1.1:4000", "1.1.1.2:4000"}, []string{}, []string{"1.1.1.1:4000", "1.1.1.2:4000", "1.1.1.3:4000", "1.1.1.4:4000"}, @@ -192,7 +191,7 @@ func TestScaleOutNodes(t *testing.T) { {ExecID: "1.1.1.2:4000"}, {ExecID: "1.1.1.1:4000"}, {ExecID: "1.1.1.1:4000"}}, - []*infosync.ServerInfo{{IP: "1.1.1.1", Port: 4000}, {IP: "1.1.1.2", Port: 4000}, {IP: "1.1.1.3", Port: 4000}}, + []string{"1.1.1.1:4000", "1.1.1.2:4000", "1.1.1.3:4000"}, []string{"1.1.1.1:4000", "1.1.1.2:4000"}, []string{}, []string{"1.1.1.1:4000", "1.1.1.2:4000", "1.1.1.3:4000"}, @@ -209,7 +208,7 @@ func TestScaleOutNodes(t *testing.T) { {ExecID: "1.1.1.1:4000"}, {ExecID: "1.1.1.1:4000"}, {ExecID: "1.1.1.1:4000"}}, - []*infosync.ServerInfo{{IP: "1.1.1.2", Port: 4000}, {IP: "1.1.1.3", Port: 4000}}, + []string{"1.1.1.2:4000", "1.1.1.3:4000"}, []string{"1.1.1.1:4000"}, []string{"1.1.1.1:4000"}, []string{"1.1.1.2:4000", "1.1.1.3:4000"}, @@ -226,7 +225,7 @@ func TestScaleOutNodes(t *testing.T) { {ExecID: "1.1.1.1:4000"}, {ExecID: "1.1.1.2:4000"}, {ExecID: "1.1.1.2:4000"}}, - []*infosync.ServerInfo{{IP: "1.1.1.2", Port: 4000}, {IP: "1.1.1.3", Port: 4000}}, + []string{"1.1.1.2:4000", "1.1.1.3:4000"}, []string{"1.1.1.1:4000", "1.1.1.2:4000"}, []string{"1.1.1.1:4000"}, []string{"1.1.1.2:4000", "1.1.1.3:4000"}, @@ -243,7 +242,7 @@ func TestScaleOutNodes(t *testing.T) { {ExecID: "1.1.1.1:4000"}, {ExecID: "1.1.1.2:4000"}, {ExecID: "1.1.1.2:4000"}}, - []*infosync.ServerInfo{{IP: "1.1.1.3", Port: 4000}, {IP: "1.1.1.4", Port: 4000}}, + []string{"1.1.1.3:4000", "1.1.1.4:4000"}, []string{"1.1.1.1:4000", "1.1.1.2:4000"}, []string{"1.1.1.1:4000", "1.1.1.2:4000"}, []string{"1.1.1.3:4000", "1.1.1.4:4000"}, @@ -260,7 +259,7 @@ func TestScaleOutNodes(t *testing.T) { {ExecID: "1.1.1.1:4000"}, {ExecID: "1.1.1.2:4000"}, {ExecID: "1.1.1.2:4000"}}, - []*infosync.ServerInfo{{IP: "1.1.1.2", Port: 4000}, {IP: "1.1.1.3", Port: 4000}, {IP: "1.1.1.4", Port: 4000}}, + []string{"1.1.1.2:4000", "1.1.1.3:4000", "1.1.1.4:4000"}, []string{"1.1.1.1:4000", "1.1.1.2:4000"}, []string{"1.1.1.1:4000"}, []string{"1.1.1.2:4000", "1.1.1.3:4000", "1.1.1.4:4000"}, @@ -279,7 +278,7 @@ func TestScaleOutNodes(t *testing.T) { {ExecID: "1.1.1.2:4000"}, {ExecID: "1.1.1.2:4000"}, {ExecID: "1.1.1.1:4000"}}, - []*infosync.ServerInfo{{IP: "1.1.1.2", Port: 4000}, {IP: "1.1.1.3", Port: 4000}, {IP: "1.1.1.4", Port: 4000}}, + []string{"1.1.1.2:4000", "1.1.1.3:4000", "1.1.1.4:4000"}, []string{"1.1.1.1:4000", "1.1.1.2:4000"}, []string{"1.1.1.1:4000"}, []string{"1.1.1.2:4000", "1.1.1.3:4000", "1.1.1.4:4000"}, @@ -298,7 +297,7 @@ func TestScaleOutNodes(t *testing.T) { {ExecID: "1.1.1.2:4000"}, {ExecID: "1.1.1.2:4000"}, {ExecID: "1.1.1.2:4000"}}, - []*infosync.ServerInfo{{IP: "1.1.1.1", Port: 4000}, {IP: "1.1.1.3", Port: 4000}, {IP: "1.1.1.4", Port: 4000}, {IP: "1.1.1.5", Port: 4000}, {IP: "1.1.1.6", Port: 4000}}, + []string{"1.1.1.1:4000", "1.1.1.3:4000", "1.1.1.4:4000", "1.1.1.5:4000", "1.1.1.6:4000"}, []string{"1.1.1.1:4000", "1.1.1.2:4000"}, []string{"1.1.1.2:4000"}, []string{"1.1.1.1:4000", "1.1.1.3:4000", "1.1.1.4:4000", "1.1.1.5:4000", "1.1.1.6:4000"}, @@ -333,7 +332,7 @@ func TestScaleInNodes(t *testing.T) { {ExecID: "1.1.1.1:4000"}, {ExecID: "1.1.1.2:4000"}, {ExecID: "1.1.1.2:4000"}}, - []*infosync.ServerInfo{{IP: "1.1.1.1", Port: 4000}}, + []string{"1.1.1.1:4000"}, []string{"1.1.1.1:4000", "1.1.1.2:4000"}, []string{"1.1.1.2:4000"}, []string{"1.1.1.1:4000"}, @@ -350,7 +349,7 @@ func TestScaleInNodes(t *testing.T) { {ExecID: "1.1.1.1:4000"}, {ExecID: "1.1.1.2:4000"}, {ExecID: "1.1.1.2:4000"}}, - []*infosync.ServerInfo{{IP: "1.1.1.3", Port: 4000}}, + []string{"1.1.1.3:4000"}, []string{"1.1.1.1:4000", "1.1.1.2:4000"}, []string{"1.1.1.1:4000", "1.1.1.2:4000"}, []string{"1.1.1.3:4000"}, @@ -373,7 +372,7 @@ func TestScaleInNodes(t *testing.T) { {ExecID: "1.1.1.8:4000"}, {ExecID: "1.1.1.9:4000"}, {ExecID: "1.1.1.10:4000"}}, - []*infosync.ServerInfo{{IP: "1.1.1.2", Port: 4000}, {IP: "1.1.1.3", Port: 4000}}, + []string{"1.1.1.2:4000", "1.1.1.3:4000"}, []string{ "1.1.1.1:4000", "1.1.1.2:4000", @@ -421,7 +420,7 @@ func TestScaleInNodes(t *testing.T) { {ExecID: "1.1.1.2:4000"}, {ExecID: "1.1.1.2:4000"}, {ExecID: "1.1.1.2:4000"}}, - []*infosync.ServerInfo{{IP: "1.1.1.1", Port: 4000}}, + []string{"1.1.1.1:4000"}, []string{"1.1.1.1:4000", "1.1.1.2:4000"}, []string{"1.1.1.2:4000"}, []string{"1.1.1.1:4000"}, @@ -443,7 +442,7 @@ func TestScaleInNodes(t *testing.T) { } } -func TestRebalanceWithoutScale(t *testing.T) { +func TestBalanceWithoutScale(t *testing.T) { store := testkit.CreateMockStore(t) gtk := testkit.NewTestKit(t, store) pool := pools.NewResourcePool(func() (pools.Resource, error) { @@ -461,7 +460,7 @@ func TestRebalanceWithoutScale(t *testing.T) { {ExecID: "1.1.1.2:4000"}, {ExecID: "1.1.1.2:4000"}, {ExecID: "1.1.1.2:4000"}}, - []*infosync.ServerInfo{{IP: "1.1.1.1", Port: 4000}, {IP: "1.1.1.2", Port: 4000}}, + []string{"1.1.1.1:4000", "1.1.1.2:4000"}, []string{"1.1.1.1:4000", "1.1.1.2:4000"}, []*proto.Subtask{ {ExecID: "1.1.1.1:4000"}, @@ -478,7 +477,7 @@ func TestRebalanceWithoutScale(t *testing.T) { {ExecID: "1.1.1.2:4000"}, {ExecID: "1.1.1.2:4000"}, {ExecID: "1.1.1.3:4000"}}, - []*infosync.ServerInfo{{IP: "1.1.1.1", Port: 4000}, {IP: "1.1.1.2", Port: 4000}, {IP: "1.1.1.3", Port: 4000}}, + []string{"1.1.1.1:4000", "1.1.1.2:4000", "1.1.1.3:4000"}, []string{"1.1.1.1:4000", "1.1.1.2:4000", "1.1.1.3:4000"}, []*proto.Subtask{ {ExecID: "1.1.1.1:4000"}, @@ -496,7 +495,7 @@ func TestRebalanceWithoutScale(t *testing.T) { {ExecID: "1.1.1.2:4000"}, {ExecID: "1.1.1.2:4000"}, {ExecID: "1.1.1.2:4000"}}, - []*infosync.ServerInfo{{IP: "1.1.1.1", Port: 4000}, {IP: "1.1.1.2", Port: 4000}}, + []string{"1.1.1.1:4000", "1.1.1.2:4000"}, []string{"1.1.1.1:4000", "1.1.1.2:4000"}, []*proto.Subtask{ {ExecID: "1.1.1.1:4000"}, @@ -513,9 +512,7 @@ func TestRebalanceWithoutScale(t *testing.T) { {ExecID: "1.1.1.1:4000"}, {ExecID: "1.1.1.1:4000"}, {ExecID: "1.1.1.1:4000"}}, - []*infosync.ServerInfo{{IP: "1.1.1.1", Port: 4000}, {IP: "1.1.1.2", Port: 4000}, - {IP: "1.1.1.3", Port: 4000}, {IP: "1.1.1.4", Port: 4000}, - {IP: "1.1.1.5", Port: 4000}, {IP: "1.1.1.6", Port: 4000}}, + []string{"1.1.1.1:4000", "1.1.1.2:4000", "1.1.1.3:4000", "1.1.1.4:4000", "1.1.1.5:4000", "1.1.1.6:4000"}, []string{"1.1.1.1:4000", "1.1.1.2:4000", "1.1.1.3:4000", "1.1.1.4:4000", "1.1.1.5:4000", "1.1.1.6:4000"}, []*proto.Subtask{ {ExecID: "1.1.1.1:4000"}, @@ -532,7 +529,7 @@ func TestRebalanceWithoutScale(t *testing.T) { {ExecID: "1.1.1.2:4000"}, {ExecID: "1.1.1.2:4000"}, {ExecID: "1.1.1.2:4000"}}, - []*infosync.ServerInfo{{IP: "1.1.1.1", Port: 4000}, {IP: "1.1.1.2", Port: 4000}}, + []string{"1.1.1.1:4000", "1.1.1.2:4000"}, []string{"1.1.1.1:4000", "1.1.1.2:4000"}, []*proto.Subtask{ {ExecID: "1.1.1.1:4000"}, @@ -548,7 +545,7 @@ func TestRebalanceWithoutScale(t *testing.T) { {ExecID: "1.1.1.1:4000"}, {ExecID: "1.1.1.2:4000"}, {ExecID: "1.1.1.2:4000"}}, - []*infosync.ServerInfo{{IP: "1.1.1.1", Port: 4000}, {IP: "1.1.1.2", Port: 4000}}, + []string{"1.1.1.1:4000", "1.1.1.2:4000"}, []string{"1.1.1.1:4000", "1.1.1.2:4000"}, []*proto.Subtask{ {ExecID: "1.1.1.1:4000"}, diff --git a/pkg/disttask/framework/scheduler/scheduler.go b/pkg/disttask/framework/scheduler/scheduler.go index bf61b4fe547ee..7976f7de341ab 100644 --- a/pkg/disttask/framework/scheduler/scheduler.go +++ b/pkg/disttask/framework/scheduler/scheduler.go @@ -41,8 +41,8 @@ const ( DefaultSubtaskConcurrency = 16 // MaxSubtaskConcurrency is the maximum concurrency for handling subtask. MaxSubtaskConcurrency = 256 - // DefaultLiveNodesCheckInterval is the tick interval of fetching all server infos from etcs. - DefaultLiveNodesCheckInterval = 2 + // defaultBalanceSubtaskTicks is the tick interval of fetching all server infos from etcs. + defaultBalanceSubtaskTicks = 2 // for a cancelled task, it's terminal state is reverted or reverted_failed, // so we use a special error message to indicate that the task is cancelled // by user. @@ -63,8 +63,6 @@ var ( // TaskHandle provides the interface for operations needed by Scheduler. // Then we can use scheduler's function in Scheduler interface. type TaskHandle interface { - // GetPreviousTaskExecutorIDs gets previous task executor IDs. - GetPreviousTaskExecutorIDs(_ context.Context, taskID int64, step proto.Step) ([]string, error) // GetPreviousSubtaskMetas gets previous subtask metas. GetPreviousSubtaskMetas(taskID int64, step proto.Step) ([][]byte, error) storage.SessionExecutor @@ -88,20 +86,14 @@ type Scheduler interface { type BaseScheduler struct { ctx context.Context taskMgr TaskManager + nodeMgr *NodeManager Task *proto.Task logCtx context.Context - // serverID, it's value is ip:port now. - serverID string // when RegisterSchedulerFactory, the factory MUST initialize this fields. Extension - // For subtasks rebalance. - // LiveNodes will fetch and store all live nodes every liveNodeInterval ticks. - LiveNodes []*infosync.ServerInfo - liveNodeFetchInterval int - // liveNodeFetchTick is the tick variable. - liveNodeFetchTick int - // TaskNodes stores the id of current task executor nodes. + balanceSubtaskTick int + // TaskNodes stores the exec id of current task executor nodes. TaskNodes []string // rand is for generating random selection of nodes. rand *rand.Rand @@ -111,20 +103,17 @@ type BaseScheduler struct { var MockOwnerChange func() // NewBaseScheduler creates a new BaseScheduler. -func NewBaseScheduler(ctx context.Context, taskMgr TaskManager, serverID string, task *proto.Task) *BaseScheduler { +func NewBaseScheduler(ctx context.Context, taskMgr TaskManager, nodeMgr *NodeManager, task *proto.Task) *BaseScheduler { logCtx := logutil.WithFields(context.Background(), zap.Int64("task-id", task.ID), zap.Stringer("task-type", task.Type)) return &BaseScheduler{ - ctx: ctx, - taskMgr: taskMgr, - Task: task, - logCtx: logCtx, - serverID: serverID, - LiveNodes: nil, - liveNodeFetchInterval: DefaultLiveNodesCheckInterval, - liveNodeFetchTick: 0, - TaskNodes: nil, - rand: rand.New(rand.NewSource(time.Now().UnixNano())), + ctx: ctx, + taskMgr: taskMgr, + nodeMgr: nodeMgr, + Task: task, + logCtx: logCtx, + TaskNodes: nil, + rand: rand.New(rand.NewSource(time.Now().UnixNano())), } } @@ -352,7 +341,7 @@ func (s *BaseScheduler) onRunning() error { return s.onNextStage() } - if err := s.BalanceSubtasks(); err != nil { + if err := s.balanceSubtasks(); err != nil { return err } // Wait all subtasks in this stage finishes. @@ -367,9 +356,8 @@ func (s *BaseScheduler) onFinished() error { return s.taskMgr.TransferSubTasks2History(s.ctx, s.Task.ID) } -// BalanceSubtasks check the liveNode num every liveNodeFetchInterval then rebalance subtasks. -func (s *BaseScheduler) BalanceSubtasks() error { - // 1. init TaskNodes if needes. +// balanceSubtasks check the liveNode num every liveNodeFetchInterval then rebalance subtasks. +func (s *BaseScheduler) balanceSubtasks() error { if len(s.TaskNodes) == 0 { var err error s.TaskNodes, err = s.taskMgr.GetTaskExecutorIDsByTaskIDAndStep(s.ctx, s.Task.ID, s.Task.Step) @@ -377,63 +365,32 @@ func (s *BaseScheduler) BalanceSubtasks() error { return err } } - s.liveNodeFetchTick++ - if s.liveNodeFetchTick == s.liveNodeFetchInterval { - // 2. update LiveNodes. - s.liveNodeFetchTick = 0 - serverInfos, err := GenerateTaskExecutorNodes(s.ctx) + s.balanceSubtaskTick++ + if s.balanceSubtaskTick == defaultBalanceSubtaskTicks { + s.balanceSubtaskTick = 0 + eligibleNodes, err := s.getEligibleNodes() if err != nil { return err } - - eligibleServerInfos, filter, err := s.GetEligibleInstances(s.ctx, s.Task) - if err != nil { - return err - } - if filter { - eligibleServerInfos, err = s.filterByRole(eligibleServerInfos) - if err != nil { - return err - } + if len(eligibleNodes) > 0 { + return s.doBalanceSubtasks(eligibleNodes) } - newInfos := serverInfos[:0] - for _, m := range serverInfos { - found := false - for _, n := range eligibleServerInfos { - if m.ID == n.ID { - found = true - break - } - } - if found { - newInfos = append(newInfos, m) - } - } - s.LiveNodes = newInfos - // 3. balance subtasks. - if len(s.LiveNodes) > 0 { - return s.ReDispatchSubtasks() - } - return nil } return nil } -func (s *BaseScheduler) replaceTaskNodes() { - s.TaskNodes = s.TaskNodes[:0] - for _, serverInfo := range s.LiveNodes { - s.TaskNodes = append(s.TaskNodes, disttaskutil.GenerateExecID(serverInfo.IP, serverInfo.Port)) - } -} - -// ReDispatchSubtasks make count of subtasks on each liveNodes balanced and clean up subtasks on dead nodes. +// DoBalanceSubtasks make count of subtasks on each liveNodes balanced and clean up subtasks on dead nodes. // TODO(ywqzzy): refine to make it easier for testing. -func (s *BaseScheduler) ReDispatchSubtasks() error { +func (s *BaseScheduler) doBalanceSubtasks(eligibleNodes []string) error { + eligibleNodeMap := make(map[string]struct{}, len(eligibleNodes)) + for _, n := range eligibleNodes { + eligibleNodeMap[n] = struct{}{} + } // 1. find out nodes need to clean subtasks. deadNodes := make([]string, 0) deadNodesMap := make(map[string]bool, 0) for _, node := range s.TaskNodes { - if !disttaskutil.MatchServerInfo(s.LiveNodes, node) { + if _, ok := eligibleNodeMap[node]; !ok { deadNodes = append(deadNodes, node) deadNodesMap[node] = true } @@ -453,10 +410,9 @@ func (s *BaseScheduler) ReDispatchSubtasks() error { subtasks = append(subtasks, subtasksOnDeadNodes...) } // 3. group subtasks for each task executor. - subtasksOnTaskExecutor := make(map[string][]*proto.Subtask, len(s.LiveNodes)+len(deadNodes)) - for _, node := range s.LiveNodes { - execID := disttaskutil.GenerateExecID(node.IP, node.Port) - subtasksOnTaskExecutor[execID] = make([]*proto.Subtask, 0) + subtasksOnTaskExecutor := make(map[string][]*proto.Subtask, len(eligibleNodes)+len(deadNodes)) + for _, node := range eligibleNodes { + subtasksOnTaskExecutor[node] = make([]*proto.Subtask, 0) } for _, subtask := range subtasks { subtasksOnTaskExecutor[subtask.ExecID] = append( @@ -464,17 +420,17 @@ func (s *BaseScheduler) ReDispatchSubtasks() error { subtask) } // 4. prepare subtasks that need to rebalance to other nodes. - averageSubtaskCnt := len(subtasks) / len(s.LiveNodes) + averageSubtaskCnt := len(subtasks) / len(eligibleNodes) rebalanceSubtasks := make([]*proto.Subtask, 0) for k, v := range subtasksOnTaskExecutor { if ok := deadNodesMap[k]; ok { rebalanceSubtasks = append(rebalanceSubtasks, v...) continue } - // When no tidb scale-in/out and averageSubtaskCnt*len(s.LiveNodes) < len(subtasks), + // When no tidb scale-in/out and averageSubtaskCnt*len(eligibleNodes) < len(subtasks), // no need to send subtask to other nodes. // eg: tidb1 with 3 subtasks, tidb2 with 2 subtasks, subtasks are balanced now. - if averageSubtaskCnt*len(s.LiveNodes) < len(subtasks) && len(s.TaskNodes) == len(s.LiveNodes) { + if averageSubtaskCnt*len(eligibleNodes) < len(subtasks) && len(s.TaskNodes) == len(eligibleNodes) { if len(v) > averageSubtaskCnt+1 { rebalanceSubtasks = append(rebalanceSubtasks, v[0:len(v)-averageSubtaskCnt]...) } @@ -503,8 +459,7 @@ func (s *BaseScheduler) ReDispatchSubtasks() error { // 7. rebalance rest subtasks evenly to liveNodes. liveNodeIdx := 0 for rebalanceIdx < len(rebalanceSubtasks) { - node := s.LiveNodes[liveNodeIdx] - rebalanceSubtasks[rebalanceIdx].ExecID = disttaskutil.GenerateExecID(node.IP, node.Port) + rebalanceSubtasks[rebalanceIdx].ExecID = eligibleNodes[liveNodeIdx] rebalanceIdx++ liveNodeIdx++ } @@ -513,12 +468,9 @@ func (s *BaseScheduler) ReDispatchSubtasks() error { if err = s.taskMgr.UpdateSubtasksExecIDs(s.ctx, s.Task.ID, subtasks); err != nil { return err } - logutil.Logger(s.logCtx).Info("rebalance subtasks", + logutil.Logger(s.logCtx).Info("balance subtasks", zap.Stringers("subtasks-rebalanced", subtasks)) - if err = s.taskMgr.CleanUpMeta(s.ctx, deadNodes); err != nil { - return err - } - s.replaceTaskNodes() + s.TaskNodes = append([]string{}, eligibleNodes...) return nil } @@ -599,17 +551,10 @@ func (s *BaseScheduler) onNextStage() (err error) { } } - serverNodes, filter, err := s.GetEligibleInstances(s.ctx, s.Task) + serverNodes, err := s.getEligibleNodes() if err != nil { return err } - logutil.Logger(s.logCtx).Debug("eligible instances", zap.Int("num", len(serverNodes))) - if filter { - serverNodes, err = s.filterByRole(serverNodes) - if err != nil { - return err - } - } logutil.Logger(s.logCtx).Info("eligible instances", zap.Int("num", len(serverNodes))) if len(serverNodes) == 0 { return errors.New("no available TiDB node to dispatch subtasks") @@ -624,26 +569,38 @@ func (s *BaseScheduler) onNextStage() (err error) { return s.scheduleSubTask(nextStep, metas, serverNodes) } +// getEligibleNodes returns the eligible(live) nodes for the task. +// if the task can only be scheduled to some specific nodes, return them directly, +// we don't care liveliness of them. +func (s *BaseScheduler) getEligibleNodes() ([]string, error) { + serverNodes, err := s.GetEligibleInstances(s.ctx, s.Task) + if err != nil { + return nil, err + } + logutil.Logger(s.logCtx).Debug("eligible instances", zap.Int("num", len(serverNodes))) + if len(serverNodes) == 0 { + serverNodes = append([]string{}, s.nodeMgr.getManagedNodes()...) + } + return serverNodes, nil +} + func (s *BaseScheduler) scheduleSubTask( subtaskStep proto.Step, metas [][]byte, - serverNodes []*infosync.ServerInfo) error { + serverNodes []string) error { logutil.Logger(s.logCtx).Info("schedule subtasks", zap.Stringer("state", s.Task.State), zap.Int64("step", int64(s.Task.Step)), zap.Int("concurrency", s.Task.Concurrency), zap.Int("subtasks", len(metas))) - s.TaskNodes = make([]string, len(serverNodes)) - for i := range serverNodes { - s.TaskNodes[i] = disttaskutil.GenerateExecID(serverNodes[i].IP, serverNodes[i].Port) - } + s.TaskNodes = serverNodes var size uint64 subTasks := make([]*proto.Subtask, 0, len(metas)) for i, meta := range metas { // we assign the subtask to the instance in a round-robin way. // TODO: assign the subtask to the instance according to the system load of each nodes pos := i % len(serverNodes) - instanceID := disttaskutil.GenerateExecID(serverNodes[pos].IP, serverNodes[pos].Port) + instanceID := serverNodes[pos] logutil.Logger(s.logCtx).Debug("create subtasks", zap.String("instanceID", instanceID)) subTasks = append(subTasks, proto.NewSubtask( subtaskStep, s.Task.ID, s.Task.Type, instanceID, s.Task.Concurrency, meta, i+1)) @@ -721,27 +678,6 @@ func GenerateTaskExecutorNodes(ctx context.Context) (serverNodes []*infosync.Ser return serverNodes, nil } -func (s *BaseScheduler) filterByRole(infos []*infosync.ServerInfo) ([]*infosync.ServerInfo, error) { - nodes, err := s.taskMgr.GetManagedNodes(s.ctx) - if err != nil { - return nil, err - } - - nodeMap := make(map[string]struct{}, len(nodes)) - for _, node := range nodes { - nodeMap[node] = struct{}{} - } - - res := make([]*infosync.ServerInfo, 0, len(nodes)) - for _, info := range infos { - _, ok := nodeMap[disttaskutil.GenerateExecID(info.IP, info.Port)] - if ok { - res = append(res, info) - } - } - return res, nil -} - // GetAllTaskExecutorIDs gets all the task executor IDs. func (s *BaseScheduler) GetAllTaskExecutorIDs(ctx context.Context, task *proto.Task) ([]string, error) { // We get all servers instead of eligible servers here @@ -781,11 +717,6 @@ func (s *BaseScheduler) GetPreviousSubtaskMetas(taskID int64, step proto.Step) ( return previousSubtaskMetas, nil } -// GetPreviousTaskExecutorIDs gets task executor IDs that run previous step. -func (s *BaseScheduler) GetPreviousTaskExecutorIDs(_ context.Context, taskID int64, step proto.Step) ([]string, error) { - return s.taskMgr.GetTaskExecutorIDsByTaskIDAndStep(s.ctx, taskID, step) -} - // WithNewSession executes the function with a new session. func (s *BaseScheduler) WithNewSession(fn func(se sessionctx.Context) error) error { return s.taskMgr.WithNewSession(fn) diff --git a/pkg/disttask/framework/scheduler/scheduler_manager.go b/pkg/disttask/framework/scheduler/scheduler_manager.go index b5c5ddb709864..6059496876b2b 100644 --- a/pkg/disttask/framework/scheduler/scheduler_manager.go +++ b/pkg/disttask/framework/scheduler/scheduler_manager.go @@ -25,7 +25,6 @@ import ( "github.com/pingcap/tidb/pkg/resourcemanager/pool/spool" "github.com/pingcap/tidb/pkg/resourcemanager/util" tidbutil "github.com/pingcap/tidb/pkg/util" - disttaskutil "github.com/pingcap/tidb/pkg/util/disttask" "github.com/pingcap/tidb/pkg/util/logutil" "github.com/pingcap/tidb/pkg/util/syncutil" "go.uber.org/zap" @@ -84,6 +83,7 @@ type Manager struct { wg tidbutil.WaitGroupWrapper gPool *spool.Pool slotMgr *slotManager + nodeMgr *NodeManager initialized bool // serverID, it's value is ip:port now. serverID string @@ -102,6 +102,7 @@ func NewManager(ctx context.Context, taskMgr TaskManager, serverID string) (*Man taskMgr: taskMgr, serverID: serverID, slotMgr: newSlotManager(), + nodeMgr: newNodeManager(), } gPool, err := spool.NewPool("schedule_pool", int32(proto.MaxConcurrentTask), util.DistTask, spool.WithBlocking(true)) if err != nil { @@ -120,9 +121,18 @@ func (sm *Manager) Start() { failpoint.Inject("disableSchedulerManager", func() { failpoint.Return() }) + // init cached managed nodes + sm.nodeMgr.refreshManagedNodes(sm.ctx, sm.taskMgr) + sm.wg.Run(sm.scheduleTaskLoop) sm.wg.Run(sm.gcSubtaskHistoryTableLoop) - sm.wg.Run(sm.cleanUpLoop) + sm.wg.Run(sm.cleanupTaskLoop) + sm.wg.Run(func() { + sm.nodeMgr.maintainLiveNodesLoop(sm.ctx, sm.taskMgr) + }) + sm.wg.Run(func() { + sm.nodeMgr.refreshManagedNodesLoop(sm.ctx, sm.taskMgr) + }) sm.initialized = true } @@ -167,7 +177,7 @@ func (sm *Manager) scheduleTaskLoop() { continue } - scheduleableTasks := make([]*proto.Task, 0, len(tasks)) + schedulableTasks := make([]*proto.Task, 0, len(tasks)) for _, task := range tasks { if sm.hasScheduler(task.ID) { continue @@ -182,9 +192,9 @@ func (sm *Manager) scheduleTaskLoop() { sm.failTask(task.ID, task.State, errors.New("unknown task type")) continue } - scheduleableTasks = append(scheduleableTasks, task) + schedulableTasks = append(schedulableTasks, task) } - if len(scheduleableTasks) == 0 { + if len(schedulableTasks) == 0 { continue } @@ -192,7 +202,7 @@ func (sm *Manager) scheduleTaskLoop() { logutil.BgLogger().Warn("update used slot failed", zap.Error(err)) continue } - for _, task := range scheduleableTasks { + for _, task := range schedulableTasks { taskCnt = sm.getSchedulerCount() if taskCnt >= proto.MaxConcurrentTask { break @@ -253,7 +263,7 @@ func (sm *Manager) startScheduler(basicTask *proto.Task, reservedExecID string) } schedulerFactory := getSchedulerFactory(task.Type) - scheduler := schedulerFactory(sm.ctx, sm.taskMgr, sm.serverID, task) + scheduler := schedulerFactory(sm.ctx, sm.taskMgr, sm.nodeMgr, task) if err = scheduler.Init(); err != nil { logutil.BgLogger().Error("init scheduler failed", zap.Error(err)) sm.failTask(task.ID, task.State, err) @@ -275,7 +285,7 @@ func (sm *Manager) startScheduler(basicTask *proto.Task, reservedExecID string) }) } -func (sm *Manager) cleanUpLoop() { +func (sm *Manager) cleanupTaskLoop() { logutil.Logger(sm.ctx).Info("cleanUp loop start") ticker := time.NewTicker(defaultCleanUpInterval) defer ticker.Stop() @@ -285,9 +295,9 @@ func (sm *Manager) cleanUpLoop() { logutil.BgLogger().Info("cleanUp loop exits", zap.Error(sm.ctx.Err())) return case <-sm.finishCh: - sm.doCleanUpRoutine() + sm.doCleanupTask() case <-ticker.C: - sm.doCleanUpRoutine() + sm.doCleanupTask() } } } @@ -295,15 +305,11 @@ func (sm *Manager) cleanUpLoop() { // WaitCleanUpFinished is used to sync the test. var WaitCleanUpFinished = make(chan struct{}) -// doCleanUpRoutine processes clean up routine defined by each type of tasks and cleanUpMeta. +// doCleanupTask processes clean up routine defined by each type of tasks and cleanUpMeta. // For example: // // tasks with global sort should clean up tmp files stored on S3. -func (sm *Manager) doCleanUpRoutine() { - cnt := sm.CleanUpMeta() - if cnt != 0 { - logutil.BgLogger().Info("clean up nodes in framework meta since nodes shutdown", zap.Int("cnt", cnt)) - } +func (sm *Manager) doCleanupTask() { tasks, err := sm.taskMgr.GetTasksInStates( sm.ctx, proto.TaskStateFailed, @@ -329,39 +335,6 @@ func (sm *Manager) doCleanUpRoutine() { logutil.Logger(sm.ctx).Info("cleanUp routine success") } -// CleanUpMeta clean up old node info in dist_framework_meta table. -func (sm *Manager) CleanUpMeta() int { - // Safe to discard errors since this function can be called at regular intervals. - serverInfos, err := GenerateTaskExecutorNodes(sm.ctx) - if err != nil { - logutil.BgLogger().Warn("generate task executor nodes met error") - return 0 - } - - oldNodes, err := sm.taskMgr.GetAllNodes(sm.ctx) - if err != nil { - logutil.BgLogger().Warn("get all nodes met error") - return 0 - } - - cleanNodes := make([]string, 0) - for _, nodeID := range oldNodes { - if ok := disttaskutil.MatchServerInfo(serverInfos, nodeID); !ok { - cleanNodes = append(cleanNodes, nodeID) - } - } - if len(cleanNodes) == 0 { - return 0 - } - logutil.BgLogger().Info("start to clean up dist_framework_meta") - err = sm.taskMgr.CleanUpMeta(sm.ctx, cleanNodes) - if err != nil { - logutil.BgLogger().Warn("clean up dist_framework_meta met error") - return 0 - } - return len(cleanNodes) -} - func (sm *Manager) cleanUpFinishedTasks(tasks []*proto.Task) error { cleanedTasks := make([]*proto.Task, 0) var firstErr error @@ -389,5 +362,5 @@ func (sm *Manager) cleanUpFinishedTasks(tasks []*proto.Task) error { // MockScheduler mock one scheduler for one task, only used for tests. func (sm *Manager) MockScheduler(task *proto.Task) *BaseScheduler { - return NewBaseScheduler(sm.ctx, sm.taskMgr, sm.serverID, task) + return NewBaseScheduler(sm.ctx, sm.taskMgr, sm.nodeMgr, task) } diff --git a/pkg/disttask/framework/scheduler/scheduler_manager_test.go b/pkg/disttask/framework/scheduler/scheduler_manager_test.go index 8d45536ebec4e..003983957c538 100644 --- a/pkg/disttask/framework/scheduler/scheduler_manager_test.go +++ b/pkg/disttask/framework/scheduler/scheduler_manager_test.go @@ -39,13 +39,13 @@ func TestCleanUpRoutine(t *testing.T) { defer ctrl.Finish() ctx := context.Background() ctx = util.WithInternalSourceType(ctx, "scheduler_manager") - mockCleanupRountine := mock.NewMockCleanUpRoutine(ctrl) + mockCleanupRoutine := mock.NewMockCleanUpRoutine(ctrl) - sch, mgr := MockSchedulerManager(t, ctrl, pool, getNumberExampleSchedulerExt(ctrl), mockCleanupRountine) - mockCleanupRountine.EXPECT().CleanUp(gomock.Any(), gomock.Any()).Return(nil).AnyTimes() + sch, mgr := MockSchedulerManager(t, ctrl, pool, getNumberExampleSchedulerExt(ctrl), mockCleanupRoutine) + mockCleanupRoutine.EXPECT().CleanUp(gomock.Any(), gomock.Any()).Return(nil).AnyTimes() + require.NoError(t, mgr.StartManager(ctx, ":4000", "")) sch.Start() defer sch.Stop() - require.NoError(t, mgr.StartManager(ctx, ":4000", "background")) taskID, err := mgr.CreateTask(ctx, "test", proto.TaskTypeExample, 1, nil) require.NoError(t, err) @@ -81,31 +81,3 @@ func TestCleanUpRoutine(t *testing.T) { return len(tasks) != 0 }, time.Second*10, time.Millisecond*300) } - -func TestCleanUpMeta(t *testing.T) { - store := testkit.CreateMockStore(t) - gtk := testkit.NewTestKit(t, store) - pool := pools.NewResourcePool(func() (pools.Resource, error) { - return gtk.Session(), nil - }, 1, 1, time.Second) - defer pool.Close() - ctrl := gomock.NewController(t) - defer ctrl.Finish() - mockTaskMgr := mock.NewMockTaskManager(ctrl) - mockCleanupRountine := mock.NewMockCleanUpRoutine(ctrl) - schMgr := MockSchedulerManagerWithMockTaskMgr(t, ctrl, pool, mockTaskMgr, getNumberExampleSchedulerExt(ctrl), mockCleanupRountine) - - mockTaskMgr.EXPECT().GetAllNodes(gomock.Any()).Return([]string{":4000", ":4001"}, nil) - mockTaskMgr.EXPECT().CleanUpMeta(gomock.Any(), gomock.Any()).Return(nil) - mockCleanupRountine.EXPECT().CleanUp(gomock.Any(), gomock.Any()).Return(nil).AnyTimes() - require.Equal(t, 1, schMgr.CleanUpMeta()) - - mockTaskMgr.EXPECT().GetAllNodes(gomock.Any()).Return([]string{":4000"}, nil) - mockCleanupRountine.EXPECT().CleanUp(gomock.Any(), gomock.Any()).Return(nil).AnyTimes() - require.Equal(t, 0, schMgr.CleanUpMeta()) - - mockTaskMgr.EXPECT().GetAllNodes(gomock.Any()).Return([]string{":4000", ":4001", ":4003"}, nil) - mockTaskMgr.EXPECT().CleanUpMeta(gomock.Any(), gomock.Any()).Return(nil) - mockCleanupRountine.EXPECT().CleanUp(gomock.Any(), gomock.Any()).Return(nil).AnyTimes() - require.Equal(t, 2, schMgr.CleanUpMeta()) -} diff --git a/pkg/disttask/framework/scheduler/scheduler_test.go b/pkg/disttask/framework/scheduler/scheduler_test.go index 682e28483098e..35b00c3ac74db 100644 --- a/pkg/disttask/framework/scheduler/scheduler_test.go +++ b/pkg/disttask/framework/scheduler/scheduler_test.go @@ -54,8 +54,8 @@ func getTestSchedulerExt(ctrl *gomock.Controller) scheduler.Extension { mockScheduler := mockDispatch.NewMockExtension(ctrl) mockScheduler.EXPECT().OnTick(gomock.Any(), gomock.Any()).Return().AnyTimes() mockScheduler.EXPECT().GetEligibleInstances(gomock.Any(), gomock.Any()).DoAndReturn( - func(_ context.Context, _ *proto.Task) ([]*infosync.ServerInfo, bool, error) { - return mockedAllServerInfos, true, nil + func(_ context.Context, _ *proto.Task) ([]string, error) { + return nil, nil }, ).AnyTimes() mockScheduler.EXPECT().IsRetryableErr(gomock.Any()).Return(true).AnyTimes() @@ -65,7 +65,7 @@ func getTestSchedulerExt(ctrl *gomock.Controller) scheduler.Extension { }, ).AnyTimes() mockScheduler.EXPECT().OnNextSubtasksBatch(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn( - func(_ context.Context, _ scheduler.TaskHandle, _ *proto.Task, _ []*infosync.ServerInfo, _ proto.Step) (metas [][]byte, err error) { + func(_ context.Context, _ scheduler.TaskHandle, _ *proto.Task, _ []string, _ proto.Step) (metas [][]byte, err error) { return nil, nil }, ).AnyTimes() @@ -78,9 +78,8 @@ func getNumberExampleSchedulerExt(ctrl *gomock.Controller) scheduler.Extension { mockScheduler := mockDispatch.NewMockExtension(ctrl) mockScheduler.EXPECT().OnTick(gomock.Any(), gomock.Any()).Return().AnyTimes() mockScheduler.EXPECT().GetEligibleInstances(gomock.Any(), gomock.Any()).DoAndReturn( - func(ctx context.Context, _ *proto.Task) ([]*infosync.ServerInfo, bool, error) { - serverInfo, err := scheduler.GenerateTaskExecutorNodes(ctx) - return serverInfo, true, err + func(ctx context.Context, _ *proto.Task) ([]string, error) { + return nil, nil }, ).AnyTimes() mockScheduler.EXPECT().IsRetryableErr(gomock.Any()).Return(true).AnyTimes() @@ -95,7 +94,7 @@ func getNumberExampleSchedulerExt(ctrl *gomock.Controller) scheduler.Extension { }, ).AnyTimes() mockScheduler.EXPECT().OnNextSubtasksBatch(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn( - func(_ context.Context, _ scheduler.TaskHandle, task *proto.Task, _ []*infosync.ServerInfo, _ proto.Step) (metas [][]byte, err error) { + func(_ context.Context, _ scheduler.TaskHandle, task *proto.Task, _ []string, _ proto.Step) (metas [][]byte, err error) { switch task.Step { case proto.StepInit: for i := 0; i < subtaskCnt; i++ { @@ -123,7 +122,7 @@ func MockSchedulerManager(t *testing.T, ctrl *gomock.Controller, pool *pools.Res sch, err := scheduler.NewManager(util.WithInternalSourceType(ctx, "scheduler"), mgr, "host:port") require.NoError(t, err) scheduler.RegisterSchedulerFactory(proto.TaskTypeExample, - func(ctx context.Context, taskMgr scheduler.TaskManager, serverID string, task *proto.Task) scheduler.Scheduler { + func(ctx context.Context, taskMgr scheduler.TaskManager, nodeMgr *scheduler.NodeManager, task *proto.Task) scheduler.Scheduler { mockScheduler := sch.MockScheduler(task) mockScheduler.Extension = ext return mockScheduler @@ -131,23 +130,6 @@ func MockSchedulerManager(t *testing.T, ctrl *gomock.Controller, pool *pools.Res return sch, mgr } -func MockSchedulerManagerWithMockTaskMgr(t *testing.T, ctrl *gomock.Controller, pool *pools.ResourcePool, taskMgr *mock.MockTaskManager, ext scheduler.Extension, cleanUp scheduler.CleanUpRoutine) *scheduler.Manager { - ctx := context.WithValue(context.Background(), "etcd", true) - sch, err := scheduler.NewManager(util.WithInternalSourceType(ctx, "scheduler"), taskMgr, "host:port") - require.NoError(t, err) - scheduler.RegisterSchedulerFactory(proto.TaskTypeExample, - func(ctx context.Context, taskMgr scheduler.TaskManager, serverID string, task *proto.Task) scheduler.Scheduler { - mockScheduler := sch.MockScheduler(task) - mockScheduler.Extension = ext - return mockScheduler - }) - scheduler.RegisterSchedulerCleanUpFactory(proto.TaskTypeExample, - func() scheduler.CleanUpRoutine { - return cleanUp - }) - return sch -} - func deleteTasks(t *testing.T, store kv.Storage, taskID int64) { tk := testkit.NewTestKit(t, store) tk.MustExec(fmt.Sprintf("delete from mysql.tidb_global_task where id = %d", taskID)) @@ -239,7 +221,7 @@ func TestTaskFailInManager(t *testing.T) { mockScheduler.EXPECT().Init().Return(errors.New("mock scheduler init error")) schManager, mgr := MockSchedulerManager(t, ctrl, pool, getTestSchedulerExt(ctrl), nil) scheduler.RegisterSchedulerFactory(proto.TaskTypeExample, - func(ctx context.Context, taskMgr scheduler.TaskManager, serverID string, task *proto.Task) scheduler.Scheduler { + func(ctx context.Context, taskMgr scheduler.TaskManager, nodeMgr *scheduler.NodeManager, task *proto.Task) scheduler.Scheduler { return mockScheduler }) schManager.Start() @@ -532,7 +514,8 @@ func TestDispatcherOnNextStage(t *testing.T) { Step: proto.StepInit, } cloneTask := task - sch := scheduler.NewBaseScheduler(ctx, taskMgr, ":4000", &cloneTask) + nodeMgr := scheduler.NewNodeManager() + sch := scheduler.NewBaseScheduler(ctx, taskMgr, nodeMgr, &cloneTask) sch.Extension = schExt // test next step is done @@ -551,21 +534,19 @@ func TestDispatcherOnNextStage(t *testing.T) { // GetEligibleInstances err schExt.EXPECT().GetNextStep(gomock.Any()).Return(proto.StepOne) - schExt.EXPECT().GetEligibleInstances(gomock.Any(), gomock.Any()).Return(nil, false, errors.New("GetEligibleInstances err")) + schExt.EXPECT().GetEligibleInstances(gomock.Any(), gomock.Any()).Return(nil, errors.New("GetEligibleInstances err")) require.ErrorContains(t, sch.OnNextStage(), "GetEligibleInstances err") require.True(t, ctrl.Satisfied()) // GetEligibleInstances no instance schExt.EXPECT().GetNextStep(gomock.Any()).Return(proto.StepOne) - schExt.EXPECT().GetEligibleInstances(gomock.Any(), gomock.Any()).Return(nil, false, nil) + schExt.EXPECT().GetEligibleInstances(gomock.Any(), gomock.Any()).Return(nil, nil) require.ErrorContains(t, sch.OnNextStage(), "no available TiDB node to dispatch subtasks") require.True(t, ctrl.Satisfied()) - serverNodes := []*infosync.ServerInfo{ - {IP: "", Port: 4000}, - } + serverNodes := []string{":4000"} // OnNextSubtasksBatch err schExt.EXPECT().GetNextStep(gomock.Any()).Return(proto.StepOne) - schExt.EXPECT().GetEligibleInstances(gomock.Any(), gomock.Any()).Return(serverNodes, false, nil) + schExt.EXPECT().GetEligibleInstances(gomock.Any(), gomock.Any()).Return(serverNodes, nil) schExt.EXPECT().OnNextSubtasksBatch(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). Return(nil, errors.New("OnNextSubtasksBatch err")) schExt.EXPECT().IsRetryableErr(gomock.Any()).Return(true) @@ -583,7 +564,7 @@ func TestDispatcherOnNextStage(t *testing.T) { []byte(`{"xx": "2"}`), } schExt.EXPECT().GetNextStep(gomock.Any()).Return(proto.StepOne) - schExt.EXPECT().GetEligibleInstances(gomock.Any(), gomock.Any()).Return(serverNodes, false, nil) + schExt.EXPECT().GetEligibleInstances(gomock.Any(), gomock.Any()).Return(serverNodes, nil) schExt.EXPECT().OnNextSubtasksBatch(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). Return(subtaskMetas, nil) taskMgr.EXPECT().SwitchTaskStepInBatch(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil) @@ -592,7 +573,7 @@ func TestDispatcherOnNextStage(t *testing.T) { require.True(t, ctrl.Satisfied()) // met unstable subtasks schExt.EXPECT().GetNextStep(gomock.Any()).Return(proto.StepOne) - schExt.EXPECT().GetEligibleInstances(gomock.Any(), gomock.Any()).Return(serverNodes, false, nil) + schExt.EXPECT().GetEligibleInstances(gomock.Any(), gomock.Any()).Return(serverNodes, nil) schExt.EXPECT().OnNextSubtasksBatch(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). Return(subtaskMetas, nil) taskMgr.EXPECT().SwitchTaskStepInBatch(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). @@ -608,7 +589,7 @@ func TestDispatcherOnNextStage(t *testing.T) { // dispatch in one txn schExt.EXPECT().GetNextStep(gomock.Any()).Return(proto.StepOne) - schExt.EXPECT().GetEligibleInstances(gomock.Any(), gomock.Any()).Return(serverNodes, false, nil) + schExt.EXPECT().GetEligibleInstances(gomock.Any(), gomock.Any()).Return(serverNodes, nil) schExt.EXPECT().OnNextSubtasksBatch(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). Return(subtaskMetas, nil) taskMgr.EXPECT().SwitchTaskStep(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil) @@ -644,7 +625,7 @@ func TestManagerDispatchLoop(t *testing.T) { serverInfos, err := infosync.GetAllServerInfo(ctx) require.NoError(t, err) for _, s := range serverInfos { - execID := disttaskutil.GenerateExecID(s.IP, s.Port) + execID := disttaskutil.GenerateExecID(s) testutil.InsertSubtask(t, taskMgr, 1000000, proto.StepOne, execID, []byte(""), proto.TaskStatePending, proto.TaskTypeExample, 16) } concurrencies := []int{4, 6, 16, 2, 4, 4} @@ -654,7 +635,7 @@ func TestManagerDispatchLoop(t *testing.T) { } var counter atomic.Int32 scheduler.RegisterSchedulerFactory(proto.TaskTypeExample, - func(ctx context.Context, taskMgr scheduler.TaskManager, serverID string, task *proto.Task) scheduler.Scheduler { + func(ctx context.Context, taskMgr scheduler.TaskManager, nodeMgr *scheduler.NodeManager, task *proto.Task) scheduler.Scheduler { idx := counter.Load() mockScheduler = mock.NewMockScheduler(ctrl) mockScheduler.EXPECT().Init().Return(nil) diff --git a/pkg/disttask/framework/storage/table_test.go b/pkg/disttask/framework/storage/table_test.go index 358cfd6125d25..7b0c6795c85a3 100644 --- a/pkg/disttask/framework/storage/table_test.go +++ b/pkg/disttask/framework/storage/table_test.go @@ -773,17 +773,17 @@ func TestDistFrameworkMeta(t *testing.T) { require.NoError(t, err) require.Equal(t, []string{":4000", ":4001", ":4002", ":4003"}, nodes) - require.NoError(t, sm.CleanUpMeta(ctx, []string{":4000"})) + require.NoError(t, sm.DeleteDeadNodes(ctx, []string{":4000"})) nodes, err = sm.GetManagedNodes(ctx) require.NoError(t, err) require.Equal(t, []string{":4002", ":4003"}, nodes) - require.NoError(t, sm.CleanUpMeta(ctx, []string{":4003"})) + require.NoError(t, sm.DeleteDeadNodes(ctx, []string{":4003"})) nodes, err = sm.GetManagedNodes(ctx) require.NoError(t, err) require.Equal(t, []string{":4002"}, nodes) - require.NoError(t, sm.CleanUpMeta(ctx, []string{":4002"})) + require.NoError(t, sm.DeleteDeadNodes(ctx, []string{":4002"})) nodes, err = sm.GetManagedNodes(ctx) require.NoError(t, err) require.Equal(t, []string{":4001"}, nodes) diff --git a/pkg/disttask/framework/storage/task_table.go b/pkg/disttask/framework/storage/task_table.go index f0d420452d1ee..888d8747124d8 100644 --- a/pkg/disttask/framework/storage/task_table.go +++ b/pkg/disttask/framework/storage/task_table.go @@ -802,8 +802,8 @@ func (stm *TaskManager) UpdateSubtasksExecIDs(ctx context.Context, taskID int64, return err } -// CleanUpMeta cleanup the outdated row in dist_framework_meta when some tidb down. -func (stm *TaskManager) CleanUpMeta(ctx context.Context, nodes []string) error { +// DeleteDeadNodes deletes the dead nodes from mysql.dist_framework_meta. +func (stm *TaskManager) DeleteDeadNodes(ctx context.Context, nodes []string) error { if len(nodes) == 0 { return nil } diff --git a/pkg/disttask/framework/testutil/BUILD.bazel b/pkg/disttask/framework/testutil/BUILD.bazel index f5ab5b4bb5436..593b2e3bcbfa0 100644 --- a/pkg/disttask/framework/testutil/BUILD.bazel +++ b/pkg/disttask/framework/testutil/BUILD.bazel @@ -20,7 +20,6 @@ go_library( "//pkg/disttask/framework/scheduler/mock", "//pkg/disttask/framework/storage", "//pkg/disttask/framework/taskexecutor", - "//pkg/domain/infosync", "//pkg/sessionctx", "//pkg/testkit", "@com_github_ngaut_pools//:pools", diff --git a/pkg/disttask/framework/testutil/disttest_util.go b/pkg/disttask/framework/testutil/disttest_util.go index de6778cdf0252..cde60207f78c1 100644 --- a/pkg/disttask/framework/testutil/disttest_util.go +++ b/pkg/disttask/framework/testutil/disttest_util.go @@ -67,8 +67,8 @@ func registerTaskMetaInner(t *testing.T, taskType proto.TaskType, mockExtension taskexecutor.ClearTaskExecutors() }) scheduler.RegisterSchedulerFactory(taskType, - func(ctx context.Context, taskMgr scheduler.TaskManager, serverID string, task *proto.Task) scheduler.Scheduler { - baseScheduler := scheduler.NewBaseScheduler(ctx, taskMgr, serverID, task) + func(ctx context.Context, taskMgr scheduler.TaskManager, nodeMgr *scheduler.NodeManager, task *proto.Task) scheduler.Scheduler { + baseScheduler := scheduler.NewBaseScheduler(ctx, taskMgr, nodeMgr, task) baseScheduler.Extension = schedulerHandle return baseScheduler }) diff --git a/pkg/disttask/framework/testutil/scheduler_util.go b/pkg/disttask/framework/testutil/scheduler_util.go index 4a8bfb43da69d..cc307e4b36c32 100644 --- a/pkg/disttask/framework/testutil/scheduler_util.go +++ b/pkg/disttask/framework/testutil/scheduler_util.go @@ -21,7 +21,6 @@ import ( "github.com/pingcap/tidb/pkg/disttask/framework/proto" "github.com/pingcap/tidb/pkg/disttask/framework/scheduler" mockDispatch "github.com/pingcap/tidb/pkg/disttask/framework/scheduler/mock" - "github.com/pingcap/tidb/pkg/domain/infosync" "go.uber.org/mock/gomock" ) @@ -30,8 +29,8 @@ func GetMockBasicSchedulerExt(ctrl *gomock.Controller) scheduler.Extension { mockScheduler := mockDispatch.NewMockExtension(ctrl) mockScheduler.EXPECT().OnTick(gomock.Any(), gomock.Any()).Return().AnyTimes() mockScheduler.EXPECT().GetEligibleInstances(gomock.Any(), gomock.Any()).DoAndReturn( - func(_ context.Context, _ *proto.Task) ([]*infosync.ServerInfo, bool, error) { - return generateTaskExecutorNodes4Test() + func(_ context.Context, _ *proto.Task) ([]string, error) { + return nil, nil }, ).AnyTimes() mockScheduler.EXPECT().IsRetryableErr(gomock.Any()).Return(true).AnyTimes() @@ -48,7 +47,7 @@ func GetMockBasicSchedulerExt(ctrl *gomock.Controller) scheduler.Extension { }, ).AnyTimes() mockScheduler.EXPECT().OnNextSubtasksBatch(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn( - func(_ context.Context, _ scheduler.TaskHandle, task *proto.Task, _ []*infosync.ServerInfo, _ proto.Step) (metas [][]byte, err error) { + func(_ context.Context, _ scheduler.TaskHandle, task *proto.Task, _ []string, _ proto.Step) (metas [][]byte, err error) { if task.Step == proto.StepInit { return [][]byte{ []byte("task1"), @@ -74,8 +73,8 @@ func GetMockHATestSchedulerExt(ctrl *gomock.Controller) scheduler.Extension { mockScheduler := mockDispatch.NewMockExtension(ctrl) mockScheduler.EXPECT().OnTick(gomock.Any(), gomock.Any()).Return().AnyTimes() mockScheduler.EXPECT().GetEligibleInstances(gomock.Any(), gomock.Any()).DoAndReturn( - func(_ context.Context, _ *proto.Task) ([]*infosync.ServerInfo, bool, error) { - return generateTaskExecutorNodes4Test() + func(_ context.Context, _ *proto.Task) ([]string, error) { + return nil, nil }, ).AnyTimes() mockScheduler.EXPECT().IsRetryableErr(gomock.Any()).Return(true).AnyTimes() @@ -92,7 +91,7 @@ func GetMockHATestSchedulerExt(ctrl *gomock.Controller) scheduler.Extension { }, ).AnyTimes() mockScheduler.EXPECT().OnNextSubtasksBatch(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn( - func(_ context.Context, _ scheduler.TaskHandle, task *proto.Task, _ []*infosync.ServerInfo, _ proto.Step) (metas [][]byte, err error) { + func(_ context.Context, _ scheduler.TaskHandle, task *proto.Task, _ []string, _ proto.Step) (metas [][]byte, err error) { if task.Step == proto.StepInit { return [][]byte{ []byte("task1"), @@ -125,26 +124,13 @@ func GetMockHATestSchedulerExt(ctrl *gomock.Controller) scheduler.Extension { return mockScheduler } -func generateTaskExecutorNodes4Test() ([]*infosync.ServerInfo, bool, error) { - serverInfos := infosync.MockGlobalServerInfoManagerEntry.GetAllServerInfo() - if len(serverInfos) == 0 { - return nil, true, errors.New("not found instance") - } - - serverNodes := make([]*infosync.ServerInfo, 0, len(serverInfos)) - for _, serverInfo := range serverInfos { - serverNodes = append(serverNodes, serverInfo) - } - return serverNodes, true, nil -} - // GetPlanNotRetryableErrSchedulerExt returns mock scheduler.Extension which will generate non retryable error when planning. func GetPlanNotRetryableErrSchedulerExt(ctrl *gomock.Controller) scheduler.Extension { mockScheduler := mockDispatch.NewMockExtension(ctrl) mockScheduler.EXPECT().OnTick(gomock.Any(), gomock.Any()).Return().AnyTimes() mockScheduler.EXPECT().GetEligibleInstances(gomock.Any(), gomock.Any()).DoAndReturn( - func(_ context.Context, _ *proto.Task) ([]*infosync.ServerInfo, bool, error) { - return generateTaskExecutorNodes4Test() + func(_ context.Context, _ *proto.Task) ([]string, error) { + return nil, nil }, ).AnyTimes() mockScheduler.EXPECT().IsRetryableErr(gomock.Any()).Return(false).AnyTimes() @@ -157,7 +143,7 @@ func GetPlanNotRetryableErrSchedulerExt(ctrl *gomock.Controller) scheduler.Exten }, ).AnyTimes() mockScheduler.EXPECT().OnNextSubtasksBatch(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn( - func(_ context.Context, _ scheduler.TaskHandle, _ *proto.Task, _ []*infosync.ServerInfo, _ proto.Step) (metas [][]byte, err error) { + func(_ context.Context, _ scheduler.TaskHandle, _ *proto.Task, _ []string, _ proto.Step) (metas [][]byte, err error) { return nil, errors.New("not retryable err") }, ).AnyTimes() @@ -171,8 +157,8 @@ func GetPlanErrSchedulerExt(ctrl *gomock.Controller, testContext *TestContext) s mockScheduler := mockDispatch.NewMockExtension(ctrl) mockScheduler.EXPECT().OnTick(gomock.Any(), gomock.Any()).Return().AnyTimes() mockScheduler.EXPECT().GetEligibleInstances(gomock.Any(), gomock.Any()).DoAndReturn( - func(_ context.Context, _ *proto.Task) ([]*infosync.ServerInfo, bool, error) { - return generateTaskExecutorNodes4Test() + func(_ context.Context, _ *proto.Task) ([]string, error) { + return nil, nil }, ).AnyTimes() mockScheduler.EXPECT().IsRetryableErr(gomock.Any()).Return(true).AnyTimes() @@ -189,7 +175,7 @@ func GetPlanErrSchedulerExt(ctrl *gomock.Controller, testContext *TestContext) s }, ).AnyTimes() mockScheduler.EXPECT().OnNextSubtasksBatch(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn( - func(_ context.Context, _ scheduler.TaskHandle, task *proto.Task, _ []*infosync.ServerInfo, _ proto.Step) (metas [][]byte, err error) { + func(_ context.Context, _ scheduler.TaskHandle, task *proto.Task, _ []string, _ proto.Step) (metas [][]byte, err error) { if task.Step == proto.StepInit { if testContext.CallTime == 0 { testContext.CallTime++ @@ -222,8 +208,8 @@ func GetMockRollbackSchedulerExt(ctrl *gomock.Controller) scheduler.Extension { mockScheduler := mockDispatch.NewMockExtension(ctrl) mockScheduler.EXPECT().OnTick(gomock.Any(), gomock.Any()).Return().AnyTimes() mockScheduler.EXPECT().GetEligibleInstances(gomock.Any(), gomock.Any()).DoAndReturn( - func(_ context.Context, _ *proto.Task) ([]*infosync.ServerInfo, bool, error) { - return generateTaskExecutorNodes4Test() + func(_ context.Context, _ *proto.Task) ([]string, error) { + return nil, nil }, ).AnyTimes() mockScheduler.EXPECT().IsRetryableErr(gomock.Any()).Return(true).AnyTimes() @@ -238,7 +224,7 @@ func GetMockRollbackSchedulerExt(ctrl *gomock.Controller) scheduler.Extension { }, ).AnyTimes() mockScheduler.EXPECT().OnNextSubtasksBatch(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn( - func(_ context.Context, _ scheduler.TaskHandle, task *proto.Task, _ []*infosync.ServerInfo, _ proto.Step) (metas [][]byte, err error) { + func(_ context.Context, _ scheduler.TaskHandle, task *proto.Task, _ []string, _ proto.Step) (metas [][]byte, err error) { if task.Step == proto.StepInit { return [][]byte{ []byte("task1"), @@ -259,8 +245,8 @@ func GetMockDynamicDispatchExt(ctrl *gomock.Controller) scheduler.Extension { mockScheduler := mockDispatch.NewMockExtension(ctrl) mockScheduler.EXPECT().OnTick(gomock.Any(), gomock.Any()).Return().AnyTimes() mockScheduler.EXPECT().GetEligibleInstances(gomock.Any(), gomock.Any()).DoAndReturn( - func(_ context.Context, _ *proto.Task) ([]*infosync.ServerInfo, bool, error) { - return generateTaskExecutorNodes4Test() + func(_ context.Context, _ *proto.Task) ([]string, error) { + return nil, nil }, ).AnyTimes() mockScheduler.EXPECT().IsRetryableErr(gomock.Any()).Return(true).AnyTimes() @@ -277,7 +263,7 @@ func GetMockDynamicDispatchExt(ctrl *gomock.Controller) scheduler.Extension { }, ).AnyTimes() mockScheduler.EXPECT().OnNextSubtasksBatch(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn( - func(_ context.Context, _ scheduler.TaskHandle, task *proto.Task, _ []*infosync.ServerInfo, _ proto.Step) (metas [][]byte, err error) { + func(_ context.Context, _ scheduler.TaskHandle, task *proto.Task, _ []string, _ proto.Step) (metas [][]byte, err error) { if task.Step == proto.StepInit { return [][]byte{ []byte("task"), diff --git a/pkg/disttask/importinto/BUILD.bazel b/pkg/disttask/importinto/BUILD.bazel index b018ad963b30c..a06e2351ee09c 100644 --- a/pkg/disttask/importinto/BUILD.bazel +++ b/pkg/disttask/importinto/BUILD.bazel @@ -57,6 +57,7 @@ go_library( "//pkg/util", "//pkg/util/backoff", "//pkg/util/dbterror/exeerrors", + "//pkg/util/disttask", "//pkg/util/etcd", "//pkg/util/logutil", "//pkg/util/mathutil", diff --git a/pkg/disttask/importinto/scheduler.go b/pkg/disttask/importinto/scheduler.go index b73cf888d7685..d7a1997b68762 100644 --- a/pkg/disttask/importinto/scheduler.go +++ b/pkg/disttask/importinto/scheduler.go @@ -35,12 +35,12 @@ import ( "github.com/pingcap/tidb/pkg/disttask/framework/proto" "github.com/pingcap/tidb/pkg/disttask/framework/scheduler" "github.com/pingcap/tidb/pkg/disttask/framework/storage" - "github.com/pingcap/tidb/pkg/domain/infosync" "github.com/pingcap/tidb/pkg/errno" "github.com/pingcap/tidb/pkg/executor/importer" "github.com/pingcap/tidb/pkg/parser/ast" "github.com/pingcap/tidb/pkg/sessionctx" "github.com/pingcap/tidb/pkg/util/backoff" + disttaskutil "github.com/pingcap/tidb/pkg/util/disttask" "github.com/pingcap/tidb/pkg/util/etcd" "github.com/pingcap/tidb/pkg/util/logutil" "github.com/pingcap/tidb/pkg/util/sqlexec" @@ -201,7 +201,7 @@ func (sch *ImportSchedulerExt) OnNextSubtasksBatch( ctx context.Context, taskHandle scheduler.TaskHandle, task *proto.Task, - serverInfos []*infosync.ServerInfo, + execIDs []string, nextStep proto.Step, ) ( resSubtaskMeta [][]byte, err error) { @@ -290,7 +290,7 @@ func (sch *ImportSchedulerExt) OnNextSubtasksBatch( PreviousSubtaskMetas: previousSubtaskMetas, GlobalSort: sch.GlobalSort, NextTaskStep: nextStep, - ExecuteNodesCnt: len(serverInfos), + ExecuteNodesCnt: len(execIDs), } logicalPlan := &LogicalPlan{} if err := logicalPlan.FromTaskMeta(task.Meta); err != nil { @@ -331,17 +331,17 @@ func (sch *ImportSchedulerExt) OnDone(ctx context.Context, handle scheduler.Task } // GetEligibleInstances implements scheduler.Extension interface. -func (*ImportSchedulerExt) GetEligibleInstances(ctx context.Context, task *proto.Task) ([]*infosync.ServerInfo, bool, error) { +func (*ImportSchedulerExt) GetEligibleInstances(_ context.Context, task *proto.Task) ([]string, error) { taskMeta := &TaskMeta{} err := json.Unmarshal(task.Meta, taskMeta) if err != nil { - return nil, true, errors.Trace(err) + return nil, errors.Trace(err) } - if len(taskMeta.EligibleInstances) > 0 { - return taskMeta.EligibleInstances, false, nil + res := make([]string, 0, len(taskMeta.EligibleInstances)) + for _, instance := range taskMeta.EligibleInstances { + res = append(res, disttaskutil.GenerateExecID(instance)) } - serverInfo, err := scheduler.GenerateTaskExecutorNodes(ctx) - return serverInfo, true, err + return res, nil } // IsRetryableErr implements scheduler.Extension interface. @@ -406,11 +406,11 @@ type importScheduler struct { } func newImportScheduler(ctx context.Context, taskMgr scheduler.TaskManager, - serverID string, task *proto.Task) scheduler.Scheduler { + nodeMgr *scheduler.NodeManager, task *proto.Task) scheduler.Scheduler { metrics := metricsManager.getOrCreateMetrics(task.ID) subCtx := metric.WithCommonMetric(ctx, metrics) sch := importScheduler{ - BaseScheduler: scheduler.NewBaseScheduler(subCtx, taskMgr, serverID, task), + BaseScheduler: scheduler.NewBaseScheduler(subCtx, taskMgr, nodeMgr, task), } return &sch } diff --git a/pkg/disttask/importinto/scheduler_test.go b/pkg/disttask/importinto/scheduler_test.go index 53dbf593102d4..b0835bbb853bf 100644 --- a/pkg/disttask/importinto/scheduler_test.go +++ b/pkg/disttask/importinto/scheduler_test.go @@ -17,13 +17,11 @@ package importinto import ( "context" "encoding/json" - "fmt" "testing" "github.com/pingcap/failpoint" "github.com/pingcap/tidb/pkg/disttask/framework/proto" "github.com/pingcap/tidb/pkg/disttask/framework/scheduler" - "github.com/pingcap/tidb/pkg/domain/infosync" "github.com/pingcap/tidb/pkg/executor/importer" "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" @@ -45,39 +43,18 @@ func (s *importIntoSuite) enableFailPoint(path, term string) { } func (s *importIntoSuite) TestSchedulerGetEligibleInstances() { - makeFailpointRes := func(v interface{}) string { - bytes, err := json.Marshal(v) - s.NoError(err) - return fmt.Sprintf("return(`%s`)", string(bytes)) - } - uuids := []string{"ddl_id_1", "ddl_id_2"} - serverInfoMap := map[string]*infosync.ServerInfo{ - uuids[0]: { - ID: uuids[0], - }, - uuids[1]: { - ID: uuids[1], - }, - } - mockedAllServerInfos := makeFailpointRes(serverInfoMap) - sch := ImportSchedulerExt{} task := &proto.Task{Meta: []byte("{}")} ctx := context.WithValue(context.Background(), "etcd", true) - s.enableFailPoint("github.com/pingcap/tidb/pkg/domain/infosync/mockGetAllServerInfo", mockedAllServerInfos) - eligibleInstances, _, err := sch.GetEligibleInstances(ctx, task) + eligibleInstances, err := sch.GetEligibleInstances(ctx, task) s.NoError(err) // order of slice is not stable, change to map - resultMap := map[string]*infosync.ServerInfo{} - for _, ins := range eligibleInstances { - resultMap[ins.ID] = ins - } - s.Equal(serverInfoMap, resultMap) + s.Empty(eligibleInstances) task.Meta = []byte(`{"EligibleInstances":[{"ip": "1.1.1.1", "listening_port": 4000}]}`) - eligibleInstances, _, err = sch.GetEligibleInstances(ctx, task) + eligibleInstances, err = sch.GetEligibleInstances(ctx, task) s.NoError(err) - s.Equal([]*infosync.ServerInfo{{IP: "1.1.1.1", Port: 4000}}, eligibleInstances) + s.Equal([]string{"1.1.1.1:4000"}, eligibleInstances) } func (s *importIntoSuite) TestUpdateCurrentTask() { diff --git a/pkg/disttask/importinto/scheduler_testkit_test.go b/pkg/disttask/importinto/scheduler_testkit_test.go index 80cb38fca6ff9..0a3fea180d04d 100644 --- a/pkg/disttask/importinto/scheduler_testkit_test.go +++ b/pkg/disttask/importinto/scheduler_testkit_test.go @@ -92,9 +92,7 @@ func TestSchedulerExtLocalSort(t *testing.T) { // to import stage, job should be running d := sch.MockScheduler(task) ext := importinto.ImportSchedulerExt{} - serverInfos, _, err := ext.GetEligibleInstances(context.Background(), task) - require.NoError(t, err) - subtaskMetas, err := ext.OnNextSubtasksBatch(ctx, d, task, serverInfos, ext.GetNextStep(task)) + subtaskMetas, err := ext.OnNextSubtasksBatch(ctx, d, task, []string{":4000"}, ext.GetNextStep(task)) require.NoError(t, err) require.Len(t, subtaskMetas, 1) task.Step = ext.GetNextStep(task) @@ -115,7 +113,7 @@ func TestSchedulerExtLocalSort(t *testing.T) { require.NoError(t, manager.FinishSubtask(ctx, s.ExecID, s.ID, []byte("{}"))) } // to post-process stage, job should be running and in validating step - subtaskMetas, err = ext.OnNextSubtasksBatch(ctx, d, task, serverInfos, ext.GetNextStep(task)) + subtaskMetas, err = ext.OnNextSubtasksBatch(ctx, d, task, []string{":4000"}, ext.GetNextStep(task)) require.NoError(t, err) require.Len(t, subtaskMetas, 1) task.Step = ext.GetNextStep(task) @@ -125,7 +123,7 @@ func TestSchedulerExtLocalSort(t *testing.T) { require.Equal(t, "running", gotJobInfo.Status) require.Equal(t, "validating", gotJobInfo.Step) // on next stage, job should be finished - subtaskMetas, err = ext.OnNextSubtasksBatch(ctx, d, task, serverInfos, ext.GetNextStep(task)) + subtaskMetas, err = ext.OnNextSubtasksBatch(ctx, d, task, []string{":4000"}, ext.GetNextStep(task)) require.NoError(t, err) require.Len(t, subtaskMetas, 0) task.Step = ext.GetNextStep(task) @@ -239,9 +237,7 @@ func TestSchedulerExtGlobalSort(t *testing.T) { ext := importinto.ImportSchedulerExt{ GlobalSort: true, } - serverInfos, _, err := ext.GetEligibleInstances(context.Background(), task) - require.NoError(t, err) - subtaskMetas, err := ext.OnNextSubtasksBatch(ctx, d, task, serverInfos, ext.GetNextStep(task)) + subtaskMetas, err := ext.OnNextSubtasksBatch(ctx, d, task, []string{":4000"}, ext.GetNextStep(task)) require.NoError(t, err) require.Len(t, subtaskMetas, 2) task.Step = ext.GetNextStep(task) @@ -298,7 +294,7 @@ func TestSchedulerExtGlobalSort(t *testing.T) { t.Cleanup(func() { require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/pkg/disttask/importinto/forceMergeSort")) }) - subtaskMetas, err = ext.OnNextSubtasksBatch(ctx, d, task, serverInfos, ext.GetNextStep(task)) + subtaskMetas, err = ext.OnNextSubtasksBatch(ctx, d, task, []string{":4000"}, ext.GetNextStep(task)) require.NoError(t, err) require.Len(t, subtaskMetas, 1) task.Step = ext.GetNextStep(task) @@ -336,7 +332,7 @@ func TestSchedulerExtGlobalSort(t *testing.T) { t.Cleanup(func() { require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/pkg/disttask/importinto/mockWriteIngestSpecs")) }) - subtaskMetas, err = ext.OnNextSubtasksBatch(ctx, d, task, serverInfos, ext.GetNextStep(task)) + subtaskMetas, err = ext.OnNextSubtasksBatch(ctx, d, task, []string{":4000"}, ext.GetNextStep(task)) require.NoError(t, err) require.Len(t, subtaskMetas, 2) task.Step = ext.GetNextStep(task) @@ -346,7 +342,7 @@ func TestSchedulerExtGlobalSort(t *testing.T) { require.Equal(t, "running", gotJobInfo.Status) require.Equal(t, "importing", gotJobInfo.Step) // on next stage, to post-process stage - subtaskMetas, err = ext.OnNextSubtasksBatch(ctx, d, task, serverInfos, ext.GetNextStep(task)) + subtaskMetas, err = ext.OnNextSubtasksBatch(ctx, d, task, []string{":4000"}, ext.GetNextStep(task)) require.NoError(t, err) require.Len(t, subtaskMetas, 1) task.Step = ext.GetNextStep(task) @@ -356,7 +352,7 @@ func TestSchedulerExtGlobalSort(t *testing.T) { require.Equal(t, "running", gotJobInfo.Status) require.Equal(t, "validating", gotJobInfo.Step) // next stage, done - subtaskMetas, err = ext.OnNextSubtasksBatch(ctx, d, task, serverInfos, ext.GetNextStep(task)) + subtaskMetas, err = ext.OnNextSubtasksBatch(ctx, d, task, []string{":4000"}, ext.GetNextStep(task)) require.NoError(t, err) require.Len(t, subtaskMetas, 0) task.Step = ext.GetNextStep(task) diff --git a/pkg/util/disttask/BUILD.bazel b/pkg/util/disttask/BUILD.bazel index dc6f7fb37ff53..e418bc836174e 100644 --- a/pkg/util/disttask/BUILD.bazel +++ b/pkg/util/disttask/BUILD.bazel @@ -14,5 +14,8 @@ go_test( srcs = ["idservice_test.go"], embed = [":disttask"], flaky = True, - deps = ["@com_github_stretchr_testify//require"], + deps = [ + "//pkg/domain/infosync", + "@com_github_stretchr_testify//require", + ], ) diff --git a/pkg/util/disttask/idservice.go b/pkg/util/disttask/idservice.go index 74ecac939b9b1..239760a7d89ca 100644 --- a/pkg/util/disttask/idservice.go +++ b/pkg/util/disttask/idservice.go @@ -25,9 +25,8 @@ import ( // GenerateExecID used to generate IP:port as exec_id value // This function is used by distributed task execution to generate serverID string to // correlated one subtask to on TiDB node to be executed. -func GenerateExecID(ip string, port uint) string { - portstring := fmt.Sprintf("%d", port) - return net.JoinHostPort(ip, portstring) +func GenerateExecID(info *infosync.ServerInfo) string { + return net.JoinHostPort(info.IP, fmt.Sprintf("%d", info.Port)) } // MatchServerInfo will check if the schedulerID matched in all serverInfos. @@ -38,7 +37,7 @@ func MatchServerInfo(serverInfos []*infosync.ServerInfo, schedulerID string) boo // FindServerInfo will find the schedulerID in all serverInfos. func FindServerInfo(serverInfos []*infosync.ServerInfo, schedulerID string) int { for i, serverInfo := range serverInfos { - serverID := GenerateExecID(serverInfo.IP, serverInfo.Port) + serverID := GenerateExecID(serverInfo) if serverID == schedulerID { return i } @@ -63,7 +62,7 @@ func GenerateSubtaskExecID(ctx context.Context, id string) string { return "" } if serverNode, ok := serverInfos[id]; ok { - return GenerateExecID(serverNode.IP, serverNode.Port) + return GenerateExecID(serverNode) } return "" } @@ -75,7 +74,7 @@ func GenerateSubtaskExecID4Test(id string) string { return "" } if serverNode, ok := serverInfos[id]; ok { - return GenerateExecID(serverNode.IP, serverNode.Port) + return GenerateExecID(serverNode) } return "" } diff --git a/pkg/util/disttask/idservice_test.go b/pkg/util/disttask/idservice_test.go index d14b26a1439a1..d991b6b4b80b0 100644 --- a/pkg/util/disttask/idservice_test.go +++ b/pkg/util/disttask/idservice_test.go @@ -17,21 +17,22 @@ package disttaskutil import ( "testing" + "github.com/pingcap/tidb/pkg/domain/infosync" "github.com/stretchr/testify/require" ) // This testCase show GenerateExecID only generate string by input parametas func TestGenServerID(t *testing.T) { var str string - serverIO := GenerateExecID("", 0) + serverIO := GenerateExecID(&infosync.ServerInfo{IP: "", Port: 0}) require.Equal(t, serverIO, ":0") - serverIO = GenerateExecID("10.124.122.25", 3456) + serverIO = GenerateExecID(&infosync.ServerInfo{IP: "10.124.122.25", Port: 3456}) require.Equal(t, serverIO, "10.124.122.25:3456") - serverIO = GenerateExecID("10.124", 3456) + serverIO = GenerateExecID(&infosync.ServerInfo{IP: "10.124", Port: 3456}) require.Equal(t, serverIO, "10.124:3456") - serverIO = GenerateExecID(str, 65537) + serverIO = GenerateExecID(&infosync.ServerInfo{IP: str, Port: 65537}) require.Equal(t, serverIO, ":65537") // IPv6 testcase - serverIO = GenerateExecID("ABCD:EF01:2345:6789:ABCD:EF01:2345:6789", 65537) + serverIO = GenerateExecID(&infosync.ServerInfo{IP: "ABCD:EF01:2345:6789:ABCD:EF01:2345:6789", Port: 65537}) require.Equal(t, serverIO, "[ABCD:EF01:2345:6789:ABCD:EF01:2345:6789]:65537") }