diff --git a/pkg/expression/builtin_arithmetic.go b/pkg/expression/builtin_arithmetic.go index f5d3b1c47ea46..596e493414d13 100644 --- a/pkg/expression/builtin_arithmetic.go +++ b/pkg/expression/builtin_arithmetic.go @@ -787,14 +787,14 @@ func (s *builtinArithmeticIntDivideDecimalSig) Clone() builtinFunc { } func (s *builtinArithmeticIntDivideIntSig) evalInt(ctx EvalContext, row chunk.Row) (int64, bool, error) { - b, bIsNull, err := s.args[1].EvalInt(ctx, row) - if bIsNull || err != nil { - return 0, bIsNull, err - } a, aIsNull, err := s.args[0].EvalInt(ctx, row) if aIsNull || err != nil { return 0, aIsNull, err } + b, bIsNull, err := s.args[1].EvalInt(ctx, row) + if bIsNull || err != nil { + return 0, bIsNull, err + } if b == 0 { return 0, true, handleDivisionByZeroError(ctx) @@ -970,18 +970,22 @@ func (s *builtinArithmeticModRealSig) Clone() builtinFunc { } func (s *builtinArithmeticModRealSig) evalReal(ctx EvalContext, row chunk.Row) (float64, bool, error) { - b, isNull, err := s.args[1].EvalReal(ctx, row) - if isNull || err != nil { - return 0, isNull, err + a, aIsNull, err := s.args[0].EvalReal(ctx, row) + if err != nil { + return 0, false, err } - if b == 0 { - return 0, true, handleDivisionByZeroError(ctx) + b, bIsNull, err := s.args[1].EvalReal(ctx, row) + if err != nil { + return 0, false, err } - a, isNull, err := s.args[0].EvalReal(ctx, row) - if isNull || err != nil { - return 0, isNull, err + if aIsNull || bIsNull { + return 0, true, nil + } + + if b == 0 { + return 0, true, handleDivisionByZeroError(ctx) } return math.Mod(a, b), false, nil @@ -1025,18 +1029,22 @@ func (s *builtinArithmeticModIntUnsignedUnsignedSig) Clone() builtinFunc { } func (s *builtinArithmeticModIntUnsignedUnsignedSig) evalInt(ctx EvalContext, row chunk.Row) (val int64, isNull bool, err error) { - b, isNull, err := s.args[1].EvalInt(ctx, row) - if isNull || err != nil { - return 0, isNull, err + a, aIsNull, err := s.args[0].EvalInt(ctx, row) + if err != nil { + return 0, false, err } - if b == 0 { - return 0, true, handleDivisionByZeroError(ctx) + b, bIsNull, err := s.args[1].EvalInt(ctx, row) + if err != nil { + return 0, false, err } - a, isNull, err := s.args[0].EvalInt(ctx, row) - if isNull || err != nil { - return 0, isNull, err + if aIsNull || bIsNull { + return 0, true, nil + } + + if b == 0 { + return 0, true, handleDivisionByZeroError(ctx) } ret := int64(uint64(a) % uint64(b)) @@ -1055,17 +1063,23 @@ func (s *builtinArithmeticModIntUnsignedSignedSig) Clone() builtinFunc { } func (s *builtinArithmeticModIntUnsignedSignedSig) evalInt(ctx EvalContext, row chunk.Row) (val int64, isNull bool, err error) { - b, isNull, err := s.args[1].EvalInt(ctx, row) - if isNull || err != nil { - return 0, isNull, err + a, aIsNull, err := s.args[0].EvalInt(ctx, row) + if err != nil { + return 0, false, err } + + b, bIsNull, err := s.args[1].EvalInt(ctx, row) + if err != nil { + return 0, false, err + } + + if aIsNull || bIsNull { + return 0, true, nil + } + if b == 0 { return 0, true, handleDivisionByZeroError(ctx) } - a, isNull, err := s.args[0].EvalInt(ctx, row) - if isNull || err != nil { - return 0, isNull, err - } var ret int64 if b < 0 { @@ -1088,18 +1102,22 @@ func (s *builtinArithmeticModIntSignedUnsignedSig) Clone() builtinFunc { } func (s *builtinArithmeticModIntSignedUnsignedSig) evalInt(ctx EvalContext, row chunk.Row) (val int64, isNull bool, err error) { - b, isNull, err := s.args[1].EvalInt(ctx, row) - if isNull || err != nil { - return 0, isNull, err + a, aIsNull, err := s.args[0].EvalInt(ctx, row) + if err != nil { + return 0, false, err } - if b == 0 { - return 0, true, handleDivisionByZeroError(ctx) + b, bIsNull, err := s.args[1].EvalInt(ctx, row) + if err != nil { + return 0, false, err } - a, isNull, err := s.args[0].EvalInt(ctx, row) - if isNull || err != nil { - return 0, isNull, err + if aIsNull || bIsNull { + return 0, true, nil + } + + if b == 0 { + return 0, true, handleDivisionByZeroError(ctx) } var ret int64 @@ -1123,18 +1141,22 @@ func (s *builtinArithmeticModIntSignedSignedSig) Clone() builtinFunc { } func (s *builtinArithmeticModIntSignedSignedSig) evalInt(ctx EvalContext, row chunk.Row) (val int64, isNull bool, err error) { - b, isNull, err := s.args[1].EvalInt(ctx, row) - if isNull || err != nil { - return 0, isNull, err + a, aIsNull, err := s.args[0].EvalInt(ctx, row) + if err != nil { + return 0, false, err } - if b == 0 { - return 0, true, handleDivisionByZeroError(ctx) + b, bIsNull, err := s.args[1].EvalInt(ctx, row) + if err != nil { + return 0, false, err } - a, isNull, err := s.args[0].EvalInt(ctx, row) - if isNull || err != nil { - return 0, isNull, err + if aIsNull || bIsNull { + return 0, true, nil + } + + if b == 0 { + return 0, true, handleDivisionByZeroError(ctx) } return a % b, false, nil diff --git a/pkg/expression/builtin_arithmetic_vec.go b/pkg/expression/builtin_arithmetic_vec.go index 8aa4df472004e..efffc82a6544e 100644 --- a/pkg/expression/builtin_arithmetic_vec.go +++ b/pkg/expression/builtin_arithmetic_vec.go @@ -139,20 +139,17 @@ func (b *builtinArithmeticModIntUnsignedUnsignedSig) vecEvalInt(ctx EvalContext, rhi64s := rh.Int64s() for i := 0; i < len(lhi64s); i++ { + if lh.IsNull(i) || rh.IsNull(i) { + result.SetNull(i, true) + continue + } if rhi64s[i] == 0 { - if rh.IsNull(i) { - continue - } if err := handleDivisionByZeroError(ctx); err != nil { return err } rh.SetNull(i, true) continue } - if lh.IsNull(i) { - rh.SetNull(i, true) - continue - } lhVar, rhVar := lhi64s[i], rhi64s[i] rhi64s[i] = int64(uint64(lhVar) % uint64(rhVar)) } @@ -183,20 +180,17 @@ func (b *builtinArithmeticModIntUnsignedSignedSig) vecEvalInt(ctx EvalContext, i rhi64s := rh.Int64s() for i := 0; i < len(lhi64s); i++ { + if lh.IsNull(i) || rh.IsNull(i) { + result.SetNull(i, true) + continue + } if rhi64s[i] == 0 { - if rh.IsNull(i) { - continue - } if err := handleDivisionByZeroError(ctx); err != nil { return err } rh.SetNull(i, true) continue } - if lh.IsNull(i) { - rh.SetNull(i, true) - continue - } lhVar, rhVar := lhi64s[i], rhi64s[i] if rhVar < 0 { rhi64s[i] = int64(uint64(lhVar) % uint64(-rhVar)) @@ -231,20 +225,17 @@ func (b *builtinArithmeticModIntSignedUnsignedSig) vecEvalInt(ctx EvalContext, i rhi64s := rh.Int64s() for i := 0; i < len(lhi64s); i++ { + if lh.IsNull(i) || rh.IsNull(i) { + result.SetNull(i, true) + continue + } if rhi64s[i] == 0 { - if rh.IsNull(i) { - continue - } if err := handleDivisionByZeroError(ctx); err != nil { return err } rh.SetNull(i, true) continue } - if lh.IsNull(i) { - rh.SetNull(i, true) - continue - } lhVar, rhVar := lhi64s[i], rhi64s[i] if lhVar < 0 { rhi64s[i] = -int64(uint64(-lhVar) % uint64(rhVar)) @@ -279,20 +270,17 @@ func (b *builtinArithmeticModIntSignedSignedSig) vecEvalInt(ctx EvalContext, inp rhi64s := rh.Int64s() for i := 0; i < len(lhi64s); i++ { + if lh.IsNull(i) || rh.IsNull(i) { + result.SetNull(i, true) + continue + } if rhi64s[i] == 0 { - if rh.IsNull(i) { - continue - } if err := handleDivisionByZeroError(ctx); err != nil { return err } rh.SetNull(i, true) continue } - if lh.IsNull(i) { - rh.SetNull(i, true) - continue - } lhVar, rhVar := lhi64s[i], rhi64s[i] rhi64s[i] = lhVar % rhVar } @@ -432,17 +420,17 @@ func (b *builtinArithmeticModRealSig) vecEvalReal(ctx EvalContext, input *chunk. return err } defer b.bufAllocator.put(buf) - if err := b.args[1].VecEvalReal(ctx, input, buf); err != nil { + if err := b.args[0].VecEvalReal(ctx, input, result); err != nil { return err } - if err := b.args[0].VecEvalReal(ctx, input, result); err != nil { + if err := b.args[1].VecEvalReal(ctx, input, buf); err != nil { return err } result.MergeNulls(buf) x := result.Float64s() y := buf.Float64s() for i := 0; i < n; i++ { - if buf.IsNull(i) { + if buf.IsNull(i) || result.IsNull(i) { continue } if y[i] == 0 { @@ -481,7 +469,7 @@ func (b *builtinArithmeticModDecimalSig) vecEvalDecimal(ctx EvalContext, input * y := buf.Decimals() var to types.MyDecimal for i := 0; i < n; i++ { - if result.IsNull(i) { + if result.IsNull(i) || buf.IsNull(i) { continue } err = types.DecimalMod(&x[i], &y[i], &to) diff --git a/pkg/expression/distsql_builtin.go b/pkg/expression/distsql_builtin.go index ecc289a3486f0..3022b506c51d2 100644 --- a/pkg/expression/distsql_builtin.go +++ b/pkg/expression/distsql_builtin.go @@ -1145,7 +1145,7 @@ func PBToExpr(ctx sessionctx.Context, expr *tipb.Expr, tps []*types.FieldType) ( case tipb.ExprType_Float64: return convertFloat(expr.Val, false) case tipb.ExprType_MysqlDecimal: - return convertDecimal(expr.Val) + return convertDecimal(expr.Val, expr.FieldType) case tipb.ExprType_MysqlDuration: return convertDuration(expr.Val) case tipb.ExprType_MysqlTime: @@ -1263,7 +1263,8 @@ func convertFloat(val []byte, f32 bool) (*Constant, error) { return &Constant{Value: d, RetType: types.NewFieldType(mysql.TypeDouble)}, nil } -func convertDecimal(val []byte) (*Constant, error) { +func convertDecimal(val []byte, ftPB *tipb.FieldType) (*Constant, error) { + ft := PbTypeToFieldType(ftPB) _, dec, precision, frac, err := codec.DecodeDecimal(val) var d types.Datum d.SetMysqlDecimal(dec) @@ -1272,7 +1273,7 @@ func convertDecimal(val []byte) (*Constant, error) { if err != nil { return nil, errors.Errorf("invalid decimal % x", val) } - return &Constant{Value: d, RetType: types.NewFieldType(mysql.TypeNewDecimal)}, nil + return &Constant{Value: d, RetType: ft}, nil } func convertDuration(val []byte) (*Constant, error) { diff --git a/pkg/expression/distsql_builtin_test.go b/pkg/expression/distsql_builtin_test.go index ec89e6d3a56fd..47214f4dd14e8 100644 --- a/pkg/expression/distsql_builtin_test.go +++ b/pkg/expression/distsql_builtin_test.go @@ -897,6 +897,7 @@ func datumExpr(t *testing.T, d types.Datum) *tipb.Expr { expr.Val = codec.EncodeInt(nil, int64(d.GetMysqlDuration().Duration)) case types.KindMysqlDecimal: expr.Tp = tipb.ExprType_MysqlDecimal + expr.FieldType = toPBFieldType(types.NewFieldType(mysql.TypeNewDecimal)) var err error expr.Val, err = codec.EncodeDecimal(nil, d.GetMysqlDecimal(), d.Length(), d.Frac()) require.NoError(t, err) diff --git a/pkg/expression/expr_to_pb_test.go b/pkg/expression/expr_to_pb_test.go index 319c4f6ac9084..a54bca90d8f2a 100644 --- a/pkg/expression/expr_to_pb_test.go +++ b/pkg/expression/expr_to_pb_test.go @@ -322,18 +322,20 @@ func TestArithmeticalFunc2Pb(t *testing.T) { require.Equalf(t, jsons[funcNames[i]], string(js), "%v\n", funcNames[i]) } - funcNames = []string{ast.IntDiv} // cannot be pushed down - for _, funcName := range funcNames { - fc, err := NewFunction( - mock.NewContext(), - funcName, - types.NewFieldType(mysql.TypeUnspecified), - genColumn(mysql.TypeDouble, 1), - genColumn(mysql.TypeDouble, 2)) - require.NoError(t, err) - _, err = ExpressionsToPBList(ctx, []Expression{fc}, client) - require.Error(t, err) - } + // IntDiv + fc, err := NewFunction( + mock.NewContext(), + ast.IntDiv, + types.NewFieldType(mysql.TypeUnspecified), + genColumn(mysql.TypeLonglong, 1), + genColumn(mysql.TypeLonglong, 2)) + require.NoError(t, err) + pbExprs, err = ExpressionsToPBList(ctx, []Expression{fc}, client) + require.NoError(t, err) + js, err := json.Marshal(pbExprs[0]) + require.NoError(t, err) + expectedJs := "{\"tp\":10000,\"children\":[{\"tp\":201,\"val\":\"gAAAAAAAAAE=\",\"sig\":0,\"field_type\":{\"tp\":8,\"flag\":0,\"flen\":20,\"decimal\":0,\"collate\":-63,\"charset\":\"binary\",\"array\":false},\"has_distinct\":false},{\"tp\":201,\"val\":\"gAAAAAAAAAI=\",\"sig\":0,\"field_type\":{\"tp\":8,\"flag\":0,\"flen\":20,\"decimal\":0,\"collate\":-63,\"charset\":\"binary\",\"array\":false},\"has_distinct\":false}],\"sig\":213,\"field_type\":{\"tp\":8,\"flag\":128,\"flen\":20,\"decimal\":0,\"collate\":-63,\"charset\":\"binary\",\"array\":false},\"has_distinct\":false}" + require.Equalf(t, expectedJs, string(js), "%v\n", ast.IntDiv) } func TestDateFunc2Pb(t *testing.T) { diff --git a/pkg/expression/infer_pushdown.go b/pkg/expression/infer_pushdown.go index 51137840e99a7..d91ad38f746b3 100644 --- a/pkg/expression/infer_pushdown.go +++ b/pkg/expression/infer_pushdown.go @@ -179,7 +179,7 @@ func scalarExprSupportedByTiKV(sf *ScalarFunction) bool { // arithmetical functions. ast.PI, /* ast.Truncate */ - ast.Plus, ast.Minus, ast.Mul, ast.Div, ast.Abs, ast.Mod, + ast.Plus, ast.Minus, ast.Mul, ast.Div, ast.Abs, ast.Mod, ast.IntDiv, // math functions. ast.Ceil, ast.Ceiling, ast.Floor, ast.Sqrt, ast.Sign, ast.Ln, ast.Log, ast.Log2, ast.Log10, ast.Exp, ast.Pow, diff --git a/tests/integrationtest/r/planner/core/casetest/physicalplantest/physical_plan.result b/tests/integrationtest/r/planner/core/casetest/physicalplantest/physical_plan.result index 541e196db8f36..47c63d2efe4f9 100644 --- a/tests/integrationtest/r/planner/core/casetest/physicalplantest/physical_plan.result +++ b/tests/integrationtest/r/planner/core/casetest/physicalplantest/physical_plan.result @@ -200,13 +200,12 @@ Level Code Message explain format = 'brief' select /*+ LIMIT_TO_COP() */ a from tn where a div 2 order by a limit 1; id estRows task access object operator info Limit 1.00 root offset:0, count:1 -└─Selection 1.00 root intdiv(planner__core__casetest__physicalplantest__physical_plan.tn.a, 2) - └─IndexReader 1.00 root index:IndexFullScan - └─IndexFullScan 1.00 cop[tikv] table:tn, index:a(a, b, c, d) keep order:true, stats:pseudo +└─IndexReader 1.00 root index:Limit + └─Limit 1.00 cop[tikv] offset:0, count:1 + └─Selection 1.00 cop[tikv] intdiv(planner__core__casetest__physicalplantest__physical_plan.tn.a, 2) + └─IndexFullScan 1.25 cop[tikv] table:tn, index:a(a, b, c, d) keep order:true, stats:pseudo show warnings; Level Code Message -Warning 1105 Scalar function 'intdiv'(signature: IntDivideInt, return type: bigint(20)) is not supported to push down to storage layer now. -Warning 1815 Optimizer Hint LIMIT_TO_COP is inapplicable explain format = 'brief' select /*+ LIMIT_TO_COP() */ a from tn where a > 10 limit 1; id estRows task access object operator info Limit 1.00 root offset:0, count:1