Skip to content

Commit

Permalink
Fix wrong result of cast(float as decimal) when overflow happens (#4380)
Browse files Browse the repository at this point in the history
close #3998
  • Loading branch information
guo-shaoge authored Mar 23, 2022
1 parent 2b71569 commit 68906ed
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 36 deletions.
3 changes: 2 additions & 1 deletion dbms/src/Functions/FunctionsTiDBConversion.h
Original file line number Diff line number Diff line change
Expand Up @@ -1233,7 +1233,8 @@ struct TiDBConvertToDecimal
static_assert(std::is_floating_point_v<FromFieldType>);
/// cast real as decimal
for (size_t i = 0; i < size; ++i)
vec_to[i] = toTiDBDecimal<FromFieldType, ToFieldType>(vec_from[i], prec, scale, context);
// Always use Float64 to avoid overflow for vec_from[i] * 10^scale.
vec_to[i] = toTiDBDecimal<Float64, ToFieldType>(static_cast<Float64>(vec_from[i]), prec, scale, context);
}
}
else
Expand Down
97 changes: 62 additions & 35 deletions dbms/src/Functions/tests/bench_function_cast.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@ class CastToDecimalBench : public benchmark::Fixture
DataTypePtr from_type_dec_60_5 = std::make_shared<DataTypeDecimal256>(60, 5);
DataTypePtr from_type_date = std::make_shared<DataTypeMyDate>();
DataTypePtr from_type_datetime_fsp5 = std::make_shared<DataTypeMyDateTime>(5);
DataTypePtr from_type_float32 = std::make_shared<DataTypeFloat32>();
DataTypePtr from_type_float64 = std::make_shared<DataTypeFloat64>();

auto tmp_col_int8 = from_type_int8->createColumn();
auto tmp_col_int16 = from_type_int16->createColumn();
Expand All @@ -95,6 +97,8 @@ class CastToDecimalBench : public benchmark::Fixture
auto tmp_col_dec_60_5 = from_type_dec_60_5->createColumn();
auto tmp_col_date = from_type_date->createColumn();
auto tmp_col_datetime_fsp5 = from_type_date->createColumn();
auto tmp_col_float32 = ColumnFloat32::create();
auto tmp_col_float64 = ColumnFloat64::create();

std::uniform_int_distribution<int64_t> dist64(std::numeric_limits<int64_t>::min(), std::numeric_limits<int64_t>::max());

Expand All @@ -120,6 +124,8 @@ class CastToDecimalBench : public benchmark::Fixture
tmp_col_uint16->insert(Field(static_cast<Int64>(static_cast<UInt16>(dist64(mt)))));
tmp_col_uint32->insert(Field(static_cast<Int64>(static_cast<UInt16>(dist64(mt)))));
tmp_col_uint64->insert(Field(static_cast<Int64>(static_cast<UInt16>(dist64(mt)))));
tmp_col_float32->insert(static_cast<Float32>(dist64(mt)));
tmp_col_float64->insert(static_cast<Float64>(dist64(mt)));

tmp_col_dec_2_1->insert(DecimalField(Decimal(static_cast<Int32>(dist64(mt) % 100)), 1));
tmp_col_dec_2_1_small->insert(DecimalField(Decimal(static_cast<Int32>(dist64(mt) % 10)), 1));
Expand All @@ -145,6 +151,8 @@ class CastToDecimalBench : public benchmark::Fixture
from_col_uint16 = ColumnWithTypeAndName(std::move(tmp_col_uint16), from_type_uint16, "from_col_uint16");
from_col_uint32 = ColumnWithTypeAndName(std::move(tmp_col_uint32), from_type_uint32, "from_col_uint32");
from_col_uint64 = ColumnWithTypeAndName(std::move(tmp_col_uint64), from_type_uint64, "from_col_uint64");
from_col_float32 = ColumnWithTypeAndName(std::move(tmp_col_float32), from_type_float32, "from_col_float32");
from_col_float64 = ColumnWithTypeAndName(std::move(tmp_col_float64), from_type_float64, "from_col_float64");

from_col_dec_2_1 = ColumnWithTypeAndName(std::move(tmp_col_dec_2_1), from_type_dec_2_1, "from_col_dec_2_1");
from_col_dec_2_1_small = ColumnWithTypeAndName(std::move(tmp_col_dec_2_1_small), from_type_dec_2_1_small, "from_col_dec_2_1_small");
Expand Down Expand Up @@ -203,16 +211,22 @@ class CastToDecimalBench : public benchmark::Fixture
from_int64_vec = std::vector<Int64>(row_num);
from_int128_vec = std::vector<Int128>(row_num);
from_int256_vec = std::vector<Int256>(row_num);
from_float32_vec = std::vector<Float32>(row_num);
from_float64_vec = std::vector<Float64>(row_num);
dest_int64_vec = std::vector<Int64>(row_num);
dest_int128_vec = std::vector<Int128>(row_num);
dest_int256_vec = std::vector<Int256>(row_num);
dest_float32_vec = std::vector<Float32>(row_num);
dest_float64_vec = std::vector<Float64>(row_num);
const Int256 mod_prec_19 = getScaleMultiplier<Decimal256>(19);
const Int256 mod_prec_38 = getScaleMultiplier<Decimal256>(38);
for (auto i = 0; i < row_num; ++i)
{
from_int64_vec[i] = dist64(mt);
from_int128_vec[i] = static_cast<Int128>(dist256(mt) % (std::numeric_limits<Int128>::max() % mod_prec_19));
from_int256_vec[i] = static_cast<Int256>(dist256(mt) % (std::numeric_limits<Int256>::max()) % mod_prec_38);
from_float32_vec[i] = static_cast<Float32>(from_int64_vec[i]);
from_float64_vec[i] = static_cast<Float64>(from_int64_vec[i]);
}
}

Expand All @@ -227,6 +241,8 @@ class CastToDecimalBench : public benchmark::Fixture
ColumnWithTypeAndName from_col_uint16;
ColumnWithTypeAndName from_col_uint32;
ColumnWithTypeAndName from_col_uint64;
ColumnWithTypeAndName from_col_float32;
ColumnWithTypeAndName from_col_float64;
ColumnWithTypeAndName from_col_dec_2_1;
ColumnWithTypeAndName from_col_dec_2_1_small;
ColumnWithTypeAndName from_col_dec_3_0;
Expand Down Expand Up @@ -267,9 +283,13 @@ class CastToDecimalBench : public benchmark::Fixture
std::vector<Int64> from_int64_vec;
std::vector<Int128> from_int128_vec;
std::vector<Int256> from_int256_vec;
std::vector<Float32> from_float32_vec;
std::vector<Float64> from_float64_vec;
std::vector<Int64> dest_int64_vec;
std::vector<Int128> dest_int128_vec;
std::vector<Int256> dest_int256_vec;
std::vector<Float32> dest_float32_vec;
std::vector<Float64> dest_float64_vec;
};

#define CAST_BENCHMARK(CLASS_NAME, CASE_NAME, FROM_COL, DEST_TYPE) \
Expand Down Expand Up @@ -334,6 +354,9 @@ CAST_BENCHMARK(CastToDecimalBench, int32_to_decimal_60_0, from_col_int32, dest_c
// no; Int64; Int256
CAST_BENCHMARK(CastToDecimalBench, int32_to_decimal_60_4, from_col_int32, dest_col_dec_60_4);

CAST_BENCHMARK(CastToDecimalBench, float32_to_decimal_60_30, from_col_float32, dest_col_dec_60_30);
CAST_BENCHMARK(CastToDecimalBench, float64_to_decimal_60_30, from_col_float64, dest_col_dec_60_30);

// need; Int128; Int32
CAST_BENCHMARK(CastToDecimalBench, int64_to_decimal_8_0, from_col_int64, dest_col_dec_8_0);
// need; Int128; Int64
Expand Down Expand Up @@ -410,44 +433,48 @@ STATIC_CAST_BENCHMARK(CastToDecimalBench, 64, 256);
STATIC_CAST_BENCHMARK(CastToDecimalBench, 128, 128);
STATIC_CAST_BENCHMARK(CastToDecimalBench, 128, 256);

#define DIV_BENCHMARK(CLASS_NAME, TYPE) \
BENCHMARK_DEFINE_F(CastToDecimalBench, div_##TYPE) \
(benchmark::State & state) \
try \
{ \
for (auto _ : state) \
{ \
for (int i = 0; i < row_num; ++i) \
{ \
dest_int##TYPE##_vec[i] = from_int##TYPE##_vec[i] / from_int##TYPE##_vec[0]; \
} \
} \
} \
CATCH \
#define DIV_BENCHMARK(CLASS_NAME, TYPE) \
BENCHMARK_DEFINE_F(CastToDecimalBench, div_##TYPE) \
(benchmark::State & state) \
try \
{ \
for (auto _ : state) \
{ \
for (int i = 0; i < row_num; ++i) \
{ \
dest_##TYPE##_vec[i] = from_##TYPE##_vec[i] / from_##TYPE##_vec[0]; \
} \
} \
} \
CATCH \
BENCHMARK_REGISTER_F(CastToDecimalBench, div_##TYPE)->Iterations(1000);

DIV_BENCHMARK(CastToDecimalBench, 64);
DIV_BENCHMARK(CastToDecimalBench, 128);
DIV_BENCHMARK(CastToDecimalBench, 256);

#define MUL_BENCHMARK(CLASS_NAME, TYPE) \
BENCHMARK_DEFINE_F(CastToDecimalBench, mul_##TYPE) \
(benchmark::State & state) \
try \
{ \
for (auto _ : state) \
{ \
for (int i = 0; i < row_num; ++i) \
{ \
dest_int##TYPE##_vec[i] = from_int##TYPE##_vec[i] * from_int##TYPE##_vec[0]; \
} \
} \
} \
CATCH \
DIV_BENCHMARK(CastToDecimalBench, int64);
DIV_BENCHMARK(CastToDecimalBench, int128);
DIV_BENCHMARK(CastToDecimalBench, int256);
DIV_BENCHMARK(CastToDecimalBench, float32);
DIV_BENCHMARK(CastToDecimalBench, float64);

#define MUL_BENCHMARK(CLASS_NAME, TYPE) \
BENCHMARK_DEFINE_F(CastToDecimalBench, mul_##TYPE) \
(benchmark::State & state) \
try \
{ \
for (auto _ : state) \
{ \
for (int i = 0; i < row_num; ++i) \
{ \
dest_##TYPE##_vec[i] = from_##TYPE##_vec[i] * from_##TYPE##_vec[0]; \
} \
} \
} \
CATCH \
BENCHMARK_REGISTER_F(CastToDecimalBench, mul_##TYPE)->Iterations(1000);

MUL_BENCHMARK(CastToDecimalBench, 64);
MUL_BENCHMARK(CastToDecimalBench, 128);
MUL_BENCHMARK(CastToDecimalBench, 256);
MUL_BENCHMARK(CastToDecimalBench, int64);
MUL_BENCHMARK(CastToDecimalBench, int128);
MUL_BENCHMARK(CastToDecimalBench, int256);
MUL_BENCHMARK(CastToDecimalBench, float32);
MUL_BENCHMARK(CastToDecimalBench, float64);
} // namespace tests
} // namespace DB
15 changes: 15 additions & 0 deletions dbms/src/Functions/tests/gtest_tidb_conversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1164,6 +1164,21 @@ try
testNotOnlyNull<Float64, Decimal256>(12.215, DecimalField256(static_cast<Int256>(1222), 2), std::make_tuple(65, 2));
testNotOnlyNull<Float64, Decimal256>(-12.215, DecimalField256(static_cast<Int256>(-1222), 2), std::make_tuple(65, 2));

// Not compatible with MySQL/TiDB.
// MySQL/TiDB: 34028199169636080000000000000000000000.00
// TiFlash: 34028199169636079590747176440761942016.00
testNotOnlyNull<Float32, Decimal256>(3.40282e+37f, DecimalField256(Decimal256(Int256("3402819916963607959074717644076194201600")), 2), std::make_tuple(50, 2));
// MySQL/TiDB: 34028200000000000000000000000000000000.00
// TiFlash: 34028200000000004441521809130870213181.44
testNotOnlyNull<Float64, Decimal256>(3.40282e+37, DecimalField256(Decimal256(Int256("3402820000000000444152180913087021318144")), 2), std::make_tuple(50, 2));

// MySQL/TiDB: 123.12345886230469000000
// TiFlash: 123.12345886230470197248
testNotOnlyNull<Float32, Decimal256>(123.123456789123456789f, DecimalField256(Decimal256(Int256("12312345886230470197248")), 20), std::make_tuple(50, 20));
// MySQL/TiDB: 123.12345886230469000000
// TiFlash: 123.12345678912344293376
testNotOnlyNull<Float64, Decimal256>(123.123456789123456789, DecimalField256(Decimal256(Int256("12312345678912344293376")), 20), std::make_tuple(50, 20));

dag_context->setFlags(ori_flags);
dag_context->clearWarnings();
}
Expand Down

0 comments on commit 68906ed

Please sign in to comment.