Skip to content

Commit

Permalink
planner: fix agg elimination logic after agg pushed down through a jo…
Browse files Browse the repository at this point in the history
…in (#44941) (#45096)

close #44795
  • Loading branch information
ti-chi-bot authored Aug 14, 2023
1 parent 60af2a7 commit 361fe40
Show file tree
Hide file tree
Showing 4 changed files with 124 additions and 6 deletions.
48 changes: 48 additions & 0 deletions executor/aggregate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1636,6 +1636,54 @@ func TestIssue27751(t *testing.T) {
tk.MustQuery("select group_concat(nname order by 1 desc separator '#' ) from t;").Check(testkit.Rows("33#2"))
}

func TestIssue44795(t *testing.T) {
store := testkit.CreateMockStore(t)
tk := testkit.NewTestKit(t, store)
tk.MustExec(`use test`)
tk.MustExec(`DROP TABLE IF EXISTS c`)

// case from tcph.
tk.MustExec("CREATE TABLE `customer` (" +
" `C_CUSTKEY` bigint(20) NOT NULL," +
" `C_NAME` varchar(25) NOT NULL," +
" `C_ADDRESS` varchar(40) NOT NULL," +
" `C_NATIONKEY` bigint(20) NOT NULL," +
" `C_PHONE` char(15) NOT NULL," +
" `C_ACCTBAL` decimal(15,2) NOT NULL," +
" `C_MKTSEGMENT` char(10) NOT NULL," +
" `C_COMMENT` varchar(117) NOT NULL," +
" PRIMARY KEY (`C_CUSTKEY`) /*T![clustered_index] CLUSTERED */" +
") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin")

tk.MustExec("CREATE TABLE `orders` (" +
" `O_ORDERKEY` bigint(20) NOT NULL," +
" `O_CUSTKEY` bigint(20) NOT NULL," +
" `O_ORDERSTATUS` char(1) NOT NULL," +
" `O_TOTALPRICE` decimal(15,2) NOT NULL," +
" `O_ORDERDATE` date NOT NULL," +
" `O_ORDERPRIORITY` char(15) NOT NULL," +
" `O_CLERK` char(15) NOT NULL," +
" `O_SHIPPRIORITY` bigint(20) NOT NULL," +
" `O_COMMENT` varchar(79) NOT NULL," +
" PRIMARY KEY (`O_ORDERKEY`) /*T![clustered_index] CLUSTERED */" +
") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin")

tk.MustExec("set tidb_opt_agg_push_down=ON;")

tk.MustQuery("explain format='brief' SELECT /*+ hash_join_build(customer) */ c_custkey, count(o_orderkey) as c_count from customer " +
"left join orders on c_custkey = o_custkey and o_comment not like '%special%requests%' group by c_custkey;").Check(testkit.Rows(
"Projection 8000.00 root test.customer.c_custkey, Column#18",
"└─HashAgg 8000.00 root group by:test.customer.c_custkey, funcs:count(Column#19)->Column#18, funcs:firstrow(test.customer.c_custkey)->test.customer.c_custkey",
" └─HashJoin 10000.00 root left outer join, equal:[eq(test.customer.c_custkey, test.orders.o_custkey)]",
" ├─TableReader(Build) 10000.00 root data:TableFullScan",
" │ └─TableFullScan 10000.00 cop[tikv] table:customer keep order:false, stats:pseudo",
" └─HashAgg(Probe) 6400.00 root group by:test.orders.o_custkey, funcs:count(Column#20)->Column#19, funcs:firstrow(test.orders.o_custkey)->test.orders.o_custkey",
" └─TableReader 6400.00 root data:HashAgg",
" └─HashAgg 6400.00 cop[tikv] group by:test.orders.o_custkey, funcs:count(test.orders.o_orderkey)->Column#20",
" └─Selection 8000.00 cop[tikv] not(like(test.orders.o_comment, \"%special%requests%\", 92))",
" └─TableFullScan 10000.00 cop[tikv] table:orders keep order:false, stats:pseudo"))
}

func TestIssue26885(t *testing.T) {
store := testkit.CreateMockStore(t)
tk := testkit.NewTestKit(t, store)
Expand Down
8 changes: 4 additions & 4 deletions planner/core/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -138,8 +138,8 @@ func TestAggPushDownLeftJoin(t *testing.T) {
"on c_custkey = o_custkey group by c_custkey").Check(testkit.Rows("6 0"))
tk.MustQuery("explain format='brief' select c_custkey, count(o_orderkey) as c_count from customer left outer join orders " +
"on c_custkey = o_custkey group by c_custkey").Check(testkit.Rows(
"Projection 10000.00 root test.customer.c_custkey, Column#7",
"└─Projection 10000.00 root if(isnull(Column#8), 0, 1)->Column#7, test.customer.c_custkey",
"Projection 8000.00 root test.customer.c_custkey, Column#7",
"└─HashAgg 8000.00 root group by:test.customer.c_custkey, funcs:count(Column#8)->Column#7, funcs:firstrow(test.customer.c_custkey)->test.customer.c_custkey",
" └─HashJoin 10000.00 root left outer join, equal:[eq(test.customer.c_custkey, test.orders.o_custkey)]",
" ├─HashAgg(Build) 8000.00 root group by:test.orders.o_custkey, funcs:count(Column#9)->Column#8, funcs:firstrow(test.orders.o_custkey)->test.orders.o_custkey",
" │ └─TableReader 8000.00 root data:HashAgg",
Expand All @@ -152,8 +152,8 @@ func TestAggPushDownLeftJoin(t *testing.T) {
"on c_custkey = o_custkey group by c_custkey").Check(testkit.Rows("6 0"))
tk.MustQuery("explain format='brief' select c_custkey, count(o_orderkey) as c_count from orders right outer join customer " +
"on c_custkey = o_custkey group by c_custkey").Check(testkit.Rows(
"Projection 10000.00 root test.customer.c_custkey, Column#7",
"└─Projection 10000.00 root if(isnull(Column#8), 0, 1)->Column#7, test.customer.c_custkey",
"Projection 8000.00 root test.customer.c_custkey, Column#7",
"└─HashAgg 8000.00 root group by:test.customer.c_custkey, funcs:count(Column#8)->Column#7, funcs:firstrow(test.customer.c_custkey)->test.customer.c_custkey",
" └─HashJoin 10000.00 root right outer join, equal:[eq(test.orders.o_custkey, test.customer.c_custkey)]",
" ├─HashAgg(Build) 8000.00 root group by:test.orders.o_custkey, funcs:count(Column#9)->Column#8, funcs:firstrow(test.orders.o_custkey)->test.orders.o_custkey",
" │ └─TableReader 8000.00 root data:HashAgg",
Expand Down
43 changes: 42 additions & 1 deletion planner/core/rule_aggregation_elimination.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,23 @@ type aggregationEliminator struct {
}

type aggregationEliminateChecker struct {
// used for agg pushed down cases, for example:
// agg -> join -> datasource1
// -> datasource2
// we just make a new agg upon datasource1 or datasource2, while the old agg is still existed and waiting for elimination.
// Note when the old agg is like below, and join is an outer join type, rewriting old agg in elimination logic has some problem.
// eg:
// count(a) -> ifnull(col#x, 0, 1) in rewriteExpr of agg function, since col#x is already the final pushed-down aggregation's
// result from new join schema, we don't need to take every row as count 1 when they don't have not-null flag in a.tryToEliminateAggregation(oldAgg, opt),
// which is not suitable here.
oldAggEliminationCheck bool
}

// tryToEliminateAggregation will eliminate aggregation grouped by unique key.
// e.g. select min(b) from t group by a. If a is a unique key, then this sql is equal to `select b from t group by a`.
// For count(expr), sum(expr), avg(expr), count(distinct expr, [expr...]) we may need to rewrite the expr. Details are shown below.
// If we can eliminate agg successful, we return a projection. Else we return a nil pointer.
func (*aggregationEliminateChecker) tryToEliminateAggregation(agg *LogicalAggregation, opt *logicalOptimizeOp) *LogicalProjection {
func (a *aggregationEliminateChecker) tryToEliminateAggregation(agg *LogicalAggregation, opt *logicalOptimizeOp) *LogicalProjection {
for _, af := range agg.AggFuncs {
// TODO(issue #9968): Actually, we can rewrite GROUP_CONCAT when all the
// arguments it accepts are promised to be NOT-NULL.
Expand All @@ -64,6 +74,9 @@ func (*aggregationEliminateChecker) tryToEliminateAggregation(agg *LogicalAggreg
}
}
if coveredByUniqueKey {
if a.oldAggEliminationCheck && !CheckCanConvertAggToProj(agg) {
return nil
}
// GroupByCols has unique key, so this aggregation can be removed.
if ok, proj := ConvertAggToProj(agg, agg.schema); ok {
proj.SetChildren(agg.children[0])
Expand Down Expand Up @@ -138,6 +151,34 @@ func appendDistinctEliminateTraceStep(agg *LogicalAggregation, uniqueKey express
opt.appendStepToCurrent(agg.ID(), agg.TP(), reason, action)
}

// CheckCanConvertAggToProj check whether a special old aggregation (which has already been pushed down) to projection.
// link: issue#44795
func CheckCanConvertAggToProj(agg *LogicalAggregation) bool {
var mayNullSchema *expression.Schema
if join, ok := agg.Children()[0].(*LogicalJoin); ok {
if join.JoinType == LeftOuterJoin {
mayNullSchema = join.Children()[1].Schema()
}
if join.JoinType == RightOuterJoin {
mayNullSchema = join.Children()[0].Schema()
}
if mayNullSchema == nil {
return true
}
// once agg function args has intersection with mayNullSchema, return nil (means elimination fail)
for _, fun := range agg.AggFuncs {
mayNullCols := expression.ExtractColumnsFromExpressions(nil, fun.Args, func(column *expression.Column) bool {
// collect may-null cols.
return mayNullSchema.Contains(column)
})
if len(mayNullCols) != 0 {
return false
}
}
}
return true
}

// ConvertAggToProj convert aggregation to projection.
func ConvertAggToProj(agg *LogicalAggregation, schema *expression.Schema) (bool, *LogicalProjection) {
proj := LogicalProjection{
Expand Down
31 changes: 30 additions & 1 deletion planner/core/rule_aggregation_push_down.go
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ func (a *aggregationPushDownSolver) tryToPushDownAgg(oldAgg *LogicalAggregation,
}
tmpSchema := expression.NewSchema(gbyCols...)
for _, key := range child.Schema().Keys {
if tmpSchema.ColumnsIndices(key) != nil {
if tmpSchema.ColumnsIndices(key) != nil { // gby item need to be covered by key.
return child, nil
}
}
Expand Down Expand Up @@ -504,10 +504,39 @@ func (a *aggregationPushDownSolver) aggPushDown(p LogicalPlan, opt *logicalOptim
resetNotNullFlag(join.schema, 0, lChild.Schema().Len())
}
buildKeyInfo(join)
// count(a) -> ifnull(col#x, 0, 1) in rewriteExpr of agg function, since col#x is already the final
// pushed-down aggregation's result, we don't need to take every row as count 1 when they don't have
// not-null flag in a.tryToEliminateAggregation(oldAgg, opt), which is not suitable here.
oldCheck := a.oldAggEliminationCheck
a.oldAggEliminationCheck = true
proj := a.tryToEliminateAggregation(agg, opt)
if proj != nil {
p = proj
}
a.oldAggEliminationCheck = oldCheck

// Combine the aggregation elimination logic below since new agg's child key info has changed.
// Notice that even if we eliminate new agg below if possible, the agg's schema is inherited by proj.
// Therefore, we don't need to set the join's schema again, just build the keyInfo again.
changed := false
if newAgg, ok1 := lChild.(*LogicalAggregation); ok1 {
proj := a.tryToEliminateAggregation(newAgg, opt)
if proj != nil {
lChild = proj
changed = true
}
}
if newAgg, ok2 := rChild.(*LogicalAggregation); ok2 {
proj := a.tryToEliminateAggregation(newAgg, opt)
if proj != nil {
rChild = proj
changed = true
}
}
if changed {
join.SetChildren(lChild, rChild)
buildKeyInfo(join)
}
}
} else if proj, ok1 := child.(*LogicalProjection); ok1 {
// push aggregation across projection
Expand Down

0 comments on commit 361fe40

Please sign in to comment.