Skip to content

Commit

Permalink
expression: support push-down INTDIV to TiKV (#49051)
Browse files Browse the repository at this point in the history
close #49050
  • Loading branch information
wjhuang2016 authored Dec 27, 2023
1 parent aeefec8 commit 75b451c
Show file tree
Hide file tree
Showing 7 changed files with 109 additions and 96 deletions.
108 changes: 65 additions & 43 deletions pkg/expression/builtin_arithmetic.go
Original file line number Diff line number Diff line change
Expand Up @@ -787,14 +787,14 @@ func (s *builtinArithmeticIntDivideDecimalSig) Clone() builtinFunc {
}

func (s *builtinArithmeticIntDivideIntSig) evalInt(ctx EvalContext, row chunk.Row) (int64, bool, error) {
b, bIsNull, err := s.args[1].EvalInt(ctx, row)
if bIsNull || err != nil {
return 0, bIsNull, err
}
a, aIsNull, err := s.args[0].EvalInt(ctx, row)
if aIsNull || err != nil {
return 0, aIsNull, err
}
b, bIsNull, err := s.args[1].EvalInt(ctx, row)
if bIsNull || err != nil {
return 0, bIsNull, err
}

if b == 0 {
return 0, true, handleDivisionByZeroError(ctx)
Expand Down Expand Up @@ -970,18 +970,22 @@ func (s *builtinArithmeticModRealSig) Clone() builtinFunc {
}

func (s *builtinArithmeticModRealSig) evalReal(ctx EvalContext, row chunk.Row) (float64, bool, error) {
b, isNull, err := s.args[1].EvalReal(ctx, row)
if isNull || err != nil {
return 0, isNull, err
a, aIsNull, err := s.args[0].EvalReal(ctx, row)
if err != nil {
return 0, false, err
}

if b == 0 {
return 0, true, handleDivisionByZeroError(ctx)
b, bIsNull, err := s.args[1].EvalReal(ctx, row)
if err != nil {
return 0, false, err
}

a, isNull, err := s.args[0].EvalReal(ctx, row)
if isNull || err != nil {
return 0, isNull, err
if aIsNull || bIsNull {
return 0, true, nil
}

if b == 0 {
return 0, true, handleDivisionByZeroError(ctx)
}

return math.Mod(a, b), false, nil
Expand Down Expand Up @@ -1025,18 +1029,22 @@ func (s *builtinArithmeticModIntUnsignedUnsignedSig) Clone() builtinFunc {
}

func (s *builtinArithmeticModIntUnsignedUnsignedSig) evalInt(ctx EvalContext, row chunk.Row) (val int64, isNull bool, err error) {
b, isNull, err := s.args[1].EvalInt(ctx, row)
if isNull || err != nil {
return 0, isNull, err
a, aIsNull, err := s.args[0].EvalInt(ctx, row)
if err != nil {
return 0, false, err
}

if b == 0 {
return 0, true, handleDivisionByZeroError(ctx)
b, bIsNull, err := s.args[1].EvalInt(ctx, row)
if err != nil {
return 0, false, err
}

a, isNull, err := s.args[0].EvalInt(ctx, row)
if isNull || err != nil {
return 0, isNull, err
if aIsNull || bIsNull {
return 0, true, nil
}

if b == 0 {
return 0, true, handleDivisionByZeroError(ctx)
}

ret := int64(uint64(a) % uint64(b))
Expand All @@ -1055,17 +1063,23 @@ func (s *builtinArithmeticModIntUnsignedSignedSig) Clone() builtinFunc {
}

func (s *builtinArithmeticModIntUnsignedSignedSig) evalInt(ctx EvalContext, row chunk.Row) (val int64, isNull bool, err error) {
b, isNull, err := s.args[1].EvalInt(ctx, row)
if isNull || err != nil {
return 0, isNull, err
a, aIsNull, err := s.args[0].EvalInt(ctx, row)
if err != nil {
return 0, false, err
}

b, bIsNull, err := s.args[1].EvalInt(ctx, row)
if err != nil {
return 0, false, err
}

if aIsNull || bIsNull {
return 0, true, nil
}

if b == 0 {
return 0, true, handleDivisionByZeroError(ctx)
}
a, isNull, err := s.args[0].EvalInt(ctx, row)
if isNull || err != nil {
return 0, isNull, err
}

var ret int64
if b < 0 {
Expand All @@ -1088,18 +1102,22 @@ func (s *builtinArithmeticModIntSignedUnsignedSig) Clone() builtinFunc {
}

func (s *builtinArithmeticModIntSignedUnsignedSig) evalInt(ctx EvalContext, row chunk.Row) (val int64, isNull bool, err error) {
b, isNull, err := s.args[1].EvalInt(ctx, row)
if isNull || err != nil {
return 0, isNull, err
a, aIsNull, err := s.args[0].EvalInt(ctx, row)
if err != nil {
return 0, false, err
}

if b == 0 {
return 0, true, handleDivisionByZeroError(ctx)
b, bIsNull, err := s.args[1].EvalInt(ctx, row)
if err != nil {
return 0, false, err
}

a, isNull, err := s.args[0].EvalInt(ctx, row)
if isNull || err != nil {
return 0, isNull, err
if aIsNull || bIsNull {
return 0, true, nil
}

if b == 0 {
return 0, true, handleDivisionByZeroError(ctx)
}

var ret int64
Expand All @@ -1123,18 +1141,22 @@ func (s *builtinArithmeticModIntSignedSignedSig) Clone() builtinFunc {
}

func (s *builtinArithmeticModIntSignedSignedSig) evalInt(ctx EvalContext, row chunk.Row) (val int64, isNull bool, err error) {
b, isNull, err := s.args[1].EvalInt(ctx, row)
if isNull || err != nil {
return 0, isNull, err
a, aIsNull, err := s.args[0].EvalInt(ctx, row)
if err != nil {
return 0, false, err
}

if b == 0 {
return 0, true, handleDivisionByZeroError(ctx)
b, bIsNull, err := s.args[1].EvalInt(ctx, row)
if err != nil {
return 0, false, err
}

a, isNull, err := s.args[0].EvalInt(ctx, row)
if isNull || err != nil {
return 0, isNull, err
if aIsNull || bIsNull {
return 0, true, nil
}

if b == 0 {
return 0, true, handleDivisionByZeroError(ctx)
}

return a % b, false, nil
Expand Down
52 changes: 20 additions & 32 deletions pkg/expression/builtin_arithmetic_vec.go
Original file line number Diff line number Diff line change
Expand Up @@ -139,20 +139,17 @@ func (b *builtinArithmeticModIntUnsignedUnsignedSig) vecEvalInt(ctx EvalContext,
rhi64s := rh.Int64s()

for i := 0; i < len(lhi64s); i++ {
if lh.IsNull(i) || rh.IsNull(i) {
result.SetNull(i, true)
continue
}
if rhi64s[i] == 0 {
if rh.IsNull(i) {
continue
}
if err := handleDivisionByZeroError(ctx); err != nil {
return err
}
rh.SetNull(i, true)
continue
}
if lh.IsNull(i) {
rh.SetNull(i, true)
continue
}
lhVar, rhVar := lhi64s[i], rhi64s[i]
rhi64s[i] = int64(uint64(lhVar) % uint64(rhVar))
}
Expand Down Expand Up @@ -183,20 +180,17 @@ func (b *builtinArithmeticModIntUnsignedSignedSig) vecEvalInt(ctx EvalContext, i
rhi64s := rh.Int64s()

for i := 0; i < len(lhi64s); i++ {
if lh.IsNull(i) || rh.IsNull(i) {
result.SetNull(i, true)
continue
}
if rhi64s[i] == 0 {
if rh.IsNull(i) {
continue
}
if err := handleDivisionByZeroError(ctx); err != nil {
return err
}
rh.SetNull(i, true)
continue
}
if lh.IsNull(i) {
rh.SetNull(i, true)
continue
}
lhVar, rhVar := lhi64s[i], rhi64s[i]
if rhVar < 0 {
rhi64s[i] = int64(uint64(lhVar) % uint64(-rhVar))
Expand Down Expand Up @@ -231,20 +225,17 @@ func (b *builtinArithmeticModIntSignedUnsignedSig) vecEvalInt(ctx EvalContext, i
rhi64s := rh.Int64s()

for i := 0; i < len(lhi64s); i++ {
if lh.IsNull(i) || rh.IsNull(i) {
result.SetNull(i, true)
continue
}
if rhi64s[i] == 0 {
if rh.IsNull(i) {
continue
}
if err := handleDivisionByZeroError(ctx); err != nil {
return err
}
rh.SetNull(i, true)
continue
}
if lh.IsNull(i) {
rh.SetNull(i, true)
continue
}
lhVar, rhVar := lhi64s[i], rhi64s[i]
if lhVar < 0 {
rhi64s[i] = -int64(uint64(-lhVar) % uint64(rhVar))
Expand Down Expand Up @@ -279,20 +270,17 @@ func (b *builtinArithmeticModIntSignedSignedSig) vecEvalInt(ctx EvalContext, inp
rhi64s := rh.Int64s()

for i := 0; i < len(lhi64s); i++ {
if lh.IsNull(i) || rh.IsNull(i) {
result.SetNull(i, true)
continue
}
if rhi64s[i] == 0 {
if rh.IsNull(i) {
continue
}
if err := handleDivisionByZeroError(ctx); err != nil {
return err
}
rh.SetNull(i, true)
continue
}
if lh.IsNull(i) {
rh.SetNull(i, true)
continue
}
lhVar, rhVar := lhi64s[i], rhi64s[i]
rhi64s[i] = lhVar % rhVar
}
Expand Down Expand Up @@ -432,17 +420,17 @@ func (b *builtinArithmeticModRealSig) vecEvalReal(ctx EvalContext, input *chunk.
return err
}
defer b.bufAllocator.put(buf)
if err := b.args[1].VecEvalReal(ctx, input, buf); err != nil {
if err := b.args[0].VecEvalReal(ctx, input, result); err != nil {
return err
}
if err := b.args[0].VecEvalReal(ctx, input, result); err != nil {
if err := b.args[1].VecEvalReal(ctx, input, buf); err != nil {
return err
}
result.MergeNulls(buf)
x := result.Float64s()
y := buf.Float64s()
for i := 0; i < n; i++ {
if buf.IsNull(i) {
if buf.IsNull(i) || result.IsNull(i) {
continue
}
if y[i] == 0 {
Expand Down Expand Up @@ -481,7 +469,7 @@ func (b *builtinArithmeticModDecimalSig) vecEvalDecimal(ctx EvalContext, input *
y := buf.Decimals()
var to types.MyDecimal
for i := 0; i < n; i++ {
if result.IsNull(i) {
if result.IsNull(i) || buf.IsNull(i) {
continue
}
err = types.DecimalMod(&x[i], &y[i], &to)
Expand Down
7 changes: 4 additions & 3 deletions pkg/expression/distsql_builtin.go
Original file line number Diff line number Diff line change
Expand Up @@ -1145,7 +1145,7 @@ func PBToExpr(ctx sessionctx.Context, expr *tipb.Expr, tps []*types.FieldType) (
case tipb.ExprType_Float64:
return convertFloat(expr.Val, false)
case tipb.ExprType_MysqlDecimal:
return convertDecimal(expr.Val)
return convertDecimal(expr.Val, expr.FieldType)
case tipb.ExprType_MysqlDuration:
return convertDuration(expr.Val)
case tipb.ExprType_MysqlTime:
Expand Down Expand Up @@ -1263,7 +1263,8 @@ func convertFloat(val []byte, f32 bool) (*Constant, error) {
return &Constant{Value: d, RetType: types.NewFieldType(mysql.TypeDouble)}, nil
}

func convertDecimal(val []byte) (*Constant, error) {
func convertDecimal(val []byte, ftPB *tipb.FieldType) (*Constant, error) {
ft := PbTypeToFieldType(ftPB)
_, dec, precision, frac, err := codec.DecodeDecimal(val)
var d types.Datum
d.SetMysqlDecimal(dec)
Expand All @@ -1272,7 +1273,7 @@ func convertDecimal(val []byte) (*Constant, error) {
if err != nil {
return nil, errors.Errorf("invalid decimal % x", val)
}
return &Constant{Value: d, RetType: types.NewFieldType(mysql.TypeNewDecimal)}, nil
return &Constant{Value: d, RetType: ft}, nil
}

func convertDuration(val []byte) (*Constant, error) {
Expand Down
1 change: 1 addition & 0 deletions pkg/expression/distsql_builtin_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -897,6 +897,7 @@ func datumExpr(t *testing.T, d types.Datum) *tipb.Expr {
expr.Val = codec.EncodeInt(nil, int64(d.GetMysqlDuration().Duration))
case types.KindMysqlDecimal:
expr.Tp = tipb.ExprType_MysqlDecimal
expr.FieldType = toPBFieldType(types.NewFieldType(mysql.TypeNewDecimal))
var err error
expr.Val, err = codec.EncodeDecimal(nil, d.GetMysqlDecimal(), d.Length(), d.Frac())
require.NoError(t, err)
Expand Down
26 changes: 14 additions & 12 deletions pkg/expression/expr_to_pb_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -322,18 +322,20 @@ func TestArithmeticalFunc2Pb(t *testing.T) {
require.Equalf(t, jsons[funcNames[i]], string(js), "%v\n", funcNames[i])
}

funcNames = []string{ast.IntDiv} // cannot be pushed down
for _, funcName := range funcNames {
fc, err := NewFunction(
mock.NewContext(),
funcName,
types.NewFieldType(mysql.TypeUnspecified),
genColumn(mysql.TypeDouble, 1),
genColumn(mysql.TypeDouble, 2))
require.NoError(t, err)
_, err = ExpressionsToPBList(ctx, []Expression{fc}, client)
require.Error(t, err)
}
// IntDiv
fc, err := NewFunction(
mock.NewContext(),
ast.IntDiv,
types.NewFieldType(mysql.TypeUnspecified),
genColumn(mysql.TypeLonglong, 1),
genColumn(mysql.TypeLonglong, 2))
require.NoError(t, err)
pbExprs, err = ExpressionsToPBList(ctx, []Expression{fc}, client)
require.NoError(t, err)
js, err := json.Marshal(pbExprs[0])
require.NoError(t, err)
expectedJs := "{\"tp\":10000,\"children\":[{\"tp\":201,\"val\":\"gAAAAAAAAAE=\",\"sig\":0,\"field_type\":{\"tp\":8,\"flag\":0,\"flen\":20,\"decimal\":0,\"collate\":-63,\"charset\":\"binary\",\"array\":false},\"has_distinct\":false},{\"tp\":201,\"val\":\"gAAAAAAAAAI=\",\"sig\":0,\"field_type\":{\"tp\":8,\"flag\":0,\"flen\":20,\"decimal\":0,\"collate\":-63,\"charset\":\"binary\",\"array\":false},\"has_distinct\":false}],\"sig\":213,\"field_type\":{\"tp\":8,\"flag\":128,\"flen\":20,\"decimal\":0,\"collate\":-63,\"charset\":\"binary\",\"array\":false},\"has_distinct\":false}"
require.Equalf(t, expectedJs, string(js), "%v\n", ast.IntDiv)
}

func TestDateFunc2Pb(t *testing.T) {
Expand Down
2 changes: 1 addition & 1 deletion pkg/expression/infer_pushdown.go
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ func scalarExprSupportedByTiKV(sf *ScalarFunction) bool {

// arithmetical functions.
ast.PI, /* ast.Truncate */
ast.Plus, ast.Minus, ast.Mul, ast.Div, ast.Abs, ast.Mod,
ast.Plus, ast.Minus, ast.Mul, ast.Div, ast.Abs, ast.Mod, ast.IntDiv,

// math functions.
ast.Ceil, ast.Ceiling, ast.Floor, ast.Sqrt, ast.Sign, ast.Ln, ast.Log, ast.Log2, ast.Log10, ast.Exp, ast.Pow,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -200,13 +200,12 @@ Level Code Message
explain format = 'brief' select /*+ LIMIT_TO_COP() */ a from tn where a div 2 order by a limit 1;
id estRows task access object operator info
Limit 1.00 root offset:0, count:1
└─Selection 1.00 root intdiv(planner__core__casetest__physicalplantest__physical_plan.tn.a, 2)
└─IndexReader 1.00 root index:IndexFullScan
└─IndexFullScan 1.00 cop[tikv] table:tn, index:a(a, b, c, d) keep order:true, stats:pseudo
└─IndexReader 1.00 root index:Limit
└─Limit 1.00 cop[tikv] offset:0, count:1
└─Selection 1.00 cop[tikv] intdiv(planner__core__casetest__physicalplantest__physical_plan.tn.a, 2)
└─IndexFullScan 1.25 cop[tikv] table:tn, index:a(a, b, c, d) keep order:true, stats:pseudo
show warnings;
Level Code Message
Warning 1105 Scalar function 'intdiv'(signature: IntDivideInt, return type: bigint(20)) is not supported to push down to storage layer now.
Warning 1815 Optimizer Hint LIMIT_TO_COP is inapplicable
explain format = 'brief' select /*+ LIMIT_TO_COP() */ a from tn where a > 10 limit 1;
id estRows task access object operator info
Limit 1.00 root offset:0, count:1
Expand Down

0 comments on commit 75b451c

Please sign in to comment.