Skip to content

Commit

Permalink
enhance: reduce the cpu usage when collection number is high
Browse files Browse the repository at this point in the history
Signed-off-by: xiaofanluan <[email protected]>
  • Loading branch information
xiaofan-luan committed Apr 18, 2024
1 parent 70beec3 commit bf48a25
Show file tree
Hide file tree
Showing 20 changed files with 138 additions and 76 deletions.
2 changes: 1 addition & 1 deletion configs/milvus.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ queryCoord:
checkIndexInterval: 10000
channelTaskTimeout: 60000 # 1 minute
segmentTaskTimeout: 120000 # 2 minute
distPullInterval: 500
distPullInterval: 1000
heartbeatAvailableInterval: 10000 # 10s, Only QueryNodes which fetched heartbeats within the duration are available
loadTimeoutSeconds: 600
distRequestTimeout: 5000 # the request timeout for querycoord fetching data distribution from querynodes, in milliseconds
Expand Down
2 changes: 1 addition & 1 deletion internal/querycoordv2/balance/multi_target_balance.go
Original file line number Diff line number Diff line change
Expand Up @@ -527,7 +527,7 @@ func (b *MultiTargetBalancer) genSegmentPlan(replica *meta.Replica) []SegmentAss
nodeSegments := make(map[int64][]*meta.Segment)
globalNodeSegments := make(map[int64][]*meta.Segment)
for _, node := range replica.GetNodes() {
dist := b.dist.SegmentDistManager.GetByFilter(meta.WithCollectionID(replica.GetCollectionID()), meta.WithNodeID(node))
dist := b.dist.SegmentDistManager.GetByCollectionAndFilter(replica.GetCollectionID(), meta.WithNodeID(node))
segments := lo.Filter(dist, func(segment *meta.Segment, _ int) bool {
return b.targetMgr.GetSealedSegment(segment.GetCollectionID(), segment.GetID(), meta.CurrentTarget) != nil &&
b.targetMgr.GetSealedSegment(segment.GetCollectionID(), segment.GetID(), meta.NextTarget) != nil &&
Expand Down
8 changes: 4 additions & 4 deletions internal/querycoordv2/balance/rowcount_based_balancer.go
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ func (b *RowCountBasedBalancer) BalanceReplica(replica *meta.Replica) ([]Segment
func (b *RowCountBasedBalancer) genStoppingSegmentPlan(replica *meta.Replica, onlineNodes []int64, offlineNodes []int64) []SegmentAssignPlan {
segmentPlans := make([]SegmentAssignPlan, 0)
for _, nodeID := range offlineNodes {
dist := b.dist.SegmentDistManager.GetByFilter(meta.WithCollectionID(replica.GetCollectionID()), meta.WithNodeID(nodeID))
dist := b.dist.SegmentDistManager.GetByCollectionAndFilter(replica.GetCollectionID(), meta.WithNodeID(nodeID))
segments := lo.Filter(dist, func(segment *meta.Segment, _ int) bool {
return b.targetMgr.GetSealedSegment(segment.GetCollectionID(), segment.GetID(), meta.CurrentTarget) != nil &&
b.targetMgr.GetSealedSegment(segment.GetCollectionID(), segment.GetID(), meta.NextTarget) != nil &&
Expand All @@ -253,7 +253,7 @@ func (b *RowCountBasedBalancer) genSegmentPlan(replica *meta.Replica, onlineNode
segmentDist := make(map[int64][]*meta.Segment)
totalRowCount := 0
for _, node := range onlineNodes {
dist := b.dist.SegmentDistManager.GetByFilter(meta.WithCollectionID(replica.GetCollectionID()), meta.WithNodeID(node))
dist := b.dist.SegmentDistManager.GetByCollectionAndFilter(replica.GetCollectionID(), meta.WithNodeID(node))
segments := lo.Filter(dist, func(segment *meta.Segment, _ int) bool {
return b.targetMgr.GetSealedSegment(segment.GetCollectionID(), segment.GetID(), meta.CurrentTarget) != nil &&
b.targetMgr.GetSealedSegment(segment.GetCollectionID(), segment.GetID(), meta.NextTarget) != nil &&
Expand Down Expand Up @@ -316,7 +316,7 @@ func (b *RowCountBasedBalancer) genSegmentPlan(replica *meta.Replica, onlineNode
func (b *RowCountBasedBalancer) genStoppingChannelPlan(replica *meta.Replica, onlineNodes []int64, offlineNodes []int64) []ChannelAssignPlan {
channelPlans := make([]ChannelAssignPlan, 0)
for _, nodeID := range offlineNodes {
dmChannels := b.dist.ChannelDistManager.GetByFilter(meta.WithCollectionID2Channel(replica.GetCollectionID()), meta.WithNodeID2Channel(nodeID))
dmChannels := b.dist.ChannelDistManager.GetByCollectionAndFilter(replica.GetCollectionID(), meta.WithNodeID2Channel(nodeID))
plans := b.AssignChannel(dmChannels, onlineNodes, false)
for i := range plans {
plans[i].From = nodeID
Expand All @@ -341,7 +341,7 @@ func (b *RowCountBasedBalancer) genChannelPlan(replica *meta.Replica, onlineNode
nodeWithLessChannel := make([]int64, 0)
channelsToMove := make([]*meta.DmChannel, 0)
for _, node := range onlineNodes {
channels := b.dist.ChannelDistManager.GetByFilter(meta.WithCollectionID2Channel(replica.GetCollectionID()), meta.WithNodeID2Channel(node))
channels := b.dist.ChannelDistManager.GetByCollectionAndFilter(replica.GetCollectionID(), meta.WithNodeID2Channel(node))

if len(channels) <= average {
nodeWithLessChannel = append(nodeWithLessChannel, node)
Expand Down
6 changes: 3 additions & 3 deletions internal/querycoordv2/balance/score_based_balancer.go
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ func (b *ScoreBasedBalancer) calculateScore(collectionID, nodeID int64) int {

collectionRowCount := 0
// calculate collection sealed segment row count
collectionSegments := b.dist.SegmentDistManager.GetByFilter(meta.WithCollectionID(collectionID), meta.WithNodeID(nodeID))
collectionSegments := b.dist.SegmentDistManager.GetByCollectionAndFilter(collectionID, meta.WithNodeID(nodeID))
for _, s := range collectionSegments {
collectionRowCount += int(s.GetNumOfRows())
}
Expand Down Expand Up @@ -254,7 +254,7 @@ func (b *ScoreBasedBalancer) BalanceReplica(replica *meta.Replica) ([]SegmentAss
func (b *ScoreBasedBalancer) genStoppingSegmentPlan(replica *meta.Replica, onlineNodes []int64, offlineNodes []int64) []SegmentAssignPlan {
segmentPlans := make([]SegmentAssignPlan, 0)
for _, nodeID := range offlineNodes {
dist := b.dist.SegmentDistManager.GetByFilter(meta.WithCollectionID(replica.GetCollectionID()), meta.WithNodeID(nodeID))
dist := b.dist.SegmentDistManager.GetByCollectionAndFilter(replica.GetCollectionID(), meta.WithNodeID(nodeID))
segments := lo.Filter(dist, func(segment *meta.Segment, _ int) bool {
return b.targetMgr.GetSealedSegment(segment.GetCollectionID(), segment.GetID(), meta.CurrentTarget) != nil &&
b.targetMgr.GetSealedSegment(segment.GetCollectionID(), segment.GetID(), meta.NextTarget) != nil &&
Expand All @@ -277,7 +277,7 @@ func (b *ScoreBasedBalancer) genSegmentPlan(replica *meta.Replica, onlineNodes [

// list all segment which could be balanced, and calculate node's score
for _, node := range onlineNodes {
dist := b.dist.SegmentDistManager.GetByFilter(meta.WithCollectionID(replica.GetCollectionID()), meta.WithNodeID(node))
dist := b.dist.SegmentDistManager.GetByCollectionAndFilter(replica.GetCollectionID(), meta.WithNodeID(node))
segments := lo.Filter(dist, func(segment *meta.Segment, _ int) bool {
return b.targetMgr.GetSealedSegment(segment.GetCollectionID(), segment.GetID(), meta.CurrentTarget) != nil &&
b.targetMgr.GetSealedSegment(segment.GetCollectionID(), segment.GetID(), meta.NextTarget) != nil &&
Expand Down
4 changes: 2 additions & 2 deletions internal/querycoordv2/balance/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ func PrintCurrentReplicaDist(replica *meta.Replica,
// 3. print stopping nodes channel distribution
distInfo += "[stoppingNodesChannelDist:"
for stoppingNodeID := range stoppingNodesSegments {
stoppingNodeChannels := channelManager.GetByFilter(meta.WithCollectionID2Channel(replica.GetCollectionID()), meta.WithNodeID2Channel(stoppingNodeID))
stoppingNodeChannels := channelManager.GetByCollectionAndFilter(replica.GetCollectionID(), meta.WithNodeID2Channel(stoppingNodeID))
distInfo += fmt.Sprintf("[nodeID:%d, count:%d,", stoppingNodeID, len(stoppingNodeChannels))
distInfo += "channels:["
for _, stoppingChan := range stoppingNodeChannels {
Expand All @@ -189,7 +189,7 @@ func PrintCurrentReplicaDist(replica *meta.Replica,
// 4. print normal nodes channel distribution
distInfo += "[normalNodesChannelDist:"
for normalNodeID := range nodeSegments {
normalNodeChannels := channelManager.GetByFilter(meta.WithCollectionID2Channel(replica.GetCollectionID()), meta.WithNodeID2Channel(normalNodeID))
normalNodeChannels := channelManager.GetByCollectionAndFilter(replica.GetCollectionID(), meta.WithNodeID2Channel(normalNodeID))
distInfo += fmt.Sprintf("[nodeID:%d, count:%d,", normalNodeID, len(normalNodeChannels))
distInfo += "channels:["
for _, normalNodeChan := range normalNodeChannels {
Expand Down
8 changes: 2 additions & 6 deletions internal/querycoordv2/checkers/channel_checker.go
Original file line number Diff line number Diff line change
Expand Up @@ -160,11 +160,7 @@ func (c *ChannelChecker) getDmChannelDiff(collectionID int64,
}

func (c *ChannelChecker) getChannelDist(replica *meta.Replica) []*meta.DmChannel {
dist := make([]*meta.DmChannel, 0)
for _, nodeID := range replica.GetNodes() {
dist = append(dist, c.dist.ChannelDistManager.GetByFilter(meta.WithCollectionID2Channel(replica.GetCollectionID()), meta.WithNodeID2Channel(nodeID))...)
}
return dist
return c.dist.ChannelDistManager.GetByCollectionAndFilter(replica.GetCollectionID(), meta.WithNodeIDsFilter(replica.GetNodes()))
}

func (c *ChannelChecker) findRepeatedChannels(ctx context.Context, replicaID int64) []*meta.DmChannel {
Expand All @@ -183,7 +179,7 @@ func (c *ChannelChecker) findRepeatedChannels(ctx context.Context, replicaID int
for _, ch := range dist {
leaderView := c.dist.LeaderViewManager.GetLeaderShardView(ch.Node, ch.GetChannelName())
if leaderView == nil {
log.Info("shard leadview is not ready, skip",
log.Info("shard leader view is not ready, skip",
zap.Int64("collectionID", replica.GetCollectionID()),
zap.Int64("replicaID", replicaID),
zap.Int64("leaderID", ch.Node),
Expand Down
2 changes: 1 addition & 1 deletion internal/querycoordv2/checkers/index_checker.go
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ func (c *IndexChecker) checkSegment(ctx context.Context, segment *meta.Segment,
func (c *IndexChecker) getSealedSegmentsDist(replica *meta.Replica) []*meta.Segment {
var ret []*meta.Segment
for _, node := range replica.GetNodes() {
ret = append(ret, c.dist.SegmentDistManager.GetByFilter(meta.WithCollectionID(replica.GetCollectionID()), meta.WithNodeID(node))...)
ret = append(ret, c.dist.SegmentDistManager.GetByCollectionAndFilter(replica.GetCollectionID(), meta.WithNodeID(node))...)
}
return ret
}
Expand Down
2 changes: 1 addition & 1 deletion internal/querycoordv2/checkers/segment_checker.go
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,7 @@ func (c *SegmentChecker) getSealedSegmentDiff(
func (c *SegmentChecker) getSealedSegmentsDist(replica *meta.Replica) []*meta.Segment {
ret := make([]*meta.Segment, 0)
for _, node := range replica.GetNodes() {
ret = append(ret, c.dist.SegmentDistManager.GetByFilter(meta.WithCollectionID(replica.GetCollectionID()), meta.WithNodeID(node))...)
ret = append(ret, c.dist.SegmentDistManager.GetByCollectionAndFilter(replica.GetCollectionID(), meta.WithNodeID(node))...)
}
return ret
}
Expand Down
1 change: 0 additions & 1 deletion internal/querycoordv2/dist/dist_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,6 @@ func (dh *distHandler) handleDistResp(resp *querypb.GetDataDistributionResponse)
}
node.SetLastHeartbeat(time.Now())
}

dh.updateSegmentsDistribution(resp)
dh.updateChannelsDistribution(resp)
dh.updateLeaderView(resp)
Expand Down
4 changes: 2 additions & 2 deletions internal/querycoordv2/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ func (s *Server) checkAnyReplicaAvailable(collectionID int64) bool {
}

func (s *Server) getCollectionSegmentInfo(collection int64) []*querypb.SegmentInfo {
segments := s.dist.SegmentDistManager.GetByFilter(meta.WithCollectionID(collection))
segments := s.dist.SegmentDistManager.GetByCollectionAndFilter(collection)
currentTargetSegmentsMap := s.targetMgr.GetSealedSegmentsByCollection(collection, meta.CurrentTarget)
infos := make(map[int64]*querypb.SegmentInfo)
for _, segment := range segments {
Expand Down Expand Up @@ -390,7 +390,7 @@ func (s *Server) fillReplicaInfo(replica *meta.Replica, withShardNodes bool) (*m
}
var segments []*meta.Segment
if withShardNodes {
segments = s.dist.SegmentDistManager.GetByFilter(meta.WithCollectionID(replica.GetCollectionID()))
segments = s.dist.SegmentDistManager.GetByCollectionAndFilter(replica.GetCollectionID())
}

for _, channel := range channels {
Expand Down
8 changes: 3 additions & 5 deletions internal/querycoordv2/job/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,16 +40,14 @@ import (
func waitCollectionReleased(dist *meta.DistributionManager, checkerController *checkers.CheckerController, collection int64, partitions ...int64) {
partitionSet := typeutil.NewUniqueSet(partitions...)
for {
var (
channels []*meta.DmChannel
segments []*meta.Segment = dist.SegmentDistManager.GetByFilter(meta.WithCollectionID(collection))
)
var channels []*meta.DmChannel
segments := dist.SegmentDistManager.GetByCollectionAndFilter(collection)
if partitionSet.Len() > 0 {
segments = lo.Filter(segments, func(segment *meta.Segment, _ int) bool {
return partitionSet.Contain(segment.GetPartitionID())
})
} else {
channels = dist.ChannelDistManager.GetByFilter(meta.WithCollectionID2Channel(collection))
channels = dist.ChannelDistManager.GetByCollectionAndFilter(collection)
}

if len(channels)+len(segments) == 0 {
Expand Down
61 changes: 50 additions & 11 deletions internal/querycoordv2/meta/channel_dist_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,27 +27,26 @@ import (

type ChannelDistFilter = func(ch *DmChannel) bool

func WithCollectionID2Channel(collectionID int64) ChannelDistFilter {
return func(ch *DmChannel) bool {
return ch.GetCollectionID() == collectionID
}
}

func WithNodeID2Channel(nodeID int64) ChannelDistFilter {
return func(ch *DmChannel) bool {
return ch.Node == nodeID
}
}

func WithReplica2Channel(replica *Replica) ChannelDistFilter {
func WithNodeIDsFilter(nodeIDs []int64) ChannelDistFilter {
return func(ch *DmChannel) bool {
return ch.GetCollectionID() == replica.GetCollectionID() && replica.Contains(ch.Node)
for _, id := range nodeIDs {
if ch.Node == id {
return true
}
}
return false
}
}

func WithChannelName2Channel(channelName string) ChannelDistFilter {
func WithReplica2Channel(replica *Replica) ChannelDistFilter {
return func(ch *DmChannel) bool {
return ch.GetChannelName() == channelName
return ch.GetCollectionID() == replica.GetCollectionID() && replica.Contains(ch.Node)
}
}

Expand Down Expand Up @@ -76,11 +75,15 @@ type ChannelDistManager struct {

// NodeID -> Channels
channels map[UniqueID][]*DmChannel

// CollectionID -> Channels
collectionIndex map[int64]map[string]*DmChannel
}

func NewChannelDistManager() *ChannelDistManager {
return &ChannelDistManager{
channels: make(map[UniqueID][]*DmChannel),
channels: make(map[UniqueID][]*DmChannel),
collectionIndex: make(map[int64]map[string]*DmChannel),
}
}

Expand Down Expand Up @@ -146,13 +149,49 @@ func (m *ChannelDistManager) GetByFilter(filters ...ChannelDistFilter) []*DmChan
return ret
}

func (m *ChannelDistManager) GetByCollectionAndFilter(collectionID int64, filters ...ChannelDistFilter) []*DmChannel {
m.rwmutex.RLock()
defer m.rwmutex.RUnlock()

mergedFilters := func(ch *DmChannel) bool {
for _, fn := range filters {
if fn != nil && !fn(ch) {
return false
}
}

return true
}

ret := make([]*DmChannel, 0)

// If a collection ID is provided, use the collection index
for _, channel := range m.collectionIndex[collectionID] {
if mergedFilters(channel) {
ret = append(ret, channel)
}
}
return ret
}

func (m *ChannelDistManager) Update(nodeID UniqueID, channels ...*DmChannel) {
m.rwmutex.Lock()
defer m.rwmutex.Unlock()

for _, channel := range channels {
channel.Node = nodeID
m.updateCollectionIndex(channel)
}

m.channels[nodeID] = channels
}

func (m *ChannelDistManager) updateCollectionIndex(channel *DmChannel) {
collectionID := channel.GetCollectionID()
index, ok := m.collectionIndex[collectionID]
if !ok {
index = make(map[string]*DmChannel)
m.collectionIndex[collectionID] = index
}
index[channel.GetChannelName()] = channel
}
10 changes: 5 additions & 5 deletions internal/querycoordv2/meta/channel_dist_manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,26 +76,26 @@ func (suite *ChannelDistManagerSuite) TestGetBy() {
}

// Test GetByCollection
channels = dist.GetByFilter(WithCollectionID2Channel(suite.collection))
channels = dist.GetByCollectionAndFilter(suite.collection)
suite.Len(channels, 4)
suite.AssertCollection(channels, suite.collection)
channels = dist.GetByFilter(WithCollectionID2Channel(-1))
channels = dist.GetByCollectionAndFilter(-1)
suite.Len(channels, 0)

// Test GetByNodeAndCollection
// 1. Valid node and valid collection
for _, node := range suite.nodes {
channels := dist.GetByFilter(WithCollectionID2Channel(suite.collection), WithNodeID2Channel(node))
channels := dist.GetByCollectionAndFilter(suite.collection, WithNodeID2Channel(node))
suite.AssertNode(channels, node)
suite.AssertCollection(channels, suite.collection)
}

// 2. Valid node and invalid collection
channels = dist.GetByFilter(WithCollectionID2Channel(-1), WithNodeID2Channel(suite.nodes[1]))
channels = dist.GetByCollectionAndFilter(-1, WithNodeID2Channel(suite.nodes[1]))
suite.Len(channels, 0)

// 3. Invalid node and valid collection
channels = dist.GetByFilter(WithCollectionID2Channel(suite.collection), WithNodeID2Channel(-1))
channels = dist.GetByCollectionAndFilter(suite.collection, WithNodeID2Channel(-1))
suite.Len(channels, 0)
}

Expand Down
Loading

0 comments on commit bf48a25

Please sign in to comment.