From 361fe40be87c200b1cb28d232f03545c2056f854 Mon Sep 17 00:00:00 2001 From: Ti Chi Robot Date: Mon, 14 Aug 2023 19:07:59 +0800 Subject: [PATCH] planner: fix agg elimination logic after agg pushed down through a join (#44941) (#45096) close pingcap/tidb#44795 --- executor/aggregate_test.go | 48 ++++++++++++++++++++ planner/core/integration_test.go | 8 ++-- planner/core/rule_aggregation_elimination.go | 43 +++++++++++++++++- planner/core/rule_aggregation_push_down.go | 31 ++++++++++++- 4 files changed, 124 insertions(+), 6 deletions(-) diff --git a/executor/aggregate_test.go b/executor/aggregate_test.go index bd6366d6d115d..9e961ee8df157 100644 --- a/executor/aggregate_test.go +++ b/executor/aggregate_test.go @@ -1636,6 +1636,54 @@ func TestIssue27751(t *testing.T) { tk.MustQuery("select group_concat(nname order by 1 desc separator '#' ) from t;").Check(testkit.Rows("33#2")) } +func TestIssue44795(t *testing.T) { + store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + tk.MustExec(`use test`) + tk.MustExec(`DROP TABLE IF EXISTS c`) + + // case from tcph. + tk.MustExec("CREATE TABLE `customer` (" + + " `C_CUSTKEY` bigint(20) NOT NULL," + + " `C_NAME` varchar(25) NOT NULL," + + " `C_ADDRESS` varchar(40) NOT NULL," + + " `C_NATIONKEY` bigint(20) NOT NULL," + + " `C_PHONE` char(15) NOT NULL," + + " `C_ACCTBAL` decimal(15,2) NOT NULL," + + " `C_MKTSEGMENT` char(10) NOT NULL," + + " `C_COMMENT` varchar(117) NOT NULL," + + " PRIMARY KEY (`C_CUSTKEY`) /*T![clustered_index] CLUSTERED */" + + ") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin") + + tk.MustExec("CREATE TABLE `orders` (" + + " `O_ORDERKEY` bigint(20) NOT NULL," + + " `O_CUSTKEY` bigint(20) NOT NULL," + + " `O_ORDERSTATUS` char(1) NOT NULL," + + " `O_TOTALPRICE` decimal(15,2) NOT NULL," + + " `O_ORDERDATE` date NOT NULL," + + " `O_ORDERPRIORITY` char(15) NOT NULL," + + " `O_CLERK` char(15) NOT NULL," + + " `O_SHIPPRIORITY` bigint(20) NOT NULL," + + " `O_COMMENT` varchar(79) NOT NULL," + + " PRIMARY KEY (`O_ORDERKEY`) /*T![clustered_index] CLUSTERED */" + + ") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin") + + tk.MustExec("set tidb_opt_agg_push_down=ON;") + + tk.MustQuery("explain format='brief' SELECT /*+ hash_join_build(customer) */ c_custkey, count(o_orderkey) as c_count from customer " + + "left join orders on c_custkey = o_custkey and o_comment not like '%special%requests%' group by c_custkey;").Check(testkit.Rows( + "Projection 8000.00 root test.customer.c_custkey, Column#18", + "└─HashAgg 8000.00 root group by:test.customer.c_custkey, funcs:count(Column#19)->Column#18, funcs:firstrow(test.customer.c_custkey)->test.customer.c_custkey", + " └─HashJoin 10000.00 root left outer join, equal:[eq(test.customer.c_custkey, test.orders.o_custkey)]", + " ├─TableReader(Build) 10000.00 root data:TableFullScan", + " │ └─TableFullScan 10000.00 cop[tikv] table:customer keep order:false, stats:pseudo", + " └─HashAgg(Probe) 6400.00 root group by:test.orders.o_custkey, funcs:count(Column#20)->Column#19, funcs:firstrow(test.orders.o_custkey)->test.orders.o_custkey", + " └─TableReader 6400.00 root data:HashAgg", + " └─HashAgg 6400.00 cop[tikv] group by:test.orders.o_custkey, funcs:count(test.orders.o_orderkey)->Column#20", + " └─Selection 8000.00 cop[tikv] not(like(test.orders.o_comment, \"%special%requests%\", 92))", + " └─TableFullScan 10000.00 cop[tikv] table:orders keep order:false, stats:pseudo")) +} + func TestIssue26885(t *testing.T) { store := testkit.CreateMockStore(t) tk := testkit.NewTestKit(t, store) diff --git a/planner/core/integration_test.go b/planner/core/integration_test.go index 9428b7c43a28f..4dba51a1c97b4 100644 --- a/planner/core/integration_test.go +++ b/planner/core/integration_test.go @@ -138,8 +138,8 @@ func TestAggPushDownLeftJoin(t *testing.T) { "on c_custkey = o_custkey group by c_custkey").Check(testkit.Rows("6 0")) tk.MustQuery("explain format='brief' select c_custkey, count(o_orderkey) as c_count from customer left outer join orders " + "on c_custkey = o_custkey group by c_custkey").Check(testkit.Rows( - "Projection 10000.00 root test.customer.c_custkey, Column#7", - "└─Projection 10000.00 root if(isnull(Column#8), 0, 1)->Column#7, test.customer.c_custkey", + "Projection 8000.00 root test.customer.c_custkey, Column#7", + "└─HashAgg 8000.00 root group by:test.customer.c_custkey, funcs:count(Column#8)->Column#7, funcs:firstrow(test.customer.c_custkey)->test.customer.c_custkey", " └─HashJoin 10000.00 root left outer join, equal:[eq(test.customer.c_custkey, test.orders.o_custkey)]", " ├─HashAgg(Build) 8000.00 root group by:test.orders.o_custkey, funcs:count(Column#9)->Column#8, funcs:firstrow(test.orders.o_custkey)->test.orders.o_custkey", " │ └─TableReader 8000.00 root data:HashAgg", @@ -152,8 +152,8 @@ func TestAggPushDownLeftJoin(t *testing.T) { "on c_custkey = o_custkey group by c_custkey").Check(testkit.Rows("6 0")) tk.MustQuery("explain format='brief' select c_custkey, count(o_orderkey) as c_count from orders right outer join customer " + "on c_custkey = o_custkey group by c_custkey").Check(testkit.Rows( - "Projection 10000.00 root test.customer.c_custkey, Column#7", - "└─Projection 10000.00 root if(isnull(Column#8), 0, 1)->Column#7, test.customer.c_custkey", + "Projection 8000.00 root test.customer.c_custkey, Column#7", + "└─HashAgg 8000.00 root group by:test.customer.c_custkey, funcs:count(Column#8)->Column#7, funcs:firstrow(test.customer.c_custkey)->test.customer.c_custkey", " └─HashJoin 10000.00 root right outer join, equal:[eq(test.orders.o_custkey, test.customer.c_custkey)]", " ├─HashAgg(Build) 8000.00 root group by:test.orders.o_custkey, funcs:count(Column#9)->Column#8, funcs:firstrow(test.orders.o_custkey)->test.orders.o_custkey", " │ └─TableReader 8000.00 root data:HashAgg", diff --git a/planner/core/rule_aggregation_elimination.go b/planner/core/rule_aggregation_elimination.go index 2b0cf91649ef5..b0b937b7aadd9 100644 --- a/planner/core/rule_aggregation_elimination.go +++ b/planner/core/rule_aggregation_elimination.go @@ -32,13 +32,23 @@ type aggregationEliminator struct { } type aggregationEliminateChecker struct { + // used for agg pushed down cases, for example: + // agg -> join -> datasource1 + // -> datasource2 + // we just make a new agg upon datasource1 or datasource2, while the old agg is still existed and waiting for elimination. + // Note when the old agg is like below, and join is an outer join type, rewriting old agg in elimination logic has some problem. + // eg: + // count(a) -> ifnull(col#x, 0, 1) in rewriteExpr of agg function, since col#x is already the final pushed-down aggregation's + // result from new join schema, we don't need to take every row as count 1 when they don't have not-null flag in a.tryToEliminateAggregation(oldAgg, opt), + // which is not suitable here. + oldAggEliminationCheck bool } // tryToEliminateAggregation will eliminate aggregation grouped by unique key. // e.g. select min(b) from t group by a. If a is a unique key, then this sql is equal to `select b from t group by a`. // For count(expr), sum(expr), avg(expr), count(distinct expr, [expr...]) we may need to rewrite the expr. Details are shown below. // If we can eliminate agg successful, we return a projection. Else we return a nil pointer. -func (*aggregationEliminateChecker) tryToEliminateAggregation(agg *LogicalAggregation, opt *logicalOptimizeOp) *LogicalProjection { +func (a *aggregationEliminateChecker) tryToEliminateAggregation(agg *LogicalAggregation, opt *logicalOptimizeOp) *LogicalProjection { for _, af := range agg.AggFuncs { // TODO(issue #9968): Actually, we can rewrite GROUP_CONCAT when all the // arguments it accepts are promised to be NOT-NULL. @@ -64,6 +74,9 @@ func (*aggregationEliminateChecker) tryToEliminateAggregation(agg *LogicalAggreg } } if coveredByUniqueKey { + if a.oldAggEliminationCheck && !CheckCanConvertAggToProj(agg) { + return nil + } // GroupByCols has unique key, so this aggregation can be removed. if ok, proj := ConvertAggToProj(agg, agg.schema); ok { proj.SetChildren(agg.children[0]) @@ -138,6 +151,34 @@ func appendDistinctEliminateTraceStep(agg *LogicalAggregation, uniqueKey express opt.appendStepToCurrent(agg.ID(), agg.TP(), reason, action) } +// CheckCanConvertAggToProj check whether a special old aggregation (which has already been pushed down) to projection. +// link: issue#44795 +func CheckCanConvertAggToProj(agg *LogicalAggregation) bool { + var mayNullSchema *expression.Schema + if join, ok := agg.Children()[0].(*LogicalJoin); ok { + if join.JoinType == LeftOuterJoin { + mayNullSchema = join.Children()[1].Schema() + } + if join.JoinType == RightOuterJoin { + mayNullSchema = join.Children()[0].Schema() + } + if mayNullSchema == nil { + return true + } + // once agg function args has intersection with mayNullSchema, return nil (means elimination fail) + for _, fun := range agg.AggFuncs { + mayNullCols := expression.ExtractColumnsFromExpressions(nil, fun.Args, func(column *expression.Column) bool { + // collect may-null cols. + return mayNullSchema.Contains(column) + }) + if len(mayNullCols) != 0 { + return false + } + } + } + return true +} + // ConvertAggToProj convert aggregation to projection. func ConvertAggToProj(agg *LogicalAggregation, schema *expression.Schema) (bool, *LogicalProjection) { proj := LogicalProjection{ diff --git a/planner/core/rule_aggregation_push_down.go b/planner/core/rule_aggregation_push_down.go index 2d3d0a607e9ce..81a99b2bb7e29 100644 --- a/planner/core/rule_aggregation_push_down.go +++ b/planner/core/rule_aggregation_push_down.go @@ -259,7 +259,7 @@ func (a *aggregationPushDownSolver) tryToPushDownAgg(oldAgg *LogicalAggregation, } tmpSchema := expression.NewSchema(gbyCols...) for _, key := range child.Schema().Keys { - if tmpSchema.ColumnsIndices(key) != nil { + if tmpSchema.ColumnsIndices(key) != nil { // gby item need to be covered by key. return child, nil } } @@ -504,10 +504,39 @@ func (a *aggregationPushDownSolver) aggPushDown(p LogicalPlan, opt *logicalOptim resetNotNullFlag(join.schema, 0, lChild.Schema().Len()) } buildKeyInfo(join) + // count(a) -> ifnull(col#x, 0, 1) in rewriteExpr of agg function, since col#x is already the final + // pushed-down aggregation's result, we don't need to take every row as count 1 when they don't have + // not-null flag in a.tryToEliminateAggregation(oldAgg, opt), which is not suitable here. + oldCheck := a.oldAggEliminationCheck + a.oldAggEliminationCheck = true proj := a.tryToEliminateAggregation(agg, opt) if proj != nil { p = proj } + a.oldAggEliminationCheck = oldCheck + + // Combine the aggregation elimination logic below since new agg's child key info has changed. + // Notice that even if we eliminate new agg below if possible, the agg's schema is inherited by proj. + // Therefore, we don't need to set the join's schema again, just build the keyInfo again. + changed := false + if newAgg, ok1 := lChild.(*LogicalAggregation); ok1 { + proj := a.tryToEliminateAggregation(newAgg, opt) + if proj != nil { + lChild = proj + changed = true + } + } + if newAgg, ok2 := rChild.(*LogicalAggregation); ok2 { + proj := a.tryToEliminateAggregation(newAgg, opt) + if proj != nil { + rChild = proj + changed = true + } + } + if changed { + join.SetChildren(lChild, rChild) + buildKeyInfo(join) + } } } else if proj, ok1 := child.(*LogicalProjection); ok1 { // push aggregation across projection