diff --git a/expression/constant_fold.go b/expression/constant_fold.go index 4c0c906744b83..b421bcd30f173 100644 --- a/expression/constant_fold.go +++ b/expression/constant_fold.go @@ -27,6 +27,7 @@ func init() { specialFoldHandler = map[string]func(*ScalarFunction) (Expression, bool){ ast.If: ifFoldHandler, ast.Ifnull: ifNullFoldHandler, + ast.Case: caseWhenHandler, } } @@ -78,6 +79,59 @@ func ifNullFoldHandler(expr *ScalarFunction) (Expression, bool) { return expr, isDeferredConst } +func caseWhenHandler(expr *ScalarFunction) (Expression, bool) { + args, l := expr.GetArgs(), len(expr.GetArgs()) + var isDeferred, isDeferredConst, hasNonConstCondition bool + for i := 0; i < l-1; i += 2 { + expr.GetArgs()[i], isDeferred = foldConstant(args[i]) + isDeferredConst = isDeferredConst || isDeferred + if _, isConst := expr.GetArgs()[i].(*Constant); isConst && !hasNonConstCondition { + // If the condition is const and true, and the previous conditions + // has no expr, then the folded execution body is returned, otherwise + // the arguments of the casewhen are folded and replaced. + val, isNull, err := args[i].EvalInt(expr.GetCtx(), chunk.Row{}) + if err != nil { + return expr, false + } + if val != 0 && !isNull { + foldedExpr, isDeferred := foldConstant(args[i+1]) + isDeferredConst = isDeferredConst || isDeferred + if _, isConst := foldedExpr.(*Constant); isConst { + foldedExpr.GetType().Decimal = expr.GetType().Decimal + return foldedExpr, isDeferredConst + } + return BuildCastFunction(expr.GetCtx(), foldedExpr, foldedExpr.GetType()), isDeferredConst + } + } else { + hasNonConstCondition = true + } + expr.GetArgs()[i+1], isDeferred = foldConstant(args[i+1]) + isDeferredConst = isDeferredConst || isDeferred + } + + if l%2 == 0 { + return expr, isDeferredConst + } + + // If the number of arguments in casewhen is odd, and the previous conditions + // is const and false, then the folded else execution body is returned. otherwise + // the execution body of the else are folded and replaced. + if !hasNonConstCondition { + foldedExpr, isDeferred := foldConstant(args[l-1]) + isDeferredConst = isDeferredConst || isDeferred + if _, isConst := foldedExpr.(*Constant); isConst { + foldedExpr.GetType().Decimal = expr.GetType().Decimal + return foldedExpr, isDeferredConst + } + return BuildCastFunction(expr.GetCtx(), foldedExpr, foldedExpr.GetType()), isDeferredConst + } + + expr.GetArgs()[l-1], isDeferred = foldConstant(args[l-1]) + isDeferredConst = isDeferredConst || isDeferred + + return expr, isDeferredConst +} + func foldConstant(expr Expression) (Expression, bool) { switch x := expr.(type) { case *ScalarFunction: diff --git a/expression/integration_test.go b/expression/integration_test.go index a13edd67c2fb0..bbec68f771067 100644 --- a/expression/integration_test.go +++ b/expression/integration_test.go @@ -4648,5 +4648,17 @@ func (s *testIntegrationSuite) TestDatetimeMicrosecond(c *C) { testkit.Rows("2007-03-28 22:06:28")) tk.MustQuery(`select DATE_ADD('2007-03-28 22:08:28',INTERVAL "-2.-2" MICROSECOND);`).Check( testkit.Rows("2007-03-28 22:08:27.999998")) +} + +func (s *testIntegrationSuite) TestFuncCaseWithLeftJoin(c *C) { + tk := testkit.NewTestKitWithInit(c, s.store) + + tk.MustExec("create table kankan1(id int, name text)") + tk.MustExec("insert into kankan1 values(1, 'a')") + tk.MustExec("insert into kankan1 values(2, 'a')") + + tk.MustExec("create table kankan2(id int, h1 text)") + tk.MustExec("insert into kankan2 values(2, 'z')") + tk.MustQuery("select t1.id from kankan1 t1 left join kankan2 t2 on t1.id = t2.id where (case when t1.name='b' then 'case2' when t1.name='a' then 'case1' else NULL end) = 'case1' order by t1.id").Check(testkit.Rows("1", "2")) }