Skip to content

Commit

Permalink
disttask: init capacity and check concurrency using cpu count of mana…
Browse files Browse the repository at this point in the history
…ged node (#49875)

ref #49008
  • Loading branch information
D3Hunter authored Jan 2, 2024
1 parent 7045202 commit 3d939d4
Show file tree
Hide file tree
Showing 22 changed files with 307 additions and 144 deletions.
2 changes: 2 additions & 0 deletions pkg/ddl/index.go
Original file line number Diff line number Diff line change
Expand Up @@ -2118,6 +2118,8 @@ func (w *worker) executeDistTask(reorgInfo *reorgInfo) error {

job := reorgInfo.Job
workerCntLimit := int(variable.GetDDLReorgWorkerCounter())
// we're using cpu count of current node, not of framework managed nodes,
// but it seems more intuitive.
concurrency := min(workerCntLimit, cpu.GetCPUCount())
logutil.BgLogger().Info("adjusted add-index task concurrency",
zap.Int("worker-cnt", workerCntLimit), zap.Int("task-concurrency", concurrency),
Expand Down
8 changes: 4 additions & 4 deletions pkg/disttask/framework/mock/scheduler_mock.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pkg/disttask/framework/proto/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test")
go_library(
name = "proto",
srcs = [
"node.go",
"subtask.go",
"task.go",
],
Expand Down
25 changes: 25 additions & 0 deletions pkg/disttask/framework/proto/node.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
// Copyright 2023 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package proto

// ManagedNode is a TiDB node that is managed by the framework.
type ManagedNode struct {
// ID see GenerateExecID, it's named as host in the meta table.
ID string
// Role of the node, either "" or "background"
// all managed node should have the same role
Role string
CPUCount int
}
3 changes: 2 additions & 1 deletion pkg/disttask/framework/scheduler/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ go_test(
embed = [":scheduler"],
flaky = True,
race = "off",
shard_count = 24,
shard_count = 25,
deps = [
"//pkg/config",
"//pkg/disttask/framework/mock",
Expand All @@ -63,6 +63,7 @@ go_test(
"//pkg/sessionctx",
"//pkg/testkit",
"//pkg/testkit/testsetup",
"//pkg/util/cpu",
"//pkg/util/disttask",
"//pkg/util/logutil",
"//pkg/util/sqlexec",
Expand Down
4 changes: 2 additions & 2 deletions pkg/disttask/framework/scheduler/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ type TaskManager interface {
GetTaskByID(ctx context.Context, taskID int64) (task *proto.Task, err error)
UpdateTaskAndAddSubTasks(ctx context.Context, task *proto.Task, subtasks []*proto.Subtask, prevState proto.TaskState) (bool, error)
GCSubtasks(ctx context.Context) error
GetAllNodes(ctx context.Context) ([]string, error)
GetAllNodes(ctx context.Context) ([]proto.ManagedNode, error)
DeleteDeadNodes(ctx context.Context, nodes []string) error
TransferTasks2History(ctx context.Context, tasks []*proto.Task) error
CancelTask(ctx context.Context, taskID int64) error
Expand Down Expand Up @@ -68,7 +68,7 @@ type TaskManager interface {
// to execute tasks. If there are any nodes with background role, we use them,
// else we use nodes without role.
// returned nodes are sorted by node id(host:port).
GetManagedNodes(ctx context.Context) ([]string, error)
GetManagedNodes(ctx context.Context) ([]proto.ManagedNode, error)
GetTaskExecutorIDsByTaskID(ctx context.Context, taskID int64) ([]string, error)
GetSubtasksByStepAndState(ctx context.Context, taskID int64, step proto.Step, state proto.TaskState) ([]*proto.Subtask, error)
GetSubtasksByExecIdsAndStepAndState(ctx context.Context, tidbIDs []string, taskID int64, step proto.Step, state proto.SubtaskState) ([]*proto.Subtask, error)
Expand Down
24 changes: 15 additions & 9 deletions pkg/disttask/framework/scheduler/nodes.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,9 @@ func (nm *NodeManager) maintainLiveNodes(ctx context.Context, taskMgr TaskManage
}

deadNodes := make([]string, 0)
for _, nodeID := range oldNodes {
if _, ok := currLiveNodes[nodeID]; !ok {
deadNodes = append(deadNodes, nodeID)
for _, node := range oldNodes {
if _, ok := currLiveNodes[node.ID]; !ok {
deadNodes = append(deadNodes, node.ID)
}
}
if len(deadNodes) == 0 {
Expand All @@ -110,30 +110,36 @@ func (nm *NodeManager) maintainLiveNodes(ctx context.Context, taskMgr TaskManage
nm.prevLiveNodes = currLiveNodes
}

func (nm *NodeManager) refreshManagedNodesLoop(ctx context.Context, taskMgr TaskManager) {
func (nm *NodeManager) refreshManagedNodesLoop(ctx context.Context, taskMgr TaskManager, slotMgr *slotManager) {
ticker := time.NewTicker(nodesCheckInterval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
nm.refreshManagedNodes(ctx, taskMgr)
nm.refreshManagedNodes(ctx, taskMgr, slotMgr)
}
}
}

// refreshManagedNodes maintains the nodes managed by the framework.
func (nm *NodeManager) refreshManagedNodes(ctx context.Context, taskMgr TaskManager) {
func (nm *NodeManager) refreshManagedNodes(ctx context.Context, taskMgr TaskManager, slotMgr *slotManager) {
newNodes, err := taskMgr.GetManagedNodes(ctx)
if err != nil {
logutil.BgLogger().Warn("get managed nodes met error", log.ShortError(err))
return
}
if newNodes == nil {
newNodes = []string{}
nodeIDs := make([]string, 0, len(newNodes))
var cpuCount int
for _, node := range newNodes {
nodeIDs = append(nodeIDs, node.ID)
if node.CPUCount > 0 {
cpuCount = node.CPUCount
}
}
nm.managedNodes.Store(&newNodes)
slotMgr.updateCapacity(cpuCount)
nm.managedNodes.Store(&nodeIDs)
}

// GetManagedNodes returns the nodes managed by the framework.
Expand Down
25 changes: 17 additions & 8 deletions pkg/disttask/framework/scheduler/nodes_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@ import (
"github.com/pingcap/errors"
"github.com/pingcap/failpoint"
"github.com/pingcap/tidb/pkg/disttask/framework/mock"
"github.com/pingcap/tidb/pkg/disttask/framework/proto"
"github.com/pingcap/tidb/pkg/domain/infosync"
"github.com/pingcap/tidb/pkg/util/cpu"
"github.com/stretchr/testify/require"
"go.uber.org/mock/gomock"
)
Expand All @@ -47,7 +49,7 @@ func TestMaintainLiveNodes(t *testing.T) {
require.Empty(t, nodeMgr.prevLiveNodes)
require.True(t, ctrl.Satisfied())
// no change
mockTaskMgr.EXPECT().GetAllNodes(gomock.Any()).Return([]string{":4000"}, nil)
mockTaskMgr.EXPECT().GetAllNodes(gomock.Any()).Return([]proto.ManagedNode{{ID: ":4000"}}, nil)
nodeMgr.maintainLiveNodes(ctx, mockTaskMgr)
require.Equal(t, map[string]struct{}{":4000": {}}, nodeMgr.prevLiveNodes)
require.True(t, ctrl.Satisfied())
Expand All @@ -63,13 +65,13 @@ func TestMaintainLiveNodes(t *testing.T) {
}

// fail on clean
mockTaskMgr.EXPECT().GetAllNodes(gomock.Any()).Return([]string{":4000", ":4001", ":4002"}, nil)
mockTaskMgr.EXPECT().GetAllNodes(gomock.Any()).Return([]proto.ManagedNode{{ID: ":4000"}, {ID: ":4001"}, {ID: ":4002"}}, nil)
mockTaskMgr.EXPECT().DeleteDeadNodes(gomock.Any(), gomock.Any()).Return(errors.New("mock error"))
nodeMgr.maintainLiveNodes(ctx, mockTaskMgr)
require.Equal(t, map[string]struct{}{":4000": {}}, nodeMgr.prevLiveNodes)
require.True(t, ctrl.Satisfied())
// remove 1 node
mockTaskMgr.EXPECT().GetAllNodes(gomock.Any()).Return([]string{":4000", ":4001", ":4002"}, nil)
mockTaskMgr.EXPECT().GetAllNodes(gomock.Any()).Return([]proto.ManagedNode{{ID: ":4000"}, {ID: ":4001"}, {ID: ":4002"}}, nil)
mockTaskMgr.EXPECT().DeleteDeadNodes(gomock.Any(), gomock.Any()).Return(nil)
nodeMgr.maintainLiveNodes(ctx, mockTaskMgr)
require.Equal(t, map[string]struct{}{":4000": {}, ":4001": {}}, nodeMgr.prevLiveNodes)
Expand All @@ -84,7 +86,7 @@ func TestMaintainLiveNodes(t *testing.T) {
{Port: 4000},
}

mockTaskMgr.EXPECT().GetAllNodes(gomock.Any()).Return([]string{":4000", ":4001", ":4002"}, nil)
mockTaskMgr.EXPECT().GetAllNodes(gomock.Any()).Return([]proto.ManagedNode{{ID: ":4000"}, {ID: ":4001"}, {ID: ":4002"}}, nil)
mockTaskMgr.EXPECT().DeleteDeadNodes(gomock.Any(), gomock.Any()).Return(nil)
nodeMgr.maintainLiveNodes(ctx, mockTaskMgr)
require.Equal(t, map[string]struct{}{":4000": {}}, nodeMgr.prevLiveNodes)
Expand All @@ -102,18 +104,25 @@ func TestMaintainManagedNodes(t *testing.T) {
mockTaskMgr := mock.NewMockTaskManager(ctrl)
nodeMgr := newNodeManager()

slotMgr := newSlotManager()
mockTaskMgr.EXPECT().GetManagedNodes(gomock.Any()).Return(nil, errors.New("mock error"))
nodeMgr.refreshManagedNodes(ctx, mockTaskMgr)
nodeMgr.refreshManagedNodes(ctx, mockTaskMgr, slotMgr)
require.Equal(t, cpu.GetCPUCount(), int(slotMgr.capacity.Load()))
require.Empty(t, nodeMgr.getManagedNodes())
require.True(t, ctrl.Satisfied())

mockTaskMgr.EXPECT().GetManagedNodes(gomock.Any()).Return([]string{":4000", ":4001"}, nil)
nodeMgr.refreshManagedNodes(ctx, mockTaskMgr)
mockTaskMgr.EXPECT().GetManagedNodes(gomock.Any()).Return([]proto.ManagedNode{
{ID: ":4000", CPUCount: 100},
{ID: ":4001", CPUCount: 100},
}, nil)
nodeMgr.refreshManagedNodes(ctx, mockTaskMgr, slotMgr)
require.Equal(t, []string{":4000", ":4001"}, nodeMgr.getManagedNodes())
require.Equal(t, 100, int(slotMgr.capacity.Load()))
require.True(t, ctrl.Satisfied())
mockTaskMgr.EXPECT().GetManagedNodes(gomock.Any()).Return(nil, nil)
nodeMgr.refreshManagedNodes(ctx, mockTaskMgr)
nodeMgr.refreshManagedNodes(ctx, mockTaskMgr, slotMgr)
require.NotNil(t, nodeMgr.getManagedNodes())
require.Empty(t, nodeMgr.getManagedNodes())
require.Equal(t, 100, int(slotMgr.capacity.Load()))
require.True(t, ctrl.Satisfied())
}
4 changes: 2 additions & 2 deletions pkg/disttask/framework/scheduler/scheduler_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ func (sm *Manager) Start() {
failpoint.Return()
})
// init cached managed nodes
sm.nodeMgr.refreshManagedNodes(sm.ctx, sm.taskMgr)
sm.nodeMgr.refreshManagedNodes(sm.ctx, sm.taskMgr, sm.slotMgr)

sm.wg.Run(sm.scheduleTaskLoop)
sm.wg.Run(sm.gcSubtaskHistoryTableLoop)
Expand All @@ -132,7 +132,7 @@ func (sm *Manager) Start() {
sm.nodeMgr.maintainLiveNodesLoop(sm.ctx, sm.taskMgr)
})
sm.wg.Run(func() {
sm.nodeMgr.refreshManagedNodesLoop(sm.ctx, sm.taskMgr)
sm.nodeMgr.refreshManagedNodesLoop(sm.ctx, sm.taskMgr, sm.slotMgr)
})
sm.initialized = true
}
Expand Down
36 changes: 27 additions & 9 deletions pkg/disttask/framework/scheduler/slots.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,12 @@ import (
"context"
"slices"
"sync"
"sync/atomic"

"github.com/pingcap/tidb/pkg/disttask/framework/proto"
"github.com/pingcap/tidb/pkg/util/cpu"
"github.com/pingcap/tidb/pkg/util/logutil"
"go.uber.org/zap"
)

type taskStripes struct {
Expand All @@ -47,10 +50,7 @@ type taskStripes struct {
// quota to subtask, but subtask can determine what to conform.
type slotManager struct {
// Capacity is the total number of slots and stripes.
// TODO: we assume that all nodes managed by dist framework are isomorphic,
// but dist owner might run on normal node where the capacity might not be
// able to run any task.
capacity int
capacity atomic.Int32

mu sync.RWMutex
// represents the number of stripes reserved by task, when we reserve by the
Expand All @@ -75,12 +75,16 @@ type slotManager struct {

// newSlotManager creates a new slotManager.
func newSlotManager() *slotManager {
return &slotManager{
capacity: cpu.GetCPUCount(),
s := &slotManager{
task2Index: make(map[int64]int),
reservedSlots: make(map[string]int),
usedSlots: make(map[string]int),
}
// this node might not be the managed node of the framework, but we initialize
// capacity with the cpu count of this node, it will be updated when node
// manager starts.
s.updateCapacity(cpu.GetCPUCount())
return s
}

// Update updates the used slots on each node.
Expand All @@ -96,7 +100,7 @@ func (sm *slotManager) update(ctx context.Context, taskMgr TaskManager) error {
}
newUsedSlots := make(map[string]int, len(nodes))
for _, node := range nodes {
newUsedSlots[node] = slotsOnNodes[node]
newUsedSlots[node.ID] = slotsOnNodes[node.ID]
}
sm.mu.Lock()
defer sm.mu.Unlock()
Expand All @@ -111,6 +115,7 @@ func (sm *slotManager) update(ctx context.Context, taskMgr TaskManager) error {
// are enough resources, or return true on resource shortage when some task
// scheduled subtasks.
func (sm *slotManager) canReserve(task *proto.Task) (execID string, ok bool) {
capacity := int(sm.capacity.Load())
sm.mu.RLock()
defer sm.mu.RUnlock()
if len(sm.usedSlots) == 0 {
Expand All @@ -125,12 +130,12 @@ func (sm *slotManager) canReserve(task *proto.Task) (execID string, ok bool) {
}
reservedForHigherPriority += s.stripes
}
if task.Concurrency+reservedForHigherPriority <= sm.capacity {
if task.Concurrency+reservedForHigherPriority <= capacity {
return "", true
}

for id, count := range sm.usedSlots {
if count+sm.reservedSlots[id]+task.Concurrency <= sm.capacity {
if count+sm.reservedSlots[id]+task.Concurrency <= capacity {
return id, true
}
}
Expand Down Expand Up @@ -178,3 +183,16 @@ func (sm *slotManager) unReserve(task *proto.Task, execID string) {
}
}
}

func (sm *slotManager) updateCapacity(cpuCount int) {
old := sm.capacity.Load()
if cpuCount > 0 && cpuCount != int(old) {
sm.capacity.Store(int32(cpuCount))
if old == 0 {
logutil.BgLogger().Info("initialize slot capacity", zap.Int("capacity", cpuCount))
} else {
logutil.BgLogger().Info("update slot capacity",
zap.Int("old", int(old)), zap.Int("new", cpuCount))
}
}
}
Loading

0 comments on commit 3d939d4

Please sign in to comment.