Skip to content

Commit

Permalink
remove unnecessary merge scope in broadcast join (#17808)
Browse files Browse the repository at this point in the history
在某些场景下,broadcast join的probe端会merge不止一次。 这个pr的目标是保证一定只merge一次。
之前测试场景覆盖不够充分,现在已经重新跑了多个测试场景。

Approved by: @m-schen, @ouyuanning, @aunjgr
  • Loading branch information
badboynt1 authored Jul 31, 2024
1 parent 9571846 commit e31b2ed
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 55 deletions.
83 changes: 41 additions & 42 deletions pkg/sql/compile/compile.go
Original file line number Diff line number Diff line change
Expand Up @@ -1814,6 +1814,7 @@ func (c *Compile) compileExternScan(n *plan.Node) ([]*Scope, error) {

if len(fileList) == 0 {
ret := newScope(Normal)
ret.NodeInfo = getEngineNode(c)
ret.DataSource = &Source{isConst: true, node: n}

currentFirstFlag := c.anal.isFirst
Expand Down Expand Up @@ -1994,6 +1995,7 @@ func (c *Compile) compileTableFunction(n *plan.Node, ss []*Scope) []*Scope {

func (c *Compile) compileValueScan(n *plan.Node) ([]*Scope, error) {
ds := newScope(Normal)
ds.NodeInfo = getEngineNode(c)
ds.DataSource = &Source{isConst: true, node: n}
ds.NodeInfo = engine.Node{Addr: c.addr, Mcpu: 1}
ds.Proc = process.NewFromProc(c.proc, c.proc.Ctx, 0)
Expand Down Expand Up @@ -2457,10 +2459,6 @@ func (c *Compile) compileBroadcastJoin(node, left, right *plan.Node, ns []*plan.
leftTyps[i] = dupType(&expr.Typ)
}

if plan2.IsShuffleChildren(left, ns) {
probeScopes = c.mergeShuffleJoinScopeList(probeScopes)
}

switch node.JoinType {
case plan.Node_INNER:
rs = c.newBroadcastJoinScopeList(probeScopes, buildScopes, node)
Expand Down Expand Up @@ -3688,59 +3686,60 @@ func (c *Compile) newJoinScopeListWithBucket(rs, left, right []*Scope, n *plan.N
return rs
}

func (c *Compile) newMergeRemoteScopeByCN(ss []*Scope) []*Scope {
rs := make([]*Scope, 0, len(c.cnList))
for i := range c.cnList {
cn := c.cnList[i]
currentSS := make([]*Scope, 0, cn.Mcpu)
for j := range ss {
if isSameCN(ss[j].NodeInfo.Addr, cn.Addr) {
currentSS = append(currentSS, ss[j])
}
}
if len(currentSS) > 0 {
mergeScope := c.newMergeRemoteScope(currentSS, cn)
rs = append(rs, mergeScope)
}
}

return rs
}

func (c *Compile) newBroadcastJoinScopeList(probeScopes []*Scope, buildScopes []*Scope, n *plan.Node) []*Scope {
length := len(probeScopes)
rs := make([]*Scope, length)
idx := 0
for i := range probeScopes {
rs[i] = newScope(Remote)
rs := c.newMergeRemoteScopeByCN(probeScopes)
for i := range rs {
rs[i].IsJoin = true
rs[i].NodeInfo = probeScopes[i].NodeInfo
rs[i].BuildIdx = 1
if isSameCN(rs[i].NodeInfo.Addr, c.addr) {
idx = i
rs[i].NodeInfo.Mcpu = c.generateCPUNumber(ncpu, int(n.Stats.BlockNum))
rs[i].BuildIdx = len(rs[i].Proc.Reg.MergeReceivers)
w := &process.WaitRegister{
Ctx: rs[i].Proc.Ctx,
Ch: make(chan *process.RegisterMessage, 10),
}
rs[i].PreScopes = []*Scope{probeScopes[i]}
rs[i].Proc = process.NewFromProc(c.proc, c.proc.Ctx, 2)
probeScopes[i].setRootOperator(
connector.NewArgument().
WithReg(rs[i].Proc.Reg.MergeReceivers[0]))
rs[i].Proc.Reg.MergeReceivers = append(rs[i].Proc.Reg.MergeReceivers, w)
}

// all join's first flag will setting in newLeftScope and newRightScope
// so we set it to false now
if c.IsTpQuery() {
rs[0].PreScopes = append(rs[0].PreScopes, buildScopes[0])
} else {
c.anal.isFirst = false
mergeChildren := c.newMergeScope(buildScopes)

mergeChildren.setRootOperator(constructDispatch(1, rs, c.addr, n, false))
mergeChildren.IsEnd = true
rs[idx].PreScopes = append(rs[idx].PreScopes, mergeChildren)
}

for i := range rs {
mergeOp := merge.NewArgument()
rs[i].setRootOperator(mergeOp)
for i := range rs {
if isSameCN(rs[i].NodeInfo.Addr, c.addr) {
mergeBuild := buildScopes[0]
if len(buildScopes) > 1 {
mergeBuild = c.newMergeScope(buildScopes)
}
mergeBuild.setRootOperator(constructDispatch(rs[i].BuildIdx, rs, c.addr, n, false))
mergeBuild.IsEnd = true
rs[i].PreScopes = append(rs[i].PreScopes, mergeBuild)
break
}
}
}

return rs
}

func (c *Compile) mergeShuffleJoinScopeList(child []*Scope) []*Scope {
lenCN := len(c.cnList)
dop := len(child) / lenCN
mergeScope := make([]*Scope, 0, lenCN)
for i, n := range c.cnList {
start := i * dop
end := start + dop
ss := child[start:end]
mergeScope = append(mergeScope, c.newMergeRemoteScope(ss, n))
}
return mergeScope
}

func (c *Compile) newShuffleJoinScopeList(left, right []*Scope, n *plan.Node) ([]*Scope, []*Scope) {
single := len(c.cnList) <= 1
if single {
Expand Down
2 changes: 1 addition & 1 deletion pkg/sql/compile/debugTools.go
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ func debugShowScopes(ss []*Scope, gap int, rmp map[*process.WaitRegister]int) st
if ss[i].Proc != nil {
receiverStr = getReceiverStr(ss[i], ss[i].Proc.Reg.MergeReceivers)
}
str += fmt.Sprintf("Scope %d (Magic: %s, Receiver: %s): [", i+1, magicShow(ss[i].Magic), receiverStr)
str += fmt.Sprintf("Scope %d (Magic: %s, mcpu: %v, Receiver: %s): [", i+1, magicShow(ss[i].Magic), ss[i].NodeInfo.Mcpu, receiverStr)

vm.HandleAllOp(ss[i].RootOp, func(parentOp vm.Operator, op vm.Operator) error {
if op.GetOperatorBase().NumChildren() != 0 {
Expand Down
12 changes: 0 additions & 12 deletions pkg/sql/plan/shuffle.go
Original file line number Diff line number Diff line change
Expand Up @@ -581,15 +581,3 @@ func shouldUseShuffleRanges(s *pb.ShuffleRange) []float64 {
}
return s.Result
}

func IsShuffleChildren(n *plan.Node, ns []*plan.Node) bool {
switch n.NodeType {
case plan.Node_JOIN:
if n.Stats.HashmapStats.Shuffle {
return true
}
case plan.Node_FILTER:
return IsShuffleChildren(ns[n.Children[0]], ns)
}
return false
}

0 comments on commit e31b2ed

Please sign in to comment.