Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

disttask: subtasks rebalance during task execution (#48306) #49029

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion pkg/ddl/backfilling_dispatcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,6 @@ func generateNonPartitionPlan(
if err != nil {
return nil, err
}

regionBatch := calculateRegionBatch(len(recordRegionMetas), instanceCnt, !useCloud)

subTaskMetas := make([][]byte, 0, 4)
Expand Down
5 changes: 5 additions & 0 deletions pkg/disttask/framework/dispatcher/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,16 @@ go_test(
"dispatcher_manager_test.go",
"dispatcher_test.go",
"main_test.go",
"rebalance_test.go",
],
embed = [":dispatcher"],
flaky = True,
race = "off",
<<<<<<< HEAD
shard_count = 15,
=======
shard_count = 19,
>>>>>>> ef02d728739 (disttask: subtasks rebalance during task execution (#48306))
deps = [
"//pkg/disttask/framework/mock",
"//pkg/disttask/framework/proto",
Expand Down
167 changes: 125 additions & 42 deletions pkg/disttask/framework/dispatcher/dispatcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,14 +87,15 @@ type BaseDispatcher struct {
// when RegisterDispatcherFactory, the factory MUST initialize this field.
Extension

// for HA
// liveNodes will fetch and store all live nodes every liveNodeInterval ticks.
liveNodes []*infosync.ServerInfo
// For subtasks rebalance.
// LiveNodes will fetch and store all live nodes every liveNodeInterval ticks.
LiveNodes []*infosync.ServerInfo
liveNodeFetchInterval int
// liveNodeFetchTick is the tick variable.
liveNodeFetchTick int
// taskNodes stores the id of current scheduler nodes.
taskNodes []string
// TaskNodes stores the id of current scheduler nodes.
TaskNodes []string

// rand is for generating random selection of nodes.
rand *rand.Rand
}
Expand All @@ -112,10 +113,10 @@ func NewBaseDispatcher(ctx context.Context, taskMgr TaskManager, serverID string
Task: task,
logCtx: logCtx,
serverID: serverID,
liveNodes: nil,
LiveNodes: nil,
liveNodeFetchInterval: DefaultLiveNodesCheckInterval,
liveNodeFetchTick: 0,
taskNodes: nil,
TaskNodes: nil,
rand: rand.New(rand.NewSource(time.Now().UnixNano())),
}
}
Expand Down Expand Up @@ -259,7 +260,7 @@ func (d *BaseDispatcher) onPausing() error {
// MockDMLExecutionOnPausedState is used to mock DML execution when tasks paused.
var MockDMLExecutionOnPausedState func(task *proto.Task)

// handle task in paused state
// handle task in paused state.
func (d *BaseDispatcher) onPaused() error {
logutil.Logger(d.logCtx).Info("on paused state", zap.Stringer("state", d.Task.State), zap.Int64("stage", int64(d.Task.Step)))
failpoint.Inject("mockDMLExecutionOnPausedState", func(val failpoint.Value) {
Expand All @@ -273,7 +274,7 @@ func (d *BaseDispatcher) onPaused() error {
// TestSyncChan is used to sync the test.
var TestSyncChan = make(chan struct{})

// handle task in resuming state
// handle task in resuming state.
func (d *BaseDispatcher) onResuming() error {
logutil.Logger(d.logCtx).Info("on resuming state", zap.Stringer("state", d.Task.State), zap.Int64("stage", int64(d.Task.Step)))
cnt, err := d.taskMgr.GetSubtaskInStatesCnt(d.ctx, d.Task.ID, proto.TaskStatePaused)
Expand Down Expand Up @@ -342,8 +343,8 @@ func (d *BaseDispatcher) onRunning() error {
if cnt == 0 {
return d.onNextStage()
}
// Check if any node are down.
if err := d.replaceDeadNodesIfAny(); err != nil {

if err := d.BalanceSubtasks(); err != nil {
return err
}
// Wait all subtasks in this stage finished.
Expand All @@ -358,16 +359,19 @@ func (d *BaseDispatcher) onFinished() error {
return d.taskMgr.TransferSubTasks2History(d.ctx, d.Task.ID)
}

func (d *BaseDispatcher) replaceDeadNodesIfAny() error {
if len(d.taskNodes) == 0 {
// BalanceSubtasks check the liveNode num every liveNodeFetchInterval then rebalance subtasks.
func (d *BaseDispatcher) BalanceSubtasks() error {
// 1. init TaskNodes if needed.
if len(d.TaskNodes) == 0 {
var err error
d.taskNodes, err = d.taskMgr.GetSchedulerIDsByTaskIDAndStep(d.ctx, d.Task.ID, d.Task.Step)
d.TaskNodes, err = d.taskMgr.GetSchedulerIDsByTaskIDAndStep(d.ctx, d.Task.ID, d.Task.Step)
if err != nil {
return err
}
}
d.liveNodeFetchTick++
if d.liveNodeFetchTick == d.liveNodeFetchInterval {
// 2. update LiveNodes.
d.liveNodeFetchTick = 0
serverInfos, err := GenerateSchedulerNodes(d.ctx)
if err != nil {
Expand Down Expand Up @@ -397,37 +401,116 @@ func (d *BaseDispatcher) replaceDeadNodesIfAny() error {
newInfos = append(newInfos, m)
}
}
d.liveNodes = newInfos
}
if len(d.liveNodes) > 0 {
replaceNodes := make(map[string]string)
cleanNodes := make([]string, 0)
for _, nodeID := range d.taskNodes {
if ok := disttaskutil.MatchServerInfo(d.liveNodes, nodeID); !ok {
n := d.liveNodes[d.rand.Int()%len(d.liveNodes)] //nolint:gosec
replaceNodes[nodeID] = disttaskutil.GenerateExecID(n.IP, n.Port)
cleanNodes = append(cleanNodes, nodeID)
}
d.LiveNodes = newInfos
// 3. balance subtasks.
if len(d.LiveNodes) > 0 {
return d.ReDispatchSubtasks()
}
if len(replaceNodes) > 0 {
logutil.Logger(d.logCtx).Info("reschedule subtasks to other nodes", zap.Int("node-cnt", len(replaceNodes)))
if err := d.taskMgr.UpdateFailedSchedulerIDs(d.ctx, d.Task.ID, replaceNodes); err != nil {
return err
}
if err := d.taskMgr.CleanUpMeta(d.ctx, cleanNodes); err != nil {
return err
return nil
}
return nil
}

func (d *BaseDispatcher) replaceTaskNodes() {
d.TaskNodes = d.TaskNodes[:0]
for _, serverInfo := range d.LiveNodes {
d.TaskNodes = append(d.TaskNodes, disttaskutil.GenerateExecID(serverInfo.IP, serverInfo.Port))
}
}

// ReDispatchSubtasks make count of subtasks on each liveNodes balanced and clean up subtasks on dead nodes.
// TODO(ywqzzy): refine to make it easier for testing.
func (d *BaseDispatcher) ReDispatchSubtasks() error {
// 1. find out nodes need to clean subtasks.
deadNodes := make([]string, 0)
deadNodesMap := make(map[string]bool, 0)
for _, node := range d.TaskNodes {
if !disttaskutil.MatchServerInfo(d.LiveNodes, node) {
deadNodes = append(deadNodes, node)
deadNodesMap[node] = true
}
}
// 2. get subtasks for each node before rebalance.
subtasks, err := d.taskMgr.GetSubtasksByStepAndState(d.ctx, d.Task.ID, d.Task.Step, proto.TaskStatePending)
if err != nil {
return err
}
if len(deadNodes) != 0 {
/// get subtask from deadNodes, since there might be some running subtasks on deadNodes.
/// In this case, all subtasks on deadNodes are in running/pending state.
subtasksOnDeadNodes, err := d.taskMgr.GetSubtasksByExecIdsAndStepAndState(d.ctx, deadNodes, d.Task.ID, d.Task.Step, proto.TaskStateRunning)
if err != nil {
return err
}
subtasks = append(subtasks, subtasksOnDeadNodes...)
}
// 3. group subtasks for each scheduler.
subtasksOnScheduler := make(map[string][]*proto.Subtask, len(d.LiveNodes)+len(deadNodes))
for _, node := range d.LiveNodes {
execID := disttaskutil.GenerateExecID(node.IP, node.Port)
subtasksOnScheduler[execID] = make([]*proto.Subtask, 0)
}
for _, subtask := range subtasks {
subtasksOnScheduler[subtask.SchedulerID] = append(
subtasksOnScheduler[subtask.SchedulerID],
subtask)
}
// 4. prepare subtasks that need to rebalance to other nodes.
averageSubtaskCnt := len(subtasks) / len(d.LiveNodes)
rebalanceSubtasks := make([]*proto.Subtask, 0)
for k, v := range subtasksOnScheduler {
if ok := deadNodesMap[k]; ok {
rebalanceSubtasks = append(rebalanceSubtasks, v...)
continue
}
// When no tidb scale-in/out and averageSubtaskCnt*len(d.LiveNodes) < len(subtasks),
// no need to send subtask to other nodes.
// eg: tidb1 with 3 subtasks, tidb2 with 2 subtasks, subtasks are balanced now.
if averageSubtaskCnt*len(d.LiveNodes) < len(subtasks) && len(d.TaskNodes) == len(d.LiveNodes) {
if len(v) > averageSubtaskCnt+1 {
rebalanceSubtasks = append(rebalanceSubtasks, v[0:len(v)-averageSubtaskCnt]...)
}
// replace local cache.
for k, v := range replaceNodes {
for m, n := range d.taskNodes {
if n == k {
d.taskNodes[m] = v
break
}
continue
}
if len(v) > averageSubtaskCnt {
rebalanceSubtasks = append(rebalanceSubtasks, v[0:len(v)-averageSubtaskCnt]...)
}
}
// 5. skip rebalance.
if len(rebalanceSubtasks) == 0 {
return nil
}
// 6.rebalance subtasks to other nodes.
rebalanceIdx := 0
for k, v := range subtasksOnScheduler {
if ok := deadNodesMap[k]; !ok {
if len(v) < averageSubtaskCnt {
for i := 0; i < averageSubtaskCnt-len(v) && rebalanceIdx < len(rebalanceSubtasks); i++ {
rebalanceSubtasks[rebalanceIdx].SchedulerID = k
rebalanceIdx++
}
}
}
}
// 7. rebalance rest subtasks evenly to liveNodes.
liveNodeIdx := 0
for rebalanceIdx < len(rebalanceSubtasks) {
node := d.LiveNodes[liveNodeIdx]
rebalanceSubtasks[rebalanceIdx].SchedulerID = disttaskutil.GenerateExecID(node.IP, node.Port)
rebalanceIdx++
liveNodeIdx++
}

// 8. update subtasks and do clean up logic.
if err = d.taskMgr.UpdateSubtasksSchedulerIDs(d.ctx, d.Task.ID, subtasks); err != nil {
return err
}
logutil.Logger(d.logCtx).Info("rebalance subtasks",
zap.Stringers("subtasks-rebalanced", subtasks))
if err = d.taskMgr.CleanUpMeta(d.ctx, deadNodes); err != nil {
return err
}
d.replaceTaskNodes()
return nil
}

Expand Down Expand Up @@ -602,9 +685,9 @@ func (d *BaseDispatcher) dispatchSubTask(
metas [][]byte,
serverNodes []*infosync.ServerInfo) error {
logutil.Logger(d.logCtx).Info("dispatch subtasks", zap.Stringer("state", d.Task.State), zap.Int64("step", int64(d.Task.Step)), zap.Uint64("concurrency", d.Task.Concurrency), zap.Int("subtasks", len(metas)))
d.taskNodes = make([]string, len(serverNodes))
d.TaskNodes = make([]string, len(serverNodes))
for i := range serverNodes {
d.taskNodes[i] = disttaskutil.GenerateExecID(serverNodes[i].IP, serverNodes[i].Port)
d.TaskNodes[i] = disttaskutil.GenerateExecID(serverNodes[i].IP, serverNodes[i].Port)
}
subTasks := make([]*proto.Subtask, 0, len(metas))
for i, meta := range metas {
Expand Down Expand Up @@ -708,7 +791,7 @@ func (d *BaseDispatcher) GetAllSchedulerIDs(ctx context.Context, task *proto.Tas

// GetPreviousSubtaskMetas get subtask metas from specific step.
func (d *BaseDispatcher) GetPreviousSubtaskMetas(taskID int64, step proto.Step) ([][]byte, error) {
previousSubtasks, err := d.taskMgr.GetSucceedSubtasksByStep(d.ctx, taskID, step)
previousSubtasks, err := d.taskMgr.GetSubtasksByStepAndState(d.ctx, taskID, step, proto.TaskStateSucceed)
if err != nil {
logutil.Logger(d.logCtx).Warn("get previous succeed subtask failed", zap.Int64("step", int64(step)))
return nil, err
Expand Down
5 changes: 3 additions & 2 deletions pkg/disttask/framework/dispatcher/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,11 @@ type TaskManager interface {
ResumeSubtasks(ctx context.Context, taskID int64) error
CollectSubTaskError(ctx context.Context, taskID int64) ([]error, error)
TransferSubTasks2History(ctx context.Context, taskID int64) error
UpdateFailedSchedulerIDs(ctx context.Context, taskID int64, replaceNodes map[string]string) error
UpdateSubtasksSchedulerIDs(ctx context.Context, taskID int64, subtasks []*proto.Subtask) error
GetNodesByRole(ctx context.Context, role string) (map[string]bool, error)
GetSchedulerIDsByTaskID(ctx context.Context, taskID int64) ([]string, error)
GetSucceedSubtasksByStep(ctx context.Context, taskID int64, step proto.Step) ([]*proto.Subtask, error)
GetSubtasksByStepAndState(ctx context.Context, taskID int64, step proto.Step, state proto.TaskState) ([]*proto.Subtask, error)
GetSubtasksByExecIdsAndStepAndState(ctx context.Context, tidbIDs []string, taskID int64, step proto.Step, state proto.TaskState) ([]*proto.Subtask, error)
GetSchedulerIDsByTaskIDAndStep(ctx context.Context, taskID int64, step proto.Step) ([]string, error)

WithNewSession(fn func(se sessionctx.Context) error) error
Expand Down
Loading