Skip to content

Commit

Permalink
planner/core: change agg cost factor (#25210)
Browse files Browse the repository at this point in the history
  • Loading branch information
hanfei1991 authored Jun 8, 2021
1 parent eb91585 commit a7f3c4d
Show file tree
Hide file tree
Showing 5 changed files with 71 additions and 278 deletions.
12 changes: 10 additions & 2 deletions planner/core/physical_plans.go
Original file line number Diff line number Diff line change
Expand Up @@ -1048,7 +1048,7 @@ func (p *basePhysicalAgg) numDistinctFunc() (num int) {
return
}

func (p *basePhysicalAgg) getAggFuncCostFactor() (factor float64) {
func (p *basePhysicalAgg) getAggFuncCostFactor(isMPP bool) (factor float64) {
factor = 0.0
for _, agg := range p.AggFuncs {
if fac, ok := aggFuncFactor[agg.Name]; ok {
Expand All @@ -1058,7 +1058,15 @@ func (p *basePhysicalAgg) getAggFuncCostFactor() (factor float64) {
}
}
if factor == 0 {
factor = 1.0
if isMPP {
// The default factor 1.0 will lead to 1-phase agg in pseudo stats settings.
// But in mpp cases, 2-phase is more usual. So we change this factor.
// TODO: This is still a little tricky and might cause regression. We should
// calibrate these factors and polish our cost model in the future.
factor = aggFuncFactor[ast.AggFuncFirstRow]
} else {
factor = 1.0
}
}
return
}
Expand Down
18 changes: 9 additions & 9 deletions planner/core/task.go
Original file line number Diff line number Diff line change
Expand Up @@ -1829,7 +1829,7 @@ func (p *PhysicalStreamAgg) attach2Task(tasks ...task) task {

// GetCost computes cost of stream aggregation considering CPU/memory.
func (p *PhysicalStreamAgg) GetCost(inputRows float64, isRoot bool) float64 {
aggFuncFactor := p.getAggFuncCostFactor()
aggFuncFactor := p.getAggFuncCostFactor(false)
var cpuCost float64
sessVars := p.ctx.GetSessionVars()
if isRoot {
Expand Down Expand Up @@ -1876,7 +1876,7 @@ func (p *PhysicalHashAgg) attach2TaskForMpp(tasks ...task) task {
if proj != nil {
attachPlan2Task(proj, mpp)
}
mpp.addCost(p.GetCost(inputRows, false))
mpp.addCost(p.GetCost(inputRows, false, true))
p.cost = mpp.cost()
return mpp
case Mpp2Phase:
Expand Down Expand Up @@ -1909,7 +1909,7 @@ func (p *PhysicalHashAgg) attach2TaskForMpp(tasks ...task) task {
attachPlan2Task(proj, newMpp)
}
// TODO: how to set 2-phase cost?
newMpp.addCost(p.GetCost(inputRows, false))
newMpp.addCost(p.GetCost(inputRows, false, true))
finalAgg.SetCost(mpp.cost())
if proj != nil {
proj.SetCost(mpp.cost())
Expand All @@ -1920,14 +1920,14 @@ func (p *PhysicalHashAgg) attach2TaskForMpp(tasks ...task) task {
if partialAgg != nil {
attachPlan2Task(partialAgg, mpp)
}
mpp.addCost(p.GetCost(inputRows, false))
mpp.addCost(p.GetCost(inputRows, false, true))
if partialAgg != nil {
partialAgg.SetCost(mpp.cost())
}
t = mpp.convertToRootTask(p.ctx)
inputRows = t.count()
attachPlan2Task(finalAgg, t)
t.addCost(p.GetCost(inputRows, true))
t.addCost(p.GetCost(inputRows, true, false))
finalAgg.SetCost(t.cost())
return t
default:
Expand Down Expand Up @@ -1958,7 +1958,7 @@ func (p *PhysicalHashAgg) attach2Task(tasks ...task) task {
partialAgg.SetChildren(cop.indexPlan)
cop.indexPlan = partialAgg
}
cop.addCost(p.GetCost(inputRows, false))
cop.addCost(p.GetCost(inputRows, false, false))
}
// In `newPartialAggregate`, we are using stats of final aggregation as stats
// of `partialAgg`, so the network cost of transferring result rows of `partialAgg`
Expand Down Expand Up @@ -1991,16 +1991,16 @@ func (p *PhysicalHashAgg) attach2Task(tasks ...task) task {
// hash aggregation, it would cause under-estimation as the reason mentioned in comment above.
// To make it simple, we also treat 2-phase parallel hash aggregation in TiDB layer as
// 1-phase when computing cost.
t.addCost(p.GetCost(inputRows, true))
t.addCost(p.GetCost(inputRows, true, false))
p.cost = t.cost()
return t
}

// GetCost computes the cost of hash aggregation considering CPU/memory.
func (p *PhysicalHashAgg) GetCost(inputRows float64, isRoot bool) float64 {
func (p *PhysicalHashAgg) GetCost(inputRows float64, isRoot bool, isMPP bool) float64 {
cardinality := p.statsInfo().RowCount
numDistinctFunc := p.numDistinctFunc()
aggFuncFactor := p.getAggFuncCostFactor()
aggFuncFactor := p.getAggFuncCostFactor(isMPP)
var cpuCost float64
sessVars := p.ctx.GetSessionVars()
if isRoot {
Expand Down
24 changes: 2 additions & 22 deletions planner/core/testdata/integration_serial_suite_in.json
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,8 @@
{
"name": "TestMPPOuterJoinBuildSideForBroadcastJoin",
"cases": [
"explain format = 'brief' select count(*) from a left join b on a.id = b.id",
"explain format = 'brief' select count(*) from b right join a on a.id = b.id"
"explain format = 'brief' select count(*) from a left join b on a.id = b.id",
"explain format = 'brief' select count(*) from b right join a on a.id = b.id"
]
},
{
Expand Down Expand Up @@ -101,26 +101,6 @@
"explain format = 'brief' select count(*) from fact_t where not exists (select 1 from d1_t where d1_k = fact_t.d1_k and value > fact_t.col1)"
]
},
{
"name": "TestBroadcastJoin",
"cases": [
"explain format = 'brief' select /*+ broadcast_join(fact_t,d1_t) */ count(*) from fact_t, d1_t where fact_t.d1_k = d1_t.d1_k",
"explain format = 'brief' select /*+ broadcast_join(fact_t,d1_t,d2_t,d3_t) */ count(*) from fact_t, d1_t, d2_t, d3_t where fact_t.d1_k = d1_t.d1_k and fact_t.d2_k = d2_t.d2_k and fact_t.d3_k = d3_t.d3_k",
"explain format = 'brief' select /*+ broadcast_join(fact_t,d1_t), broadcast_join_local(d1_t) */ count(*) from fact_t, d1_t where fact_t.d1_k = d1_t.d1_k",
"explain format = 'brief' select /*+ broadcast_join(fact_t,d1_t,d2_t,d3_t), broadcast_join_local(d2_t) */ count(*) from fact_t, d1_t, d2_t, d3_t where fact_t.d1_k = d1_t.d1_k and fact_t.d2_k = d2_t.d2_k and fact_t.d3_k = d3_t.d3_k",
"explain format = 'brief' select /*+ broadcast_join(fact_t,d1_t) */ count(*) from fact_t left join d1_t on fact_t.d1_k = d1_t.d1_k",
"explain format = 'brief' select /*+ broadcast_join(fact_t,d1_t) */ count(*) from fact_t right join d1_t on fact_t.d1_k = d1_t.d1_k",
"explain format = 'brief' select /*+ broadcast_join(fact_t,d1_t) */ count(*) from fact_t join d1_t on fact_t.d1_k = d1_t.d1_k and fact_t.col1 > d1_t.value",
"explain format = 'brief' select /*+ broadcast_join(fact_t,d1_t) */ count(*) from fact_t left join d1_t on fact_t.d1_k = d1_t.d1_k and fact_t.col1 > 10",
"explain format = 'brief' select /*+ broadcast_join(fact_t,d1_t) */ count(*) from fact_t left join d1_t on fact_t.d1_k = d1_t.d1_k and fact_t.col2 > 10 and fact_t.col1 > d1_t.value",
"explain format = 'brief' select /*+ broadcast_join(fact_t,d1_t) */ count(*) from fact_t right join d1_t on fact_t.d1_k = d1_t.d1_k and d1_t.value > 10",
"explain format = 'brief' select /*+ broadcast_join(fact_t,d1_t) */ count(*) from fact_t right join d1_t on fact_t.d1_k = d1_t.d1_k and d1_t.value > 10 and fact_t.col1 > d1_t.value",
"explain format = 'brief' select /*+ broadcast_join(fact_t,d1_t) */ count(*) from fact_t where exists (select 1 from d1_t where d1_k = fact_t.d1_k)",
"explain format = 'brief' select /*+ broadcast_join(fact_t,d1_t) */ count(*) from fact_t where exists (select 1 from d1_t where d1_k = fact_t.d1_k and value > fact_t.col1)",
"explain format = 'brief' select /*+ broadcast_join(fact_t,d1_t) */ count(*) from fact_t where not exists (select 1 from d1_t where d1_k = fact_t.d1_k)",
"explain format = 'brief' select /*+ broadcast_join(fact_t,d1_t) */ count(*) from fact_t where not exists (select 1 from d1_t where d1_k = fact_t.d1_k and value > fact_t.col1)"
]
},
{
"name": "TestJoinNotSupportedByTiFlash",
"cases": [
Expand Down
Loading

0 comments on commit a7f3c4d

Please sign in to comment.