diff --git a/cmd/explaintest/r/explain_easy.result b/cmd/explaintest/r/explain_easy.result index c708739f26a3f..0c989d357bba6 100644 --- a/cmd/explaintest/r/explain_easy.result +++ b/cmd/explaintest/r/explain_easy.result @@ -180,6 +180,35 @@ HashAgg_18 24000.00 root group by:c1, funcs:firstrow(join_agg_0) └─IndexReader_62 8000.00 root index:StreamAgg_53 └─StreamAgg_53 8000.00 cop group by:test.t2.c1, funcs:firstrow(test.t2.c1), firstrow(test.t2.c1) └─IndexScan_60 10000.00 cop table:t2, index:c1, range:[NULL,+inf], keep order:true, stats:pseudo +explain select count(1) from (select count(1) from (select * from t1 where c3 = 100) k) k2; +id count task operator info +StreamAgg_13 1.00 root funcs:count(1) +└─StreamAgg_28 1.00 root funcs:firstrow(col_0) + └─TableReader_29 1.00 root data:StreamAgg_17 + └─StreamAgg_17 1.00 cop funcs:firstrow(1) + └─Selection_27 10.00 cop eq(test.t1.c3, 100) + └─TableScan_26 10000.00 cop table:t1, range:[-inf,+inf], keep order:false, stats:pseudo +explain select 1 from (select count(c2), count(c3) from t1) k; +id count task operator info +Projection_5 1.00 root 1 +└─StreamAgg_17 1.00 root funcs:firstrow(col_0) + └─TableReader_18 1.00 root data:StreamAgg_9 + └─StreamAgg_9 1.00 cop funcs:firstrow(1) + └─TableScan_16 10000.00 cop table:t1, range:[-inf,+inf], keep order:false, stats:pseudo +explain select count(1) from (select max(c2), count(c3) as m from t1) k; +id count task operator info +StreamAgg_11 1.00 root funcs:count(1) +└─StreamAgg_23 1.00 root funcs:firstrow(col_0) + └─TableReader_24 1.00 root data:StreamAgg_15 + └─StreamAgg_15 1.00 cop funcs:firstrow(1) + └─TableScan_22 10000.00 cop table:t1, range:[-inf,+inf], keep order:false, stats:pseudo +explain select count(1) from (select count(c2) from t1 group by c3) k; +id count task operator info +StreamAgg_11 1.00 root funcs:count(1) +└─HashAgg_23 8000.00 root group by:col_1, funcs:firstrow(col_0) + └─TableReader_24 8000.00 root data:HashAgg_20 + └─HashAgg_20 8000.00 cop group by:test.t1.c3, funcs:firstrow(1) + └─TableScan_15 10000.00 cop table:t1, range:[-inf,+inf], keep order:false, stats:pseudo set @@session.tidb_opt_insubquery_unfold = 0; explain select sum(t1.c1 in (select c1 from t2)) from t1; id count task operator info diff --git a/cmd/explaintest/t/explain_easy.test b/cmd/explaintest/t/explain_easy.test index 47ced4372309c..25c2e94b79f81 100644 --- a/cmd/explaintest/t/explain_easy.test +++ b/cmd/explaintest/t/explain_easy.test @@ -35,6 +35,12 @@ explain select if(10, t1.c1, t1.c2) from t1; explain select c1 from t2 union select c1 from t2 union all select c1 from t2; explain select c1 from t2 union all select c1 from t2 union select c1 from t2; +# https://github.com/pingcap/tidb/issues/9125 +explain select count(1) from (select count(1) from (select * from t1 where c3 = 100) k) k2; +explain select 1 from (select count(c2), count(c3) from t1) k; +explain select count(1) from (select max(c2), count(c3) as m from t1) k; +explain select count(1) from (select count(c2) from t1 group by c3) k; + set @@session.tidb_opt_insubquery_unfold = 0; explain select sum(t1.c1 in (select c1 from t2)) from t1; diff --git a/planner/core/logical_plan_test.go b/planner/core/logical_plan_test.go index 73da2788b562c..391d7e7e1ce1d 100644 --- a/planner/core/logical_plan_test.go +++ b/planner/core/logical_plan_test.go @@ -1240,6 +1240,24 @@ func (s *testPlanSuite) TestColumnPruning(c *C) { 12: {"test.t4.a"}, }, }, + { + sql: "select 1 from (select count(b) as cnt from t) t1;", + ans: map[int][]string{ + 1: {"test.t.a"}, + }, + }, + { + sql: "select count(1) from (select count(b) as cnt from t) t1;", + ans: map[int][]string{ + 1: {"test.t.a"}, + }, + }, + { + sql: "select count(1) from (select count(b) as cnt from t group by c) t1;", + ans: map[int][]string{ + 1: {"test.t.c"}, + }, + }, } for _, tt := range tests { comment := Commentf("for %s", tt.sql) diff --git a/planner/core/rule_column_pruning.go b/planner/core/rule_column_pruning.go index 508797b27958f..b7e6a72a4e64a 100644 --- a/planner/core/rule_column_pruning.go +++ b/planner/core/rule_column_pruning.go @@ -19,8 +19,10 @@ import ( "github.com/pingcap/parser/ast" "github.com/pingcap/parser/model" "github.com/pingcap/tidb/expression" + "github.com/pingcap/tidb/expression/aggregation" "github.com/pingcap/tidb/infoschema" "github.com/pingcap/tidb/mysql" + "github.com/pingcap/tidb/types" ) type columnPruner struct { @@ -97,6 +99,21 @@ func (la *LogicalAggregation) PruneColumns(parentUsedCols []*expression.Column) for _, aggrFunc := range la.AggFuncs { selfUsedCols = expression.ExtractColumnsFromExpressions(selfUsedCols, aggrFunc.Args, nil) } + if len(la.AggFuncs) == 0 { + // If all the aggregate functions are pruned, we should add an aggregate function to keep the correctness. + one, err := aggregation.NewAggFuncDesc(la.ctx, ast.AggFuncFirstRow, []expression.Expression{expression.One}, false) + if err != nil { + panic(fmt.Sprintf("error building dummy aggregate function'first(1)': %s", err.Error())) + } + la.AggFuncs = []*aggregation.AggFuncDesc{one} + col := &expression.Column{ + ColName: model.NewCIStr("dummy_agg"), + UniqueID: la.ctx.GetSessionVars().AllocPlanColumnID(), + RetType: types.NewFieldType(mysql.TypeLonglong), + } + la.schema.Columns = []*expression.Column{col} + } + if len(la.GroupByItems) > 0 { for i := len(la.GroupByItems) - 1; i >= 0; i-- { cols := expression.ExtractColumns(la.GroupByItems[i])