diff --git a/dbms/src/Functions/FunctionsTiDBConversion.h b/dbms/src/Functions/FunctionsTiDBConversion.h index 506374aae21..fde9cbca1d4 100644 --- a/dbms/src/Functions/FunctionsTiDBConversion.h +++ b/dbms/src/Functions/FunctionsTiDBConversion.h @@ -816,21 +816,27 @@ struct TiDBConvertToDecimal using FromFieldType = typename FromDataType::FieldType; template - static U toTiDBDecimalInternal(T value, PrecType prec, ScaleType scale, const Context & context) + static U toTiDBDecimalInternal(T int_value, PrecType prec, ScaleType scale, const Context & context) { + // int_value is the value that exposes to user. Such as cast(val to decimal), val is the int_value which used by user. + // And val * scale_mul is the scaled_value, which is stored in ColumnDecimal internally. + static_assert(std::is_integral_v); using UType = typename U::NativeType; - auto max_value = DecimalMaxValue::get(prec); - if (value > max_value || value < -max_value) + UType scale_mul = getScaleMultiplier(scale); + + Int256 scaled_value = static_cast(int_value) * static_cast(scale_mul); + Int256 scaled_max_value = DecimalMaxValue::get(prec); + + if (scaled_value > scaled_max_value || scaled_value < -scaled_max_value) { context.getDAGContext()->handleOverflowError("cast to decimal", Errors::Types::Truncated); - if (value > 0) - return static_cast(max_value); + if (int_value > 0) + return static_cast(scaled_max_value); else - return static_cast(-max_value); + return static_cast(-scaled_max_value); } - UType scale_mul = getScaleMultiplier(scale); - U result = static_cast(value) * scale_mul; - return result; + + return static_cast(scaled_value); } template diff --git a/dbms/src/Functions/tests/gtest_tidb_conversion.cpp b/dbms/src/Functions/tests/gtest_tidb_conversion.cpp index 3f9abf6471f..887e896168f 100644 --- a/dbms/src/Functions/tests/gtest_tidb_conversion.cpp +++ b/dbms/src/Functions/tests/gtest_tidb_conversion.cpp @@ -670,6 +670,28 @@ try executeFunction(func_name, {createColumn>({MAX_INT64, {}}), createCastTypeConstColumn("Nullable(Decimal(65,0))")})); + + ASSERT_THROW(executeFunction(func_name, + {createColumn>({9999}), createCastTypeConstColumn("Nullable(Decimal(4, 1))")}), + TiFlashException); + + ASSERT_THROW(executeFunction(func_name, + {createColumn>({-9999}), createCastTypeConstColumn("Nullable(Decimal(4, 1))")}), + TiFlashException); + + ASSERT_COLUMN_EQ( + createColumn>( + std::make_tuple(4, 1), + {DecimalField32(static_cast(9990), 1)}), + executeFunction(func_name, + {createColumn>({999}), createCastTypeConstColumn("Nullable(Decimal(4, 1))")})); + + ASSERT_COLUMN_EQ( + createColumn>( + std::make_tuple(4, 1), + {DecimalField32(static_cast(-9990), 1)}), + executeFunction(func_name, + {createColumn>({-999}), createCastTypeConstColumn("Nullable(Decimal(4, 1))")})); } CATCH diff --git a/tests/fullstack-test/expr/cast_as_decimal.test b/tests/fullstack-test/expr/cast_as_decimal.test index f0985e3cad3..2ce5db81a0d 100644 --- a/tests/fullstack-test/expr/cast_as_decimal.test +++ b/tests/fullstack-test/expr/cast_as_decimal.test @@ -1,4 +1,21 @@ mysql> drop table if exists test.t1; +mysql> create table test.t1(c1 int); +mysql> insert into test.t1 values(9999), (-9999), (99), (-99); +mysql> alter table test.t1 set tiflash replica 1; +func> wait_table test t1 +mysql> set @@tidb_isolation_read_engines='tiflash'; select cast(c1 as decimal(4, 1)) from test.t1 order by 1; +cast(c1 as decimal(4, 1)) +-999.9 +-99.0 +99.0 +999.9 +mysql> set @@tidb_isolation_read_engines='tiflash'; select cast(c1 as decimal(2, 2)) from test.t1 order by 1; +cast(c1 as decimal(2, 2)) +-0.99 +-0.99 + 0.99 + 0.99 +mysql> drop table if exists test.t1; mysql> create table test.t1(c1 datetime(5)); mysql> insert into test.t1 values('2022-10-10 10:10:10.12345'); mysql> alter table test.t1 set tiflash replica 1; @@ -21,3 +38,4 @@ mysql> set @@tidb_isolation_read_engines='tiflash'; set @@tidb_enforce_mpp = 1; +------------------------------------+ | 20221010101010.123 | +------------------------------------+ +