Skip to content

Commit

Permalink
expression: avoid unnecessary warnings/errors when folding constants …
Browse files Browse the repository at this point in the history
…in control expr (#19675) (#19910)

Signed-off-by: ti-srebot <[email protected]>
  • Loading branch information
ti-srebot authored Sep 16, 2020
1 parent 51d365f commit 8e4b18a
Show file tree
Hide file tree
Showing 6 changed files with 104 additions and 73 deletions.
65 changes: 21 additions & 44 deletions expression/builtin_control.go
Original file line number Diff line number Diff line change
Expand Up @@ -537,12 +537,10 @@ func (b *builtinIfIntSig) evalInt(row chunk.Row) (ret int64, isNull bool, err er
if err != nil {
return 0, true, err
}
arg1, isNull1, err := b.args[1].EvalInt(b.ctx, row)
if (!isNull0 && arg0 != 0) || err != nil {
return arg1, isNull1, err
if !isNull0 && arg0 != 0 {
return b.args[1].EvalInt(b.ctx, row)
}
arg2, isNull2, err := b.args[2].EvalInt(b.ctx, row)
return arg2, isNull2, err
return b.args[2].EvalInt(b.ctx, row)
}

type builtinIfRealSig struct {
Expand All @@ -560,12 +558,10 @@ func (b *builtinIfRealSig) evalReal(row chunk.Row) (ret float64, isNull bool, er
if err != nil {
return 0, true, err
}
arg1, isNull1, err := b.args[1].EvalReal(b.ctx, row)
if (!isNull0 && arg0 != 0) || err != nil {
return arg1, isNull1, err
if !isNull0 && arg0 != 0 {
return b.args[1].EvalReal(b.ctx, row)
}
arg2, isNull2, err := b.args[2].EvalReal(b.ctx, row)
return arg2, isNull2, err
return b.args[2].EvalReal(b.ctx, row)
}

type builtinIfDecimalSig struct {
Expand All @@ -583,12 +579,10 @@ func (b *builtinIfDecimalSig) evalDecimal(row chunk.Row) (ret *types.MyDecimal,
if err != nil {
return nil, true, err
}
arg1, isNull1, err := b.args[1].EvalDecimal(b.ctx, row)
if (!isNull0 && arg0 != 0) || err != nil {
return arg1, isNull1, err
if !isNull0 && arg0 != 0 {
return b.args[1].EvalDecimal(b.ctx, row)
}
arg2, isNull2, err := b.args[2].EvalDecimal(b.ctx, row)
return arg2, isNull2, err
return b.args[2].EvalDecimal(b.ctx, row)
}

type builtinIfStringSig struct {
Expand All @@ -606,12 +600,10 @@ func (b *builtinIfStringSig) evalString(row chunk.Row) (ret string, isNull bool,
if err != nil {
return "", true, err
}
arg1, isNull1, err := b.args[1].EvalString(b.ctx, row)
if (!isNull0 && arg0 != 0) || err != nil {
return arg1, isNull1, err
if !isNull0 && arg0 != 0 {
return b.args[1].EvalString(b.ctx, row)
}
arg2, isNull2, err := b.args[2].EvalString(b.ctx, row)
return arg2, isNull2, err
return b.args[2].EvalString(b.ctx, row)
}

type builtinIfTimeSig struct {
Expand All @@ -629,12 +621,10 @@ func (b *builtinIfTimeSig) evalTime(row chunk.Row) (ret types.Time, isNull bool,
if err != nil {
return ret, true, err
}
arg1, isNull1, err := b.args[1].EvalTime(b.ctx, row)
if (!isNull0 && arg0 != 0) || err != nil {
return arg1, isNull1, err
if !isNull0 && arg0 != 0 {
return b.args[1].EvalTime(b.ctx, row)
}
arg2, isNull2, err := b.args[2].EvalTime(b.ctx, row)
return arg2, isNull2, err
return b.args[2].EvalTime(b.ctx, row)
}

type builtinIfDurationSig struct {
Expand All @@ -652,12 +642,10 @@ func (b *builtinIfDurationSig) evalDuration(row chunk.Row) (ret types.Duration,
if err != nil {
return ret, true, err
}
arg1, isNull1, err := b.args[1].EvalDuration(b.ctx, row)
if (!isNull0 && arg0 != 0) || err != nil {
return arg1, isNull1, err
if !isNull0 && arg0 != 0 {
return b.args[1].EvalDuration(b.ctx, row)
}
arg2, isNull2, err := b.args[2].EvalDuration(b.ctx, row)
return arg2, isNull2, err
return b.args[2].EvalDuration(b.ctx, row)
}

type builtinIfJSONSig struct {
Expand All @@ -675,21 +663,10 @@ func (b *builtinIfJSONSig) evalJSON(row chunk.Row) (ret json.BinaryJSON, isNull
if err != nil {
return ret, true, err
}
arg1, isNull1, err := b.args[1].EvalJSON(b.ctx, row)
if err != nil {
return ret, true, err
}
arg2, isNull2, err := b.args[2].EvalJSON(b.ctx, row)
if err != nil {
return ret, true, err
}
switch {
case isNull0 || arg0 == 0:
ret, isNull = arg2, isNull2
case arg0 != 0:
ret, isNull = arg1, isNull1
if !isNull0 && arg0 != 0 {
return b.args[1].EvalJSON(b.ctx, row)
}
return
return b.args[2].EvalJSON(b.ctx, row)
}

type ifNullFunctionClass struct {
Expand Down
35 changes: 10 additions & 25 deletions expression/constant_fold.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,12 +56,8 @@ func ifFoldHandler(expr *ScalarFunction) (Expression, bool) {
}
return foldConstant(args[2])
}
var isDeferred, isDeferredConst bool
expr.GetArgs()[1], isDeferred = foldConstant(args[1])
isDeferredConst = isDeferredConst || isDeferred
expr.GetArgs()[2], isDeferred = foldConstant(args[2])
isDeferredConst = isDeferredConst || isDeferred
return expr, isDeferredConst
// if the condition is not const, which branch is unknown to run, so directly return.
return expr, false
}

func ifNullFoldHandler(expr *ScalarFunction) (Expression, bool) {
Expand All @@ -76,18 +72,17 @@ func ifNullFoldHandler(expr *ScalarFunction) (Expression, bool) {
}
return constArg, isDeferred
}
var isDeferredConst bool
expr.GetArgs()[1], isDeferredConst = foldConstant(args[1])
return expr, isDeferredConst
// if the condition is not const, which branch is unknown to run, so directly return.
return expr, false
}

func caseWhenHandler(expr *ScalarFunction) (Expression, bool) {
args, l := expr.GetArgs(), len(expr.GetArgs())
var isDeferred, isDeferredConst, hasNonConstCondition bool
var isDeferred, isDeferredConst bool
for i := 0; i < l-1; i += 2 {
expr.GetArgs()[i], isDeferred = foldConstant(args[i])
isDeferredConst = isDeferredConst || isDeferred
if _, isConst := expr.GetArgs()[i].(*Constant); isConst && !hasNonConstCondition {
if _, isConst := expr.GetArgs()[i].(*Constant); isConst {
// If the condition is const and true, and the previous conditions
// has no expr, then the folded execution body is returned, otherwise
// the arguments of the casewhen are folded and replaced.
Expand All @@ -105,20 +100,14 @@ func caseWhenHandler(expr *ScalarFunction) (Expression, bool) {
return BuildCastFunction(expr.GetCtx(), foldedExpr, foldedExpr.GetType()), isDeferredConst
}
} else {
hasNonConstCondition = true
// for no-const, here should return directly, because the following branches are unknown to be run or not
return expr, false
}
expr.GetArgs()[i+1], isDeferred = foldConstant(args[i+1])
isDeferredConst = isDeferredConst || isDeferred
}

if l%2 == 0 {
return expr, isDeferredConst
}

// If the number of arguments in casewhen is odd, and the previous conditions
// is const and false, then the folded else execution body is returned. otherwise
// is false, then the folded else execution body is returned. otherwise
// the execution body of the else are folded and replaced.
if !hasNonConstCondition {
if l%2 == 1 {
foldedExpr, isDeferred := foldConstant(args[l-1])
isDeferredConst = isDeferredConst || isDeferred
if _, isConst := foldedExpr.(*Constant); isConst {
Expand All @@ -127,10 +116,6 @@ func caseWhenHandler(expr *ScalarFunction) (Expression, bool) {
}
return BuildCastFunction(expr.GetCtx(), foldedExpr, foldedExpr.GetType()), isDeferredConst
}

expr.GetArgs()[l-1], isDeferred = foldConstant(args[l-1])
isDeferredConst = isDeferredConst || isDeferred

return expr, isDeferredConst
}

Expand Down
9 changes: 9 additions & 0 deletions expression/function_traits.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,15 @@ var DisableFoldFunctions = map[string]struct{}{
ast.Benchmark: {},
}

// TryFoldFunctions stores functions which try to fold constant in child scope functions if without errors/warnings,
// otherwise, the child functions do not fold constant.
// Note: the function itself should fold constant.
var TryFoldFunctions = map[string]struct{}{
ast.If: {},
ast.Ifnull: {},
ast.Case: {},
}

// IllegalFunctions4GeneratedColumns stores functions that is illegal for generated columns.
// See https://github.com/mysql/mysql-server/blob/5.7/mysql-test/suite/gcol/inc/gcol_blocked_sql_funcs_main.inc for details
var IllegalFunctions4GeneratedColumns = map[string]struct{}{
Expand Down
18 changes: 18 additions & 0 deletions expression/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2808,6 +2808,24 @@ func (s *testIntegrationSuite2) TestBuiltin(c *C) {
tk.MustQuery("select ifnull(b, b/0) from t")
tk.MustQuery("show warnings").Check(testkit.Rows())

tk.MustQuery("select case when 1 then 1 else 1/0 end")
tk.MustQuery("show warnings").Check(testkit.Rows())
tk.MustQuery(" select if(1,1,1/0)")
tk.MustQuery("show warnings").Check(testkit.Rows())
tk.MustQuery("select ifnull(1, 1/0)")
tk.MustQuery("show warnings").Check(testkit.Rows())

tk.MustExec("delete from t")
tk.MustExec("insert t values ('str2', 0)")
tk.MustQuery("select case when b < 1 then 1 else 1/0 end from t")
tk.MustQuery("show warnings").Check(testkit.Rows())
tk.MustQuery("select case when b < 1 then 1 when 1/0 then b else 1/0 end from t")
tk.MustQuery("show warnings").Check(testkit.Rows())
tk.MustQuery("select if(b < 1 , 1, 1/0) from t")
tk.MustQuery("show warnings").Check(testkit.Rows())
tk.MustQuery("select ifnull(b, 1/0) from t")
tk.MustQuery("show warnings").Check(testkit.Rows())

tk.MustQuery("select case 2.0 when 2.0 then 3.0 when 3.0 then 2.0 end").Check(testkit.Rows("3.0"))
tk.MustQuery("select case 2.0 when 3.0 then 2.0 when 4.0 then 3.0 else 5.0 end").Check(testkit.Rows("5.0"))
tk.MustQuery("select case cast('2011-01-01' as date) when cast('2011-01-01' as date) then cast('2011-02-02' as date) end").Check(testkit.Rows("2011-02-02"))
Expand Down
26 changes: 22 additions & 4 deletions expression/scalar_function.go
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,9 @@ func typeInferForNull(args []Expression) {
}

// newFunctionImpl creates a new scalar function or constant.
func newFunctionImpl(ctx sessionctx.Context, fold bool, funcName string, retType *types.FieldType, args ...Expression) (Expression, error) {
// fold: 1 means folding constants, while 0 means not,
// -1 means try to fold constants if without errors/warnings, otherwise not.
func newFunctionImpl(ctx sessionctx.Context, fold int, funcName string, retType *types.FieldType, args ...Expression) (Expression, error) {
if retType == nil {
return nil, errors.Errorf("RetType cannot be nil for ScalarFunction.")
}
Expand Down Expand Up @@ -210,20 +212,36 @@ func newFunctionImpl(ctx sessionctx.Context, fold bool, funcName string, retType
RetType: retType,
Function: f,
}
if fold {
if fold == 1 {
return FoldConstant(sf), nil
} else if fold == -1 {
// try to fold constants, and return the original function if errors/warnings occur
sc := ctx.GetSessionVars().StmtCtx
beforeWarns := sc.WarningCount()
newSf := FoldConstant(sf)
afterWarns := sc.WarningCount()
if afterWarns > beforeWarns {
sc.TruncateWarnings(int(beforeWarns))
return sf, nil
}
return newSf, nil
}
return sf, nil
}

// NewFunction creates a new scalar function or constant via a constant folding.
func NewFunction(ctx sessionctx.Context, funcName string, retType *types.FieldType, args ...Expression) (Expression, error) {
return newFunctionImpl(ctx, true, funcName, retType, args...)
return newFunctionImpl(ctx, 1, funcName, retType, args...)
}

// NewFunctionBase creates a new scalar function with no constant folding.
func NewFunctionBase(ctx sessionctx.Context, funcName string, retType *types.FieldType, args ...Expression) (Expression, error) {
return newFunctionImpl(ctx, false, funcName, retType, args...)
return newFunctionImpl(ctx, 0, funcName, retType, args...)
}

// NewFunctionTryFold creates a new scalar function with trying constant folding.
func NewFunctionTryFold(ctx sessionctx.Context, funcName string, retType *types.FieldType, args ...Expression) (Expression, error) {
return newFunctionImpl(ctx, -1, funcName, retType, args...)
}

// NewFunctionInternal is similar to NewFunction, but do not returns error, should only be used internally.
Expand Down
24 changes: 24 additions & 0 deletions planner/core/expression_rewriter.go
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@ func (b *PlanBuilder) getExpressionRewriter(ctx context.Context, p LogicalPlan)
rewriter.preprocess = nil
rewriter.insertPlan = nil
rewriter.disableFoldCounter = 0
rewriter.tryFoldCounter = 0
rewriter.ctxStack = rewriter.ctxStack[:0]
rewriter.ctxNameStk = rewriter.ctxNameStk[:0]
rewriter.ctx = ctx
Expand Down Expand Up @@ -226,6 +227,7 @@ type expressionRewriter struct {
// leaving the scope(enable again), the counter will -1.
// NOTE: This value can be changed during expression rewritten.
disableFoldCounter int
tryFoldCounter int
}

func (er *expressionRewriter) ctxStackLen() int {
Expand Down Expand Up @@ -401,6 +403,16 @@ func (er *expressionRewriter) Enter(inNode ast.Node) (ast.Node, bool) {
if _, ok := expression.DisableFoldFunctions[v.FnName.L]; ok {
er.disableFoldCounter++
}
if _, ok := expression.TryFoldFunctions[v.FnName.L]; ok {
er.tryFoldCounter++
}
case *ast.CaseExpr:
if _, ok := expression.DisableFoldFunctions["case"]; ok {
er.disableFoldCounter++
}
if _, ok := expression.TryFoldFunctions["case"]; ok {
er.tryFoldCounter++
}
case *ast.SetCollationExpr:
// Do nothing
default:
Expand Down Expand Up @@ -944,6 +956,9 @@ func (er *expressionRewriter) Leave(originInNode ast.Node) (retNode ast.Node, ok
case *ast.VariableExpr:
er.rewriteVariable(v)
case *ast.FuncCallExpr:
if _, ok := expression.TryFoldFunctions[v.FnName.L]; ok {
er.tryFoldCounter--
}
er.funcCallToExpression(v)
if _, ok := expression.DisableFoldFunctions[v.FnName.L]; ok {
er.disableFoldCounter--
Expand All @@ -959,7 +974,13 @@ func (er *expressionRewriter) Leave(originInNode ast.Node) (retNode ast.Node, ok
case *ast.BetweenExpr:
er.betweenToExpression(v)
case *ast.CaseExpr:
if _, ok := expression.TryFoldFunctions["case"]; ok {
er.tryFoldCounter--
}
er.caseToExpression(v)
if _, ok := expression.DisableFoldFunctions["case"]; ok {
er.disableFoldCounter--
}
case *ast.FuncCastExpr:
arg := er.ctxStack[len(er.ctxStack)-1]
er.err = expression.CheckArgsNotMultiColumnRow(arg)
Expand Down Expand Up @@ -1052,6 +1073,9 @@ func (er *expressionRewriter) newFunction(funcName string, retType *types.FieldT
if er.disableFoldCounter > 0 {
return expression.NewFunctionBase(er.sctx, funcName, retType, args...)
}
if er.tryFoldCounter > 0 {
return expression.NewFunctionTryFold(er.sctx, funcName, retType, args...)
}
return expression.NewFunction(er.sctx, funcName, retType, args...)
}

Expand Down

0 comments on commit 8e4b18a

Please sign in to comment.