From 23398ad762c38f9f4234d220b60b4ab1b3153b39 Mon Sep 17 00:00:00 2001 From: Liqi Geng Date: Thu, 10 Oct 2024 19:22:31 +0800 Subject: [PATCH] This is an automated cherry-pick of #9507 Signed-off-by: ti-chi-bot --- dbms/src/Functions/FunctionsString.cpp | 295 ++++++++++++++++-- .../src/Functions/tests/gtest_string_left.cpp | 43 +-- .../Functions/tests/gtest_strings_right.cpp | 43 +-- dbms/src/Functions/tests/gtest_substring.cpp | 157 +++++++++- dbms/src/TestUtils/FunctionTestUtils.h | 79 +++++ tests/fullstack-test/expr/substring_utf8.test | 20 +- 6 files changed, 566 insertions(+), 71 deletions(-) diff --git a/dbms/src/Functions/FunctionsString.cpp b/dbms/src/Functions/FunctionsString.cpp index 83e5d08307a..60ca493b2b6 100644 --- a/dbms/src/Functions/FunctionsString.cpp +++ b/dbms/src/Functions/FunctionsString.cpp @@ -1673,6 +1673,7 @@ class FunctionSubstringUTF8 : public IFunction bool implicit_length = (arguments.size() == 2); +<<<<<<< HEAD bool is_start_type_valid = getNumberType(block.getByPosition(arguments[1]).type, [&](const auto & start_type, bool) { using StartType = std::decay_t; // Int64 / UInt64 @@ -1692,6 +1693,48 @@ class FunctionSubstringUTF8 : public IFunction length = getValueFromLengthField((*block.getByPosition(arguments[2]).column)[0]); return true; }); +======= + bool is_start_type_valid + = getNumberType(block.getByPosition(arguments[1]).type, [&](const auto & start_type, bool) { + using StartType = std::decay_t; + using StartFieldType = typename StartType::FieldType; + const ColumnVector * column_vector_start + = getInnerColumnVector(column_start); + if unlikely (!column_vector_start) + throw Exception( + fmt::format( + "Illegal type {} of argument 2 of function {}", + block.getByPosition(arguments[1]).type->getName(), + getName()), + ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); + + // vector const const + if (!column_string->isColumnConst() && column_start->isColumnConst() + && (implicit_length || block.getByPosition(arguments[2]).column->isColumnConst())) + { + auto [is_positive, start_abs] = getValueFromStartColumn(*column_vector_start, 0); + UInt64 length = 0; + if (!implicit_length) + { + bool is_length_type_valid = getNumberType( + block.getByPosition(arguments[2]).type, + [&](const auto & length_type, bool) { + using LengthType = std::decay_t; + using LengthFieldType = typename LengthType::FieldType; + const ColumnVector * column_vector_length + = getInnerColumnVector(block.getByPosition(arguments[2]).column); + if unlikely (!column_vector_length) + throw Exception( + fmt::format( + "Illegal type {} of argument 3 of function {}", + block.getByPosition(arguments[2]).type->getName(), + getName()), + ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); + + length = getValueFromLengthColumn(*column_vector_length, 0); + return true; + }); +>>>>>>> 9f85bb2f16 (Support other integer types for SubstringUTF8 & RightUTF8 functions (#9507)) if (!is_length_type_valid) throw Exception(fmt::format("3nd argument of function {} must have UInt/Int type.", getName())); @@ -1704,6 +1747,7 @@ class FunctionSubstringUTF8 : public IFunction return true; } +<<<<<<< HEAD const auto * col = checkAndGetColumn(column_string.get()); assert(col); auto col_res = ColumnString::create(); @@ -1757,6 +1801,80 @@ class FunctionSubstringUTF8 : public IFunction if (!is_length_type_valid) throw Exception(fmt::format("3nd argument of function {} must have UInt/Int type.", getName())); } +======= + const auto * col = checkAndGetColumn(column_string.get()); + assert(col); + auto col_res = ColumnString::create(); + getVectorConstConstFunc(implicit_length, is_positive)( + col->getChars(), + col->getOffsets(), + start_abs, + length, + col_res->getChars(), + col_res->getOffsets()); + block.getByPosition(result).column = std::move(col_res); + } + else // all other cases are converted to vector vector vector + { + std::function(size_t)> get_start_func; + if (column_start->isColumnConst()) + { + // func always return const value + auto start_const = getValueFromStartColumn(*column_vector_start, 0); + get_start_func = [start_const](size_t) { + return start_const; + }; + } + else + { + get_start_func = [column_vector_start](size_t i) { + return getValueFromStartColumn(*column_vector_start, i); + }; + } + + // if implicit_length, get_length_func be nil is ok. + std::function get_length_func; + if (!implicit_length) + { + const ColumnPtr & column_length = block.getByPosition(arguments[2]).column; + bool is_length_type_valid = getNumberType( + block.getByPosition(arguments[2]).type, + [&](const auto & length_type, bool) { + using LengthType = std::decay_t; + using LengthFieldType = typename LengthType::FieldType; + const ColumnVector * column_vector_length + = getInnerColumnVector(column_length); + if unlikely (!column_vector_length) + throw Exception( + fmt::format( + "Illegal type {} of argument 3 of function {}", + block.getByPosition(arguments[2]).type->getName(), + getName()), + ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); + + if (column_length->isColumnConst()) + { + // func always return const value + auto length_const + = getValueFromLengthColumn(*column_vector_length, 0); + get_length_func = [length_const](size_t) { + return length_const; + }; + } + else + { + get_length_func = [column_vector_length](size_t i) { + return getValueFromLengthColumn(*column_vector_length, i); + }; + } + return true; + }); + + if unlikely (!is_length_type_valid) + throw Exception( + fmt::format("3nd argument of function {} must have UInt/Int type.", getName())); + } +>>>>>>> 9f85bb2f16 (Support other integer types for SubstringUTF8 & RightUTF8 functions (#9507)) // convert to vector if string is const. ColumnPtr full_column_string = column_string->isColumnConst() ? column_string->convertToFullColumnIfConst() : column_string; @@ -1777,10 +1895,38 @@ class FunctionSubstringUTF8 : public IFunction return true; }); - if (!is_start_type_valid) + if unlikely (!is_start_type_valid) throw Exception(fmt::format("2nd argument of function {} must have UInt/Int type.", getName())); } + template + static const ColumnVector * getInnerColumnVector(const ColumnPtr & column) + { + if (column->isColumnConst()) + return checkAndGetColumn>( + checkAndGetColumn(column.get())->getDataColumnPtr().get()); + return checkAndGetColumn>(column.get()); + } + + template + static size_t getValueFromLengthColumn(const ColumnVector & column, size_t index) + { + Integer val = column.getElement(index); + if constexpr ( + std::is_same_v || std::is_same_v || std::is_same_v + || std::is_same_v) + { + return val < 0 ? 0 : val; + } + else + { + static_assert( + std::is_same_v || std::is_same_v || std::is_same_v + || std::is_same_v); + return val; + } + } + private: using VectorConstConstFunc = std::function - static size_t getValueFromLengthField(const Field & length_field) - { - if constexpr (std::is_same_v) - { - Int64 signed_length = length_field.get(); - return signed_length < 0 ? 0 : signed_length; - } - else - { - static_assert(std::is_same_v); - return length_field.get(); - } - } - // return {is_positive, abs} template - static std::pair getValueFromStartField(const Field & start_field) + static std::pair getValueFromStartColumn(const ColumnVector & column, size_t index) { - if constexpr (std::is_same_v) + Integer val = column.getElement(index); + if constexpr ( + std::is_same_v || std::is_same_v || std::is_same_v + || std::is_same_v) { - Int64 signed_length = start_field.get(); - - if (signed_length < 0) - { - return {false, static_cast(-signed_length)}; - } - else - { - return {true, static_cast(signed_length)}; - } + if (val < 0) + return {false, static_cast(-val)}; + return {true, static_cast(val)}; } else { - static_assert(std::is_same_v); - return {true, start_field.get()}; + static_assert( + std::is_same_v || std::is_same_v || std::is_same_v + || std::is_same_v); + return {true, val}; } } @@ -1845,8 +1974,19 @@ class FunctionSubstringUTF8 : public IFunction static bool getNumberType(DataTypePtr type, F && f) { return castTypeToEither< +<<<<<<< HEAD DataTypeInt64, DataTypeUInt64>(type.get(), std::forward(f)); +======= + DataTypeUInt8, + DataTypeUInt16, + DataTypeUInt32, + DataTypeUInt64, + DataTypeInt8, + DataTypeInt16, + DataTypeInt32, + DataTypeInt64>(type.get(), std::forward(f)); +>>>>>>> 9f85bb2f16 (Support other integer types for SubstringUTF8 & RightUTF8 functions (#9507)) } }; @@ -1891,6 +2031,7 @@ class FunctionRightUTF8 : public IFunction const ColumnPtr column_string = block.getByPosition(arguments[0]).column; const ColumnPtr column_length = block.getByPosition(arguments[1]).column; +<<<<<<< HEAD bool is_length_type_valid = getLengthType(block.getByPosition(arguments[1]).type, [&](const auto & length_type, bool) { using LengthType = std::decay_t; // Int64 / UInt64 @@ -1903,6 +2044,33 @@ class FunctionRightUTF8 : public IFunction { // vector const size_t length = getValueFromLengthField((*column_length)[0]); +======= + bool is_length_type_valid + = getLengthType(block.getByPosition(arguments[1]).type, [&](const auto & length_type, bool) { + using LengthType = std::decay_t; + using LengthFieldType = typename LengthType::FieldType; + + const ColumnVector * column_vector_length + = FunctionSubstringUTF8::getInnerColumnVector(column_length); + if unlikely (!column_vector_length) + throw Exception( + fmt::format( + "Illegal type {} of argument 2 of function {}", + block.getByPosition(arguments[1]).type->getName(), + getName()), + ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); + + + auto col_res = ColumnString::create(); + if (const auto * col_string = checkAndGetColumn(column_string.get())) + { + if (column_length->isColumnConst()) + { + // vector const + size_t length = FunctionSubstringUTF8::getValueFromLengthColumn( + *column_vector_length, + 0); +>>>>>>> 9f85bb2f16 (Support other integer types for SubstringUTF8 & RightUTF8 functions (#9507)) // for const 0, return const blank string. if (0 == length) @@ -1911,6 +2079,7 @@ class FunctionRightUTF8 : public IFunction return true; } +<<<<<<< HEAD RightUTF8Impl::vectorConst(col_string->getChars(), col_string->getOffsets(), length, col_res->getChars(), col_res->getOffsets()); } else @@ -1942,6 +2111,61 @@ class FunctionRightUTF8 : public IFunction block.getByPosition(result).column = std::move(col_res); return true; }); +======= + RightUTF8Impl::vectorConst( + col_string->getChars(), + col_string->getOffsets(), + length, + col_res->getChars(), + col_res->getOffsets()); + } + else + { + // vector vector + auto get_length_func = [column_vector_length](size_t i) { + return FunctionSubstringUTF8::getValueFromLengthColumn( + *column_vector_length, + i); + }; + RightUTF8Impl::vectorVector( + col_string->getChars(), + col_string->getOffsets(), + get_length_func, + col_res->getChars(), + col_res->getOffsets()); + } + } + else if ( + const ColumnConst * col_const_string = checkAndGetColumnConst(column_string.get())) + { + // const vector + const auto * col_string_from_const + = checkAndGetColumn(col_const_string->getDataColumnPtr().get()); + assert(col_string_from_const); + // When useDefaultImplementationForConstants is true, string and length are not both constants + assert(!column_length->isColumnConst()); + auto get_length_func = [column_vector_length](size_t i) { + return FunctionSubstringUTF8::getValueFromLengthColumn( + *column_vector_length, + i); + }; + RightUTF8Impl::constVector( + column_length->size(), + col_string_from_const->getChars(), + col_string_from_const->getOffsets(), + get_length_func, + col_res->getChars(), + col_res->getOffsets()); + } + else + { + // Impossible to reach here + return false; + } + block.getByPosition(result).column = std::move(col_res); + return true; + }); +>>>>>>> 9f85bb2f16 (Support other integer types for SubstringUTF8 & RightUTF8 functions (#9507)) if (!is_length_type_valid) throw Exception(fmt::format("2nd argument of function {} must have UInt/Int type.", getName())); @@ -1953,6 +2177,7 @@ class FunctionRightUTF8 : public IFunction getLengthType(DataTypePtr type, F && f) { return castTypeToEither< +<<<<<<< HEAD DataTypeInt64, DataTypeUInt64>(type.get(), std::forward(f)); } @@ -1970,6 +2195,16 @@ class FunctionRightUTF8 : public IFunction static_assert(std::is_same_v); return length_field.get(); } +======= + DataTypeUInt8, + DataTypeUInt16, + DataTypeUInt32, + DataTypeUInt64, + DataTypeInt8, + DataTypeInt16, + DataTypeInt32, + DataTypeInt64>(type.get(), std::forward(f)); +>>>>>>> 9f85bb2f16 (Support other integer types for SubstringUTF8 & RightUTF8 functions (#9507)) } }; diff --git a/dbms/src/Functions/tests/gtest_string_left.cpp b/dbms/src/Functions/tests/gtest_string_left.cpp index 23600f4e8b2..96286c5441c 100644 --- a/dbms/src/Functions/tests/gtest_string_left.cpp +++ b/dbms/src/Functions/tests/gtest_string_left.cpp @@ -67,6 +67,7 @@ class StringLeftTest : public DB::tests::FunctionTest for (bool is_length_const : is_consts) inner_test(is_str_const, is_length_const); } +<<<<<<< HEAD template void testInvalidLengthType() @@ -85,12 +86,20 @@ class StringLeftTest : public DB::tests::FunctionTest for (bool is_length_const : is_consts) inner_test(is_str_const, is_length_const); } +======= +>>>>>>> 9f85bb2f16 (Support other integer types for SubstringUTF8 & RightUTF8 functions (#9507)) }; TEST_F(StringLeftTest, testBoundary) try { + testBoundary(); + testBoundary(); + testBoundary(); testBoundary(); + testBoundary(); + testBoundary(); + testBoundary(); testBoundary(); } CATCH @@ -98,6 +107,16 @@ CATCH TEST_F(StringLeftTest, testMoreCases) try { +#define CALL(A, B, C) \ + test(A, B, C); \ + test(A, B, C); \ + test(A, B, C); \ + test(A, B, C); \ + test(A, B, C); \ + test(A, B, C); \ + test(A, B, C); \ + test(A, B, C); + // test big string // big_string.size() > length String big_string; @@ -105,23 +124,19 @@ try String unit_string = "big string is 我!!!!!!!"; for (size_t i = 0; i < 1000; ++i) big_string += unit_string; - test(big_string, 22, unit_string); - test(big_string, 22, unit_string); + CALL(big_string, 22, unit_string); // test origin_str.size() == length String origin_str = "我的 size = 12"; - test(origin_str, 12, origin_str); - test(origin_str, 12, origin_str); + CALL(origin_str, 12, origin_str); // test origin_str.size() < length - test(origin_str, 22, origin_str); - test(origin_str, 22, origin_str); + CALL(origin_str, 22, origin_str); // Mixed language String english_str = "This is English"; String mixed_language_str = english_str + ",这是中文,C'est français,これが日本の"; - test(mixed_language_str, english_str.size(), english_str); - test(mixed_language_str, english_str.size(), english_str); + CALL(mixed_language_str, english_str.size(), english_str); // column size != 1 // case 1 @@ -145,18 +160,8 @@ try func_name, createConstColumn>(8, second_case_string), createColumn>({0, 1, 0, 1, 0, 0, 1, 1}))); -} -CATCH -TEST_F(StringLeftTest, testInvalidLengthType) -try -{ - testInvalidLengthType(); - testInvalidLengthType(); - testInvalidLengthType(); - testInvalidLengthType(); - testInvalidLengthType(); - testInvalidLengthType(); +#undef CALL } CATCH diff --git a/dbms/src/Functions/tests/gtest_strings_right.cpp b/dbms/src/Functions/tests/gtest_strings_right.cpp index aea13f2b03b..08e6058b36e 100644 --- a/dbms/src/Functions/tests/gtest_strings_right.cpp +++ b/dbms/src/Functions/tests/gtest_strings_right.cpp @@ -66,6 +66,7 @@ class StringRightTest : public DB::tests::FunctionTest for (bool is_length_const : is_consts) inner_test(is_str_const, is_length_const); } +<<<<<<< HEAD template void testInvalidLengthType() @@ -84,12 +85,20 @@ class StringRightTest : public DB::tests::FunctionTest for (bool is_length_const : is_consts) inner_test(is_str_const, is_length_const); } +======= +>>>>>>> 9f85bb2f16 (Support other integer types for SubstringUTF8 & RightUTF8 functions (#9507)) }; TEST_F(StringRightTest, testBoundary) try { + testBoundary(); + testBoundary(); + testBoundary(); testBoundary(); + testBoundary(); + testBoundary(); + testBoundary(); testBoundary(); } CATCH @@ -97,6 +106,16 @@ CATCH TEST_F(StringRightTest, testMoreCases) try { +#define CALL(A, B, C) \ + test(A, B, C); \ + test(A, B, C); \ + test(A, B, C); \ + test(A, B, C); \ + test(A, B, C); \ + test(A, B, C); \ + test(A, B, C); \ + test(A, B, C); + // test big string // big_string.size() > length String big_string; @@ -104,23 +123,19 @@ try String unit_string = "big string is 我!!!!!!!"; for (size_t i = 0; i < 1000; ++i) big_string += unit_string; - test(big_string, 22, unit_string); - test(big_string, 22, unit_string); + CALL(big_string, 22, unit_string); // test origin_str.size() == length String origin_str = "我的 size = 12"; - test(origin_str, 12, origin_str); - test(origin_str, 12, origin_str); + CALL(origin_str, 12, origin_str); // test origin_str.size() < length - test(origin_str, 22, origin_str); - test(origin_str, 22, origin_str); + CALL(origin_str, 22, origin_str); // Mixed language String english_str = "This is English"; String mixed_language_str = "这是中文,C'est français,これが日本の," + english_str; - test(mixed_language_str, english_str.size(), english_str); - test(mixed_language_str, english_str.size(), english_str); + CALL(mixed_language_str, english_str.size(), english_str); // column size != 1 // case 1 @@ -144,18 +159,8 @@ try func_name, createConstColumn>(8, second_case_string), createColumn>({0, 1, 0, 1, 0, 0, 1, 1}))); -} -CATCH -TEST_F(StringRightTest, testInvalidLengthType) -try -{ - testInvalidLengthType(); - testInvalidLengthType(); - testInvalidLengthType(); - testInvalidLengthType(); - testInvalidLengthType(); - testInvalidLengthType(); +#undef CALL } CATCH diff --git a/dbms/src/Functions/tests/gtest_substring.cpp b/dbms/src/Functions/tests/gtest_substring.cpp index b7a7e53ad08..326a33292d3 100644 --- a/dbms/src/Functions/tests/gtest_substring.cpp +++ b/dbms/src/Functions/tests/gtest_substring.cpp @@ -28,9 +28,160 @@ class SubString : public DB::tests::FunctionTest { }; +template +class TestNullableSigned +{ +public: + static void operator()(SubString & sub_string) + { + ASSERT_COLUMN_EQ( + createColumn>({"p.co", "ww.p", "pingcap", "com", ".com", "", "", "", {}, {}, {}}), + sub_string.executeFunction( + "substringUTF8", + createColumn>( + {"www.pingcap.com", + "ww.pingcap.com", + "w.pingcap.com", + ".pingcap.com", + "pingcap.com", + "pingcap.com", + "pingcap.com", + "pingcap.com", + {}, + "pingcap", + "pingcap"}), + createColumn({-5, 1, 3, -3, 8, 2, -100, 0, 2, {}, -3}), + createColumn({4, 4, 7, 4, 5, -5, 2, 3, 6, 4, {}}))); + } +}; + +template +class TestSigned +{ +public: + static void operator()(SubString & sub_string) + { + ASSERT_COLUMN_EQ( + createColumn>({"p.co", "ww.p", "pingcap", "com", ".com", "", "", "", {}}), + sub_string.executeFunction( + "substringUTF8", + createColumn>( + {"www.pingcap.com", + "ww.pingcap.com", + "w.pingcap.com", + ".pingcap.com", + "pingcap.com", + "pingcap.com", + "pingcap.com", + "pingcap.com", + {}}), + createColumn({-5, 1, 3, -3, 8, 2, -100, 0, 2}), + createColumn({4, 4, 7, 4, 5, -5, 2, 3, 6}))); + } +}; + +template +class TestNullableUnsigned +{ +public: + static void operator()(SubString & sub_string) + { + ASSERT_COLUMN_EQ( + createColumn>({"p.co", "ww.p", "pingcap", "com", ".com", "", "", {}, {}, {}}), + sub_string.executeFunction( + "substringUTF8", + createColumn>( + {"www.pingcap.com", + "ww.pingcap.com", + "w.pingcap.com", + ".pingcap.com", + "pingcap.com", + "pingcap.com", + "pingcap.com", + {}, + "pingcap", + "pingcap"}), + createColumn({11, 1, 3, 10, 8, 2, 0, 9, {}, 7}), + createColumn({4, 4, 7, 4, 5, 0, 3, 6, 1, {}}))); + } +}; + +template +class TestUnsigned +{ +public: + static void operator()(SubString & sub_string) + { + ASSERT_COLUMN_EQ( + createColumn>({"p.co", "ww.p", "pingcap", "com", ".com", "", "", {}}), + sub_string.executeFunction( + "substringUTF8", + createColumn>( + {"www.pingcap.com", + "ww.pingcap.com", + "w.pingcap.com", + ".pingcap.com", + "pingcap.com", + "pingcap.com", + "pingcap.com", + {}}), + createColumn({11, 1, 3, 10, 8, 2, 0, 2}), + createColumn({4, 4, 7, 4, 5, 0, 3, 1}))); + } +}; + +template +class TestConstPos +{ +public: + static void operator()(SubString & sub_string) + { + ASSERT_COLUMN_EQ( + createColumn>({"w", "ww", "w.p", ".pin"}), + sub_string.executeFunction( + "substringUTF8", + createColumn>({"www.pingcap.com", "ww.pingcap.com", "w.pingcap.com", ".pingcap.com"}), + createConstColumn(4, 1), + createColumn({1, 2, 3, 4}))); + } +}; + +template +class TestConstLength +{ +public: + static void operator()(SubString & sub_string) + { + ASSERT_COLUMN_EQ( + createColumn>({"www.", "w.pi", "ping", "ngca"}), + sub_string.executeFunction( + "substringUTF8", + createColumn>({"www.pingcap.com", "ww.pingcap.com", "w.pingcap.com", ".pingcap.com"}), + createColumn({1, 2, 3, 4}), + createConstColumn(4, 4))); + } +}; + TEST_F(SubString, subStringUTF8Test) try { + TestTypePair::run(*this); + TestTypePair::run(*this); + + TestTypePair::run(*this); + TestTypePair::run(*this); + TestTypePair::run(*this); + + TestTypePair::run(*this); + TestTypePair::run(*this); + TestTypePair::run(*this); + + TestTypePair::run(*this); + TestTypePair::run(*this); + + TestTypePair::run(*this); + TestTypePair::run(*this); + // column, const, const ASSERT_COLUMN_EQ( createColumn>({"www.", "ww.p", "w.pi", ".pin"}), @@ -39,6 +190,7 @@ try createColumn>({"www.pingcap.com", "ww.pingcap.com", "w.pingcap.com", ".pingcap.com"}), createConstColumn>(4, 1), createConstColumn>(4, 4))); + // const, const, const ASSERT_COLUMN_EQ( createConstColumn(1, "www."), @@ -47,6 +199,7 @@ try createConstColumn>(1, "www.pingcap.com"), createConstColumn>(1, 1), createConstColumn>(1, 4))); +<<<<<<< HEAD // Test Null ASSERT_COLUMN_EQ( createColumn>({{}, "www."}), @@ -56,8 +209,10 @@ try {{}, "www.pingcap.com"}), createConstColumn>(2, 1), createConstColumn>(2, 4))); +======= +>>>>>>> 9f85bb2f16 (Support other integer types for SubstringUTF8 & RightUTF8 functions (#9507)) } CATCH } // namespace tests -} // namespace DB \ No newline at end of file +} // namespace DB diff --git a/dbms/src/TestUtils/FunctionTestUtils.h b/dbms/src/TestUtils/FunctionTestUtils.h index 4e86076d91a..a29a74b3fb8 100644 --- a/dbms/src/TestUtils/FunctionTestUtils.h +++ b/dbms/src/TestUtils/FunctionTestUtils.h @@ -834,6 +834,85 @@ class FunctionTest : public ::testing::Test std::unique_ptr dag_context_ptr; }; +template +struct TestTypeList +{ +}; + +using TestNullableIntTypes = TestTypeList, Nullable, Nullable, Nullable>; + +using TestNullableUIntTypes = TestTypeList, Nullable, Nullable, Nullable>; + +using TestIntTypes = TestTypeList; + +using TestUIntTypes = TestTypeList; + +using TestAllIntTypes + = TestTypeList, Nullable, Nullable, Nullable, Int8, Int16, Int32, Int64>; + +using TestAllUIntTypes = TestTypeList< + Nullable, + Nullable, + Nullable, + Nullable, + UInt8, + UInt16, + UInt32, + UInt64>; + +template class Func, typename FuncParam> +struct TestTypeSingle; + +template class Func, typename FuncParam> +struct TestTypeSingle, Func, FuncParam> +{ + static void run(FuncParam & p) + { + Func::operator()(p); + // Recursively handle the rest of T2List + TestTypeSingle, Func, FuncParam>::run(p); + } +}; + +template class Func, typename FuncParam> +struct TestTypeSingle, Func, FuncParam> +{ + static void run(FuncParam &) + { + // Do nothing when T2List is empty + } +}; + +template class Func, typename FuncParam> +struct TestTypePair; + +template < + typename T1, + typename... T1Rest, + typename T2List, + template + class Func, + typename FuncParam> +struct TestTypePair, T2List, Func, FuncParam> +{ + static void run(FuncParam & p) + { + // For the current T1, traverse all types in T2List + TestTypeSingle::run(p); + // Recursively handle the rest of T1List + TestTypePair, T2List, Func, FuncParam>::run(p); + } +}; + +template class Func, typename FuncParam> +struct TestTypePair, T2List, Func, FuncParam> +{ + static void run(FuncParam &) + { + // Do nothing when T1List is empty + } +}; + #define ASSERT_COLUMN_EQ(expected, actual) ASSERT_TRUE(DB::tests::columnEqual((expected), (actual))) #define ASSERT_BLOCK_EQ(expected, actual) DB::tests::blockEqual((expected), (actual)) diff --git a/tests/fullstack-test/expr/substring_utf8.test b/tests/fullstack-test/expr/substring_utf8.test index b5bc0d837a2..e33a89b703f 100644 --- a/tests/fullstack-test/expr/substring_utf8.test +++ b/tests/fullstack-test/expr/substring_utf8.test @@ -13,12 +13,13 @@ # limitations under the License. mysql> drop table if exists test.t -mysql> create table test.t(a char(10)) -mysql> insert into test.t values(''), ('abc') +mysql> create table test.t(a char(10), b int, c tinyint unsigned) +mysql> insert into test.t values('', -3, 2), ('abc', -3, 2) mysql> alter table test.t set tiflash replica 1 func> wait_table test t +<<<<<<< HEAD mysql> set session tidb_isolation_read_engines='tiflash'; set tidb_allow_mpp=0; select * from test.t where substring(a, -3, 4) = 'abc' a abc @@ -28,6 +29,21 @@ a abc mysql> set session tidb_isolation_read_engines='tiflash'; set tidb_allow_mpp=0; select * from test.t where substring(a, -4, 3) = 'abc' +======= +mysql> set session tidb_isolation_read_engines='tiflash'; set tidb_allow_mpp=0; set tidb_allow_tiflash_cop = ON; select a from test.t where substring(a, -3, 4) = 'abc' +a +abc + +mysql> set session tidb_isolation_read_engines='tiflash'; set tidb_allow_mpp=0; set tidb_allow_tiflash_cop = ON; select a from test.t where substring(a, -3, 2) = 'ab' +a +abc + +mysql> set session tidb_isolation_read_engines='tiflash'; set tidb_allow_mpp=0; set tidb_allow_tiflash_cop = ON; select a from test.t where substring(a, b, c) = 'ab' +a +abc + +mysql> set session tidb_isolation_read_engines='tiflash'; set tidb_allow_mpp=0; set tidb_allow_tiflash_cop = ON; select a from test.t where substring(a, -4, 3) = 'abc' +>>>>>>> 9f85bb2f16 (Support other integer types for SubstringUTF8 & RightUTF8 functions (#9507)) # Empty mysql> set session tidb_isolation_read_engines='tiflash'; set tidb_allow_mpp=0; select count(*) from test.t where substring(a, 0, 3) = '' order by a