Skip to content

Commit

Permalink
planner/core: separate aggPrune from aggPushDown (#7676)
Browse files Browse the repository at this point in the history
  • Loading branch information
winoros authored Oct 8, 2018
1 parent f33e04f commit 78303cb
Show file tree
Hide file tree
Showing 9 changed files with 185 additions and 124 deletions.
17 changes: 8 additions & 9 deletions cmd/explaintest/r/explain_easy.result
Original file line number Diff line number Diff line change
Expand Up @@ -96,15 +96,14 @@ TableReader_11 2.00 root data:TableScan_10
└─TableScan_10 2.00 cop table:t1, range:[0,0], [1,1], keep order:false, stats:pseudo
explain select (select count(1) k from t1 s where s.c1 = t1.c1 having k != 0) from t1;
id count task operator info
Projection_13 10000.00 root k
└─Projection_14 10000.00 root test.t1.c1, ifnull(5_col_0, 0)
└─MergeJoin_15 10000.00 root left outer join, left key:test.t1.c1, right key:s.c1
├─TableReader_18 10000.00 root data:TableScan_17
│ └─TableScan_17 10000.00 cop table:t1, range:[-inf,+inf], keep order:true, stats:pseudo
└─Selection_20 8000.00 root ne(k, 0)
└─Projection_21 10000.00 root 1, s.c1
└─TableReader_23 10000.00 root data:TableScan_22
└─TableScan_22 10000.00 cop table:s, range:[-inf,+inf], keep order:true, stats:pseudo
Projection_12 10000.00 root k
└─Projection_13 10000.00 root test.t1.c1, ifnull(5_col_0, 0)
└─MergeJoin_14 10000.00 root left outer join, left key:test.t1.c1, right key:s.c1
├─TableReader_17 10000.00 root data:TableScan_16
│ └─TableScan_16 10000.00 cop table:t1, range:[-inf,+inf], keep order:true, stats:pseudo
└─Projection_19 8000.00 root 1, s.c1
└─TableReader_21 10000.00 root data:TableScan_20
└─TableScan_20 10000.00 cop table:s, range:[-inf,+inf], keep order:true, stats:pseudo
explain select * from information_schema.columns;
id count task operator info
MemTableScan_4 10000.00 root
Expand Down
12 changes: 10 additions & 2 deletions expression/aggregation/descriptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -298,10 +298,14 @@ func (a *AggFuncDesc) typeInfer4Sum(ctx sessionctx.Context) {
a.RetTp.Decimal = mysql.MaxDecimalScale
}
// TODO: a.Args[0] = expression.WrapWithCastAsDecimal(ctx, a.Args[0])
default:
case mysql.TypeDouble, mysql.TypeFloat:
a.RetTp = types.NewFieldType(mysql.TypeDouble)
a.RetTp.Flen, a.RetTp.Decimal = mysql.MaxRealWidth, a.Args[0].GetType().Decimal
//TODO: a.Args[0] = expression.WrapWithCastAsReal(ctx, a.Args[0])
default:
a.RetTp = types.NewFieldType(mysql.TypeDouble)
a.RetTp.Flen, a.RetTp.Decimal = mysql.MaxRealWidth, types.UnspecifiedLength
// TODO: a.Args[0] = expression.WrapWithCastAsReal(ctx, a.Args[0])
}
types.SetBinChsClnFlag(a.RetTp)
}
Expand All @@ -319,10 +323,14 @@ func (a *AggFuncDesc) typeInfer4Avg(ctx sessionctx.Context) {
}
a.RetTp.Flen = mysql.MaxDecimalWidth
// TODO: a.Args[0] = expression.WrapWithCastAsDecimal(ctx, a.Args[0])
default:
case mysql.TypeDouble, mysql.TypeFloat:
a.RetTp = types.NewFieldType(mysql.TypeDouble)
a.RetTp.Flen, a.RetTp.Decimal = mysql.MaxRealWidth, a.Args[0].GetType().Decimal
// TODO: a.Args[0] = expression.WrapWithCastAsReal(ctx, a.Args[0])
default:
a.RetTp = types.NewFieldType(mysql.TypeDouble)
a.RetTp.Flen, a.RetTp.Decimal = mysql.MaxRealWidth, types.UnspecifiedLength
// TODO: a.Args[0] = expression.WrapWithCastAsReal(ctx, a.Args[0])
}
types.SetBinChsClnFlag(a.RetTp)
}
Expand Down
4 changes: 2 additions & 2 deletions expression/typeinfer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -822,14 +822,14 @@ func (s *testInferTypeSuite) createTestCase4Aggregations() []typeInferTestCase {
{"sum(c_decimal)", mysql.TypeNewDecimal, charset.CharsetBin, mysql.BinaryFlag, mysql.MaxDecimalWidth, 3},
{"sum(1.0)", mysql.TypeNewDecimal, charset.CharsetBin, mysql.BinaryFlag, mysql.MaxDecimalWidth, 1},
{"sum(1.2e2)", mysql.TypeDouble, charset.CharsetBin, mysql.BinaryFlag, mysql.MaxRealWidth, types.UnspecifiedLength},
{"sum(c_char)", mysql.TypeDouble, charset.CharsetBin, mysql.BinaryFlag, mysql.MaxRealWidth, 0},
{"sum(c_char)", mysql.TypeDouble, charset.CharsetBin, mysql.BinaryFlag, mysql.MaxRealWidth, types.UnspecifiedLength},
{"avg(c_int_d)", mysql.TypeNewDecimal, charset.CharsetBin, mysql.BinaryFlag, mysql.MaxDecimalWidth, 4},
{"avg(c_float_d)", mysql.TypeDouble, charset.CharsetBin, mysql.BinaryFlag, mysql.MaxRealWidth, types.UnspecifiedLength},
{"avg(c_double_d)", mysql.TypeDouble, charset.CharsetBin, mysql.BinaryFlag, mysql.MaxRealWidth, types.UnspecifiedLength},
{"avg(c_decimal)", mysql.TypeNewDecimal, charset.CharsetBin, mysql.BinaryFlag, mysql.MaxDecimalWidth, 7},
{"avg(1.0)", mysql.TypeNewDecimal, charset.CharsetBin, mysql.BinaryFlag, mysql.MaxDecimalWidth, 5},
{"avg(1.2e2)", mysql.TypeDouble, charset.CharsetBin, mysql.BinaryFlag, mysql.MaxRealWidth, types.UnspecifiedLength},
{"avg(c_char)", mysql.TypeDouble, charset.CharsetBin, mysql.BinaryFlag, mysql.MaxRealWidth, 0},
{"avg(c_char)", mysql.TypeDouble, charset.CharsetBin, mysql.BinaryFlag, mysql.MaxRealWidth, types.UnspecifiedLength},
{"group_concat(c_int_d)", mysql.TypeVarString, charset.CharsetUTF8, 0, mysql.MaxBlobWidth, 0},
}
}
Expand Down
5 changes: 3 additions & 2 deletions planner/core/logical_plan_builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,13 +68,14 @@ func (la *LogicalAggregation) collectGroupByColumns() {

func (b *planBuilder) buildAggregation(p LogicalPlan, aggFuncList []*ast.AggregateFuncExpr, gbyItems []expression.Expression) (LogicalPlan, map[int]int, error) {
b.optFlag = b.optFlag | flagBuildKeyInfo
b.optFlag = b.optFlag | flagAggregationOptimize
b.optFlag = b.optFlag | flagPushDownAgg
// We may apply aggregation eliminate optimization.
// So we add the flagMaxMinEliminate to try to convert max/min to topn and flagPushDownTopN to handle the newly added topn operator.
b.optFlag = b.optFlag | flagMaxMinEliminate
b.optFlag = b.optFlag | flagPushDownTopN
// when we eliminate the max and min we may add `is not null` filter.
b.optFlag = b.optFlag | flagPredicatePushDown
b.optFlag = b.optFlag | flagEliminateAgg

plan4Agg := LogicalAggregation{AggFuncs: make([]*aggregation.AggFuncDesc, 0, len(aggFuncList))}.init(b.ctx)
schema4Agg := expression.NewSchema(make([]*expression.Column, 0, len(aggFuncList)+p.Schema().Len())...)
Expand Down Expand Up @@ -605,7 +606,7 @@ func (b *planBuilder) buildProjection(p LogicalPlan, fields []*ast.SelectField,

func (b *planBuilder) buildDistinct(child LogicalPlan, length int) *LogicalAggregation {
b.optFlag = b.optFlag | flagBuildKeyInfo
b.optFlag = b.optFlag | flagAggregationOptimize
b.optFlag = b.optFlag | flagPushDownAgg
plan4Agg := LogicalAggregation{
AggFuncs: make([]*aggregation.AggFuncDesc, 0, child.Schema().Len()),
GroupByItems: expression.Column2Exprs(child.Schema().Clone().Columns[:length]),
Expand Down
14 changes: 6 additions & 8 deletions planner/core/logical_plan_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1090,6 +1090,10 @@ func (s *testPlanSuite) TestEagerAggregation(c *C) {
sql: "select max(a.c) from t a join t b on a.a=b.a and a.b=b.b group by a.b",
best: "Join{DataScan(a)->DataScan(b)}(a.a,b.a)(a.b,b.b)->Aggr(max(a.c))->Projection",
},
{
sql: "select t1.a, count(t2.b) from t t1, t t2 where t1.a = t2.a group by t1.a",
best: "Join{DataScan(t1)->DataScan(t2)}(t1.a,t2.a)->Projection->Projection",
},
}
s.ctx.GetSessionVars().AllowAggPushDown = true
for _, tt := range tests {
Expand All @@ -1099,7 +1103,7 @@ func (s *testPlanSuite) TestEagerAggregation(c *C) {

p, err := BuildLogicalPlan(s.ctx, stmt, s.is)
c.Assert(err, IsNil)
p, err = logicalOptimize(flagBuildKeyInfo|flagPredicatePushDown|flagPrunColumns|flagAggregationOptimize, p.(LogicalPlan))
p, err = logicalOptimize(flagBuildKeyInfo|flagPredicatePushDown|flagPrunColumns|flagPushDownAgg, p.(LogicalPlan))
c.Assert(err, IsNil)
c.Assert(ToString(p), Equals, tt.best, Commentf("for %s", tt.sql))
}
Expand Down Expand Up @@ -1530,10 +1534,6 @@ func (s *testPlanSuite) TestAggPrune(c *C) {
sql: "select sum(b) from t group by c, d, e",
best: "DataScan(t)->Aggr(sum(test.t.b))->Projection",
},
{
sql: "select t1.a, count(t2.b) from t t1, t t2 where t1.a = t2.a group by t1.a",
best: "Join{DataScan(t1)->DataScan(t2)}(t1.a,t2.a)->Projection->Projection",
},
{
sql: "select tt.a, sum(tt.b) from (select a, b from t) tt group by tt.a",
best: "DataScan(t)->Projection->Projection->Projection",
Expand All @@ -1543,7 +1543,6 @@ func (s *testPlanSuite) TestAggPrune(c *C) {
best: "DataScan(t)->Projection->Projection->Projection->Projection",
},
}
s.ctx.GetSessionVars().AllowAggPushDown = true
for _, tt := range tests {
comment := Commentf("for %s", tt.sql)
stmt, err := s.ParseOneStmt(tt.sql, "", "")
Expand All @@ -1552,11 +1551,10 @@ func (s *testPlanSuite) TestAggPrune(c *C) {
p, err := BuildLogicalPlan(s.ctx, stmt, s.is)
c.Assert(err, IsNil)

p, err = logicalOptimize(flagPredicatePushDown|flagPrunColumns|flagBuildKeyInfo|flagAggregationOptimize, p.(LogicalPlan))
p, err = logicalOptimize(flagPredicatePushDown|flagPrunColumns|flagBuildKeyInfo|flagEliminateAgg, p.(LogicalPlan))
c.Assert(err, IsNil)
c.Assert(ToString(p), Equals, tt.best, comment)
}
s.ctx.GetSessionVars().AllowAggPushDown = false
}

func (s *testPlanSuite) TestVisitInfo(c *C) {
Expand Down
6 changes: 4 additions & 2 deletions planner/core/optimizer.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,11 @@ const (
flagEliminateProjection
flagBuildKeyInfo
flagDecorrelate
flagEliminateAgg
flagMaxMinEliminate
flagPredicatePushDown
flagPartitionProcessor
flagAggregationOptimize
flagPushDownAgg
flagPushDownTopN
)

Expand All @@ -45,10 +46,11 @@ var optRuleList = []logicalOptRule{
&projectionEliminater{},
&buildKeySolver{},
&decorrelateSolver{},
&aggregationEliminator{},
&maxMinEliminator{},
&ppdSolver{},
&partitionProcessor{},
&aggregationOptimizer{},
&aggregationPushDownSolver{},
&pushDownTopNOptimizer{},
}

Expand Down
2 changes: 1 addition & 1 deletion planner/core/physical_plan_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -851,7 +851,7 @@ func (s *testPlanSuite) TestDAGPlanBuilderAgg(c *C) {
},
{
sql: "select (select count(1) k from t s where s.a = t.a having k != 0) from t",
best: "LeftHashJoin{TableReader(Table(t))->TableReader(Table(t)->StreamAgg)->StreamAgg->Sel([ne(k, 0)])}(test.t.a,s.a)->Projection->Projection",
best: "MergeLeftOuterJoin{TableReader(Table(t))->TableReader(Table(t))->Projection}(test.t.a,s.a)->Projection->Projection",
},
// Test stream agg with multi group by columns.
{
Expand Down
134 changes: 134 additions & 0 deletions planner/core/rule_aggregation_elimination.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
// Copyright 2018 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// See the License for the specific language governing permissions and
// limitations under the License.

package core

import (
"math"

"github.com/pingcap/tidb/ast"
"github.com/pingcap/tidb/expression"
"github.com/pingcap/tidb/expression/aggregation"
"github.com/pingcap/tidb/mysql"
"github.com/pingcap/tidb/sessionctx"
"github.com/pingcap/tidb/types"
)

type aggregationEliminator struct {
aggregationEliminateChecker
}

type aggregationEliminateChecker struct {
}

// tryToEliminateAggregation will eliminate aggregation grouped by unique key.
// e.g. select min(b) from t group by a. If a is a unique key, then this sql is equal to `select b from t group by a`.
// For count(expr), sum(expr), avg(expr), count(distinct expr, [expr...]) we may need to rewrite the expr. Details are shown below.
// If we can eliminate agg successful, we return a projection. Else we return a nil pointer.
func (a *aggregationEliminateChecker) tryToEliminateAggregation(agg *LogicalAggregation) *LogicalProjection {
schemaByGroupby := expression.NewSchema(agg.groupByCols...)
coveredByUniqueKey := false
for _, key := range agg.children[0].Schema().Keys {
if schemaByGroupby.ColumnsIndices(key) != nil {
coveredByUniqueKey = true
break
}
}
if coveredByUniqueKey {
// GroupByCols has unique key, so this aggregation can be removed.
proj := a.convertAggToProj(agg)
proj.SetChildren(agg.children[0])
return proj
}
return nil
}

func (a *aggregationEliminateChecker) convertAggToProj(agg *LogicalAggregation) *LogicalProjection {
proj := LogicalProjection{
Exprs: make([]expression.Expression, 0, len(agg.AggFuncs)),
}.init(agg.ctx)
for _, fun := range agg.AggFuncs {
expr := a.rewriteExpr(agg.ctx, fun)
proj.Exprs = append(proj.Exprs, expr)
}
proj.SetSchema(agg.schema.Clone())
return proj
}

// rewriteExpr will rewrite the aggregate function to expression doesn't contain aggregate function.
func (a *aggregationEliminateChecker) rewriteExpr(ctx sessionctx.Context, aggFunc *aggregation.AggFuncDesc) expression.Expression {
switch aggFunc.Name {
case ast.AggFuncCount:
if aggFunc.Mode == aggregation.FinalMode {
return a.wrapCastFunction(ctx, aggFunc.Args[0], aggFunc.RetTp)
}
return a.rewriteCount(ctx, aggFunc.Args, aggFunc.RetTp)
case ast.AggFuncSum, ast.AggFuncAvg, ast.AggFuncFirstRow, ast.AggFuncMax, ast.AggFuncMin, ast.AggFuncGroupConcat:
return a.wrapCastFunction(ctx, aggFunc.Args[0], aggFunc.RetTp)
case ast.AggFuncBitAnd, ast.AggFuncBitOr, ast.AggFuncBitXor:
return a.rewriteBitFunc(ctx, aggFunc.Name, aggFunc.Args[0], aggFunc.RetTp)
default:
panic("Unsupported function")
}
}

func (a *aggregationEliminateChecker) rewriteCount(ctx sessionctx.Context, exprs []expression.Expression, targetTp *types.FieldType) expression.Expression {
// If is count(expr), we will change it to if(isnull(expr), 0, 1).
// If is count(distinct x, y, z) we will change it to if(isnull(x) or isnull(y) or isnull(z), 0, 1).
isNullExprs := make([]expression.Expression, 0, len(exprs))
for _, expr := range exprs {
isNullExpr := expression.NewFunctionInternal(ctx, ast.IsNull, types.NewFieldType(mysql.TypeTiny), expr)
isNullExprs = append(isNullExprs, isNullExpr)
}
innerExpr := expression.ComposeDNFCondition(ctx, isNullExprs...)
newExpr := expression.NewFunctionInternal(ctx, ast.If, targetTp, innerExpr, expression.Zero, expression.One)
return newExpr
}

func (a *aggregationEliminateChecker) rewriteBitFunc(ctx sessionctx.Context, funcType string, arg expression.Expression, targetTp *types.FieldType) expression.Expression {
// For not integer type. We need to cast(cast(arg as signed) as unsigned) to make the bit function work.
innerCast := expression.WrapWithCastAsInt(ctx, arg)
outerCast := a.wrapCastFunction(ctx, innerCast, targetTp)
var finalExpr expression.Expression
if funcType != ast.AggFuncBitAnd {
finalExpr = expression.NewFunctionInternal(ctx, ast.Ifnull, targetTp, outerCast, expression.Zero.Clone())
} else {
finalExpr = expression.NewFunctionInternal(ctx, ast.Ifnull, outerCast.GetType(), outerCast, &expression.Constant{Value: types.NewUintDatum(math.MaxUint64), RetType: targetTp})
}
return finalExpr
}

// wrapCastFunction will wrap a cast if the targetTp is not equal to the arg's.
func (a *aggregationEliminateChecker) wrapCastFunction(ctx sessionctx.Context, arg expression.Expression, targetTp *types.FieldType) expression.Expression {
if arg.GetType() == targetTp {
return arg
}
return expression.BuildCastFunction(ctx, arg, targetTp)
}

func (a *aggregationEliminator) optimize(p LogicalPlan) (LogicalPlan, error) {
newChildren := make([]LogicalPlan, 0, len(p.Children()))
for _, child := range p.Children() {
newChild, _ := a.optimize(child)
newChildren = append(newChildren, newChild)
}
p.SetChildren(newChildren...)
agg, ok := p.(*LogicalAggregation)
if !ok {
return p, nil
}
if proj := a.tryToEliminateAggregation(agg); proj != nil {
return proj, nil
}
return p, nil
}
Loading

0 comments on commit 78303cb

Please sign in to comment.