Skip to content

Commit

Permalink
fix: channel unbalance during stopping balance progress (#38971)
Browse files Browse the repository at this point in the history
issue: #38970
cause the stopping balance channel still use the row_count_based policy,
which may causes channel unbalance in multi-collection case.

This PR impl a score based stopping balance channel policy.

Signed-off-by: Wei Liu <[email protected]>
  • Loading branch information
weiliu1031 authored Jan 13, 2025
1 parent 640a49f commit cc5d593
Show file tree
Hide file tree
Showing 4 changed files with 155 additions and 13 deletions.
11 changes: 7 additions & 4 deletions internal/querycoordv2/balance/balance.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,13 @@ func (segPlan *SegmentAssignPlan) String() string {
}

type ChannelAssignPlan struct {
Channel *meta.DmChannel
Replica *meta.Replica
From int64
To int64
Channel *meta.DmChannel
Replica *meta.Replica
From int64
To int64
FromScore int64
ToScore int64
ChannelScore int64
}

func (chanPlan *ChannelAssignPlan) String() string {
Expand Down
2 changes: 1 addition & 1 deletion internal/querycoordv2/balance/rowcount_based_balancer.go
Original file line number Diff line number Diff line change
Expand Up @@ -393,7 +393,7 @@ func newNodeItem(currentScore int, nodeID int64) nodeItem {

func (b *nodeItem) getPriority() int {
// if node lacks more score between assignedScore and currentScore, then higher priority
return int(b.currentScore - b.assignedScore)
return int(math.Ceil(b.currentScore - b.assignedScore))
}

func (b *nodeItem) setPriority(priority int) {
Expand Down
30 changes: 22 additions & 8 deletions internal/querycoordv2/balance/score_based_balancer.go
Original file line number Diff line number Diff line change
Expand Up @@ -191,19 +191,19 @@ func (b *ScoreBasedBalancer) assignChannel(br *balanceReport, collectionID int64
}

from := int64(-1)
// fromScore := int64(0)
fromScore := int64(0)
if sourceNode != nil {
from = sourceNode.nodeID
// fromScore = int64(sourceNode.getPriority())
fromScore = int64(sourceNode.getPriority())
}

plan := ChannelAssignPlan{
From: from,
To: targetNode.nodeID,
Channel: ch,
// FromScore: fromScore,
// ToScore: int64(targetNode.getPriority()),
// SegmentScore: int64(scoreChanges),
From: from,
To: targetNode.nodeID,
Channel: ch,
FromScore: fromScore,
ToScore: int64(targetNode.getPriority()),
ChannelScore: int64(scoreChanges),
}
br.AddRecord(StrRecordf("add segment plan %s", plan))
plans = append(plans, plan)
Expand Down Expand Up @@ -487,6 +487,20 @@ func (b *ScoreBasedBalancer) BalanceReplica(ctx context.Context, replica *meta.R
return segmentPlans, channelPlans
}

func (b *ScoreBasedBalancer) genStoppingChannelPlan(ctx context.Context, replica *meta.Replica, rwNodes []int64, roNodes []int64) []ChannelAssignPlan {
channelPlans := make([]ChannelAssignPlan, 0)
for _, nodeID := range roNodes {
dmChannels := b.dist.ChannelDistManager.GetByCollectionAndFilter(replica.GetCollectionID(), meta.WithNodeID2Channel(nodeID))
plans := b.AssignChannel(ctx, replica.GetCollectionID(), dmChannels, rwNodes, false)
for i := range plans {
plans[i].From = nodeID
plans[i].Replica = replica
}
channelPlans = append(channelPlans, plans...)
}
return channelPlans
}

func (b *ScoreBasedBalancer) genStoppingSegmentPlan(ctx context.Context, replica *meta.Replica, onlineNodes []int64, offlineNodes []int64) []SegmentAssignPlan {
segmentPlans := make([]SegmentAssignPlan, 0)
for _, nodeID := range offlineNodes {
Expand Down
125 changes: 125 additions & 0 deletions internal/querycoordv2/balance/score_based_balancer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (
"github.com/samber/lo"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/suite"
"go.uber.org/atomic"

etcdkv "github.com/milvus-io/milvus/internal/kv/etcd"
"github.com/milvus-io/milvus/internal/metastore/kv/querycoord"
Expand Down Expand Up @@ -1470,3 +1471,127 @@ func (suite *ScoreBasedBalancerTestSuite) TestBalanceChannelOnChannelExclusive()
_, channelPlans = suite.getCollectionBalancePlans(balancer, 3)
suite.Len(channelPlans, 2)
}

func (suite *ScoreBasedBalancerTestSuite) TestBalanceChannelOnStoppingNode() {
ctx := context.Background()
balancer := suite.balancer

// mock 10 collections with each collection has 1 channel
collectionNum := 10
channelNum := 1
for i := 1; i <= collectionNum; i++ {
collectionID := int64(i)
collection := utils.CreateTestCollection(collectionID, int32(1))
collection.LoadPercentage = 100
collection.Status = querypb.LoadStatus_Loaded
balancer.meta.CollectionManager.PutCollection(ctx, collection)
balancer.meta.CollectionManager.PutPartition(ctx, utils.CreateTestPartition(collectionID, collectionID))
balancer.meta.ReplicaManager.Spawn(ctx, collectionID, map[string]int{meta.DefaultResourceGroupName: 1}, nil)

channels := make([]*datapb.VchannelInfo, channelNum)
for i := 0; i < channelNum; i++ {
channels[i] = &datapb.VchannelInfo{CollectionID: collectionID, ChannelName: fmt.Sprintf("channel-%d-%d", collectionID, i)}
}
suite.broker.EXPECT().GetRecoveryInfoV2(mock.Anything, collectionID).Return(
channels, nil, nil)
suite.broker.EXPECT().GetPartitions(mock.Anything, collectionID).Return([]int64{collectionID}, nil).Maybe()
balancer.targetMgr.UpdateCollectionNextTarget(ctx, collectionID)
balancer.targetMgr.UpdateCollectionCurrentTarget(ctx, collectionID)
}

// mock querynode-1 to node manager
nodeInfo := session.NewNodeInfo(session.ImmutableNodeInfo{
NodeID: 1,
Address: "127.0.0.1:0",
Hostname: "localhost",
Version: common.Version,
})
nodeInfo.SetState(session.NodeStateNormal)
suite.balancer.nodeManager.Add(nodeInfo)
suite.balancer.meta.ResourceManager.HandleNodeUp(ctx, 1)
utils.RecoverAllCollection(balancer.meta)

// mock channel distribution
channelDist := make([]*meta.DmChannel, 0)
for i := 1; i <= collectionNum; i++ {
collectionID := int64(i)
for i := 0; i < channelNum; i++ {
channelDist = append(channelDist, &meta.DmChannel{
VchannelInfo: &datapb.VchannelInfo{CollectionID: collectionID, ChannelName: fmt.Sprintf("channel-%d-%d", collectionID, i)}, Node: 1,
})
}
}
balancer.dist.ChannelDistManager.Update(1, channelDist...)

// assert balance channel won't happens on 1 querynode
ret := make([]ChannelAssignPlan, 0)
for i := 1; i <= collectionNum; i++ {
collectionID := int64(i)
_, channelPlans := suite.getCollectionBalancePlans(balancer, collectionID)
ret = append(ret, channelPlans...)
}
suite.Len(ret, 0)

// mock querynode-2 and querynode-3 to node manager
nodeInfo2 := session.NewNodeInfo(session.ImmutableNodeInfo{
NodeID: 2,
Address: "127.0.0.1:0",
Hostname: "localhost",
Version: common.Version,
})
suite.balancer.nodeManager.Add(nodeInfo2)
suite.balancer.meta.ResourceManager.HandleNodeUp(ctx, 2)
// mock querynode-2 and querynode-3 to node manager
nodeInfo3 := session.NewNodeInfo(session.ImmutableNodeInfo{
NodeID: 3,
Address: "127.0.0.1:0",
Hostname: "localhost",
Version: common.Version,
})
suite.balancer.nodeManager.Add(nodeInfo3)
suite.balancer.meta.ResourceManager.HandleNodeUp(ctx, 3)
utils.RecoverAllCollection(balancer.meta)
// mock querynode-1 to stopping, trigger stopping balance, expect to generate 10 balance channel task, and 5 for node-2, 5 for node-3
nodeInfo.SetState(session.NodeStateStopping)
suite.balancer.meta.ResourceManager.HandleNodeDown(ctx, 1)
utils.RecoverAllCollection(balancer.meta)

node2Counter := atomic.NewInt32(0)
node3Counter := atomic.NewInt32(0)

suite.mockScheduler.ExpectedCalls = nil
suite.mockScheduler.EXPECT().GetSegmentTaskDelta(mock.Anything, mock.Anything).Return(0).Maybe()
suite.mockScheduler.EXPECT().GetChannelTaskDelta(mock.Anything, mock.Anything).RunAndReturn(func(nodeID, collection int64) int {
if collection == -1 {
if nodeID == 2 {
return int(node2Counter.Load())
}

if nodeID == 3 {
return int(node3Counter.Load())
}
}
return 0
})
suite.mockScheduler.EXPECT().GetSegmentTaskNum(mock.Anything, mock.Anything).Return(0).Maybe()
suite.mockScheduler.EXPECT().GetChannelTaskNum(mock.Anything, mock.Anything).Return(0).Maybe()

for i := 1; i <= collectionNum; i++ {
collectionID := int64(i)
_, channelPlans := suite.getCollectionBalancePlans(balancer, collectionID)
suite.Len(channelPlans, 1)
if channelPlans[0].To == 2 {
node2Counter.Inc()
}

if channelPlans[0].To == 3 {
node3Counter.Inc()
}

if i%2 == 0 {
suite.Equal(node2Counter.Load(), node3Counter.Load())
}
}
suite.Equal(node2Counter.Load(), int32(5))
suite.Equal(node3Counter.Load(), int32(5))
}

0 comments on commit cc5d593

Please sign in to comment.