Skip to content

Commit

Permalink
fix the issue that decimal divide not round. (pingcap#6471)
Browse files Browse the repository at this point in the history
  • Loading branch information
LittleFall authored and guo-shaoge committed Feb 10, 2023
1 parent 0d25d4d commit 206b807
Show file tree
Hide file tree
Showing 4 changed files with 314 additions and 7 deletions.
37 changes: 32 additions & 5 deletions dbms/src/Functions/divide.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,36 @@ struct TiDBDivideFloatingImpl<A, B, false>
using ResultType = typename NumberTraits::ResultOfFloatingPointDivision<A, B>::Type;

template <typename Result = ResultType>
static Result apply(A a, B b)
static Result apply(A x, B d)
{
return static_cast<Result>(a) / b;
/// ref https://github.com/pingcap/tiflash/issues/6462
/// For division of Decimal/Decimal or Int/Decimal or Decimal/Int, we should round the result to make compatible with TiDB.
/// basically refer to https://stackoverflow.com/a/71634489
if constexpr (std::is_integral_v<Result> || std::is_same_v<Result, Int256>)
{
/// 1. do division first, get the quotient and mod, todo:(perf) find a unified `divmod` function to speed up this.
Result quotient = x / d;
Result mod = x % d;
/// 2. get the half of divisor, which is threshold to decide whether to round up or down.
/// note: don't directly use bit operation here, it may cause unexpected result.
Result half = (d / 2) + (d % 2);

/// 3. compare the abstract values of mod and half, if mod >= half, then round up.
Result abs_m = mod < 0 ? -mod : mod;
Result abs_h = half < 0 ? -half : half;
if (abs_m >= abs_h)
{
/// 4. now we need to round up, i.e., add 1 to the quotient's absolute value.
/// if the signs of dividend and divisor are the same, then the quotient should be positive, otherwise negative.
if ((x < 0) == (d < 0)) // same_sign, i.e., quotient >= 0
quotient = quotient + 1;
else
quotient = quotient - 1;
}
return quotient;
}
else
return static_cast<Result>(x) / d;
}
template <typename Result = ResultType>
static Result apply(A a, B b, UInt8 & res_null)
Expand All @@ -75,7 +102,7 @@ struct TiDBDivideFloatingImpl<A, B, false>
res_null = 1;
return static_cast<Result>(0);
}
return static_cast<Result>(a) / b;
return apply<Result>(a, b);
}
};

Expand All @@ -102,7 +129,7 @@ struct TiDBDivideFloatingImpl<A, B, true>
res_null = 1;
return static_cast<Result>(0);
}
return static_cast<Result>(a) / static_cast<Result>(b);
return apply<Result>(a, b);
}
};

Expand Down Expand Up @@ -332,4 +359,4 @@ void registerFunctionDivideIntegralOrZero(FunctionFactory & factory)
factory.registerFunction<FunctionDivideIntegralOrZero>();
}

} // namespace DB
} // namespace DB
137 changes: 137 additions & 0 deletions dbms/src/Functions/tests/gtest_arithmetic_functions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@
#include <Interpreters/Context.h>
#include <TestUtils/FunctionTestUtils.h>
#include <TestUtils/TiFlashTestBasic.h>
#include <gtest/gtest.h>

#include <Functions/divide.cpp>
#include <string>
#include <unordered_map>
#include <vector>
Expand Down Expand Up @@ -103,6 +105,141 @@ class TestBinaryArithmeticFunctions : public DB::tests::FunctionTest
}
};

template <typename TYPE>
void doTiDBDivideDecimalRoundInternalTest()
{
auto apply = static_cast<TYPE (*)(TYPE, TYPE)>(&TiDBDivideFloatingImpl<TYPE, TYPE, false>::apply);

constexpr TYPE max = std::numeric_limits<TYPE>::max();
// note: Int256's min is not equal to -max-1
// according to https://www.boost.org/doc/libs/1_60_0/libs/multiprecision/doc/html/boost_multiprecision/tut/ints/cpp_int.html
constexpr TYPE min = std::numeric_limits<TYPE>::min();

// clang-format off
const std::vector<std::array<TYPE, 3>> cases = {
{1, 2, 1}, {1, -2, -1}, {-1, 2, -1}, {-1, -2, 1},

{0, 3, 0}, {0, -3, 0}, {0, 3, 0}, {0, -3, 0},
{1, 3, 0}, {1, -3, 0}, {-1, 3, 0}, {-1, -3, 0},
{2, 3, 1}, {2, -3, -1}, {-2, 3, -1}, {-2, -3, 1},
{3, 3, 1}, {3, -3, -1}, {-3, 3, -1}, {-3, -3, 1},
{4, 3, 1}, {4, -3, -1}, {-4, 3, -1}, {-4, -3, 1},
{5, 3, 2}, {5, -3, -2}, {-5, 3, -2}, {-5, -3, 2},

// ±max as divisor
{0, max, 0}, {max/2-1, max, 0}, {max/2, max, 0}, {max/2+1, max, 1}, {max-1, max, 1}, {max, max, 1},
{-1, max, 0}, {-max/2+1, max, 0}, {-max/2, max, 0}, {-max/2-1, max, -1}, {-max+1, max, -1}, {-max, max, -1}, {min, max, -1},
{0, -max, 0}, {max/2-1, -max, 0}, {max/2, -max, 0}, {max/2+1, -max, -1}, {max-1, -max, -1}, {max, -max, -1},
{-1, -max, 0}, {-max/2+1, -max, 0}, {-max/2, -max, 0}, {-max/2-1, -max, 1}, {-max+1, -max, 1}, {-max, -max, 1}, {min, -max, 1},

// ±max as dividend
{max, 1, max}, {max, 2, max/2+1}, {max, max/2-1, 2}, {max, max/2, 2}, {max, max/2+1, 2}, {max, max-1, 1},
{max, -1, -max}, {max, -2, -max/2-1}, {max, -max/2+1, -2}, {max, -max/2, -2}, {max, -max/2-1, -2}, {max, -max+1, -1},
{-max, 1, -max}, {-max, 2, -max/2-1}, {-max, max/2+1, -2}, {-max, max/2, -2}, {-max, max/2-1, -2}, {-max, max-1, -1},
{-max, -1, max}, {-max, -2, max/2+1}, {-max, -max/2-1, 2}, {-max, -max/2, 2}, {-max, -max/2+1, 2}, {-max, -max+1, 1},
};
// clang-format on

for (const auto & expect : cases)
{
std::array<TYPE, 3> actual = {expect[0], expect[1], apply(expect[0], expect[1])};
ASSERT_EQ(expect, actual);
}
}

TEST_F(TestBinaryArithmeticFunctions, TiDBDivideDecimalRoundInternal)
try
{
doTiDBDivideDecimalRoundInternalTest<Int32>();
doTiDBDivideDecimalRoundInternalTest<Int64>();
doTiDBDivideDecimalRoundInternalTest<Int128>();
doTiDBDivideDecimalRoundInternalTest<Int256>();
}
CATCH

TEST_F(TestBinaryArithmeticFunctions, TiDBDivideDecimalRound)
try
{
const String func_name = "tidbDivide";

// decimal32
{
// int and decimal
ASSERT_COLUMN_EQ(
createColumn<Nullable<Decimal64>>(std::make_tuple(18, 4), {DecimalField64(1, 4), DecimalField64(1, 4), DecimalField64(1, 4), DecimalField64(1, 4), DecimalField64(0, 4)}),
executeFunction(
func_name,
createColumn<Int32>({1, 1, 1, 1, 1}),
createColumn<Decimal32>(std::make_tuple(20, 4), {DecimalField32(100000000, 4), DecimalField32(100010000, 4), DecimalField32(199990000, 4), DecimalField32(200000000, 4), DecimalField32(200010000, 4)})));

// decimal and decimal
ASSERT_COLUMN_EQ(
createColumn<Nullable<Decimal128>>(std::make_tuple(26, 8), {DecimalField128(10000, 8), DecimalField128(9999, 8), DecimalField128(5000, 8), DecimalField128(5000, 8), DecimalField128(5000, 8)}),
executeFunction(
func_name,
createColumn<Decimal32>(std::make_tuple(18, 4), {DecimalField32(10000, 4), DecimalField32(10000, 4), DecimalField32(10000, 4), DecimalField32(10000, 4), DecimalField32(10000, 4)}),
createColumn<Decimal32>(std::make_tuple(18, 4), {DecimalField32(100000000, 4), DecimalField32(100010000, 4), DecimalField32(199990000, 4), DecimalField32(200000000, 4), DecimalField32(200010000, 4)})));
}

// decimal64
{
// int and decimal
ASSERT_COLUMN_EQ(
createColumn<Nullable<Decimal64>>(std::make_tuple(18, 4), {DecimalField64(1, 4), DecimalField64(1, 4), DecimalField64(1, 4), DecimalField64(1, 4), DecimalField64(0, 4)}),
executeFunction(
func_name,
createColumn<Int32>({1, 1, 1, 1, 1}),
createColumn<Decimal64>(std::make_tuple(20, 4), {DecimalField64(100000000, 4), DecimalField64(100010000, 4), DecimalField64(199990000, 4), DecimalField64(200000000, 4), DecimalField64(200010000, 4)})));

// decimal and decimal
ASSERT_COLUMN_EQ(
createColumn<Nullable<Decimal128>>(std::make_tuple(26, 8), {DecimalField128(10000, 8), DecimalField128(9999, 8), DecimalField128(5000, 8), DecimalField128(5000, 8), DecimalField128(5000, 8)}),
executeFunction(
func_name,
createColumn<Decimal64>(std::make_tuple(18, 4), {DecimalField64(10000, 4), DecimalField64(10000, 4), DecimalField64(10000, 4), DecimalField64(10000, 4), DecimalField64(10000, 4)}),
createColumn<Decimal64>(std::make_tuple(18, 4), {DecimalField64(100000000, 4), DecimalField64(100010000, 4), DecimalField64(199990000, 4), DecimalField64(200000000, 4), DecimalField64(200010000, 4)})));
}

// decimal128
{
// int and decimal
ASSERT_COLUMN_EQ(
createColumn<Nullable<Decimal64>>(std::make_tuple(18, 4), {DecimalField64(1, 4), DecimalField64(1, 4), DecimalField64(1, 4), DecimalField64(1, 4), DecimalField64(0, 4)}),
executeFunction(
func_name,
createColumn<Int32>({1, 1, 1, 1, 1}),
createColumn<Decimal128>(std::make_tuple(20, 4), {DecimalField128(100000000, 4), DecimalField128(100010000, 4), DecimalField128(199990000, 4), DecimalField128(200000000, 4), DecimalField128(200010000, 4)})));

// decimal and decimal
ASSERT_COLUMN_EQ(
createColumn<Nullable<Decimal128>>(std::make_tuple(26, 8), {DecimalField128(10000, 8), DecimalField128(9999, 8), DecimalField128(5000, 8), DecimalField128(5000, 8), DecimalField128(5000, 8)}),
executeFunction(
func_name,
createColumn<Decimal128>(std::make_tuple(18, 4), {DecimalField128(10000, 4), DecimalField128(10000, 4), DecimalField128(10000, 4), DecimalField128(10000, 4), DecimalField128(10000, 4)}),
createColumn<Decimal128>(std::make_tuple(18, 4), {DecimalField128(100000000, 4), DecimalField128(100010000, 4), DecimalField128(199990000, 4), DecimalField128(200000000, 4), DecimalField128(200010000, 4)})));
}

// decimal256
{
// int and decimal
ASSERT_COLUMN_EQ(
createColumn<Nullable<Decimal64>>(std::make_tuple(18, 4), {DecimalField64(1, 4), DecimalField64(1, 4), DecimalField64(1, 4), DecimalField64(1, 4), DecimalField64(0, 4)}),
executeFunction(
func_name,
createColumn<Int32>({1, 1, 1, 1, 1}),
createColumn<Decimal256>(std::make_tuple(20, 4), {DecimalField256(Int256(100000000), 4), DecimalField256(Int256(100010000), 4), DecimalField256(Int256(199990000), 4), DecimalField256(Int256(200000000), 4), DecimalField256(Int256(200010000), 4)})));

// decimal and decimal
ASSERT_COLUMN_EQ(
createColumn<Nullable<Decimal128>>(std::make_tuple(26, 8), {DecimalField128(10000, 8), DecimalField128(9999, 8), DecimalField128(5000, 8), DecimalField128(5000, 8), DecimalField128(5000, 8)}),
executeFunction(
func_name,
createColumn<Decimal256>(std::make_tuple(18, 4), {DecimalField256(Int256(10000), 4), DecimalField256(Int256(10000), 4), DecimalField256(Int256(10000), 4), DecimalField256(Int256(10000), 4), DecimalField256(Int256(10000), 4)}),
createColumn<Decimal256>(std::make_tuple(18, 4), {DecimalField256(Int256(100000000), 4), DecimalField256(Int256(100010000), 4), DecimalField256(Int256(199990000), 4), DecimalField256(Int256(200000000), 4), DecimalField256(Int256(200010000), 4)})));
}
}
CATCH

TEST_F(TestBinaryArithmeticFunctions, TiDBDivideDecimal)
try
{
Expand Down
137 changes: 137 additions & 0 deletions tests/fullstack-test/expr/decimal_divide.test
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
# Copyright 2023 PingCAP, Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# decimal / decimal
mysql> drop table if exists test.t;
mysql> create table test.t(a decimal(4,0), b decimal(40, 20));
mysql> alter table test.t set tiflash replica 1
mysql> insert into test.t values (1, 10000), (1, 10001), (1, 20000), (1, 20001);
func> wait_table test t
mysql> set tidb_enforce_mpp=1; select a, b, a/b from test.t order by b;
+------+----------------------------+--------+
| a | b | a/b |
+------+----------------------------+--------+
| 1 | 10000.00000000000000000000 | 0.0001 |
| 1 | 10001.00000000000000000000 | 0.0001 |
| 1 | 20000.00000000000000000000 | 0.0001 |
| 1 | 20001.00000000000000000000 | 0.0000 |
+------+----------------------------+--------+

# int / decimal
mysql> drop table if exists test.t;
mysql> create table test.t(a int, b decimal(40, 20));
mysql> alter table test.t set tiflash replica 1
mysql> insert into test.t values (1, 10000), (1, 10001), (1, 20000), (1, 20001);
func> wait_table test t
mysql> set tidb_enforce_mpp=1; select a, b, a/b from test.t order by b;
+------+----------------------------+--------+
| a | b | a/b |
+------+----------------------------+--------+
| 1 | 10000.00000000000000000000 | 0.0001 |
| 1 | 10001.00000000000000000000 | 0.0001 |
| 1 | 20000.00000000000000000000 | 0.0001 |
| 1 | 20001.00000000000000000000 | 0.0000 |
+------+----------------------------+--------+

# decimal / int
mysql> drop table if exists test.t;
mysql> create table test.t(a int, b decimal(40, 20));
mysql> alter table test.t set tiflash replica 1
mysql> insert into test.t values (1, 10000), (1, 10001), (1, 20000), (1, 20001);
func> wait_table test t
mysql> set tidb_enforce_mpp=1; select a, b, a/b from test.t order by b;
+------+----------------------------+--------+
| a | b | a/b |
+------+----------------------------+--------+
| 1 | 10000.00000000000000000000 | 0.0001 |
| 1 | 10001.00000000000000000000 | 0.0001 |
| 1 | 20000.00000000000000000000 | 0.0001 |
| 1 | 20001.00000000000000000000 | 0.0000 |
+------+----------------------------+--------+

# int / int
mysql> drop table if exists test.t;
mysql> create table test.t(a int, b int);
mysql> alter table test.t set tiflash replica 1
mysql> insert into test.t values (1, 10000), (1, 10001), (1, 20000), (1, 20001);
func> wait_table test t
mysql> set tidb_enforce_mpp=1; select a, b, a/b from test.t order by b;
+------+-------+--------+
| a | b | a/b |
+------+-------+--------+
| 1 | 10000 | 0.0001 |
| 1 | 10001 | 0.0001 |
| 1 | 20000 | 0.0001 |
| 1 | 20001 | 0.0000 |
+------+-------+--------+

mysql> drop table if exists test.t;
mysql> create table test.t(a decimal(10,0), b decimal(10,0));
mysql> alter table test.t set tiflash replica 1
mysql> insert into test.t values (2147483647, 1), (2147483647, 1073741823), (2147483647, 1073741824), (2147483647, 2147483646), (2147483647, 2147483647);
mysql> insert into test.t values (-2147483647, 1), (-2147483647, 1073741823), (-2147483647, 1073741824), (-2147483647, 2147483646), (-2147483647, 2147483647);
mysql> insert into test.t values (-2147483647, -1), (-2147483647, -1073741823), (-2147483647, -1073741824), (-2147483647, -2147483646), (-2147483647, -2147483647);
mysql> insert into test.t values (2147483647, -1), (2147483647, -1073741823), (2147483647, -1073741824), (2147483647, -2147483646), (2147483647, -2147483647);
func> wait_table test t
mysql> set tidb_enforce_mpp=1; select b, a, b/(a*10000) from test.t where a/b order by b;
+-------------+-------------+-------------+
| b | a | b/(a*10000) |
+-------------+-------------+-------------+
| -2147483647 | 2147483647 | -0.0001 |
| -2147483647 | -2147483647 | 0.0001 |
| -2147483646 | 2147483647 | -0.0001 |
| -2147483646 | -2147483647 | 0.0001 |
| -1073741824 | 2147483647 | -0.0001 |
| -1073741824 | -2147483647 | 0.0001 |
| -1073741823 | -2147483647 | 0.0000 |
| -1073741823 | 2147483647 | 0.0000 |
| -1 | 2147483647 | 0.0000 |
| -1 | -2147483647 | 0.0000 |
| 1 | -2147483647 | 0.0000 |
| 1 | 2147483647 | 0.0000 |
| 1073741823 | -2147483647 | 0.0000 |
| 1073741823 | 2147483647 | 0.0000 |
| 1073741824 | -2147483647 | -0.0001 |
| 1073741824 | 2147483647 | 0.0001 |
| 2147483646 | -2147483647 | -0.0001 |
| 2147483646 | 2147483647 | 0.0001 |
| 2147483647 | -2147483647 | -0.0001 |
| 2147483647 | 2147483647 | 0.0001 |
+-------------+-------------+-------------+
mysql> delete from test.t;
mysql> insert into test.t values (2147483647, 9999999999), (9999999999, 2147483647), (1, 9999999999), (4999999999, 9999999999), (5000000000, 9999999999);
mysql> insert into test.t values (-2147483647, 9999999999), (-9999999999, 2147483647), (-1, 9999999999), (-4999999999, 9999999999), (-5000000000, 9999999999);
mysql> insert into test.t values (-2147483647, -9999999999), (-9999999999, -2147483647), (-1, -9999999999), (-4999999999, -9999999999), (-5000000000, -9999999999);
mysql> insert into test.t values (2147483647, -9999999999), (9999999999, -2147483647), (1, -9999999999), (4999999999, -9999999999), (5000000000, -9999999999);
mysql> set tidb_enforce_mpp=1; select b, a, b/(a*10000) from test.t where a/b order by b;
+-------------+-------------+-------------+
| b | a | b/(a*10000) |
+-------------+-------------+-------------+
| -9999999999 | 2147483647 | -0.0005 |
| -9999999999 | -4999999999 | 0.0002 |
| -9999999999 | 5000000000 | -0.0002 |
| -9999999999 | 4999999999 | -0.0002 |
| -9999999999 | -2147483647 | 0.0005 |
| -9999999999 | -5000000000 | 0.0002 |
| -2147483647 | -9999999999 | 0.0000 |
| -2147483647 | 9999999999 | 0.0000 |
| 2147483647 | 9999999999 | 0.0000 |
| 2147483647 | -9999999999 | 0.0000 |
| 9999999999 | -4999999999 | -0.0002 |
| 9999999999 | -2147483647 | -0.0005 |
| 9999999999 | -5000000000 | -0.0002 |
| 9999999999 | 2147483647 | 0.0005 |
| 9999999999 | 5000000000 | 0.0002 |
| 9999999999 | 4999999999 | 0.0002 |
+-------------+-------------+-------------+
10 changes: 8 additions & 2 deletions tests/tidb-ci/fullstack-test-dt/issue_1425.test
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,20 @@ mysql> drop table if exists test.t;

mysql> create table test.t (id int, value decimal(7,4), c1 int, c2 int);

mysql> insert into test.t values(1,1.9286,54,28);
mysql> insert into test.t values (1,1.9285,54,28), (1,1.9286,54,28);

mysql> alter table test.t set tiflash replica 1;

func> wait_table test t

# note: ref to https://github.com/pingcap/tiflash/issues/1682,
# The precision of tiflash results is different from that of tidb, which is a compatibility issue
mysql> use test; set session tidb_isolation_read_engines='tiflash'; select * from t where value = 54/28;

mysql> use test; set session tidb_isolation_read_engines='tiflash'; select * from t where value = c1/c2;
+------+--------+------+------+
| id | value | c1 | c2 |
+------+--------+------+------+
| 1 | 1.9286 | 54 | 28 |
+------+--------+------+------+

mysql> drop table if exists test.t;

0 comments on commit 206b807

Please sign in to comment.