diff --git a/configs/milvus.yaml b/configs/milvus.yaml index de101742d6d5a..ee26df210e22b 100644 --- a/configs/milvus.yaml +++ b/configs/milvus.yaml @@ -347,7 +347,7 @@ queryCoord: balanceCostThreshold: 0.001 # the threshold of balance cost, if the difference of cluster's cost after executing the balance plan is less than this value, the plan will not be executed checkSegmentInterval: 1000 checkChannelInterval: 1000 - checkBalanceInterval: 10000 + checkBalanceInterval: 3000 checkIndexInterval: 10000 channelTaskTimeout: 60000 # 1 minute segmentTaskTimeout: 120000 # 2 minute diff --git a/internal/querycoordv2/balance/balance.go b/internal/querycoordv2/balance/balance.go index ca03c74b18134..0ccdc90ddafea 100644 --- a/internal/querycoordv2/balance/balance.go +++ b/internal/querycoordv2/balance/balance.go @@ -141,6 +141,14 @@ func (b *RoundRobinBalancer) BalanceReplica(replica *meta.Replica) ([]SegmentAss return nil, nil } +func (b *RoundRobinBalancer) permitBalanceChannel(collectionID int64) bool { + return b.scheduler.GetSegmentTaskNum(task.WithCollectionID2TaskFilter(collectionID), task.WithTaskTypeFilter(task.TaskTypeMove)) == 0 +} + +func (b *RoundRobinBalancer) permitBalanceSegment(collectionID int64) bool { + return b.scheduler.GetChannelTaskNum(task.WithCollectionID2TaskFilter(collectionID), task.WithTaskTypeFilter(task.TaskTypeMove)) == 0 +} + func (b *RoundRobinBalancer) getNodes(nodes []int64) []*session.NodeInfo { ret := make([]*session.NodeInfo, 0, len(nodes)) for _, n := range nodes { diff --git a/internal/querycoordv2/balance/channel_level_score_balancer.go b/internal/querycoordv2/balance/channel_level_score_balancer.go index c03e3e3221abf..bee771d58f6ff 100644 --- a/internal/querycoordv2/balance/channel_level_score_balancer.go +++ b/internal/querycoordv2/balance/channel_level_score_balancer.go @@ -121,16 +121,19 @@ func (b *ChannelLevelScoreBalancer) BalanceReplica(replica *meta.Replica) (segme zap.Any("available nodes", rwNodes), ) // handle stopped nodes here, have to assign segments on stopping nodes to nodes with the smallest score - channelPlans = append(channelPlans, b.genStoppingChannelPlan(replica, channelName, rwNodes, roNodes)...) - if len(channelPlans) == 0 { + if b.permitBalanceChannel(replica.GetCollectionID()) { + channelPlans = append(channelPlans, b.genStoppingChannelPlan(replica, channelName, rwNodes, roNodes)...) + } + + if len(channelPlans) == 0 && b.permitBalanceSegment(replica.GetCollectionID()) { segmentPlans = append(segmentPlans, b.genStoppingSegmentPlan(replica, channelName, rwNodes, roNodes)...) } } else { - if paramtable.Get().QueryCoordCfg.AutoBalanceChannel.GetAsBool() { + if paramtable.Get().QueryCoordCfg.AutoBalanceChannel.GetAsBool() && b.permitBalanceChannel(replica.GetCollectionID()) { channelPlans = append(channelPlans, b.genChannelPlan(replica, channelName, rwNodes)...) } - if len(channelPlans) == 0 { + if len(channelPlans) == 0 && b.permitBalanceSegment(replica.GetCollectionID()) { segmentPlans = append(segmentPlans, b.genSegmentPlan(br, replica, channelName, rwNodes)...) } } diff --git a/internal/querycoordv2/balance/channel_level_score_balancer_test.go b/internal/querycoordv2/balance/channel_level_score_balancer_test.go index 0e5c94ce63788..9d0af873002cd 100644 --- a/internal/querycoordv2/balance/channel_level_score_balancer_test.go +++ b/internal/querycoordv2/balance/channel_level_score_balancer_test.go @@ -76,6 +76,8 @@ func (suite *ChannelLevelScoreBalancerTestSuite) SetupTest() { suite.mockScheduler.EXPECT().GetSegmentTaskDelta(mock.Anything, mock.Anything).Return(0).Maybe() suite.mockScheduler.EXPECT().GetChannelTaskDelta(mock.Anything, mock.Anything).Return(0).Maybe() + suite.mockScheduler.EXPECT().GetSegmentTaskNum(mock.Anything, mock.Anything).Return(0).Maybe() + suite.mockScheduler.EXPECT().GetChannelTaskNum(mock.Anything, mock.Anything).Return(0).Maybe() } func (suite *ChannelLevelScoreBalancerTestSuite) TearDownTest() { diff --git a/internal/querycoordv2/balance/multi_target_balance.go b/internal/querycoordv2/balance/multi_target_balance.go index 854a6373867b9..9f1de2b026cfc 100644 --- a/internal/querycoordv2/balance/multi_target_balance.go +++ b/internal/querycoordv2/balance/multi_target_balance.go @@ -497,16 +497,18 @@ func (b *MultiTargetBalancer) BalanceReplica(replica *meta.Replica) (segmentPlan zap.Any("available nodes", rwNodes), ) // handle stopped nodes here, have to assign segments on stopping nodes to nodes with the smallest score - channelPlans = append(channelPlans, b.genStoppingChannelPlan(replica, rwNodes, roNodes)...) - if len(channelPlans) == 0 { + if b.permitBalanceChannel(replica.GetCollectionID()) { + channelPlans = append(channelPlans, b.genStoppingChannelPlan(replica, rwNodes, roNodes)...) + } + if len(channelPlans) == 0 && b.permitBalanceSegment(replica.GetCollectionID()) { segmentPlans = append(segmentPlans, b.genStoppingSegmentPlan(replica, rwNodes, roNodes)...) } } else { - if paramtable.Get().QueryCoordCfg.AutoBalanceChannel.GetAsBool() { + if paramtable.Get().QueryCoordCfg.AutoBalanceChannel.GetAsBool() && b.permitBalanceChannel(replica.GetCollectionID()) { channelPlans = append(channelPlans, b.genChannelPlan(br, replica, rwNodes)...) } - if len(channelPlans) == 0 { + if len(channelPlans) == 0 && b.permitBalanceSegment(replica.GetCollectionID()) { segmentPlans = b.genSegmentPlan(replica, rwNodes) } } diff --git a/internal/querycoordv2/balance/rowcount_based_balancer.go b/internal/querycoordv2/balance/rowcount_based_balancer.go index 697a35b4d58af..53d09d89f3fa4 100644 --- a/internal/querycoordv2/balance/rowcount_based_balancer.go +++ b/internal/querycoordv2/balance/rowcount_based_balancer.go @@ -205,16 +205,19 @@ func (b *RowCountBasedBalancer) BalanceReplica(replica *meta.Replica) (segmentPl zap.Any("available nodes", rwNodes), ) // handle stopped nodes here, have to assign segments on stopping nodes to nodes with the smallest score - channelPlans = append(channelPlans, b.genStoppingChannelPlan(replica, rwNodes, roNodes)...) - if len(channelPlans) == 0 { + if b.permitBalanceChannel(replica.GetCollectionID()) { + channelPlans = append(channelPlans, b.genStoppingChannelPlan(replica, rwNodes, roNodes)...) + } + + if len(channelPlans) == 0 && b.permitBalanceSegment(replica.GetCollectionID()) { segmentPlans = append(segmentPlans, b.genStoppingSegmentPlan(replica, rwNodes, roNodes)...) } } else { - if paramtable.Get().QueryCoordCfg.AutoBalanceChannel.GetAsBool() { + if paramtable.Get().QueryCoordCfg.AutoBalanceChannel.GetAsBool() && b.permitBalanceChannel(replica.GetCollectionID()) { channelPlans = append(channelPlans, b.genChannelPlan(br, replica, rwNodes)...) } - if len(channelPlans) == 0 { + if len(channelPlans) == 0 && b.permitBalanceSegment(replica.GetCollectionID()) { segmentPlans = append(segmentPlans, b.genSegmentPlan(replica, rwNodes)...) } } diff --git a/internal/querycoordv2/balance/rowcount_based_balancer_test.go b/internal/querycoordv2/balance/rowcount_based_balancer_test.go index 53096f8a2925a..a435dead519c5 100644 --- a/internal/querycoordv2/balance/rowcount_based_balancer_test.go +++ b/internal/querycoordv2/balance/rowcount_based_balancer_test.go @@ -81,6 +81,8 @@ func (suite *RowCountBasedBalancerTestSuite) SetupTest() { suite.mockScheduler.EXPECT().GetSegmentTaskDelta(mock.Anything, mock.Anything).Return(0).Maybe() suite.mockScheduler.EXPECT().GetChannelTaskDelta(mock.Anything, mock.Anything).Return(0).Maybe() + suite.mockScheduler.EXPECT().GetSegmentTaskNum(mock.Anything, mock.Anything).Return(0).Maybe() + suite.mockScheduler.EXPECT().GetChannelTaskNum(mock.Anything, mock.Anything).Return(0).Maybe() } func (suite *RowCountBasedBalancerTestSuite) TearDownTest() { diff --git a/internal/querycoordv2/balance/score_based_balancer.go b/internal/querycoordv2/balance/score_based_balancer.go index 57548ccd08dcc..0e3aad1f78efd 100644 --- a/internal/querycoordv2/balance/score_based_balancer.go +++ b/internal/querycoordv2/balance/score_based_balancer.go @@ -308,16 +308,18 @@ func (b *ScoreBasedBalancer) BalanceReplica(replica *meta.Replica) (segmentPlans ) br.AddRecord(StrRecordf("executing stopping balance: %v", roNodes)) // handle stopped nodes here, have to assign segments on stopping nodes to nodes with the smallest score - channelPlans = append(channelPlans, b.genStoppingChannelPlan(replica, rwNodes, roNodes)...) - if len(channelPlans) == 0 { + if b.permitBalanceChannel(replica.GetCollectionID()) { + channelPlans = append(channelPlans, b.genStoppingChannelPlan(replica, rwNodes, roNodes)...) + } + if len(channelPlans) == 0 && b.permitBalanceSegment(replica.GetCollectionID()) { segmentPlans = append(segmentPlans, b.genStoppingSegmentPlan(replica, rwNodes, roNodes)...) } } else { - if paramtable.Get().QueryCoordCfg.AutoBalanceChannel.GetAsBool() { + if paramtable.Get().QueryCoordCfg.AutoBalanceChannel.GetAsBool() && b.permitBalanceChannel(replica.GetCollectionID()) { channelPlans = append(channelPlans, b.genChannelPlan(br, replica, rwNodes)...) } - if len(channelPlans) == 0 { + if len(channelPlans) == 0 && b.permitBalanceSegment(replica.GetCollectionID()) { segmentPlans = append(segmentPlans, b.genSegmentPlan(br, replica, rwNodes)...) } } diff --git a/internal/querycoordv2/balance/score_based_balancer_test.go b/internal/querycoordv2/balance/score_based_balancer_test.go index 2401ae482de60..99a6b09e142eb 100644 --- a/internal/querycoordv2/balance/score_based_balancer_test.go +++ b/internal/querycoordv2/balance/score_based_balancer_test.go @@ -76,6 +76,8 @@ func (suite *ScoreBasedBalancerTestSuite) SetupTest() { suite.mockScheduler.EXPECT().GetSegmentTaskDelta(mock.Anything, mock.Anything).Return(0).Maybe() suite.mockScheduler.EXPECT().GetChannelTaskDelta(mock.Anything, mock.Anything).Return(0).Maybe() + suite.mockScheduler.EXPECT().GetSegmentTaskNum(mock.Anything, mock.Anything).Return(0).Maybe() + suite.mockScheduler.EXPECT().GetChannelTaskNum(mock.Anything, mock.Anything).Return(0).Maybe() } func (suite *ScoreBasedBalancerTestSuite) TearDownTest() { @@ -605,6 +607,8 @@ func (suite *ScoreBasedBalancerTestSuite) TestBalanceWithExecutingTask() { for i, node := range c.nodes { suite.mockScheduler.EXPECT().GetSegmentTaskDelta(node, int64(1)).Return(c.deltaCounts[i]).Maybe() suite.mockScheduler.EXPECT().GetSegmentTaskDelta(node, int64(-1)).Return(c.deltaCounts[i]).Maybe() + suite.mockScheduler.EXPECT().GetSegmentTaskNum(mock.Anything, mock.Anything).Return(0).Maybe() + suite.mockScheduler.EXPECT().GetChannelTaskNum(mock.Anything, mock.Anything).Return(0).Maybe() } // 4. balance and verify result @@ -1123,3 +1127,90 @@ func (suite *ScoreBasedBalancerTestSuite) getCollectionBalancePlans(balancer *Sc } return segmentPlans, channelPlans } + +func (suite *ScoreBasedBalancerTestSuite) TestBalanceSegmentAndChannel() { + nodes := []int64{1, 2, 3} + collectionID := int64(1) + replicaID := int64(1) + collectionsSegments := []*datapb.SegmentInfo{ + {ID: 1, PartitionID: 1}, {ID: 2, PartitionID: 1}, {ID: 3, PartitionID: 1}, + } + states := []session.State{session.NodeStateNormal, session.NodeStateNormal, session.NodeStateNormal} + + balancer := suite.balancer + + collection := utils.CreateTestCollection(collectionID, int32(replicaID)) + suite.broker.EXPECT().GetRecoveryInfoV2(mock.Anything, collectionID).Return( + nil, collectionsSegments, nil) + suite.broker.EXPECT().GetPartitions(mock.Anything, collectionID).Return([]int64{collectionID}, nil).Maybe() + collection.LoadPercentage = 100 + collection.Status = querypb.LoadStatus_Loaded + balancer.meta.CollectionManager.PutCollection(collection) + balancer.meta.CollectionManager.PutPartition(utils.CreateTestPartition(collectionID, collectionID)) + balancer.meta.ReplicaManager.Put(utils.CreateTestReplica(replicaID, collectionID, nodes)) + balancer.targetMgr.UpdateCollectionNextTarget(collectionID) + balancer.targetMgr.UpdateCollectionCurrentTarget(collectionID) + + for i := range nodes { + nodeInfo := session.NewNodeInfo(session.ImmutableNodeInfo{ + NodeID: nodes[i], + Address: "127.0.0.1:0", + Hostname: "localhost", + Version: common.Version, + }) + nodeInfo.SetState(states[i]) + suite.balancer.nodeManager.Add(nodeInfo) + suite.balancer.meta.ResourceManager.HandleNodeUp(nodes[i]) + } + utils.RecoverAllCollection(balancer.meta) + + // set unbalance segment distribution + balancer.dist.SegmentDistManager.Update(1, []*meta.Segment{ + {SegmentInfo: &datapb.SegmentInfo{ID: 1, CollectionID: 1, NumOfRows: 10}, Node: 1}, + {SegmentInfo: &datapb.SegmentInfo{ID: 2, CollectionID: 1, NumOfRows: 10}, Node: 1}, + {SegmentInfo: &datapb.SegmentInfo{ID: 3, CollectionID: 1, NumOfRows: 10}, Node: 1}, + }...) + + // expect to generate 2 balance segment task + suite.mockScheduler.ExpectedCalls = nil + suite.mockScheduler.EXPECT().GetChannelTaskNum(mock.Anything, mock.Anything).Return(0).Maybe() + suite.mockScheduler.EXPECT().GetSegmentTaskNum(mock.Anything, mock.Anything).Return(0).Maybe() + suite.mockScheduler.EXPECT().GetSegmentTaskDelta(mock.Anything, mock.Anything).Return(0).Maybe() + suite.mockScheduler.EXPECT().GetChannelTaskDelta(mock.Anything, mock.Anything).Return(0).Maybe() + segmentPlans, _ := suite.getCollectionBalancePlans(balancer, collectionID) + suite.Equal(len(segmentPlans), 2) + + // mock balance channel is executing, expect to generate 0 balance segment task + suite.mockScheduler.ExpectedCalls = nil + suite.mockScheduler.EXPECT().GetChannelTaskNum(mock.Anything, mock.Anything).Return(1).Maybe() + suite.mockScheduler.EXPECT().GetSegmentTaskNum(mock.Anything, mock.Anything).Return(0).Maybe() + suite.mockScheduler.EXPECT().GetSegmentTaskDelta(mock.Anything, mock.Anything).Return(0).Maybe() + suite.mockScheduler.EXPECT().GetChannelTaskDelta(mock.Anything, mock.Anything).Return(0).Maybe() + segmentPlans, _ = suite.getCollectionBalancePlans(balancer, collectionID) + suite.Equal(len(segmentPlans), 0) + + // set unbalance channel distribution + balancer.dist.ChannelDistManager.Update(1, []*meta.DmChannel{ + {VchannelInfo: &datapb.VchannelInfo{CollectionID: 1, ChannelName: "channel1"}, Node: 1}, + {VchannelInfo: &datapb.VchannelInfo{CollectionID: 1, ChannelName: "channel2"}, Node: 1}, + {VchannelInfo: &datapb.VchannelInfo{CollectionID: 1, ChannelName: "channel3"}, Node: 1}, + }...) + + // expect to generate 2 balance segment task + suite.mockScheduler.ExpectedCalls = nil + suite.mockScheduler.EXPECT().GetChannelTaskNum(mock.Anything, mock.Anything).Return(0).Maybe() + suite.mockScheduler.EXPECT().GetSegmentTaskNum(mock.Anything, mock.Anything).Return(0).Maybe() + suite.mockScheduler.EXPECT().GetSegmentTaskDelta(mock.Anything, mock.Anything).Return(0).Maybe() + suite.mockScheduler.EXPECT().GetChannelTaskDelta(mock.Anything, mock.Anything).Return(0).Maybe() + _, channelPlans := suite.getCollectionBalancePlans(balancer, collectionID) + suite.Equal(len(channelPlans), 2) + + // mock balance channel is executing, expect to generate 0 balance segment task + suite.mockScheduler.ExpectedCalls = nil + suite.mockScheduler.EXPECT().GetChannelTaskNum(mock.Anything, mock.Anything).Return(0).Maybe() + suite.mockScheduler.EXPECT().GetSegmentTaskNum(mock.Anything, mock.Anything).Return(1).Maybe() + suite.mockScheduler.EXPECT().GetSegmentTaskDelta(mock.Anything, mock.Anything).Return(0).Maybe() + suite.mockScheduler.EXPECT().GetChannelTaskDelta(mock.Anything, mock.Anything).Return(0).Maybe() + _, channelPlans = suite.getCollectionBalancePlans(balancer, collectionID) + suite.Equal(len(channelPlans), 0) +} diff --git a/internal/querycoordv2/task/mock_scheduler.go b/internal/querycoordv2/task/mock_scheduler.go index f3eb7bd69eb5f..1e0f1cf746d11 100644 --- a/internal/querycoordv2/task/mock_scheduler.go +++ b/internal/querycoordv2/task/mock_scheduler.go @@ -168,13 +168,19 @@ func (_c *MockScheduler_GetChannelTaskDelta_Call) RunAndReturn(run func(int64, i return _c } -// GetChannelTaskNum provides a mock function with given fields: -func (_m *MockScheduler) GetChannelTaskNum() int { - ret := _m.Called() +// GetChannelTaskNum provides a mock function with given fields: filters +func (_m *MockScheduler) GetChannelTaskNum(filters ...TaskFilter) int { + _va := make([]interface{}, len(filters)) + for _i := range filters { + _va[_i] = filters[_i] + } + var _ca []interface{} + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) var r0 int - if rf, ok := ret.Get(0).(func() int); ok { - r0 = rf() + if rf, ok := ret.Get(0).(func(...TaskFilter) int); ok { + r0 = rf(filters...) } else { r0 = ret.Get(0).(int) } @@ -188,13 +194,21 @@ type MockScheduler_GetChannelTaskNum_Call struct { } // GetChannelTaskNum is a helper method to define mock.On call -func (_e *MockScheduler_Expecter) GetChannelTaskNum() *MockScheduler_GetChannelTaskNum_Call { - return &MockScheduler_GetChannelTaskNum_Call{Call: _e.mock.On("GetChannelTaskNum")} +// - filters ...TaskFilter +func (_e *MockScheduler_Expecter) GetChannelTaskNum(filters ...interface{}) *MockScheduler_GetChannelTaskNum_Call { + return &MockScheduler_GetChannelTaskNum_Call{Call: _e.mock.On("GetChannelTaskNum", + append([]interface{}{}, filters...)...)} } -func (_c *MockScheduler_GetChannelTaskNum_Call) Run(run func()) *MockScheduler_GetChannelTaskNum_Call { +func (_c *MockScheduler_GetChannelTaskNum_Call) Run(run func(filters ...TaskFilter)) *MockScheduler_GetChannelTaskNum_Call { _c.Call.Run(func(args mock.Arguments) { - run() + variadicArgs := make([]TaskFilter, len(args)-0) + for i, a := range args[0:] { + if a != nil { + variadicArgs[i] = a.(TaskFilter) + } + } + run(variadicArgs...) }) return _c } @@ -204,7 +218,7 @@ func (_c *MockScheduler_GetChannelTaskNum_Call) Return(_a0 int) *MockScheduler_G return _c } -func (_c *MockScheduler_GetChannelTaskNum_Call) RunAndReturn(run func() int) *MockScheduler_GetChannelTaskNum_Call { +func (_c *MockScheduler_GetChannelTaskNum_Call) RunAndReturn(run func(...TaskFilter) int) *MockScheduler_GetChannelTaskNum_Call { _c.Call.Return(run) return _c } @@ -296,13 +310,19 @@ func (_c *MockScheduler_GetSegmentTaskDelta_Call) RunAndReturn(run func(int64, i return _c } -// GetSegmentTaskNum provides a mock function with given fields: -func (_m *MockScheduler) GetSegmentTaskNum() int { - ret := _m.Called() +// GetSegmentTaskNum provides a mock function with given fields: filters +func (_m *MockScheduler) GetSegmentTaskNum(filters ...TaskFilter) int { + _va := make([]interface{}, len(filters)) + for _i := range filters { + _va[_i] = filters[_i] + } + var _ca []interface{} + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) var r0 int - if rf, ok := ret.Get(0).(func() int); ok { - r0 = rf() + if rf, ok := ret.Get(0).(func(...TaskFilter) int); ok { + r0 = rf(filters...) } else { r0 = ret.Get(0).(int) } @@ -316,13 +336,21 @@ type MockScheduler_GetSegmentTaskNum_Call struct { } // GetSegmentTaskNum is a helper method to define mock.On call -func (_e *MockScheduler_Expecter) GetSegmentTaskNum() *MockScheduler_GetSegmentTaskNum_Call { - return &MockScheduler_GetSegmentTaskNum_Call{Call: _e.mock.On("GetSegmentTaskNum")} +// - filters ...TaskFilter +func (_e *MockScheduler_Expecter) GetSegmentTaskNum(filters ...interface{}) *MockScheduler_GetSegmentTaskNum_Call { + return &MockScheduler_GetSegmentTaskNum_Call{Call: _e.mock.On("GetSegmentTaskNum", + append([]interface{}{}, filters...)...)} } -func (_c *MockScheduler_GetSegmentTaskNum_Call) Run(run func()) *MockScheduler_GetSegmentTaskNum_Call { +func (_c *MockScheduler_GetSegmentTaskNum_Call) Run(run func(filters ...TaskFilter)) *MockScheduler_GetSegmentTaskNum_Call { _c.Call.Run(func(args mock.Arguments) { - run() + variadicArgs := make([]TaskFilter, len(args)-0) + for i, a := range args[0:] { + if a != nil { + variadicArgs[i] = a.(TaskFilter) + } + } + run(variadicArgs...) }) return _c } @@ -332,7 +360,7 @@ func (_c *MockScheduler_GetSegmentTaskNum_Call) Return(_a0 int) *MockScheduler_G return _c } -func (_c *MockScheduler_GetSegmentTaskNum_Call) RunAndReturn(run func() int) *MockScheduler_GetSegmentTaskNum_Call { +func (_c *MockScheduler_GetSegmentTaskNum_Call) RunAndReturn(run func(...TaskFilter) int) *MockScheduler_GetSegmentTaskNum_Call { _c.Call.Return(run) return _c } diff --git a/internal/querycoordv2/task/scheduler.go b/internal/querycoordv2/task/scheduler.go index aab1a30f3ceb9..4cc328959a733 100644 --- a/internal/querycoordv2/task/scheduler.go +++ b/internal/querycoordv2/task/scheduler.go @@ -142,8 +142,8 @@ type Scheduler interface { Dispatch(node int64) RemoveByNode(node int64) GetExecutedFlag(nodeID int64) <-chan struct{} - GetChannelTaskNum() int - GetSegmentTaskNum() int + GetChannelTaskNum(filters ...TaskFilter) int + GetSegmentTaskNum(filters ...TaskFilter) int GetSegmentTaskDelta(nodeID int64, collectionID int64) int GetChannelTaskDelta(nodeID int64, collectionID int64) int @@ -553,18 +553,68 @@ func (scheduler *taskScheduler) GetExecutedFlag(nodeID int64) <-chan struct{} { return executor.GetExecutedFlag() } -func (scheduler *taskScheduler) GetChannelTaskNum() int { +type TaskFilter func(task Task) bool + +func WithCollectionID2TaskFilter(collectionID int64) TaskFilter { + return func(task Task) bool { + return task.CollectionID() == collectionID + } +} + +func WithTaskTypeFilter(taskType Type) TaskFilter { + return func(task Task) bool { + return GetTaskType(task) == taskType + } +} + +func (scheduler *taskScheduler) GetChannelTaskNum(filters ...TaskFilter) int { scheduler.rwmutex.RLock() defer scheduler.rwmutex.RUnlock() - return len(scheduler.channelTasks) + if len(filters) == 0 { + return len(scheduler.channelTasks) + } + + // rewrite this with for loop + counter := 0 + for _, task := range scheduler.channelTasks { + allMatch := true + for _, filter := range filters { + if !filter(task) { + allMatch = false + break + } + } + if allMatch { + counter++ + } + } + return counter } -func (scheduler *taskScheduler) GetSegmentTaskNum() int { +func (scheduler *taskScheduler) GetSegmentTaskNum(filters ...TaskFilter) int { scheduler.rwmutex.RLock() defer scheduler.rwmutex.RUnlock() - return len(scheduler.segmentTasks) + if len(filters) == 0 { + return len(scheduler.segmentTasks) + } + + // rewrite this with for loop + counter := 0 + for _, task := range scheduler.segmentTasks { + allMatch := true + for _, filter := range filters { + if !filter(task) { + allMatch = false + break + } + } + if allMatch { + counter++ + } + } + return counter } // schedule selects some tasks to execute, follow these steps for each started selected tasks: diff --git a/pkg/util/paramtable/component_param_test.go b/pkg/util/paramtable/component_param_test.go index bb0cc532caf20..a05f2cabaa42a 100644 --- a/pkg/util/paramtable/component_param_test.go +++ b/pkg/util/paramtable/component_param_test.go @@ -321,8 +321,8 @@ func TestComponentParam(t *testing.T) { assert.Equal(t, 1000, Params.SegmentCheckInterval.GetAsInt()) assert.Equal(t, 1000, Params.ChannelCheckInterval.GetAsInt()) - params.Save(Params.BalanceCheckInterval.Key, "10000") - assert.Equal(t, 10000, Params.BalanceCheckInterval.GetAsInt()) + params.Save(Params.BalanceCheckInterval.Key, "3000") + assert.Equal(t, 3000, Params.BalanceCheckInterval.GetAsInt()) assert.Equal(t, 10000, Params.IndexCheckInterval.GetAsInt()) assert.Equal(t, 3, Params.CollectionRecoverTimesLimit.GetAsInt()) assert.Equal(t, true, Params.AutoBalance.GetAsBool())