From eef9e22ecdfd6cb8a34c353eaf5ad1dba8406981 Mon Sep 17 00:00:00 2001 From: xzhangxian1008 Date: Sun, 31 Jul 2022 18:30:04 +0800 Subject: [PATCH 01/11] Add benchmark for collation (#5491) close pingcap/tiflash#5500 --- dbms/src/Functions/tests/bench_collation.cpp | 150 +++++++++++++++++++ 1 file changed, 150 insertions(+) create mode 100644 dbms/src/Functions/tests/bench_collation.cpp diff --git a/dbms/src/Functions/tests/bench_collation.cpp b/dbms/src/Functions/tests/bench_collation.cpp new file mode 100644 index 00000000000..a9054119dd2 --- /dev/null +++ b/dbms/src/Functions/tests/bench_collation.cpp @@ -0,0 +1,150 @@ +// Copyright 2022 PingCAP, Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include + +/// this is a hack, include the cpp file so we can test MatchImpl directly +#include // NOLINT + +namespace DB +{ +namespace tests +{ + +class CollationBench : public benchmark::Fixture +{ +public: + using ColStringType = typename TypeTraits::FieldType; + using ColUInt8Type = typename TypeTraits::FieldType; + + ColumnsWithTypeAndName data{toVec("col0", std::vector(1000000, "aaaaaaaaaaaaa")), + toVec("col1", std::vector(1000000, "aaaaaaaaaaaaa")), + toVec("result", std::vector{})}; + + ColumnsWithTypeAndName like_data{toVec("col0", std::vector(1000000, "qwdgefwabchfue")), + createConstColumn(1000000, "%abc%"), + createConstColumn(1000000, static_cast('\\')), + toVec("result", std::vector{})}; +}; + +class CollationLessBench : public CollationBench +{ +public: + void SetUp(const benchmark::State &) override {} +}; + +class CollationEqBench : public CollationBench +{ +public: + void SetUp(const benchmark::State &) override {} +}; + +class CollationLikeBench : public CollationBench +{ +public: + void SetUp(const benchmark::State &) override {} +}; + +#define BENCH_LESS_COLLATOR(collator) \ + BENCHMARK_DEFINE_F(CollationLessBench, collator) \ + (benchmark::State & state) \ + try \ + { \ + FunctionLess fl; \ + TiDB::TiDBCollatorPtr collator = TiDB::ITiDBCollator::getCollator(TiDB::ITiDBCollator::collator); \ + fl.setCollator(collator); \ + Block block(data); \ + ColumnNumbers arguments{0, 1}; \ + for (auto _ : state) \ + { \ + fl.executeImpl(block, arguments, 2); \ + } \ + } \ + CATCH \ + BENCHMARK_REGISTER_F(CollationLessBench, collator)->Iterations(10); + + +#define BENCH_EQ_COLLATOR(collator) \ + BENCHMARK_DEFINE_F(CollationEqBench, collator) \ + (benchmark::State & state) \ + try \ + { \ + FunctionEquals fe; \ + TiDB::TiDBCollatorPtr collator = TiDB::ITiDBCollator::getCollator(TiDB::ITiDBCollator::collator); \ + fe.setCollator(collator); \ + Block block(data); \ + ColumnNumbers arguments{0, 1}; \ + for (auto _ : state) \ + { \ + fe.executeImpl(block, arguments, 2); \ + } \ + } \ + CATCH \ + BENCHMARK_REGISTER_F(CollationEqBench, collator)->Iterations(10); + + +#define BENCH_LIKE_COLLATOR(collator) \ + BENCHMARK_DEFINE_F(CollationLikeBench, collator) \ + (benchmark::State & state) \ + try \ + { \ + FunctionLike3Args fl; \ + TiDB::TiDBCollatorPtr collator = TiDB::ITiDBCollator::getCollator(TiDB::ITiDBCollator::collator); \ + fl.setCollator(collator); \ + Block block(like_data); \ + ColumnNumbers arguments{0, 1, 2}; \ + for (auto _ : state) \ + { \ + fl.executeImpl(block, arguments, 3); \ + } \ + } \ + CATCH \ + BENCHMARK_REGISTER_F(CollationLikeBench, collator)->Iterations(10); + + +BENCH_LESS_COLLATOR(UTF8MB4_BIN); +BENCH_LESS_COLLATOR(UTF8MB4_GENERAL_CI); +BENCH_LESS_COLLATOR(UTF8MB4_UNICODE_CI); +BENCH_LESS_COLLATOR(UTF8_BIN); +BENCH_LESS_COLLATOR(UTF8_GENERAL_CI); +BENCH_LESS_COLLATOR(UTF8_UNICODE_CI); +BENCH_LESS_COLLATOR(ASCII_BIN); +BENCH_LESS_COLLATOR(BINARY); +BENCH_LESS_COLLATOR(LATIN1_BIN); + +BENCH_EQ_COLLATOR(UTF8MB4_BIN); +BENCH_EQ_COLLATOR(UTF8MB4_GENERAL_CI); +BENCH_EQ_COLLATOR(UTF8MB4_UNICODE_CI); +BENCH_EQ_COLLATOR(UTF8_BIN); +BENCH_EQ_COLLATOR(UTF8_GENERAL_CI); +BENCH_EQ_COLLATOR(UTF8_UNICODE_CI); +BENCH_EQ_COLLATOR(ASCII_BIN); +BENCH_EQ_COLLATOR(BINARY); +BENCH_EQ_COLLATOR(LATIN1_BIN); + +BENCH_LIKE_COLLATOR(UTF8MB4_BIN); +BENCH_LIKE_COLLATOR(UTF8MB4_GENERAL_CI); +BENCH_LIKE_COLLATOR(UTF8MB4_UNICODE_CI); +BENCH_LIKE_COLLATOR(UTF8_BIN); +BENCH_LIKE_COLLATOR(UTF8_GENERAL_CI); +BENCH_LIKE_COLLATOR(UTF8_UNICODE_CI); +BENCH_LIKE_COLLATOR(ASCII_BIN); +BENCH_LIKE_COLLATOR(BINARY); +BENCH_LIKE_COLLATOR(LATIN1_BIN); + +} // namespace tests +} // namespace DB From bebd45a85eed13cffc55275ffbd3f98f9a438f8e Mon Sep 17 00:00:00 2001 From: Fu Zhe Date: Tue, 2 Aug 2022 11:26:05 +0800 Subject: [PATCH 02/11] *: Combine LOG_XXX and LOG_FMT_XXX (#5512) ref pingcap/tiflash#5511 --- libs/libcommon/include/common/MacroUtils.h | 45 +++++++ libs/libcommon/include/common/logger_useful.h | 113 ++++++++---------- 2 files changed, 95 insertions(+), 63 deletions(-) create mode 100644 libs/libcommon/include/common/MacroUtils.h diff --git a/libs/libcommon/include/common/MacroUtils.h b/libs/libcommon/include/common/MacroUtils.h new file mode 100644 index 00000000000..e7466f62536 --- /dev/null +++ b/libs/libcommon/include/common/MacroUtils.h @@ -0,0 +1,45 @@ +// Copyright 2022 PingCAP, Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#define TF_GET_1ST_ARG(a, ...) a +#define TF_GET_2ND_ARG(a1, a2, ...) a2 +#define TF_GET_3RD_ARG(a1, a2, a3, ...) a3 +#define TF_GET_4TH_ARG(a1, a2, a3, a4, ...) a4 +#define TF_GET_5TH_ARG(a1, a2, a3, a4, a5, ...) a5 +#define TF_GET_6TH_ARG(a1, a2, a3, a4, a5, a6, ...) a6 +#define TF_GET_7TH_ARG(a1, a2, a3, a4, a5, a6, a7, ...) a7 +#define TF_GET_8TH_ARG(a1, a2, a3, a4, a5, a6, a7, a8, ...) a8 +#define TF_GET_9TH_ARG(a1, a2, a3, a4, a5, a6, a7, a8, a9, ...) a9 +#define TF_GET_10TH_ARG(a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, ...) a10 +#define TF_GET_11TH_ARG(a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, ...) a11 +#define TF_GET_12TH_ARG(a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, ...) a12 +#define TF_GET_13TH_ARG(a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, ...) a13 +#define TF_GET_14TH_ARG(a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, ...) a14 +#define TF_GET_15TH_ARG(a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, a15, ...) a15 +#define TF_GET_16TH_ARG(a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, a15, a16, ...) a16 +#define TF_GET_17TH_ARG(a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, a15, a16, a17, ...) a17 +#define TF_GET_18TH_ARG(a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, a15, a16, a17, a18, ...) a18 +#define TF_GET_19TH_ARG(a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, a15, a16, a17, a18, a19, ...) a19 +#define TF_GET_20TH_ARG(a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, a15, a16, a17, a18, a19, a20, ...) a20 +#define TF_GET_21TH_ARG(a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, a15, a16, a17, a18, a19, a20, a21, ...) a21 +#define TF_GET_22TH_ARG(a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, a15, a16, a17, a18, a19, a20, a21, a22, ...) a22 +#define TF_GET_23TH_ARG(a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, a15, a16, a17, a18, a19, a20, a21, a22, a23, ...) a23 +#define TF_GET_24TH_ARG(a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, a15, a16, a17, a18, a19, a20, a21, a22, a23, a24, ...) a24 +#define TF_GET_25TH_ARG(a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, a15, a16, a17, a18, a19, a20, a21, a22, a23, a24, a25, ...) a25 +#define TF_GET_26TH_ARG(a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, a15, a16, a17, a18, a19, a20, a21, a22, a23, a24, a25, a26, ...) a26 +#define TF_GET_27TH_ARG(a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, a15, a16, a17, a18, a19, a20, a21, a22, a23, a24, a25, a26, a27, ...) a27 +#define TF_GET_28TH_ARG(a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, a15, a16, a17, a18, a19, a20, a21, a22, a23, a24, a25, a26, a27, a28, ...) a28 +#define TF_GET_29TH_ARG(a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, a15, a16, a17, a18, a19, a20, a21, a22, a23, a24, a25, a26, a27, a28, a29, ...) a29 diff --git a/libs/libcommon/include/common/logger_useful.h b/libs/libcommon/include/common/logger_useful.h index e3981baf34c..44e7d45fca1 100644 --- a/libs/libcommon/include/common/logger_useful.h +++ b/libs/libcommon/include/common/logger_useful.h @@ -17,6 +17,7 @@ /// Macros for convenient usage of Poco logger. #include +#include #include #include @@ -26,17 +27,6 @@ namespace LogFmtDetails { -template -inline constexpr size_t numArgs(Ts &&...) -{ - return sizeof...(Ts); -} -template -inline constexpr auto firstArg(T && x, Ts &&...) -{ - return std::forward(x); -} - // https://stackoverflow.com/questions/8487986/file-macro-shows-full-path/54335644#54335644 template inline constexpr size_t getFileNameOffset(const T (&str)[S], size_t i = S - 1) @@ -50,8 +40,8 @@ inline constexpr size_t getFileNameOffset(T (&/*str*/)[1]) return 0; } -template -std::string toCheckedFmtStr(const S & format, const Ignored &, Args &&... args) +template +std::string toCheckedFmtStr(const S & format, Args &&... args) { // The second arg is the same as `format`, just ignore // Apply `make_args_checked` for checks `format` validity at compile time. @@ -60,61 +50,58 @@ std::string toCheckedFmtStr(const S & format, const Ignored &, Args &&... args) } } // namespace LogFmtDetails -/// Logs a message to a specified logger with that level. - -#define LOG_IMPL(logger, PRIORITY, message) \ - do \ - { \ - if ((logger)->is((PRIORITY))) \ - { \ - Poco::Message poco_message( \ - /*source*/ (logger)->name(), \ - /*text*/ message, \ - /*prio*/ (PRIORITY), \ - /*file*/ &__FILE__[LogFmtDetails::getFileNameOffset(__FILE__)], \ - /*line*/ __LINE__); \ - (logger)->log(poco_message); \ - } \ - } while (false) - -#define LOG_TRACE(logger, message) LOG_IMPL(logger, Poco::Message::PRIO_TRACE, message) -#define LOG_DEBUG(logger, message) LOG_IMPL(logger, Poco::Message::PRIO_DEBUG, message) -#define LOG_INFO(logger, message) LOG_IMPL(logger, Poco::Message::PRIO_INFORMATION, message) -#define LOG_WARNING(logger, message) LOG_IMPL(logger, Poco::Message::PRIO_WARNING, message) -#define LOG_ERROR(logger, message) LOG_IMPL(logger, Poco::Message::PRIO_ERROR, message) -#define LOG_FATAL(logger, message) LOG_IMPL(logger, Poco::Message::PRIO_FATAL, message) - - /// Logs a message to a specified logger with that level. /// If more than one argument is provided, /// the first argument is interpreted as template with {}-substitutions /// and the latter arguments treat as values to substitute. /// If only one argument is provided, it is threat as message without substitutions. -#define LOG_GET_FIRST_ARG(arg, ...) arg -#define LOG_FMT_IMPL(logger, PRIORITY, ...) \ - do \ - { \ - if ((logger)->is((PRIORITY))) \ - { \ - std::string formatted_message = LogFmtDetails::numArgs(__VA_ARGS__) > 1 \ - ? LogFmtDetails::toCheckedFmtStr( \ - FMT_STRING(LOG_GET_FIRST_ARG(__VA_ARGS__)), \ - __VA_ARGS__) \ - : LogFmtDetails::firstArg(__VA_ARGS__); \ - Poco::Message poco_message( \ - /*source*/ (logger)->name(), \ - /*text*/ formatted_message, \ - /*prio*/ (PRIORITY), \ - /*file*/ &__FILE__[LogFmtDetails::getFileNameOffset(__FILE__)], \ - /*line*/ __LINE__); \ - (logger)->log(poco_message); \ - } \ +#define LOG_INTERNAL(logger, PRIORITY, message) \ + do \ + { \ + Poco::Message poco_message( \ + /*source*/ (logger)->name(), \ + /*text*/ (message), \ + /*prio*/ (PRIORITY), \ + /*file*/ &__FILE__[LogFmtDetails::getFileNameOffset(__FILE__)], \ + /*line*/ __LINE__); \ + (logger)->log(poco_message); \ } while (false) -#define LOG_FMT_TRACE(logger, ...) LOG_FMT_IMPL(logger, Poco::Message::PRIO_TRACE, __VA_ARGS__) -#define LOG_FMT_DEBUG(logger, ...) LOG_FMT_IMPL(logger, Poco::Message::PRIO_DEBUG, __VA_ARGS__) -#define LOG_FMT_INFO(logger, ...) LOG_FMT_IMPL(logger, Poco::Message::PRIO_INFORMATION, __VA_ARGS__) -#define LOG_FMT_WARNING(logger, ...) LOG_FMT_IMPL(logger, Poco::Message::PRIO_WARNING, __VA_ARGS__) -#define LOG_FMT_ERROR(logger, ...) LOG_FMT_IMPL(logger, Poco::Message::PRIO_ERROR, __VA_ARGS__) -#define LOG_FMT_FATAL(logger, ...) LOG_FMT_IMPL(logger, Poco::Message::PRIO_FATAL, __VA_ARGS__) + +#define LOG_IMPL_0(logger, PRIORITY, message) \ + do \ + { \ + if ((logger)->is((PRIORITY))) \ + LOG_INTERNAL(logger, PRIORITY, message); \ + } while (false) + +#define LOG_IMPL_1(logger, PRIORITY, fmt_str, ...) \ + do \ + { \ + if ((logger)->is((PRIORITY))) \ + { \ + auto _message = LogFmtDetails::toCheckedFmtStr(FMT_STRING(fmt_str), __VA_ARGS__); \ + LOG_INTERNAL(logger, PRIORITY, _message); \ + } \ + } while (false) + +#define LOG_IMPL_CHOSER(...) TF_GET_29TH_ARG(__VA_ARGS__, LOG_IMPL_1, LOG_IMPL_1, LOG_IMPL_1, LOG_IMPL_1, LOG_IMPL_1, LOG_IMPL_1, LOG_IMPL_1, LOG_IMPL_1, LOG_IMPL_1, LOG_IMPL_1, LOG_IMPL_1, LOG_IMPL_1, LOG_IMPL_1, LOG_IMPL_1, LOG_IMPL_1, LOG_IMPL_1, LOG_IMPL_1, LOG_IMPL_1, LOG_IMPL_1, LOG_IMPL_1, LOG_IMPL_1, LOG_IMPL_1, LOG_IMPL_1, LOG_IMPL_1, LOG_IMPL_1, LOG_IMPL_1, LOG_IMPL_1, LOG_IMPL_0) + +// clang-format off +#define LOG_IMPL(logger, PRIORITY, ...) LOG_IMPL_CHOSER(__VA_ARGS__)(logger, PRIORITY, __VA_ARGS__) +// clang-format on + +#define LOG_TRACE(logger, ...) LOG_IMPL(logger, Poco::Message::PRIO_TRACE, __VA_ARGS__) +#define LOG_DEBUG(logger, ...) LOG_IMPL(logger, Poco::Message::PRIO_DEBUG, __VA_ARGS__) +#define LOG_INFO(logger, ...) LOG_IMPL(logger, Poco::Message::PRIO_INFORMATION, __VA_ARGS__) +#define LOG_WARNING(logger, ...) LOG_IMPL(logger, Poco::Message::PRIO_WARNING, __VA_ARGS__) +#define LOG_ERROR(logger, ...) LOG_IMPL(logger, Poco::Message::PRIO_ERROR, __VA_ARGS__) +#define LOG_FATAL(logger, ...) LOG_IMPL(logger, Poco::Message::PRIO_FATAL, __VA_ARGS__) + +#define LOG_FMT_TRACE(...) LOG_TRACE(__VA_ARGS__) +#define LOG_FMT_DEBUG(...) LOG_DEBUG(__VA_ARGS__) +#define LOG_FMT_INFO(...) LOG_INFO(__VA_ARGS__) +#define LOG_FMT_WARNING(...) LOG_WARNING(__VA_ARGS__) +#define LOG_FMT_ERROR(...) LOG_ERROR(__VA_ARGS__) +#define LOG_FMT_FATAL(...) LOG_FATAL(__VA_ARGS__) From 4972cf3faa4c53cf0341ce6c3ea1f3e8750b9e8b Mon Sep 17 00:00:00 2001 From: yanweiqi <592838129@qq.com> Date: Tue, 2 Aug 2022 12:08:05 +0800 Subject: [PATCH 03/11] *: decouple FlashGrpcServerHolder from Server.cpp (#5516) ref pingcap/tiflash#4609 --- dbms/src/Flash/DiagnosticsService.cpp | 12 +- dbms/src/Flash/DiagnosticsService.h | 8 +- dbms/src/Server/CMakeLists.txt | 1 + dbms/src/Server/FlashGrpcServerHolder.cpp | 198 ++++++++++++++++++++++ dbms/src/Server/FlashGrpcServerHolder.h | 48 ++++++ dbms/src/Server/Server.cpp | 194 +-------------------- dbms/src/Server/Server.h | 2 +- 7 files changed, 260 insertions(+), 203 deletions(-) create mode 100644 dbms/src/Server/FlashGrpcServerHolder.cpp create mode 100644 dbms/src/Server/FlashGrpcServerHolder.h diff --git a/dbms/src/Flash/DiagnosticsService.cpp b/dbms/src/Flash/DiagnosticsService.cpp index 937f2794fa8..11de7687e46 100644 --- a/dbms/src/Flash/DiagnosticsService.cpp +++ b/dbms/src/Flash/DiagnosticsService.cpp @@ -38,7 +38,7 @@ ::grpc::Status DiagnosticsService::server_info( ::diagnosticspb::ServerInfoResponse * response) try { - const TiFlashRaftProxyHelper * helper = server.context().getTMTContext().getKVStore()->getProxyHelper(); + const TiFlashRaftProxyHelper * helper = context.getTMTContext().getKVStore()->getProxyHelper(); if (helper) { std::string req = request->SerializeAsString(); @@ -63,18 +63,18 @@ catch (const std::exception & e) } // get & filter(ts of last record < start-time) all files in same log directory. -std::list getFilesToSearch(IServer & server, Poco::Logger * log, const int64_t start_time) +std::list getFilesToSearch(Poco::Util::LayeredConfiguration & config, Poco::Logger * log, const int64_t start_time) { std::list files_to_search; std::string log_dir; // log directory - auto error_log_file_prefix = server.config().getString("logger.errorlog", "*"); - auto tracing_log_file_prefix = server.config().getString("logger.tracing_log", "*"); + auto error_log_file_prefix = config.getString("logger.errorlog", "*"); + auto tracing_log_file_prefix = config.getString("logger.tracing_log", "*"); // ignore tiflash error log and mpp task tracing log std::vector ignore_log_file_prefixes = {error_log_file_prefix, tracing_log_file_prefix}; { - auto log_file_prefix = server.config().getString("logger.log"); + auto log_file_prefix = config.getString("logger.log"); if (auto it = log_file_prefix.rfind('/'); it != std::string::npos) { log_dir = std::string(log_file_prefix.begin(), log_file_prefix.begin() + it); @@ -163,7 +163,7 @@ ::grpc::Status DiagnosticsService::search_log( LOG_FMT_DEBUG(log, "Handling SearchLog done: {}", request->DebugString()); }); - auto files_to_search = getFilesToSearch(server, log, start_time); + auto files_to_search = getFilesToSearch(config, log, start_time); for (const auto & path : files_to_search) { diff --git a/dbms/src/Flash/DiagnosticsService.h b/dbms/src/Flash/DiagnosticsService.h index a48e1e51a0c..1bba7d63c53 100644 --- a/dbms/src/Flash/DiagnosticsService.h +++ b/dbms/src/Flash/DiagnosticsService.h @@ -32,9 +32,10 @@ class DiagnosticsService final : public ::diagnosticspb::Diagnostics::Service , private boost::noncopyable { public: - explicit DiagnosticsService(IServer & _server) + explicit DiagnosticsService(Context & context_, Poco::Util::LayeredConfiguration & config_) : log(&Poco::Logger::get("DiagnosticsService")) - , server(_server) + , context(context_) + , config(config_) {} ~DiagnosticsService() override = default; @@ -51,8 +52,9 @@ class DiagnosticsService final : public ::diagnosticspb::Diagnostics::Service private: Poco::Logger * log; + Context & context; - IServer & server; + Poco::Util::LayeredConfiguration & config; }; } // namespace DB diff --git a/dbms/src/Server/CMakeLists.txt b/dbms/src/Server/CMakeLists.txt index 2948bb076db..77ab5e69838 100644 --- a/dbms/src/Server/CMakeLists.txt +++ b/dbms/src/Server/CMakeLists.txt @@ -29,6 +29,7 @@ configure_file (config_tools.h.in ${CMAKE_CURRENT_BINARY_DIR}/config_tools.h) add_library (clickhouse-server-lib HTTPHandler.cpp + FlashGrpcServerHolder.cpp MetricsTransmitter.cpp MetricsPrometheus.cpp NotFoundHandler.cpp diff --git a/dbms/src/Server/FlashGrpcServerHolder.cpp b/dbms/src/Server/FlashGrpcServerHolder.cpp new file mode 100644 index 00000000000..c82f79976e8 --- /dev/null +++ b/dbms/src/Server/FlashGrpcServerHolder.cpp @@ -0,0 +1,198 @@ +// Copyright 2022 PingCAP, Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include +namespace DB +{ +namespace ErrorCodes +{ +extern const int IP_ADDRESS_NOT_ALLOWED; +} // namespace ErrorCodes +namespace +{ +void handleRpcs(grpc::ServerCompletionQueue * curcq, const LoggerPtr & log) +{ + GET_METRIC(tiflash_thread_count, type_total_rpc_async_worker).Increment(); + SCOPE_EXIT({ + GET_METRIC(tiflash_thread_count, type_total_rpc_async_worker).Decrement(); + }); + void * tag = nullptr; // uniquely identifies a request. + bool ok = false; + while (true) + { + String err_msg; + try + { + // Block waiting to read the next event from the completion queue. The + // event is uniquely identified by its tag, which in this case is the + // memory address of a EstablishCallData instance. + // The return value of Next should always be checked. This return value + // tells us whether there is any kind of event or cq is shutting down. + if (!curcq->Next(&tag, &ok)) + { + LOG_FMT_INFO(log, "CQ is fully drained and shut down"); + break; + } + GET_METRIC(tiflash_thread_count, type_active_rpc_async_worker).Increment(); + SCOPE_EXIT({ + GET_METRIC(tiflash_thread_count, type_active_rpc_async_worker).Decrement(); + }); + // If ok is false, it means server is shutdown. + // We need not log all not ok events, since the volumn is large which will pollute the content of log. + if (ok) + static_cast(tag)->proceed(); + else + static_cast(tag)->cancel(); + } + catch (Exception & e) + { + err_msg = e.displayText(); + LOG_FMT_ERROR(log, "handleRpcs meets error: {} Stack Trace : {}", err_msg, e.getStackTrace().toString()); + } + catch (pingcap::Exception & e) + { + err_msg = e.message(); + LOG_FMT_ERROR(log, "handleRpcs meets error: {}", err_msg); + } + catch (std::exception & e) + { + err_msg = e.what(); + LOG_FMT_ERROR(log, "handleRpcs meets error: {}", err_msg); + } + catch (...) + { + err_msg = "unrecovered error"; + LOG_FMT_ERROR(log, "handleRpcs meets error: {}", err_msg); + throw; + } + } +} +} // namespace + +FlashGrpcServerHolder::FlashGrpcServerHolder(Context & context, Poco::Util::LayeredConfiguration & config_, TiFlashSecurityConfig & security_config, const TiFlashRaftConfig & raft_config, const LoggerPtr & log_) + : log(log_) + , is_shutdown(std::make_shared>(false)) +{ + grpc::ServerBuilder builder; + if (security_config.has_tls_config) + { + grpc::SslServerCredentialsOptions server_cred(GRPC_SSL_REQUEST_AND_REQUIRE_CLIENT_CERTIFICATE_AND_VERIFY); + auto options = security_config.readAndCacheSecurityInfo(); + server_cred.pem_root_certs = options.pem_root_certs; + server_cred.pem_key_cert_pairs.push_back( + grpc::SslServerCredentialsOptions::PemKeyCertPair{options.pem_private_key, options.pem_cert_chain}); + builder.AddListeningPort(raft_config.flash_server_addr, grpc::SslServerCredentials(server_cred)); + } + else + { + builder.AddListeningPort(raft_config.flash_server_addr, grpc::InsecureServerCredentials()); + } + + /// Init and register flash service. + bool enable_async_server = context.getSettingsRef().enable_async_server; + if (enable_async_server) + flash_service = std::make_unique(security_config, context); + else + flash_service = std::make_unique(security_config, context); + diagnostics_service = std::make_unique(context, config_); + builder.SetOption(grpc::MakeChannelArgumentOption(GRPC_ARG_HTTP2_MIN_RECV_PING_INTERVAL_WITHOUT_DATA_MS, 5 * 1000)); + builder.SetOption(grpc::MakeChannelArgumentOption(GRPC_ARG_HTTP2_MIN_SENT_PING_INTERVAL_WITHOUT_DATA_MS, 10 * 1000)); + builder.SetOption(grpc::MakeChannelArgumentOption(GRPC_ARG_KEEPALIVE_PERMIT_WITHOUT_CALLS, 1)); + // number of grpc thread pool's non-temporary threads, better tune it up to avoid frequent creation/destruction of threads + auto max_grpc_pollers = context.getSettingsRef().max_grpc_pollers; + if (max_grpc_pollers > 0 && max_grpc_pollers <= std::numeric_limits::max()) + builder.SetSyncServerOption(grpc::ServerBuilder::SyncServerOption::MAX_POLLERS, max_grpc_pollers); + builder.RegisterService(flash_service.get()); + LOG_FMT_INFO(log, "Flash service registered"); + builder.RegisterService(diagnostics_service.get()); + LOG_FMT_INFO(log, "Diagnostics service registered"); + + /// Kick off grpc server. + // Prevent TiKV from throwing "Received message larger than max (4404462 vs. 4194304)" error. + builder.SetMaxReceiveMessageSize(-1); + builder.SetMaxSendMessageSize(-1); + thread_manager = DB::newThreadManager(); + int async_cq_num = context.getSettingsRef().async_cqs; + if (enable_async_server) + { + for (int i = 0; i < async_cq_num; ++i) + { + cqs.emplace_back(builder.AddCompletionQueue()); + notify_cqs.emplace_back(builder.AddCompletionQueue()); + } + } + flash_grpc_server = builder.BuildAndStart(); + if (!flash_grpc_server) + { + throw Exception("Exception happens when start grpc server, the flash.service_addr may be invalid, flash.service_addr is " + raft_config.flash_server_addr, ErrorCodes::IP_ADDRESS_NOT_ALLOWED); + } + LOG_FMT_INFO(log, "Flash grpc server listening on [{}]", raft_config.flash_server_addr); + Debug::setServiceAddr(raft_config.flash_server_addr); + if (enable_async_server) + { + int preallocated_request_count_per_poller = context.getSettingsRef().preallocated_request_count_per_poller; + int pollers_per_cq = context.getSettingsRef().async_pollers_per_cq; + for (int i = 0; i < async_cq_num * pollers_per_cq; ++i) + { + auto * cq = cqs[i / pollers_per_cq].get(); + auto * notify_cq = notify_cqs[i / pollers_per_cq].get(); + for (int j = 0; j < preallocated_request_count_per_poller; ++j) + { + // EstablishCallData will handle its lifecycle by itself. + EstablishCallData::spawn(assert_cast(flash_service.get()), cq, notify_cq, is_shutdown); + } + thread_manager->schedule(false, "async_poller", [cq, this] { handleRpcs(cq, log); }); + thread_manager->schedule(false, "async_poller", [notify_cq, this] { handleRpcs(notify_cq, log); }); + } + } +} + +FlashGrpcServerHolder::~FlashGrpcServerHolder() +{ + try + { + /// Shut down grpc server. + LOG_FMT_INFO(log, "Begin to shut down flash grpc server"); + flash_grpc_server->Shutdown(); + *is_shutdown = true; + // Wait all existed MPPTunnels done to prevent crash. + // If all existed MPPTunnels are done, almost in all cases it means all existed MPPTasks and ExchangeReceivers are also done. + const int max_wait_cnt = 300; + int wait_cnt = 0; + while (GET_METRIC(tiflash_object_count, type_count_of_mpptunnel).Value() >= 1 && (wait_cnt++ < max_wait_cnt)) + std::this_thread::sleep_for(std::chrono::seconds(1)); + + for (auto & cq : cqs) + cq->Shutdown(); + for (auto & cq : notify_cqs) + cq->Shutdown(); + thread_manager->wait(); + flash_grpc_server->Wait(); + flash_grpc_server.reset(); + if (GRPCCompletionQueuePool::global_instance) + GRPCCompletionQueuePool::global_instance->markShutdown(); + LOG_FMT_INFO(log, "Shut down flash grpc server"); + + /// Close flash service. + LOG_FMT_INFO(log, "Begin to shut down flash service"); + flash_service.reset(); + LOG_FMT_INFO(log, "Shut down flash service"); + } + catch (...) + { + auto message = getCurrentExceptionMessage(false); + LOG_FMT_FATAL(log, "Exception happens in destructor of FlashGrpcServerHolder with message: {}", message); + std::terminate(); + } +} +} // namespace DB \ No newline at end of file diff --git a/dbms/src/Server/FlashGrpcServerHolder.h b/dbms/src/Server/FlashGrpcServerHolder.h new file mode 100644 index 00000000000..81c50dc609b --- /dev/null +++ b/dbms/src/Server/FlashGrpcServerHolder.h @@ -0,0 +1,48 @@ +// Copyright 2022 PingCAP, Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include +#include +#include +#include +#include +#include + +namespace DB +{ +class FlashGrpcServerHolder +{ +public: + FlashGrpcServerHolder( + Context & context, + Poco::Util::LayeredConfiguration & config_, + TiFlashSecurityConfig & security_config, + const TiFlashRaftConfig & raft_config, + const LoggerPtr & log_); + ~FlashGrpcServerHolder(); + +private: + const LoggerPtr & log; + std::shared_ptr> is_shutdown; + std::unique_ptr flash_service = nullptr; + std::unique_ptr diagnostics_service = nullptr; + std::unique_ptr flash_grpc_server = nullptr; + // cqs and notify_cqs are used for processing async grpc events (currently only EstablishMPPConnection). + std::vector> cqs; + std::vector> notify_cqs; + std::shared_ptr thread_manager; +}; + +} // namespace DB \ No newline at end of file diff --git a/dbms/src/Server/Server.cpp b/dbms/src/Server/Server.cpp index 24b0dfd2a69..607b7e3e6c8 100644 --- a/dbms/src/Server/Server.cpp +++ b/dbms/src/Server/Server.cpp @@ -194,7 +194,6 @@ extern const int NO_ELEMENTS_IN_CONFIG; extern const int SUPPORT_IS_DISABLED; extern const int ARGUMENT_OUT_OF_BOUND; extern const int INVALID_CONFIG_PARAMETER; -extern const int IP_ADDRESS_NOT_ALLOWED; } // namespace ErrorCodes namespace Debug @@ -518,196 +517,6 @@ void initStores(Context & global_context, const LoggerPtr & log, bool lazily_ini } } -void handleRpcs(grpc::ServerCompletionQueue * curcq, const LoggerPtr & log) -{ - GET_METRIC(tiflash_thread_count, type_total_rpc_async_worker).Increment(); - SCOPE_EXIT({ - GET_METRIC(tiflash_thread_count, type_total_rpc_async_worker).Decrement(); - }); - void * tag = nullptr; // uniquely identifies a request. - bool ok = false; - while (true) - { - String err_msg; - try - { - // Block waiting to read the next event from the completion queue. The - // event is uniquely identified by its tag, which in this case is the - // memory address of a EstablishCallData instance. - // The return value of Next should always be checked. This return value - // tells us whether there is any kind of event or cq is shutting down. - if (!curcq->Next(&tag, &ok)) - { - LOG_FMT_INFO(grpc_log, "CQ is fully drained and shut down"); - break; - } - GET_METRIC(tiflash_thread_count, type_active_rpc_async_worker).Increment(); - SCOPE_EXIT({ - GET_METRIC(tiflash_thread_count, type_active_rpc_async_worker).Decrement(); - }); - // If ok is false, it means server is shutdown. - // We need not log all not ok events, since the volumn is large which will pollute the content of log. - if (ok) - static_cast(tag)->proceed(); - else - static_cast(tag)->cancel(); - } - catch (Exception & e) - { - err_msg = e.displayText(); - LOG_FMT_ERROR(log, "handleRpcs meets error: {} Stack Trace : {}", err_msg, e.getStackTrace().toString()); - } - catch (pingcap::Exception & e) - { - err_msg = e.message(); - LOG_FMT_ERROR(log, "handleRpcs meets error: {}", err_msg); - } - catch (std::exception & e) - { - err_msg = e.what(); - LOG_FMT_ERROR(log, "handleRpcs meets error: {}", err_msg); - } - catch (...) - { - err_msg = "unrecovered error"; - LOG_FMT_ERROR(log, "handleRpcs meets error: {}", err_msg); - throw; - } - } -} - -class Server::FlashGrpcServerHolder -{ -public: - FlashGrpcServerHolder(Server & server, const TiFlashRaftConfig & raft_config, const LoggerPtr & log_) - : log(log_) - , is_shutdown(std::make_shared>(false)) - { - grpc::ServerBuilder builder; - if (server.security_config.has_tls_config) - { - grpc::SslServerCredentialsOptions server_cred(GRPC_SSL_REQUEST_AND_REQUIRE_CLIENT_CERTIFICATE_AND_VERIFY); - auto options = server.security_config.readAndCacheSecurityInfo(); - server_cred.pem_root_certs = options.pem_root_certs; - server_cred.pem_key_cert_pairs.push_back( - grpc::SslServerCredentialsOptions::PemKeyCertPair{options.pem_private_key, options.pem_cert_chain}); - builder.AddListeningPort(raft_config.flash_server_addr, grpc::SslServerCredentials(server_cred)); - } - else - { - builder.AddListeningPort(raft_config.flash_server_addr, grpc::InsecureServerCredentials()); - } - - /// Init and register flash service. - bool enable_async_server = server.context().getSettingsRef().enable_async_server; - if (enable_async_server) - flash_service = std::make_unique(server.securityConfig(), server.context()); - else - flash_service = std::make_unique(server.securityConfig(), server.context()); - diagnostics_service = std::make_unique(server); - builder.SetOption(grpc::MakeChannelArgumentOption(GRPC_ARG_HTTP2_MIN_RECV_PING_INTERVAL_WITHOUT_DATA_MS, 5 * 1000)); - builder.SetOption(grpc::MakeChannelArgumentOption(GRPC_ARG_HTTP2_MIN_SENT_PING_INTERVAL_WITHOUT_DATA_MS, 10 * 1000)); - builder.SetOption(grpc::MakeChannelArgumentOption(GRPC_ARG_KEEPALIVE_PERMIT_WITHOUT_CALLS, 1)); - // number of grpc thread pool's non-temporary threads, better tune it up to avoid frequent creation/destruction of threads - auto max_grpc_pollers = server.context().getSettingsRef().max_grpc_pollers; - if (max_grpc_pollers > 0 && max_grpc_pollers <= std::numeric_limits::max()) - builder.SetSyncServerOption(grpc::ServerBuilder::SyncServerOption::MAX_POLLERS, max_grpc_pollers); - builder.RegisterService(flash_service.get()); - LOG_FMT_INFO(log, "Flash service registered"); - builder.RegisterService(diagnostics_service.get()); - LOG_FMT_INFO(log, "Diagnostics service registered"); - - /// Kick off grpc server. - // Prevent TiKV from throwing "Received message larger than max (4404462 vs. 4194304)" error. - builder.SetMaxReceiveMessageSize(-1); - builder.SetMaxSendMessageSize(-1); - thread_manager = DB::newThreadManager(); - int async_cq_num = server.context().getSettingsRef().async_cqs; - if (enable_async_server) - { - for (int i = 0; i < async_cq_num; ++i) - { - cqs.emplace_back(builder.AddCompletionQueue()); - notify_cqs.emplace_back(builder.AddCompletionQueue()); - } - } - flash_grpc_server = builder.BuildAndStart(); - if (!flash_grpc_server) - { - throw Exception("Exception happens when start grpc server, the flash.service_addr may be invalid, flash.service_addr is " + raft_config.flash_server_addr, ErrorCodes::IP_ADDRESS_NOT_ALLOWED); - } - LOG_FMT_INFO(log, "Flash grpc server listening on [{}]", raft_config.flash_server_addr); - Debug::setServiceAddr(raft_config.flash_server_addr); - if (enable_async_server) - { - int preallocated_request_count_per_poller = server.context().getSettingsRef().preallocated_request_count_per_poller; - int pollers_per_cq = server.context().getSettingsRef().async_pollers_per_cq; - for (int i = 0; i < async_cq_num * pollers_per_cq; ++i) - { - auto * cq = cqs[i / pollers_per_cq].get(); - auto * notify_cq = notify_cqs[i / pollers_per_cq].get(); - for (int j = 0; j < preallocated_request_count_per_poller; ++j) - { - // EstablishCallData will handle its lifecycle by itself. - EstablishCallData::spawn(assert_cast(flash_service.get()), cq, notify_cq, is_shutdown); - } - thread_manager->schedule(false, "async_poller", [cq, this] { handleRpcs(cq, log); }); - thread_manager->schedule(false, "async_poller", [notify_cq, this] { handleRpcs(notify_cq, log); }); - } - } - } - - ~FlashGrpcServerHolder() - { - try - { - /// Shut down grpc server. - LOG_FMT_INFO(log, "Begin to shut down flash grpc server"); - flash_grpc_server->Shutdown(); - *is_shutdown = true; - // Wait all existed MPPTunnels done to prevent crash. - // If all existed MPPTunnels are done, almost in all cases it means all existed MPPTasks and ExchangeReceivers are also done. - const int max_wait_cnt = 300; - int wait_cnt = 0; - while (GET_METRIC(tiflash_object_count, type_count_of_mpptunnel).Value() >= 1 && (wait_cnt++ < max_wait_cnt)) - std::this_thread::sleep_for(std::chrono::seconds(1)); - - for (auto & cq : cqs) - cq->Shutdown(); - for (auto & cq : notify_cqs) - cq->Shutdown(); - thread_manager->wait(); - flash_grpc_server->Wait(); - flash_grpc_server.reset(); - if (GRPCCompletionQueuePool::global_instance) - GRPCCompletionQueuePool::global_instance->markShutdown(); - LOG_FMT_INFO(log, "Shut down flash grpc server"); - - /// Close flash service. - LOG_FMT_INFO(log, "Begin to shut down flash service"); - flash_service.reset(); - LOG_FMT_INFO(log, "Shut down flash service"); - } - catch (...) - { - auto message = getCurrentExceptionMessage(false); - LOG_FMT_FATAL(log, "Exception happens in destructor of FlashGrpcServerHolder with message: {}", message); - std::terminate(); - } - } - -private: - const LoggerPtr & log; - std::shared_ptr> is_shutdown; - std::unique_ptr flash_service = nullptr; - std::unique_ptr diagnostics_service = nullptr; - std::unique_ptr flash_grpc_server = nullptr; - // cqs and notify_cqs are used for processing async grpc events (currently only EstablishMPPConnection). - std::vector> cqs; - std::vector> notify_cqs; - std::shared_ptr thread_manager; -}; - class Server::TcpHttpServersHolder { public: @@ -1080,7 +889,6 @@ int Server::main(const std::vector & /*args*/) LOG_FMT_INFO(log, "TiFlashRaftProxyHelper is null, failed to get server info"); } - // print necessary grpc log. grpc_log = Logger::get("grpc"); gpr_set_log_verbosity(GPR_LOG_SEVERITY_DEBUG); gpr_set_log_function(&printGRPCLog); @@ -1432,7 +1240,7 @@ int Server::main(const std::vector & /*args*/) } /// Then, startup grpc server to serve raft and/or flash services. - FlashGrpcServerHolder flash_grpc_server_holder(*this, raft_config, log); + FlashGrpcServerHolder flash_grpc_server_holder(this->context(), this->config(), this->security_config, raft_config, log); { TcpHttpServersHolder tcpHttpServersHolder(*this, settings, log); diff --git a/dbms/src/Server/Server.h b/dbms/src/Server/Server.h index 07c5b955a92..9f083d298cf 100644 --- a/dbms/src/Server/Server.h +++ b/dbms/src/Server/Server.h @@ -14,6 +14,7 @@ #pragma once +#include #include #include #include @@ -72,7 +73,6 @@ class Server : public BaseDaemon ServerInfo server_info; - class FlashGrpcServerHolder; class TcpHttpServersHolder; }; From 30fc64c323d262e0faad10150d31025d06efe517 Mon Sep 17 00:00:00 2001 From: Zhigao Tong Date: Tue, 2 Aug 2022 17:02:06 +0800 Subject: [PATCH 04/11] Optimize expression `LIKE() ESCAPE()` for bin collator (#5489) ref pingcap/tiflash#5294 --- .../Functions/CollationOperatorOptimized.h | 14 +- .../CollationStringSearchOptimized.h | 472 ++++++++++++++++++ dbms/src/Functions/FunctionsStringSearch.cpp | 58 +-- dbms/src/Storages/Transaction/Collator.cpp | 2 - .../Transaction/tests/gtest_tidb_collator.cpp | 141 +++++- 5 files changed, 637 insertions(+), 50 deletions(-) create mode 100644 dbms/src/Functions/CollationStringSearchOptimized.h diff --git a/dbms/src/Functions/CollationOperatorOptimized.h b/dbms/src/Functions/CollationOperatorOptimized.h index 8276a41fa17..e1bf36a537f 100644 --- a/dbms/src/Functions/CollationOperatorOptimized.h +++ b/dbms/src/Functions/CollationOperatorOptimized.h @@ -60,8 +60,6 @@ __attribute__((flatten, always_inline)) inline void LoopTwoColumns( { ColumnString::Offset a_prev_offset = 0; ColumnString::Offset b_prev_offset = 0; - const auto * a_ptr = reinterpret_cast(a_data.data()); - const auto * b_ptr = reinterpret_cast(b_data.data()); for (size_t i = 0; i < size; ++i) { @@ -69,10 +67,9 @@ __attribute__((flatten, always_inline)) inline void LoopTwoColumns( auto b_size = b_offsets[i] - b_prev_offset; // Remove last zero byte. - func({a_ptr, a_size - 1}, {b_ptr, b_size - 1}, i); - - a_ptr += a_size; - b_ptr += b_size; + func({reinterpret_cast(&a_data[a_prev_offset]), a_size - 1}, + {reinterpret_cast(&b_data[b_prev_offset]), b_size - 1}, + i); a_prev_offset = a_offsets[i]; b_prev_offset = b_offsets[i]; @@ -89,16 +86,13 @@ __attribute__((flatten, always_inline)) inline void LoopOneColumn( F && func) { ColumnString::Offset a_prev_offset = 0; - const auto * a_ptr = reinterpret_cast(a_data.data()); for (size_t i = 0; i < size; ++i) { auto a_size = a_offsets[i] - a_prev_offset; // Remove last zero byte. - func({a_ptr, a_size - 1}, i); - - a_ptr += a_size; + func({reinterpret_cast(&a_data[a_prev_offset]), a_size - 1}, i); a_prev_offset = a_offsets[i]; } } diff --git a/dbms/src/Functions/CollationStringSearchOptimized.h b/dbms/src/Functions/CollationStringSearchOptimized.h new file mode 100644 index 00000000000..499b95ce36e --- /dev/null +++ b/dbms/src/Functions/CollationStringSearchOptimized.h @@ -0,0 +1,472 @@ +// Copyright 2022 PingCAP, Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +namespace TiDB +{ + +static constexpr char ANY = '%'; +static constexpr char ONE = '_'; + +/* + Unicode Code UTF-8 Code + 0000~007F 0xxxxxxx + 0080~07FF 110xxxxx 10xxxxxx + 0800~FFFF 1110xxxx 10xxxxxx 10xxxxxx + 10000~10FFFF 11110xxx 10xxxxxx 10xxxxxx 10xxxxxx +*/ +template +inline size_t BinCharSizeFromHead(const uint8_t b0) +{ + if constexpr (!utf8) + { + return 1; + } + return DB::UTF8::seqLength(b0); +} + +template +inline size_t BinCharSizeFromEnd(const char * b_, const char * begin_) +{ + if constexpr (!utf8) + { + return 1; + } + + const auto * b = reinterpret_cast(b_); + if (*b < 0x80) + { + return 1; + } + const auto * ori = b; + + const auto * begin = reinterpret_cast(begin_); + + // check range in case that bin str is invalid + while (begin < b && *b < 0xC0) + { + --b; + } + return ori - b + 1; +} + +template +struct BinStrPattern +{ + void compile(std::string_view pattern, char escape_) + { + { + match_sub_str.clear(); + match_sub_str.reserve(8); + match_types.clear(); + match_types.reserve(8); + } + escape = escape_; + + auto last_match_start = std::string_view::npos; + + const auto & fn_try_add_last_match_str = [&](size_t end_offset) { + if (last_match_start != std::string_view::npos) + { + match_sub_str.emplace_back(&pattern[last_match_start], end_offset - last_match_start); + match_types.emplace_back(MatchType::Match); + // reset + last_match_start = std::string_view::npos; + } + }; + + for (size_t offset = 0; offset < pattern.size();) + { + auto c = pattern[offset]; + auto cur_offset = offset; + auto size = BinCharSizeFromHead(pattern[offset]); + offset += size; // move next + + if (size == 1) + { + if (c == escape) + { + fn_try_add_last_match_str(cur_offset); + + if (offset < pattern.size()) + { + // start from current offset + last_match_start = offset; + + // use next to match + auto new_size = BinCharSizeFromHead(pattern[offset]); + offset += new_size; // move next + } + else + { + // use `escape` to match + match_sub_str.emplace_back(&escape, sizeof(escape)); + match_types.emplace_back(MatchType::Match); + } + } + else if (c == ANY) + { + fn_try_add_last_match_str(cur_offset); + match_types.emplace_back(MatchType::Any); + } + else if (c == ONE) + { + fn_try_add_last_match_str(cur_offset); + match_types.emplace_back(MatchType::One); + } + else + { + // if last match start offset is none, start from current offset. + last_match_start = last_match_start == std::string_view::npos ? cur_offset : last_match_start; + } + } + else + { + // if last match start offset is none, start from current offset. + last_match_start = last_match_start == std::string_view::npos ? cur_offset : last_match_start; + } + } + fn_try_add_last_match_str(pattern.size()); + } + struct MatchDesc + { + ssize_t pattern_index_start{}, pattern_index_end{}; + ssize_t match_str_index_start{}, match_str_index_end{}; + ssize_t src_index_start{}, src_index_end{}; + + bool isSrcValid() const + { + return !isSrcEmpty(); + } + bool isSrcEmpty() const + { + return src_index_start >= src_index_end; + } + size_t srcSize() const + { + return src_index_end - src_index_start; + } + std::string_view getSrcStrView(const char * src_data, size_t size) const + { + return std::string_view{src_data + src_index_start, size}; + } + void srcMoveByOffset(size_t size) + { + src_index_start += size; + } + void srcSkipChar(const char * src_data) + { + auto size = BinCharSizeFromHead(src_data[src_index_start]); + srcMoveByOffset(size); + } + bool patternEmpty() const + { + return pattern_index_start >= pattern_index_end; + } + void makeSrcInvalid() + { + src_index_start = src_index_end; + } + }; + + // check str equality + // - make src invalid if remain size if smaller than required + bool matchStrEqual(const std::string_view & src, MatchDesc & desc) const + { + const auto & match_str = match_sub_str[desc.match_str_index_start]; + if (desc.srcSize() < match_str.size()) + { + desc.makeSrcInvalid(); + return false; + } + if (DB::RawStrEqualCompare(desc.getSrcStrView(src.data(), match_str.size()), match_str)) + { + return false; + } + desc.match_str_index_start++; + desc.srcMoveByOffset(match_str.size()); + return true; + } + + // match from start exactly + // - return true if meet % + // - return false if failed to match else true + bool matchExactly(const std::string_view & src, MatchDesc & cur_match_desc) const + { + // match from start + for (; !cur_match_desc.patternEmpty(); cur_match_desc.pattern_index_start++) + { + const auto & type = match_types[cur_match_desc.pattern_index_start]; + if (type == MatchType::Any) + { + // break from loop + break; + } + + if (type == MatchType::Match) + { + if (!matchStrEqual(src, cur_match_desc)) + return false; + } + else + { + // src must be not empty + if (!cur_match_desc.isSrcValid()) + return false; + cur_match_desc.srcSkipChar(src.data()); + } + } + return true; + }; + + // match from end exactly + // - return true if meet % + // - return false if failed to match else true + bool matchExactlyReverse(const std::string_view & src, MatchDesc & cur_match_desc) const + { + for (; !cur_match_desc.patternEmpty(); --cur_match_desc.pattern_index_end) + { + const auto & type = match_types[cur_match_desc.pattern_index_end - 1]; + if (type == MatchType::Any) + { + break; + } + + if (type == MatchType::Match) + { + const auto & match_str = match_sub_str[cur_match_desc.match_str_index_end - 1]; + if (cur_match_desc.srcSize() < match_str.size()) + { + return false; + } + + if (DB::RawStrEqualCompare({src.data() + cur_match_desc.src_index_end - match_str.size(), match_str.size()}, match_str)) + { + return false; + } + cur_match_desc.match_str_index_end--; + cur_match_desc.src_index_end -= match_str.size(); + } + else + { + // src must be not empty + if (!cur_match_desc.isSrcValid()) + return false; + + auto size = BinCharSizeFromEnd(&src[cur_match_desc.src_index_end - 1], &src[cur_match_desc.src_index_start]); + cur_match_desc.src_index_end -= size; // remove from end + } + } + return true; + }; + + // search by pattern `...%..%` + // - return true if meet % + // - return false if failed to search + bool searchByPattern(const std::string_view & src, MatchDesc & desc) const + { + assert(match_types[desc.pattern_index_end - 1] == MatchType::Any); + assert(!desc.patternEmpty()); + + // leading `MatchType::One` can be removed first + for (; match_types[desc.pattern_index_start] == MatchType::One; desc.pattern_index_start++) + { + // src must be not empty + if (!desc.isSrcValid()) + return false; + desc.srcSkipChar(src.data()); + } + + if (match_types[desc.pattern_index_start] == MatchType::Any) + { + return true; + } + + // current type is MatchType::Match + // loop: + // - search next position of match sub str + // - if position found, start to match exactly + // - if match fail, fallback to loop + // - if match success, return match end pos + // - if position not found, return with fail. + for (;;) + { + const auto & match_str = match_sub_str[desc.match_str_index_start]; + auto src_view = desc.getSrcStrView(src.data(), desc.srcSize()); + auto pos = std::string_view::npos; + + // search sub str + // - seachers like `ASCIICaseSensitiveStringSearcher` or `Volnitsky` are too heavy for small str + // - TODO: optimize strstr search by simd + { + pos = src_view.find(match_str); + // pos = sse2_strstr(src_view, match_str); + } + + if (pos == std::string_view::npos) + { + return false; + } + else + { + // move to sub str position + desc.src_index_start = pos + src_view.data() - src.data(); + + MatchDesc new_desc = desc; + new_desc.srcMoveByOffset(match_str.size()); // start to check rest + new_desc.match_str_index_start++; + new_desc.pattern_index_start++; + + if (!matchExactly(src, new_desc)) + { + if (!new_desc.isSrcValid()) + return false; + // skip one char and restart to search + desc.srcSkipChar(src.data()); + } + else + { + desc = new_desc; + return true; + } + } + } + }; + + bool match(std::string_view src) const + { + MatchDesc cur_match_desc; + { + cur_match_desc.pattern_index_end = match_types.size(); + cur_match_desc.match_str_index_end = match_sub_str.size(); + cur_match_desc.src_index_end = src.size(); + } + + // if pattern starts or ends with `MatchType::Match` or `MatchType::One`, match exactly + { + // match from start + if (!matchExactly(src, cur_match_desc)) + { + return false; + } + // match from end + if (!matchExactlyReverse(src, cur_match_desc)) + { + return false; + } + } + + // if remain pattern is empty, src must be empty + if (cur_match_desc.patternEmpty()) + { + return cur_match_desc.isSrcEmpty(); + } + + assert(match_types[cur_match_desc.pattern_index_end - 1] == MatchType::Any); + + // remain pattern should be %..%...% + // search sub str one by one based on greedy rule + for (;;) + { + assert(match_types[cur_match_desc.pattern_index_start] == MatchType::Any); + + // move to next match type + cur_match_desc.pattern_index_start++; + + if (cur_match_desc.patternEmpty()) // if % is the last one + break; + + if (!searchByPattern(src, cur_match_desc)) + return false; + } + return true; + } + + enum class MatchType + { + Match, + One, + Any, + }; + + std::vector match_types; + std::vector match_sub_str; + char escape{}; +}; +} // namespace TiDB + +namespace DB +{ +template +ALWAYS_INLINE inline void BinStringPatternMatch( + const ColumnString::Chars_t & a_data, + const ColumnString::Offsets & a_offsets, + const std::string_view & pattern_str, + uint8_t escape_char, + Result & c) +{ + TiDB::BinStrPattern matcher; + matcher.compile(pattern_str, escape_char); + LoopOneColumn(a_data, a_offsets, a_offsets.size(), [&](const std::string_view & view, size_t i) { + c[i] = revert ^ matcher.match(view); + }); +} + +template +ALWAYS_INLINE inline bool StringPatternMatch( + const ColumnString::Chars_t & a_data, + const ColumnString::Offsets & a_offsets, + const std::string_view & pattern_str, + uint8_t escape_char, + const TiDB::TiDBCollatorPtr & collator, + Result & c) +{ + bool use_optimized_path = false; + + switch (collator->getCollatorId()) + { + case TiDB::ITiDBCollator::UTF8MB4_BIN: + case TiDB::ITiDBCollator::UTF8_BIN: + { + BinStringPatternMatch(a_data, a_offsets, pattern_str, escape_char, c); + use_optimized_path = true; + break; + } + case TiDB::ITiDBCollator::BINARY: + case TiDB::ITiDBCollator::ASCII_BIN: + case TiDB::ITiDBCollator::LATIN1_BIN: + { + BinStringPatternMatch(a_data, a_offsets, pattern_str, escape_char, c); + use_optimized_path = true; + break; + } + + default: + break; + } + return use_optimized_path; +} +} // namespace DB \ No newline at end of file diff --git a/dbms/src/Functions/FunctionsStringSearch.cpp b/dbms/src/Functions/FunctionsStringSearch.cpp index f0c6cd6f303..5da3ee55e60 100644 --- a/dbms/src/Functions/FunctionsStringSearch.cpp +++ b/dbms/src/Functions/FunctionsStringSearch.cpp @@ -17,6 +17,7 @@ #include #include #include +#include #include #include #include @@ -504,14 +505,14 @@ struct MatchImpl /// fully supported.(Only case sensitive/insensitive is supported) if (like && collator != nullptr) { - auto matcher = collator->pattern(); - matcher->compile(orig_pattern, escape_char); - size_t size = offsets.size(); - size_t prev_offset = 0; - for (size_t i = 0; i < size; ++i) + bool use_optimized_path = StringPatternMatch(data, offsets, orig_pattern, escape_char, collator, res); + if (!use_optimized_path) { - res[i] = revert ^ matcher->match(reinterpret_cast(&data[prev_offset]), offsets[i] - prev_offset - 1); - prev_offset = offsets[i]; + auto matcher = collator->pattern(); + matcher->compile(orig_pattern, escape_char); + LoopOneColumn(data, offsets, offsets.size(), [&](const std::string_view & view, size_t i) { + res[i] = revert ^ matcher->match(view.data(), view.size()); + }); } return; } @@ -1930,18 +1931,18 @@ class FunctionStringReplace : public IFunction const String & match_type, ColumnWithTypeAndName & column_result) const { - const ColumnConst * c1_const = typeid_cast(column_needle.get()); - const ColumnConst * c2_const = typeid_cast(column_replacement.get()); - String needle = c1_const->getValue(); - String replacement = c2_const->getValue(); + const auto * c1_const = typeid_cast(column_needle.get()); + const auto * c2_const = typeid_cast(column_replacement.get()); + auto needle = c1_const->getValue(); + auto replacement = c2_const->getValue(); - if (const ColumnString * col = checkAndGetColumn(column_src.get())) + if (const auto * col = checkAndGetColumn(column_src.get())) { auto col_res = ColumnString::create(); Impl::vector(col->getChars(), col->getOffsets(), needle, replacement, pos, occ, match_type, collator, col_res->getChars(), col_res->getOffsets()); column_result.column = std::move(col_res); } - else if (const ColumnFixedString * col = checkAndGetColumn(column_src.get())) + else if (const auto * col = checkAndGetColumn(column_src.get())) { auto col_res = ColumnString::create(); Impl::vectorFixed(col->getChars(), col->getN(), needle, replacement, pos, occ, match_type, collator, col_res->getChars(), col_res->getOffsets()); @@ -1964,17 +1965,17 @@ class FunctionStringReplace : public IFunction { if constexpr (Impl::support_non_const_needle) { - const ColumnString * col_needle = typeid_cast(column_needle.get()); - const ColumnConst * col_replacement_const = typeid_cast(column_replacement.get()); - String replacement = col_replacement_const->getValue(); + const auto * col_needle = typeid_cast(column_needle.get()); + const auto * col_replacement_const = typeid_cast(column_replacement.get()); + auto replacement = col_replacement_const->getValue(); - if (const ColumnString * col = checkAndGetColumn(column_src.get())) + if (const auto * col = checkAndGetColumn(column_src.get())) { auto col_res = ColumnString::create(); Impl::vectorNonConstNeedle(col->getChars(), col->getOffsets(), col_needle->getChars(), col_needle->getOffsets(), replacement, pos, occ, match_type, collator, col_res->getChars(), col_res->getOffsets()); column_result.column = std::move(col_res); } - else if (const ColumnFixedString * col = checkAndGetColumn(column_src.get())) + else if (const auto * col = checkAndGetColumn(column_src.get())) { auto col_res = ColumnString::create(); Impl::vectorFixedNonConstNeedle(col->getChars(), col->getN(), col_needle->getChars(), col_needle->getOffsets(), replacement, pos, occ, match_type, collator, col_res->getChars(), col_res->getOffsets()); @@ -2002,17 +2003,17 @@ class FunctionStringReplace : public IFunction { if constexpr (Impl::support_non_const_replacement) { - const ColumnConst * col_needle_const = typeid_cast(column_needle.get()); - String needle = col_needle_const->getValue(); - const ColumnString * col_replacement = typeid_cast(column_replacement.get()); + const auto * col_needle_const = typeid_cast(column_needle.get()); + auto needle = col_needle_const->getValue(); + const auto * col_replacement = typeid_cast(column_replacement.get()); - if (const ColumnString * col = checkAndGetColumn(column_src.get())) + if (const auto * col = checkAndGetColumn(column_src.get())) { auto col_res = ColumnString::create(); Impl::vectorNonConstReplacement(col->getChars(), col->getOffsets(), needle, col_replacement->getChars(), col_replacement->getOffsets(), pos, occ, match_type, collator, col_res->getChars(), col_res->getOffsets()); column_result.column = std::move(col_res); } - else if (const ColumnFixedString * col = checkAndGetColumn(column_src.get())) + else if (const auto * col = checkAndGetColumn(column_src.get())) { auto col_res = ColumnString::create(); Impl::vectorFixedNonConstReplacement(col->getChars(), col->getN(), needle, col_replacement->getChars(), col_replacement->getOffsets(), pos, occ, match_type, collator, col_res->getChars(), col_res->getOffsets()); @@ -2040,16 +2041,16 @@ class FunctionStringReplace : public IFunction { if constexpr (Impl::support_non_const_needle && Impl::support_non_const_replacement) { - const ColumnString * col_needle = typeid_cast(column_needle.get()); - const ColumnString * col_replacement = typeid_cast(column_replacement.get()); + const auto * col_needle = typeid_cast(column_needle.get()); + const auto * col_replacement = typeid_cast(column_replacement.get()); - if (const ColumnString * col = checkAndGetColumn(column_src.get())) + if (const auto * col = checkAndGetColumn(column_src.get())) { auto col_res = ColumnString::create(); Impl::vectorNonConstNeedleReplacement(col->getChars(), col->getOffsets(), col_needle->getChars(), col_needle->getOffsets(), col_replacement->getChars(), col_replacement->getOffsets(), pos, occ, match_type, collator, col_res->getChars(), col_res->getOffsets()); column_result.column = std::move(col_res); } - else if (const ColumnFixedString * col = checkAndGetColumn(column_src.get())) + else if (const auto * col = checkAndGetColumn(column_src.get())) { auto col_res = ColumnString::create(); Impl::vectorFixedNonConstNeedleReplacement(col->getChars(), col->getN(), col_needle->getChars(), col_needle->getOffsets(), col_replacement->getChars(), col_replacement->getOffsets(), pos, occ, match_type, collator, col_res->getChars(), col_res->getOffsets()); @@ -2065,7 +2066,8 @@ class FunctionStringReplace : public IFunction throw Exception("Argument at index 2 and 3 for function replace must be constant", ErrorCodes::ILLEGAL_COLUMN); } } - TiDB::TiDBCollatorPtr collator; + + TiDB::TiDBCollatorPtr collator{}; }; struct NamePosition diff --git a/dbms/src/Storages/Transaction/Collator.cpp b/dbms/src/Storages/Transaction/Collator.cpp index fc40701e3c5..a8434cd7eb7 100644 --- a/dbms/src/Storages/Transaction/Collator.cpp +++ b/dbms/src/Storages/Transaction/Collator.cpp @@ -17,8 +17,6 @@ #include #include -#include - namespace DB::ErrorCodes { extern const int LOGICAL_ERROR; diff --git a/dbms/src/Storages/Transaction/tests/gtest_tidb_collator.cpp b/dbms/src/Storages/Transaction/tests/gtest_tidb_collator.cpp index 13c51dba2db..904cefb26ef 100644 --- a/dbms/src/Storages/Transaction/tests/gtest_tidb_collator.cpp +++ b/dbms/src/Storages/Transaction/tests/gtest_tidb_collator.cpp @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include #include #include #include @@ -108,8 +109,82 @@ const typename CollatorCases::PatternCase CollatorCases::pattern_cases[] = { {{"ÀÀ", {false, false, false, false, false}}, {"aÀÀ", {false, false, true, false, true}}, {"ÀÀÀa", {true, true, true, true, true}}, {"aÀÀÀ", {false, false, true, false, true}}}}, {"___a", {{"中a", {true, true, false, false, false}}, {"中文字a", {false, false, true, true, true}}}}, {"𐐭", {{"𐐨", {false, false, true, false, false}}}}, + { + "%pending%deposits%", + { + {"riously after the carefully pending foxes. deposits are careful", {true, true, true, true, true}}, + {"pendingdeposits", {true, true, true, true, true}}, + {"pendingdeposits", {true, true, true, true, true}}, + }, + }, + { + "1234567\\", // `ESCAPE` at last + { + {"1234567\\", {true, true, true, true, true}}, + {"1234567", {false, false, false, false, false}}, + {"1234567\\1", {false, false, false, false, false}}, + }, + }, + { + "1234567\\910", // `ESCAPE` at middle + { + {"1234567\\910", {false, false, false, false, false}}, + {"1234567910", {true, true, true, true, true}}, + }, + }, + { + "%__", // test match from end + { + {"1", {false, false, false, false, false}}, // 1 bytes + {"À", {true, true, false, false, false}}, // 2 bytes + {"12", {true, true, true, true, true}}, // 2 bytes + {"中", {true, true, false, false, false}}, // 3 bytes + {"À1", {true, true, true, true, true}}, // 3 bytes + {"ÀÀ", {true, true, true, true, true}}, // 4 bytes + {"𒀈", {true, true, false, false, false}}, // 4 bytes 1 char + {"À中", {true, true, true, true, true}}, // 5 bytes + {"中中", {true, true, true, true, true}}, // 6 bytes + }, + }, + { + "%__%", // test + { + {"1", {false, false, false, false, false}}, // 1 bytes + {"À", {true, true, false, false, false}}, // 2 bytes + {"12", {true, true, true, true, true}}, // 2 bytes + {"中", {true, true, false, false, false}}, // 3 bytes + {"À1", {true, true, true, true, true}}, // 3 bytes + {"ÀÀ", {true, true, true, true, true}}, // 4 bytes + {"𒀈", {true, true, false, false, false}}, // 4 bytes 1 char + }, + }, + { + "%一_二", // test match from end + { + {"xx一a二", {true, true, true, true, true}}, + {"xx一À二", {false, false, true, true, true}}, + }, + }, + { + "%一_三%四五六%七", + { + {"一二三四五七", {false, false, false, false, false}}, + {"0一二三四五六.七", {false, false, true, true, true}}, + {"一二四五六七", {false, false, false, false, false}}, + {"一2三.四五六...七", {true, true, true, true, true}}, + }, + }, + { + "%一_三%", + { + {"000一二3", {false, false, false, false, false}}, + {"000一", {false, false, false, false, false}}, + }, + }, }; +static constexpr char ESCAPE = '\\'; + template void testCollator() { @@ -130,18 +205,64 @@ void testCollator() std::string buf; ASSERT_EQ(collator->sortKey(s.data(), s.length(), buf).toString(), ans); } - auto pattern = collator->pattern(); - for (const auto & c : CollatorCases::pattern_cases) { - const std::string & p = c.first; - pattern->compile(p, '\\'); - const auto & inner_cases = c.second; - for (const auto & inner_c : inner_cases) + TiDB::BinStrPattern matcher; + matcher.compile("%%%", '%'); + ASSERT_TRUE(matcher.match("%%")); + matcher.compile("%%", '.'); + ASSERT_TRUE(matcher.match("")); + + auto pattern = collator->pattern(); + pattern->compile("%%%", '%'); + ASSERT_TRUE(pattern->match("%%", 2)); + } + { + auto pattern = collator->pattern(); + for (const auto & c : CollatorCases::pattern_cases) { - const std::string & s = inner_c.first; - bool ans = std::get(inner_c.second); - std::cout << "Pattern case (" << p << ", " << s << ", " << ans << ")" << std::endl; - ASSERT_EQ(pattern->match(s.data(), s.length()), ans); + const std::string & p = c.first; + const auto & inner_cases = c.second; + + ColumnString::Chars_t strs; + ColumnString::Offsets offsets; + std::vector res; + { // init data + ColumnString::Offset current_new_offset = 0; + for (const auto & inner_c : inner_cases) + { + const auto s = inner_c.first + char(0); + { + current_new_offset += s.size(); + offsets.push_back(current_new_offset); + } + { + strs.resize(strs.size() + s.size()); + memcpySmallAllowReadWriteOverflow15( + &strs[strs.size() - s.size()], + s.data(), + s.size()); + } + res.emplace_back(0); + } + } + if (!StringPatternMatch(strs, offsets, p, ESCAPE, collator, res)) + { + pattern->compile(p, ESCAPE); + for (size_t idx = 0; idx < std::size(inner_cases); ++idx) + { + const auto & inner_c = inner_cases[idx]; + const std::string & s = inner_c.first; + res[idx] = pattern->match(s.data(), s.length()); + } + } + + for (size_t idx = 0; idx < std::size(inner_cases); ++idx) + { + const auto & inner_c = inner_cases[idx]; + bool ans = std::get(inner_c.second); + std::cout << "Pattern case (" << p << ", " << inner_c.first << ", " << ans << ")" << std::endl; + ASSERT_EQ(res[idx], ans); + } } } } From 1022499407d5636c99a74a160659955232b333e1 Mon Sep 17 00:00:00 2001 From: lidezhu <47731263+lidezhu@users.noreply.github.com> Date: Tue, 2 Aug 2022 20:44:05 +0800 Subject: [PATCH 05/11] fix wrong path for BlobStat under multi-disks env (#5520) close pingcap/tiflash#5519 --- dbms/src/Storages/Page/V3/BlobStore.cpp | 35 ++++- dbms/src/Storages/Page/V3/BlobStore.h | 9 +- .../Page/V3/tests/gtest_blob_store.cpp | 132 +++++++++++++----- dbms/src/TestUtils/MockDiskDelegator.h | 40 ++++-- 4 files changed, 164 insertions(+), 52 deletions(-) diff --git a/dbms/src/Storages/Page/V3/BlobStore.cpp b/dbms/src/Storages/Page/V3/BlobStore.cpp index dc527b8fc65..c13d1288d82 100644 --- a/dbms/src/Storages/Page/V3/BlobStore.cpp +++ b/dbms/src/Storages/Page/V3/BlobStore.cpp @@ -509,8 +509,10 @@ void BlobStore::removePosFromStats(BlobFileId blob_id, BlobFileOffset offset, si { LOG_FMT_INFO(log, "Removing BlobFile [blob_id={}]", blob_id); auto lock_stats = blob_stats.lock(); + // need get blob file before remove its stat otherwise we cannot find the blob file + auto blob_file = getBlobFile(blob_id); blob_stats.eraseStat(std::move(stat), lock_stats); - getBlobFile(blob_id)->remove(); + blob_file->remove(); cached_files.remove(blob_id); } } @@ -1189,8 +1191,7 @@ PageEntriesEdit BlobStore::gc(std::map & String BlobStore::getBlobFileParentPath(BlobFileId blob_id) { - PageFileIdAndLevel id_lvl{blob_id, 0}; - String parent_path = delegator->choosePath(id_lvl); + String parent_path = blob_stats.blobIdToPath(blob_id); if (auto f = Poco::File(parent_path); !f.exists()) f.createDirectories(); @@ -1353,10 +1354,9 @@ BlobStatPtr BlobStore::BlobStats::createBigPageStatNotChecking(BlobFileId blob_f return stat; } -void BlobStore::BlobStats::eraseStat(const BlobStatPtr && stat, const std::lock_guard &) +void BlobStore::BlobStats::eraseStat(const BlobStatPtr && stat, const std::lock_guard & lock) { - PageFileIdAndLevel id_lvl{stat->id, 0}; - stats_map[delegator->getPageFilePath(id_lvl)].remove(stat); + stats_map[blobIdToPathImpl(stat->id, lock)].remove(stat); } void BlobStore::BlobStats::eraseStat(BlobFileId blob_file_id, const std::lock_guard & lock) @@ -1477,6 +1477,29 @@ BlobStatPtr BlobStore::BlobStats::blobIdToStat(BlobFileId file_id, bool ignore_n return nullptr; } +String BlobStore::BlobStats::blobIdToPath(BlobFileId file_id) +{ + auto guard = lock(); + return blobIdToPathImpl(file_id, guard); +} + +String BlobStore::BlobStats::blobIdToPathImpl(BlobFileId file_id, const std::lock_guard &) +{ + for (const auto & [path, stats] : stats_map) + { + for (const auto & stat : stats) + { + if (stat->id == file_id) + { + return path; + } + } + } + throw Exception(fmt::format("Can't find BlobStat with [blob_id={}]", + file_id), + ErrorCodes::LOGICAL_ERROR); +} + /********************* * BlobStat methods * ********************/ diff --git a/dbms/src/Storages/Page/V3/BlobStore.h b/dbms/src/Storages/Page/V3/BlobStore.h index 6b139b98557..6fa077156c8 100644 --- a/dbms/src/Storages/Page/V3/BlobStore.h +++ b/dbms/src/Storages/Page/V3/BlobStore.h @@ -223,7 +223,14 @@ class BlobStore : private Allocator BlobStatPtr blobIdToStat(BlobFileId file_id, bool ignore_not_exist = false); - std::map> getStats() const + String blobIdToPath(BlobFileId file_id); + + private: + String blobIdToPathImpl(BlobFileId file_id, const std::lock_guard &); + + public: + using StatsMap = std::map>; + StatsMap getStats() const { auto guard = lock(); return stats_map; diff --git a/dbms/src/Storages/Page/V3/tests/gtest_blob_store.cpp b/dbms/src/Storages/Page/V3/tests/gtest_blob_store.cpp index f9daacc4cce..1839c5cc5c0 100644 --- a/dbms/src/Storages/Page/V3/tests/gtest_blob_store.cpp +++ b/dbms/src/Storages/Page/V3/tests/gtest_blob_store.cpp @@ -33,6 +33,8 @@ namespace DB::PS::V3::tests using BlobStat = BlobStore::BlobStats::BlobStat; using BlobStats = BlobStore::BlobStats; +constexpr size_t path_num = 3; + class BlobStoreStatsTest : public DB::base::TiFlashStorageTestBasic { public: @@ -42,7 +44,12 @@ class BlobStoreStatsTest : public DB::base::TiFlashStorageTestBasic auto path = getTemporaryPath(); DB::tests::TiFlashTestEnv::tryRemovePath(path); createIfNotExist(path); - delegator = std::make_shared(path); + Strings paths; + for (size_t i = 0; i < path_num; i++) + { + paths.emplace_back(fmt::format("{}/{}", path, i)); + } + delegator = std::make_shared(paths); } protected: @@ -51,6 +58,16 @@ class BlobStoreStatsTest : public DB::base::TiFlashStorageTestBasic PSDiskDelegatorPtr delegator; }; +static size_t getTotalStatsNum(const BlobStore::BlobStats::StatsMap & stats_map) +{ + size_t total_stats_num = 0; + for (auto iter = stats_map.begin(); iter != stats_map.end(); iter++) + { + total_stats_num += iter->second.size(); + } + return total_stats_num; +} + TEST_F(BlobStoreStatsTest, RestoreEmpty) { BlobStats stats(logger, delegator, config); @@ -108,8 +125,8 @@ try auto stats_copy = stats.getStats(); - ASSERT_EQ(stats_copy.size(), 1); - ASSERT_EQ(stats_copy.begin()->second.size(), 2); + ASSERT_EQ(stats_copy.size(), std::min(getTotalStatsNum(stats_copy), path_num)); + ASSERT_EQ(getTotalStatsNum(stats_copy), 2); EXPECT_EQ(stats.roll_id, 13); auto stat1 = stats.blobIdToStat(file_id1); @@ -142,13 +159,13 @@ TEST_F(BlobStoreStatsTest, testStats) auto stats_copy = stats.getStats(); - ASSERT_EQ(stats_copy.size(), 1); - ASSERT_EQ(stats_copy.begin()->second.size(), 3); + ASSERT_EQ(stats_copy.size(), std::min(getTotalStatsNum(stats_copy), path_num)); + ASSERT_EQ(getTotalStatsNum(stats_copy), 3); ASSERT_EQ(stats.roll_id, 3); stats.eraseStat(0, stats.lock()); stats.eraseStat(1, stats.lock()); - ASSERT_EQ(stats.stats_map.size(), 1); + ASSERT_EQ(getTotalStatsNum(stats.getStats()), 1); ASSERT_EQ(stats.roll_id, 3); } @@ -278,7 +295,12 @@ class BlobStoreTest : public DB::base::TiFlashStorageTestBasic auto path = getTemporaryPath(); DB::tests::TiFlashTestEnv::tryRemovePath(path); createIfNotExist(path); - delegator = std::make_shared(path); + Strings paths; + for (size_t i = 0; i < path_num; i++) + { + paths.emplace_back(fmt::format("{}/{}", path, i)); + } + delegator = std::make_shared(paths); } protected: @@ -296,10 +318,13 @@ try BlobFileId file_id1 = 10; BlobFileId file_id2 = 12; - const auto & path = getTemporaryPath(); - createIfNotExist(path); - Poco::File(fmt::format("{}/{}{}", path, BlobFile::BLOB_PREFIX_NAME, file_id1)).createFile(); - Poco::File(fmt::format("{}/{}{}", path, BlobFile::BLOB_PREFIX_NAME, file_id2)).createFile(); + const auto & paths = delegator->listPaths(); + for (const auto & path : paths) + { + createIfNotExist(path); + } + Poco::File(fmt::format("{}/{}{}", paths[rand() % path_num], BlobFile::BLOB_PREFIX_NAME, file_id1)).createFile(); + Poco::File(fmt::format("{}/{}{}", paths[rand() % path_num], BlobFile::BLOB_PREFIX_NAME, file_id2)).createFile(); blob_store.registerPaths(); { @@ -386,11 +411,20 @@ try write_batch.clear(); }; - auto check_in_disk_file = [](String parent_path, std::vector exited_blobs) -> bool { + auto check_in_disk_file = [](const Strings & paths, std::vector exited_blobs) -> bool { for (const auto blob_id : exited_blobs) { - Poco::File file(fmt::format("{}/{}{}", parent_path, BlobFile::BLOB_PREFIX_NAME, blob_id)); - if (!file.exists()) + bool exists = false; + for (const auto & path : paths) + { + Poco::File file(fmt::format("{}/{}{}", path, BlobFile::BLOB_PREFIX_NAME, blob_id)); + if (file.exists()) + { + exists = true; + break; + } + } + if (!exists) { return false; } @@ -415,83 +449,95 @@ try // Case 1, all of blob been restored { - auto test_path = getTemporaryPath(); + auto test_paths = delegator->listPaths(); auto blob_store = BlobStore(getCurrentTestName(), file_provider, delegator, config); write_blob_datas(blob_store); - ASSERT_TRUE(check_in_disk_file(test_path, {1, 2, 3})); + ASSERT_TRUE(check_in_disk_file(test_paths, {1, 2, 3})); auto blob_store_check = BlobStore(getCurrentTestName(), file_provider, delegator, config); restore_blobs(blob_store_check, {1, 2, 3}); blob_store_check.blob_stats.restore(); - ASSERT_TRUE(check_in_disk_file(test_path, {1, 2, 3})); - DB::tests::TiFlashTestEnv::tryRemovePath(test_path); - createIfNotExist(test_path); + ASSERT_TRUE(check_in_disk_file(test_paths, {1, 2, 3})); + for (const auto & path : test_paths) + { + DB::tests::TiFlashTestEnv::tryRemovePath(path); + createIfNotExist(path); + } } // Case 2, only recover blob 1 { - auto test_path = getTemporaryPath(); + auto test_paths = delegator->listPaths(); auto blob_store = BlobStore(getCurrentTestName(), file_provider, delegator, config); write_blob_datas(blob_store); - ASSERT_TRUE(check_in_disk_file(test_path, {1, 2, 3})); + ASSERT_TRUE(check_in_disk_file(test_paths, {1, 2, 3})); auto blob_store_check = BlobStore(getCurrentTestName(), file_provider, delegator, config); restore_blobs(blob_store_check, {1}); blob_store_check.blob_stats.restore(); - ASSERT_TRUE(check_in_disk_file(test_path, {1})); - DB::tests::TiFlashTestEnv::tryRemovePath(test_path); - createIfNotExist(test_path); + ASSERT_TRUE(check_in_disk_file(test_paths, {1})); + for (const auto & path : test_paths) + { + DB::tests::TiFlashTestEnv::tryRemovePath(path); + createIfNotExist(path); + } } // Case 3, only recover blob 2 { - auto test_path = getTemporaryPath(); + auto test_paths = delegator->listPaths(); auto blob_store = BlobStore(getCurrentTestName(), file_provider, delegator, config); write_blob_datas(blob_store); - ASSERT_TRUE(check_in_disk_file(test_path, {1, 2, 3})); + ASSERT_TRUE(check_in_disk_file(test_paths, {1, 2, 3})); auto blob_store_check = BlobStore(getCurrentTestName(), file_provider, delegator, config); restore_blobs(blob_store_check, {2}); blob_store_check.blob_stats.restore(); - ASSERT_TRUE(check_in_disk_file(test_path, {2})); - DB::tests::TiFlashTestEnv::tryRemovePath(test_path); - createIfNotExist(test_path); + ASSERT_TRUE(check_in_disk_file(test_paths, {2})); + for (const auto & path : test_paths) + { + DB::tests::TiFlashTestEnv::tryRemovePath(path); + createIfNotExist(path); + } } // Case 4, only recover blob 3 { - auto test_path = getTemporaryPath(); + auto test_paths = delegator->listPaths(); auto blob_store = BlobStore(getCurrentTestName(), file_provider, delegator, config); write_blob_datas(blob_store); - ASSERT_TRUE(check_in_disk_file(test_path, {1, 2, 3})); + ASSERT_TRUE(check_in_disk_file(test_paths, {1, 2, 3})); auto blob_store_check = BlobStore(getCurrentTestName(), file_provider, delegator, config); restore_blobs(blob_store_check, {3}); blob_store_check.blob_stats.restore(); - ASSERT_TRUE(check_in_disk_file(test_path, {3})); - DB::tests::TiFlashTestEnv::tryRemovePath(test_path); - createIfNotExist(test_path); + ASSERT_TRUE(check_in_disk_file(test_paths, {3})); + for (const auto & path : test_paths) + { + DB::tests::TiFlashTestEnv::tryRemovePath(path); + createIfNotExist(path); + } } // Case 5, recover a not exist blob { - auto test_path = getTemporaryPath(); + auto test_paths = delegator->listPaths(); auto blob_store = BlobStore(getCurrentTestName(), file_provider, delegator, config); write_blob_datas(blob_store); - ASSERT_TRUE(check_in_disk_file(test_path, {1, 2, 3})); + ASSERT_TRUE(check_in_disk_file(test_paths, {1, 2, 3})); auto blob_store_check = BlobStore(getCurrentTestName(), file_provider, delegator, config); ASSERT_THROW(restore_blobs(blob_store_check, {4}), DB::Exception); @@ -1034,14 +1080,16 @@ TEST_F(BlobStoreTest, testBlobStoreGcStats) auto edit = blob_store.write(wb, nullptr); size_t idx = 0; - PageEntriesV3 entries_del1, entries_del2; + PageEntriesV3 entries_del1, entries_del2, remain_entries; for (const auto & record : edit.getRecords()) { + bool deleted = false; for (size_t index : remove_entries_idx1) { if (idx == index) { entries_del1.emplace_back(record.entry); + deleted = true; break; } } @@ -1051,9 +1099,14 @@ TEST_F(BlobStoreTest, testBlobStoreGcStats) if (idx == index) { entries_del2.emplace_back(record.entry); + deleted = true; break; } } + if (!deleted) + { + remain_entries.emplace_back(record.entry); + } idx++; } @@ -1089,6 +1142,11 @@ TEST_F(BlobStoreTest, testBlobStoreGcStats) String path = blob_store.getBlobFile(1)->getPath(); Poco::File blob_file_in_disk(path); ASSERT_EQ(blob_file_in_disk.getSize(), stat->sm_total_size); + + // Check whether the stat can be totally removed + stat->changeToReadOnly(); + blob_store.remove(remain_entries); + ASSERT_EQ(getTotalStatsNum(blob_store.blob_stats.getStats()), 0); } TEST_F(BlobStoreTest, testBlobStoreGcStats2) diff --git a/dbms/src/TestUtils/MockDiskDelegator.h b/dbms/src/TestUtils/MockDiskDelegator.h index 5b0fd7ea5b7..1a7386e39c8 100644 --- a/dbms/src/TestUtils/MockDiskDelegator.h +++ b/dbms/src/TestUtils/MockDiskDelegator.h @@ -99,9 +99,9 @@ class MockDiskDelegatorMulti final : public PSDiskDelegator throw Exception("Should not generate MockDiskDelegatorMulti with empty paths"); } - bool fileExist(const PageFileIdAndLevel & /*id_lvl*/) const + bool fileExist(const PageFileIdAndLevel & id_lvl) const { - return true; + return page_path_map.find(id_lvl) != page_path_map.end(); } @@ -115,9 +115,14 @@ class MockDiskDelegatorMulti final : public PSDiskDelegator return paths[0]; } - String getPageFilePath(const PageFileIdAndLevel & /*id_lvl*/) const + String getPageFilePath(const PageFileIdAndLevel & id_lvl) const { - throw Exception("Not implemented", ErrorCodes::NOT_IMPLEMENTED); + auto iter = page_path_map.find(id_lvl); + if (likely(iter != page_path_map.end())) + { + return paths[iter->second]; + } + throw Exception(fmt::format("Can not find path for PageFile [id={}_{}]", id_lvl.first, id_lvl.second)); } void removePageFile(const PageFileIdAndLevel & /*id_lvl*/, size_t /*file_size*/, bool /*meta_left*/, bool /*remove_from_default_path*/) {} @@ -135,11 +140,28 @@ class MockDiskDelegatorMulti final : public PSDiskDelegator } size_t addPageFileUsedSize( - const PageFileIdAndLevel & /*id_lvl*/, + const PageFileIdAndLevel & id_lvl, size_t /*size_to_add*/, - const String & /*pf_parent_path*/, - bool /*need_insert_location*/) - { + const String & pf_parent_path, + bool need_insert_location) + { + if (need_insert_location) + { + UInt32 index = UINT32_MAX; + + for (size_t i = 0; i < paths.size(); i++) + { + if (paths[i] == pf_parent_path) + { + index = i; + break; + } + } + + if (unlikely(index == UINT32_MAX)) + throw Exception(fmt::format("Unrecognized path {}", pf_parent_path)); + page_path_map[id_lvl] = index; + } return 0; } @@ -154,6 +176,8 @@ class MockDiskDelegatorMulti final : public PSDiskDelegator private: Strings paths; size_t choose_idx = 0; + // PageFileID -> path index + PathPool::PageFilePathMap page_path_map; }; } // namespace tests From 3153a3b75bd3ac1112a9c526f54cef64414c6d5b Mon Sep 17 00:00:00 2001 From: yanweiqi <592838129@qq.com> Date: Wed, 3 Aug 2022 11:00:06 +0800 Subject: [PATCH 06/11] Test: Mock TiFlash compute service and dispatch MPPTask to single service. (#5450) ref pingcap/tiflash#4609 --- dbms/src/Debug/dbgFuncCoprocessor.cpp | 275 ++++++++++-------- dbms/src/Debug/dbgFuncCoprocessor.h | 4 + dbms/src/Flash/Coprocessor/DAGContext.cpp | 10 - dbms/src/Flash/Coprocessor/DAGContext.h | 12 - .../Coprocessor/DAGQueryBlockInterpreter.cpp | 22 +- .../Flash/Coprocessor/InterpreterUtils.cpp | 4 +- dbms/src/Flash/Coprocessor/MockSourceStream.h | 2 +- .../Flash/Coprocessor/TablesRegionsInfo.cpp | 2 +- dbms/src/Flash/CoprocessorHandler.cpp | 4 +- dbms/src/Flash/CoprocessorHandler.h | 2 +- dbms/src/Flash/FlashService.cpp | 9 +- dbms/src/Flash/FlashService.h | 2 +- dbms/src/Flash/Mpp/MPPTask.cpp | 1 + .../tests/gtest_aggregation_executor.cpp | 4 +- dbms/src/Flash/tests/gtest_compute_server.cpp | 109 +++++++ dbms/src/Interpreters/Context.cpp | 52 +++- dbms/src/Interpreters/Context.h | 23 +- dbms/src/Server/CMakeLists.txt | 6 + dbms/src/Server/IServer.h | 7 +- dbms/src/Server/MockComputeClient.h | 50 ++++ dbms/src/Server/Server.cpp | 2 +- dbms/src/TestUtils/CMakeLists.txt | 2 +- dbms/src/TestUtils/ExecutorTestUtils.cpp | 15 +- dbms/src/TestUtils/ExecutorTestUtils.h | 9 + dbms/src/TestUtils/MPPTaskTestUtils.h | 63 ++++ dbms/src/TestUtils/mockExecutor.cpp | 4 +- 26 files changed, 517 insertions(+), 178 deletions(-) create mode 100644 dbms/src/Flash/tests/gtest_compute_server.cpp create mode 100644 dbms/src/Server/MockComputeClient.h create mode 100644 dbms/src/TestUtils/MPPTaskTestUtils.h diff --git a/dbms/src/Debug/dbgFuncCoprocessor.cpp b/dbms/src/Debug/dbgFuncCoprocessor.cpp index 62a8b7537f1..72f3599ebc4 100644 --- a/dbms/src/Debug/dbgFuncCoprocessor.cpp +++ b/dbms/src/Debug/dbgFuncCoprocessor.cpp @@ -45,6 +45,7 @@ #include #include #include +#include #include #include #include @@ -183,155 +184,176 @@ void setTipbRegionInfo(coprocessor::RegionInfo * tipb_region_info, const std::pa range->set_end(RecordKVFormat::genRawKey(table_id, handle_range.second.handle_id)); } -BlockInputStreamPtr executeQuery(Context & context, RegionID region_id, const DAGProperties & properties, QueryTasks & query_tasks, MakeResOutputStream & func_wrap_output_stream) +BlockInputStreamPtr executeMPPQuery(Context & context, const DAGProperties & properties, QueryTasks & query_tasks) { - if (properties.is_mpp_query) + DAGSchema root_task_schema; + std::vector root_task_ids; + for (auto & task : query_tasks) { - DAGSchema root_task_schema; - std::vector root_task_ids; - for (auto & task : query_tasks) + if (task.is_root_task) { - if (task.is_root_task) - { - root_task_ids.push_back(task.task_id); - root_task_schema = task.result_schema; - } - auto req = std::make_shared(); - auto * tm = req->mutable_meta(); - tm->set_start_ts(properties.start_ts); - tm->set_partition_id(task.partition_id); - tm->set_address(Debug::LOCAL_HOST); - tm->set_task_id(task.task_id); - auto * encoded_plan = req->mutable_encoded_plan(); - task.dag_request->AppendToString(encoded_plan); - req->set_timeout(properties.mpp_timeout); - req->set_schema_ver(DEFAULT_UNSPECIFIED_SCHEMA_VERSION); - auto table_id = task.table_id; - if (table_id != -1) + root_task_ids.push_back(task.task_id); + root_task_schema = task.result_schema; + } + auto req = std::make_shared(); + auto * tm = req->mutable_meta(); + tm->set_start_ts(properties.start_ts); + tm->set_partition_id(task.partition_id); + tm->set_address(Debug::LOCAL_HOST); + tm->set_task_id(task.task_id); + auto * encoded_plan = req->mutable_encoded_plan(); + task.dag_request->AppendToString(encoded_plan); + req->set_timeout(properties.mpp_timeout); + req->set_schema_ver(DEFAULT_UNSPECIFIED_SCHEMA_VERSION); + auto table_id = task.table_id; + if (table_id != -1) + { + /// contains a table scan + const auto & table_info = MockTiDB::instance().getTableInfoByID(table_id); + if (table_info->is_partition_table) { - /// contains a table scan - const auto & table_info = MockTiDB::instance().getTableInfoByID(table_id); - if (table_info->is_partition_table) + size_t current_region_size = 0; + coprocessor::TableRegions * current_table_regions = nullptr; + for (const auto & partition : table_info->partition.definitions) { - size_t current_region_size = 0; - coprocessor::TableRegions * current_table_regions = nullptr; - for (const auto & partition : table_info->partition.definitions) + const auto partition_id = partition.id; + auto regions = context.getTMTContext().getRegionTable().getRegionsByTable(partition_id); + for (size_t i = 0; i < regions.size(); ++i) { - const auto partition_id = partition.id; - auto regions = context.getTMTContext().getRegionTable().getRegionsByTable(partition_id); - for (size_t i = 0; i < regions.size(); ++i) + if ((current_region_size + i) % properties.mpp_partition_num != static_cast(task.partition_id)) + continue; + if (current_table_regions != nullptr && current_table_regions->physical_table_id() != partition_id) + current_table_regions = nullptr; + if (current_table_regions == nullptr) { - if ((current_region_size + i) % properties.mpp_partition_num != static_cast(task.partition_id)) - continue; - if (current_table_regions != nullptr && current_table_regions->physical_table_id() != partition_id) - current_table_regions = nullptr; - if (current_table_regions == nullptr) - { - current_table_regions = req->add_table_regions(); - current_table_regions->set_physical_table_id(partition_id); - } - setTipbRegionInfo(current_table_regions->add_regions(), regions[i], partition_id); + current_table_regions = req->add_table_regions(); + current_table_regions->set_physical_table_id(partition_id); } - current_region_size += regions.size(); + setTipbRegionInfo(current_table_regions->add_regions(), regions[i], partition_id); } - if (current_region_size < static_cast(properties.mpp_partition_num)) - throw Exception("Not supported: table region num less than mpp partition num"); + current_region_size += regions.size(); } - else + if (current_region_size < static_cast(properties.mpp_partition_num)) + throw Exception("Not supported: table region num less than mpp partition num"); + } + else + { + auto regions = context.getTMTContext().getRegionTable().getRegionsByTable(table_id); + if (regions.size() < static_cast(properties.mpp_partition_num)) + throw Exception("Not supported: table region num less than mpp partition num"); + for (size_t i = 0; i < regions.size(); ++i) { - auto regions = context.getTMTContext().getRegionTable().getRegionsByTable(table_id); - if (regions.size() < static_cast(properties.mpp_partition_num)) - throw Exception("Not supported: table region num less than mpp partition num"); - for (size_t i = 0; i < regions.size(); ++i) - { - if (i % properties.mpp_partition_num != static_cast(task.partition_id)) - continue; - setTipbRegionInfo(req->add_regions(), regions[i], table_id); - } + if (i % properties.mpp_partition_num != static_cast(task.partition_id)) + continue; + setTipbRegionInfo(req->add_regions(), regions[i], table_id); } } - pingcap::kv::RpcCall call(req); - context.getTMTContext().getCluster()->rpc_client->sendRequest(Debug::LOCAL_HOST, call, 1000); - if (call.getResp()->has_error()) - throw Exception("Meet error while dispatch mpp task: " + call.getResp()->error().msg()); } - tipb::ExchangeReceiver tipb_exchange_receiver; - for (const auto root_task_id : root_task_ids) + + if (context.isMPPTest()) { - mpp::TaskMeta tm; - tm.set_start_ts(properties.start_ts); - tm.set_address(Debug::LOCAL_HOST); - tm.set_task_id(root_task_id); - tm.set_partition_id(-1); - auto * tm_string = tipb_exchange_receiver.add_encoded_task_meta(); - tm.AppendToString(tm_string); + MockComputeClient client( + grpc::CreateChannel(Debug::LOCAL_HOST, grpc::InsecureChannelCredentials())); + client.runDispatchMPPTask(req); } - for (auto & field : root_task_schema) + else { - auto tipb_type = TiDB::columnInfoToFieldType(field.second); - tipb_type.set_collate(properties.collator); - auto * field_type = tipb_exchange_receiver.add_field_types(); - *field_type = tipb_type; + pingcap::kv::RpcCall call(req); + context.getTMTContext().getCluster()->rpc_client->sendRequest(Debug::LOCAL_HOST, call, 1000); + if (call.getResp()->has_error()) + throw Exception("Meet error while dispatch mpp task: " + call.getResp()->error().msg()); } - mpp::TaskMeta root_tm; - root_tm.set_start_ts(properties.start_ts); - root_tm.set_address(Debug::LOCAL_HOST); - root_tm.set_task_id(-1); - root_tm.set_partition_id(-1); - std::shared_ptr exchange_receiver - = std::make_shared( - std::make_shared( - tipb_exchange_receiver, - root_tm, - context.getTMTContext().getKVCluster(), - context.getTMTContext().getMPPTaskManager(), - context.getSettingsRef().enable_local_tunnel, - context.getSettingsRef().enable_async_grpc_client), - tipb_exchange_receiver.encoded_task_meta_size(), - 10, - /*req_id=*/"", - /*executor_id=*/"", - /*fine_grained_shuffle_stream_count=*/0); - BlockInputStreamPtr ret = std::make_shared(exchange_receiver, /*req_id=*/"", /*executor_id=*/"", /*stream_id*/ 0); - return ret; + } + tipb::ExchangeReceiver tipb_exchange_receiver; + for (const auto root_task_id : root_task_ids) + { + mpp::TaskMeta tm; + tm.set_start_ts(properties.start_ts); + tm.set_address(Debug::LOCAL_HOST); + tm.set_task_id(root_task_id); + tm.set_partition_id(-1); + auto * tm_string = tipb_exchange_receiver.add_encoded_task_meta(); + tm.AppendToString(tm_string); + } + for (auto & field : root_task_schema) + { + auto tipb_type = TiDB::columnInfoToFieldType(field.second); + tipb_type.set_collate(properties.collator); + auto * field_type = tipb_exchange_receiver.add_field_types(); + *field_type = tipb_type; + } + mpp::TaskMeta root_tm; + root_tm.set_start_ts(properties.start_ts); + root_tm.set_address(Debug::LOCAL_HOST); + root_tm.set_task_id(-1); + root_tm.set_partition_id(-1); + std::shared_ptr exchange_receiver + = std::make_shared( + std::make_shared( + tipb_exchange_receiver, + root_tm, + context.getTMTContext().getKVCluster(), + context.getTMTContext().getMPPTaskManager(), + context.getSettingsRef().enable_local_tunnel, + context.getSettingsRef().enable_async_grpc_client), + tipb_exchange_receiver.encoded_task_meta_size(), + 10, + /*req_id=*/"", + /*executor_id=*/"", + /*fine_grained_shuffle_stream_count=*/0); + BlockInputStreamPtr ret = std::make_shared(exchange_receiver, /*req_id=*/"", /*executor_id=*/"", /*stream_id*/ 0); + return ret; +} + +BlockInputStreamPtr executeNonMPPQuery(Context & context, RegionID region_id, const DAGProperties & properties, QueryTasks & query_tasks, MakeResOutputStream & func_wrap_output_stream) +{ + auto & task = query_tasks[0]; + auto table_id = task.table_id; + RegionPtr region; + if (region_id == InvalidRegionID) + { + auto regions = context.getTMTContext().getRegionTable().getRegionsByTable(table_id); + if (regions.empty()) + throw Exception("No region for table", ErrorCodes::BAD_ARGUMENTS); + region = regions[0].second; + region_id = regions[0].first; } else { - auto & task = query_tasks[0]; - auto table_id = task.table_id; - RegionPtr region; - if (region_id == InvalidRegionID) - { - auto regions = context.getTMTContext().getRegionTable().getRegionsByTable(table_id); - if (regions.empty()) - throw Exception("No region for table", ErrorCodes::BAD_ARGUMENTS); - region = regions[0].second; - region_id = regions[0].first; - } - else - { - region = context.getTMTContext().getKVStore()->getRegion(region_id); - if (!region) - throw Exception("No such region", ErrorCodes::BAD_ARGUMENTS); - } - auto handle_range = getHandleRangeByTable(region->getRange()->rawKeys(), table_id); - std::vector> key_ranges; - DecodedTiKVKeyPtr start_key = std::make_shared(RecordKVFormat::genRawKey(table_id, handle_range.first.handle_id)); - DecodedTiKVKeyPtr end_key = std::make_shared(RecordKVFormat::genRawKey(table_id, handle_range.second.handle_id)); - key_ranges.emplace_back(std::make_pair(std::move(start_key), std::move(end_key))); - tipb::SelectResponse dag_response = executeDAGRequest( - context, - *task.dag_request, - region_id, - region->version(), - region->confVer(), - properties.start_ts, - key_ranges); - - return func_wrap_output_stream(outputDAGResponse(context, task.result_schema, dag_response)); + region = context.getTMTContext().getKVStore()->getRegion(region_id); + if (!region) + throw Exception("No such region", ErrorCodes::BAD_ARGUMENTS); + } + auto handle_range = getHandleRangeByTable(region->getRange()->rawKeys(), table_id); + std::vector> key_ranges; + DecodedTiKVKeyPtr start_key = std::make_shared(RecordKVFormat::genRawKey(table_id, handle_range.first.handle_id)); + DecodedTiKVKeyPtr end_key = std::make_shared(RecordKVFormat::genRawKey(table_id, handle_range.second.handle_id)); + key_ranges.emplace_back(std::make_pair(std::move(start_key), std::move(end_key))); + tipb::SelectResponse dag_response = executeDAGRequest( + context, + *task.dag_request, + region_id, + region->version(), + region->confVer(), + properties.start_ts, + key_ranges); + + return func_wrap_output_stream(outputDAGResponse(context, task.result_schema, dag_response)); +} + +BlockInputStreamPtr executeQuery(Context & context, RegionID region_id, const DAGProperties & properties, QueryTasks & query_tasks, MakeResOutputStream & func_wrap_output_stream) +{ + if (properties.is_mpp_query) + { + return executeMPPQuery(context, properties, query_tasks); + } + else + { + return executeNonMPPQuery(context, region_id, properties, query_tasks, func_wrap_output_stream); } } + void dbgFuncTiDBQueryFromNaturalDag(Context & context, const ASTs & args, DBGInvoker::Printer output) { if (args.size() != 1) @@ -400,7 +422,7 @@ bool runAndCompareDagReq(const coprocessor::Request & req, const coprocessor::Re bool unequal_flag = false; DAGProperties properties = getDAGProperties(""); - std::vector> key_ranges = CoprocessorHandler::GenCopKeyRange(req.ranges()); + std::vector> key_ranges = CoprocessorHandler::genCopKeyRange(req.ranges()); static auto log = Logger::get("MockDAG"); LOG_FMT_INFO(log, "Handling DAG request: {}", dag_request.DebugString()); tipb::SelectResponse dag_response; @@ -514,7 +536,6 @@ struct QueryFragment dag_request.set_time_zone_name(properties.tz_name); dag_request.set_time_zone_offset(properties.tz_offset); dag_request.set_flags(dag_request.flags() | (1u << 1u /* TRUNCATE_AS_WARNING */) | (1u << 6u /* OVERFLOW_AS_WARNING */)); - if (is_top_fragment) { if (properties.encode_type == "chunk") diff --git a/dbms/src/Debug/dbgFuncCoprocessor.h b/dbms/src/Debug/dbgFuncCoprocessor.h index 5fe39c7e626..41456e54ac4 100644 --- a/dbms/src/Debug/dbgFuncCoprocessor.h +++ b/dbms/src/Debug/dbgFuncCoprocessor.h @@ -82,6 +82,10 @@ QueryTasks queryPlanToQueryTasks( ExecutorPtr root_executor, size_t & executor_index, const Context & context); + +BlockInputStreamPtr executeQuery(Context & context, RegionID region_id, const DAGProperties & properties, QueryTasks & query_tasks, MakeResOutputStream & func_wrap_output_stream); + +BlockInputStreamPtr executeMPPQuery(Context & context, const DAGProperties & properties, QueryTasks & query_tasks); namespace Debug { void setServiceAddr(const std::string & addr); diff --git a/dbms/src/Flash/Coprocessor/DAGContext.cpp b/dbms/src/Flash/Coprocessor/DAGContext.cpp index 9fc07489c94..47c3a3b2450 100644 --- a/dbms/src/Flash/Coprocessor/DAGContext.cpp +++ b/dbms/src/Flash/Coprocessor/DAGContext.cpp @@ -236,14 +236,4 @@ const SingleTableRegions & DAGContext::getTableRegionsInfoByTableID(Int64 table_ { return tables_regions_info.getTableRegionInfoByTableID(table_id); } - -ColumnsWithTypeAndName DAGContext::columnsForTest(String executor_id) -{ - auto it = columns_for_test_map.find(executor_id); - if (unlikely(it == columns_for_test_map.end())) - { - throw DB::Exception("Don't have columns for mock source executors"); - } - return it->second; -} } // namespace DB diff --git a/dbms/src/Flash/Coprocessor/DAGContext.h b/dbms/src/Flash/Coprocessor/DAGContext.h index 4613568ac0d..c2136933a14 100644 --- a/dbms/src/Flash/Coprocessor/DAGContext.h +++ b/dbms/src/Flash/Coprocessor/DAGContext.h @@ -162,7 +162,6 @@ class DAGContext , warning_count(0) { assert(dag_request->has_root_executor() && dag_request->root_executor().has_executor_id()); - // only mpp task has join executor. initExecutorIdToJoinIdMap(); initOutputInfo(); @@ -179,7 +178,6 @@ class DAGContext , max_recorded_error_count(max_error_count_) , warnings(max_recorded_error_count) , warning_count(0) - , is_test(true) {} // for tests need to run query tasks. @@ -194,7 +192,6 @@ class DAGContext , max_recorded_error_count(getMaxErrorCount(*dag_request)) , warnings(max_recorded_error_count) , warning_count(0) - , is_test(true) { assert(dag_request->has_root_executor() || dag_request->executors_size() > 0); return_executor_id = dag_request->root_executor().has_executor_id() || dag_request->executors(0).has_executor_id(); @@ -309,12 +306,6 @@ class DAGContext void updateFinalConcurrency(size_t cur_streams_size, size_t streams_upper_limit); - bool isTest() const { return is_test; } - void setColumnsForTest(std::unordered_map & columns_for_test_map_) { columns_for_test_map = columns_for_test_map_; } - ColumnsWithTypeAndName columnsForTest(String executor_id); - - bool columnsForTestEmpty() { return columns_for_test_map.empty(); } - ExchangeReceiverPtr getMPPExchangeReceiver(const String & executor_id) const; void setMPPReceiverSet(const MPPReceiverSetPtr & receiver_set) { @@ -391,9 +382,6 @@ class DAGContext /// vector of SubqueriesForSets(such as join build subquery). /// The order of the vector is also the order of the subquery. std::vector subqueries; - - bool is_test = false; /// switch for test, do not use it in production. - std::unordered_map columns_for_test_map; /// , for multiple sources }; } // namespace DB diff --git a/dbms/src/Flash/Coprocessor/DAGQueryBlockInterpreter.cpp b/dbms/src/Flash/Coprocessor/DAGQueryBlockInterpreter.cpp index d07a74b4ac7..df647520094 100644 --- a/dbms/src/Flash/Coprocessor/DAGQueryBlockInterpreter.cpp +++ b/dbms/src/Flash/Coprocessor/DAGQueryBlockInterpreter.cpp @@ -99,7 +99,7 @@ AnalysisResult analyzeExpressions( ExpressionActionsChain chain; // selection on table scan had been executed in handleTableScan // In test mode, filter is not pushed down to table scan - if (query_block.selection && (!query_block.isTableScanSource() || context.getDAGContext()->isTest())) + if (query_block.selection && (!query_block.isTableScanSource() || context.isTest())) { std::vector where_conditions; for (const auto & c : query_block.selection->selection().conditions()) @@ -159,7 +159,7 @@ AnalysisResult analyzeExpressions( // for tests, we need to mock tableScan blockInputStream as the source stream. void DAGQueryBlockInterpreter::handleMockTableScan(const TiDBTableScan & table_scan, DAGPipeline & pipeline) { - if (context.getDAGContext()->columnsForTestEmpty() || context.getDAGContext()->columnsForTest(table_scan.getTableScanExecutorID()).empty()) + if (context.columnsForTestEmpty() || context.columnsForTest(table_scan.getTableScanExecutorID()).empty()) { auto names_and_types = genNamesAndTypes(table_scan); auto columns_with_type_and_name = getColumnWithTypeAndName(names_and_types); @@ -279,7 +279,7 @@ void DAGQueryBlockInterpreter::handleJoin(const tipb::Join & join, DAGPipeline & join_execute_info.join_build_streams.push_back(stream); }); // for test, join executor need the return blocks to output. - executeUnion(build_pipeline, max_streams, log, /*ignore_block=*/!dagContext().isTest(), "for join"); + executeUnion(build_pipeline, max_streams, log, /*ignore_block=*/!context.isTest(), "for join"); right_query.source = build_pipeline.firstStream(); right_query.join = join_ptr; @@ -492,7 +492,7 @@ void DAGQueryBlockInterpreter::handleExchangeReceiver(DAGPipeline & pipeline) // for tests, we need to mock ExchangeReceiver blockInputStream as the source stream. void DAGQueryBlockInterpreter::handleMockExchangeReceiver(DAGPipeline & pipeline) { - if (context.getDAGContext()->columnsForTestEmpty() || context.getDAGContext()->columnsForTest(query_block.source_name).empty()) + if (context.columnsForTestEmpty() || context.columnsForTest(query_block.source_name).empty()) { for (size_t i = 0; i < max_streams; ++i) { @@ -590,10 +590,14 @@ void DAGQueryBlockInterpreter::executeImpl(DAGPipeline & pipeline) } else if (query_block.source->tp() == tipb::ExecType::TypeExchangeReceiver) { - if (unlikely(dagContext().isTest())) + if (unlikely(context.isExecutorTest())) handleMockExchangeReceiver(pipeline); else + { + // for MPP test, we can use real exchangeReceiver to run an query across different compute nodes + // or use one compute node to simulate MPP process. handleExchangeReceiver(pipeline); + } recordProfileStreams(pipeline, query_block.source_name); } else if (query_block.source->tp() == tipb::ExecType::TypeProjection) @@ -604,7 +608,7 @@ void DAGQueryBlockInterpreter::executeImpl(DAGPipeline & pipeline) else if (query_block.isTableScanSource()) { TiDBTableScan table_scan(query_block.source, query_block.source_name, dagContext()); - if (unlikely(dagContext().isTest())) + if (unlikely(context.isTest())) handleMockTableScan(table_scan, pipeline); else handleTableScan(table_scan, pipeline); @@ -685,10 +689,14 @@ void DAGQueryBlockInterpreter::executeImpl(DAGPipeline & pipeline) // execute exchange_sender if (query_block.exchange_sender) { - if (unlikely(dagContext().isTest())) + if (unlikely(context.isExecutorTest())) handleMockExchangeSender(pipeline); else + { + // for MPP test, we can use real exchangeReceiver to run an query across different compute nodes + // or use one compute node to simulate MPP process. handleExchangeSender(pipeline); + } recordProfileStreams(pipeline, query_block.exchange_sender_name); } } diff --git a/dbms/src/Flash/Coprocessor/InterpreterUtils.cpp b/dbms/src/Flash/Coprocessor/InterpreterUtils.cpp index 3c07071785a..3da680c925e 100644 --- a/dbms/src/Flash/Coprocessor/InterpreterUtils.cpp +++ b/dbms/src/Flash/Coprocessor/InterpreterUtils.cpp @@ -175,8 +175,10 @@ void executeCreatingSets( { DAGContext & dag_context = *context.getDAGContext(); /// add union to run in parallel if needed - if (unlikely(dag_context.isTest())) + if (unlikely(context.isExecutorTest())) executeUnion(pipeline, max_streams, log, /*ignore_block=*/false, "for test"); + else if (context.isMPPTest()) + executeUnion(pipeline, max_streams, log, /*ignore_block=*/true, "for mpp test"); else if (dag_context.isMPPTask()) /// MPPTask do not need the returned blocks. executeUnion(pipeline, max_streams, log, /*ignore_block=*/true, "for mpp"); diff --git a/dbms/src/Flash/Coprocessor/MockSourceStream.h b/dbms/src/Flash/Coprocessor/MockSourceStream.h index 039cba22e3d..5b69630b40f 100644 --- a/dbms/src/Flash/Coprocessor/MockSourceStream.h +++ b/dbms/src/Flash/Coprocessor/MockSourceStream.h @@ -28,7 +28,7 @@ std::pair>> mockSourceStr NamesAndTypes names_and_types; size_t rows = 0; std::vector> mock_source_streams; - columns_with_type_and_name = context.getDAGContext()->columnsForTest(executor_id); + columns_with_type_and_name = context.columnsForTest(executor_id); for (const auto & col : columns_with_type_and_name) { if (rows == 0) diff --git a/dbms/src/Flash/Coprocessor/TablesRegionsInfo.cpp b/dbms/src/Flash/Coprocessor/TablesRegionsInfo.cpp index ab4a0f82e95..b61a19c3177 100644 --- a/dbms/src/Flash/Coprocessor/TablesRegionsInfo.cpp +++ b/dbms/src/Flash/Coprocessor/TablesRegionsInfo.cpp @@ -59,7 +59,7 @@ static void insertRegionInfoToTablesRegionInfo(const google::protobuf::RepeatedP auto & table_region_info = tables_region_infos.getOrCreateTableRegionInfoByTableID(table_id); for (const auto & r : regions) { - RegionInfo region_info(r.region_id(), r.region_epoch().version(), r.region_epoch().conf_ver(), CoprocessorHandler::GenCopKeyRange(r.ranges()), nullptr); + RegionInfo region_info(r.region_id(), r.region_epoch().version(), r.region_epoch().conf_ver(), CoprocessorHandler::genCopKeyRange(r.ranges()), nullptr); if (region_info.key_ranges.empty()) { throw TiFlashException( diff --git a/dbms/src/Flash/CoprocessorHandler.cpp b/dbms/src/Flash/CoprocessorHandler.cpp index 3d653025b83..5478eca7e22 100644 --- a/dbms/src/Flash/CoprocessorHandler.cpp +++ b/dbms/src/Flash/CoprocessorHandler.cpp @@ -53,7 +53,7 @@ CoprocessorHandler::CoprocessorHandler( , log(&Poco::Logger::get("CoprocessorHandler")) {} -std::vector> CoprocessorHandler::GenCopKeyRange( +std::vector> CoprocessorHandler::genCopKeyRange( const ::google::protobuf::RepeatedPtrField<::coprocessor::KeyRange> & ranges) { std::vector> key_ranges; @@ -100,7 +100,7 @@ grpc::Status CoprocessorHandler::execute() cop_context.kv_context.region_id(), cop_context.kv_context.region_epoch().version(), cop_context.kv_context.region_epoch().conf_ver(), - GenCopKeyRange(cop_request->ranges()), + genCopKeyRange(cop_request->ranges()), &bypass_lock_ts)); DAGContext dag_context(dag_request); diff --git a/dbms/src/Flash/CoprocessorHandler.h b/dbms/src/Flash/CoprocessorHandler.h index 8f8cab1b297..67b08cddbdf 100644 --- a/dbms/src/Flash/CoprocessorHandler.h +++ b/dbms/src/Flash/CoprocessorHandler.h @@ -53,7 +53,7 @@ class CoprocessorHandler virtual grpc::Status execute(); - static std::vector> GenCopKeyRange( + static std::vector> genCopKeyRange( const ::google::protobuf::RepeatedPtrField<::coprocessor::KeyRange> & ranges); protected: diff --git a/dbms/src/Flash/FlashService.cpp b/dbms/src/Flash/FlashService.cpp index 1ab5a36d606..11ef37030a6 100644 --- a/dbms/src/Flash/FlashService.cpp +++ b/dbms/src/Flash/FlashService.cpp @@ -157,7 +157,8 @@ ::grpc::Status FlashService::DispatchMPPTask( { CPUAffinityManager::getInstance().bindSelfGrpcThread(); LOG_FMT_DEBUG(log, "Handling mpp dispatch request: {}", request->DebugString()); - if (!security_config.checkGrpcContext(grpc_context)) + // For MPP test, we don't care about security config. + if (!context.isMPPTest() && !security_config.checkGrpcContext(grpc_context)) { return grpc::Status(grpc::PERMISSION_DENIED, tls_err_msg); } @@ -380,7 +381,9 @@ std::tuple FlashService::createDBContext(const grpc::S std::string client_ip = peer.substr(pos + 1); Poco::Net::SocketAddress client_address(client_ip); - tmp_context->setUser(user, password, client_address, quota_key); + // For MPP test, we don't care about security config. + if (!context.isMPPTest()) + tmp_context->setUser(user, password, client_address, quota_key); String query_id = getClientMetaVarWithDefault(grpc_context, "query_id", ""); tmp_context->setCurrentQueryId(query_id); @@ -436,4 +439,4 @@ ::grpc::Status FlashService::Compact(::grpc::ServerContext * grpc_context, const return manual_compact_manager->handleRequest(request, response); } -} // namespace DB +} // namespace DB \ No newline at end of file diff --git a/dbms/src/Flash/FlashService.h b/dbms/src/Flash/FlashService.h index 67425a0755b..7a25aae4fa2 100644 --- a/dbms/src/Flash/FlashService.h +++ b/dbms/src/Flash/FlashService.h @@ -122,4 +122,4 @@ class AsyncFlashService final : public FlashService } }; -} // namespace DB +} // namespace DB \ No newline at end of file diff --git a/dbms/src/Flash/Mpp/MPPTask.cpp b/dbms/src/Flash/Mpp/MPPTask.cpp index fc333c7e1fa..43ced96f844 100644 --- a/dbms/src/Flash/Mpp/MPPTask.cpp +++ b/dbms/src/Flash/Mpp/MPPTask.cpp @@ -278,6 +278,7 @@ void MPPTask::prepare(const mpp::DispatchTaskRequest & task_request) dag_context->log = log; dag_context->tables_regions_info = std::move(tables_regions_info); dag_context->tidb_host = context->getClientInfo().current_address.toString(); + context->setDAGContext(dag_context.get()); if (dag_context->isRootMPPTask()) diff --git a/dbms/src/Flash/tests/gtest_aggregation_executor.cpp b/dbms/src/Flash/tests/gtest_aggregation_executor.cpp index 9cbf0b80f89..30923dd5475 100644 --- a/dbms/src/Flash/tests/gtest_aggregation_executor.cpp +++ b/dbms/src/Flash/tests/gtest_aggregation_executor.cpp @@ -325,8 +325,8 @@ try .aggregation({}, {col("s1")}) .build(context); { - ASSERT_COLUMNS_EQ_R(executeStreams(request), - createColumns({toNullableVec("s1", {{}, "banana"})})); + ASSERT_COLUMNS_EQ_UR(executeStreams(request), + createColumns({toNullableVec("s1", {{}, "banana"})})); } } CATCH diff --git a/dbms/src/Flash/tests/gtest_compute_server.cpp b/dbms/src/Flash/tests/gtest_compute_server.cpp new file mode 100644 index 00000000000..d2e0d0b78e5 --- /dev/null +++ b/dbms/src/Flash/tests/gtest_compute_server.cpp @@ -0,0 +1,109 @@ +// Copyright 2022 PingCAP, Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +namespace DB +{ +namespace tests +{ +class ComputeServerRunner : public DB::tests::MPPTaskTestUtils +{ +public: + void initializeContext() override + { + ExecutorTest::initializeContext(); + /// for agg + context.addMockTable( + {"test_db", "test_table_1"}, + {{"s1", TiDB::TP::TypeLong}, {"s2", TiDB::TP::TypeString}, {"s3", TiDB::TP::TypeString}}, + {toNullableVec("s1", {1, {}, 10000000}), toNullableVec("s2", {"apple", {}, "banana"}), toNullableVec("s3", {"apple", {}, "banana"})}); + + /// for join + context.addMockTable( + {"test_db", "l_table"}, + {{"s", TiDB::TP::TypeString}, {"join_c", TiDB::TP::TypeString}}, + {toNullableVec("s", {"banana", {}, "banana"}), toNullableVec("join_c", {"apple", {}, "banana"})}); + context.addMockTable( + {"test_db", "r_table"}, + {{"s", TiDB::TP::TypeString}, {"join_c", TiDB::TP::TypeString}}, + {toNullableVec("s", {"banana", {}, "banana"}), toNullableVec("join_c", {"apple", {}, "banana"})}); + } +}; + +TEST_F(ComputeServerRunner, runAggTasks) +try +{ + { + auto tasks = context.scan("test_db", "test_table_1") + .aggregation({Max(col("s1"))}, {col("s2"), col("s3")}) + .project({"max(s1)"}) + .buildMPPTasks(context); + + size_t task_size = tasks.size(); + + std::vector expected_strings = { + "exchange_sender_5 | type:Hash, {<0, Long>, <1, String>, <2, String>}\n" + " aggregation_4 | group_by: {<1, String>, <2, String>}, agg_func: {max(<0, Long>)}\n" + " table_scan_0 | {<0, Long>, <1, String>, <2, String>}\n", + "exchange_sender_3 | type:PassThrough, {<0, Long>}\n" + " project_2 | {<0, Long>}\n" + " aggregation_1 | group_by: {<1, String>, <2, String>}, agg_func: {max(<0, Long>)}\n" + " exchange_receiver_6 | type:PassThrough, {<0, Long>, <1, String>, <2, String>}\n"}; + for (size_t i = 0; i < task_size; ++i) + { + ASSERT_DAGREQUEST_EQAUL(expected_strings[i], tasks[i].dag_request); + } + + auto expected_cols = {toNullableVec({1, {}, 10000000})}; + ASSERT_MPPTASK_EQUAL(tasks, expected_cols); + } +} +CATCH + +TEST_F(ComputeServerRunner, runJoinTasks) +try +{ + auto tasks = context + .scan("test_db", "l_table") + .join(context.scan("test_db", "r_table"), {col("join_c")}, tipb::JoinType::TypeLeftOuterJoin) + .topN("join_c", false, 2) + .buildMPPTasks(context); + + size_t task_size = tasks.size(); + std::vector expected_strings = { + "exchange_sender_6 | type:Hash, {<0, String>}\n" + " table_scan_1 | {<0, String>}", + "exchange_sender_5 | type:Hash, {<0, String>, <1, String>}\n" + " table_scan_0 | {<0, String>, <1, String>}", + "exchange_sender_4 | type:PassThrough, {<0, String>, <1, String>, <2, String>}\n" + " topn_3 | order_by: {(<1, String>, desc: false)}, limit: 2\n" + " Join_2 | LeftOuterJoin, HashJoin. left_join_keys: {<0, String>}, right_join_keys: {<0, String>}\n" + " exchange_receiver_7 | type:PassThrough, {<0, String>, <1, String>}\n" + " exchange_receiver_8 | type:PassThrough, {<0, String>}"}; + for (size_t i = 0; i < task_size; ++i) + { + ASSERT_DAGREQUEST_EQAUL(expected_strings[i], tasks[i].dag_request); + } + + auto expected_cols = { + toNullableVec({{}, "banana"}), + toNullableVec({{}, "apple"}), + toNullableVec({{}, {}})}; + ASSERT_MPPTASK_EQUAL(tasks, expected_cols); +} +CATCH + +} // namespace tests +} // namespace DB diff --git a/dbms/src/Interpreters/Context.cpp b/dbms/src/Interpreters/Context.cpp index 7cd0cb5ad53..83e678101db 100644 --- a/dbms/src/Interpreters/Context.cpp +++ b/dbms/src/Interpreters/Context.cpp @@ -1894,7 +1894,7 @@ size_t Context::getMaxStreams() const bool is_cop_request = false; if (dag_context != nullptr) { - if (dag_context->isTest()) + if (isExecutorTest()) max_streams = dag_context->initialize_concurrency; else if (!dag_context->isBatchCop() && !dag_context->isMPPTask()) { @@ -1912,6 +1912,56 @@ size_t Context::getMaxStreams() const return max_streams; } +bool Context::isMPPTest() const +{ + return test_mode == mpp_test; +} + +void Context::setMPPTest() +{ + test_mode = mpp_test; +} + +bool Context::isExecutorTest() const +{ + return test_mode == executor_test; +} + +void Context::setExecutorTest() +{ + test_mode = executor_test; +} + +bool Context::isTest() const +{ + return test_mode != non_test; +} + +void Context::setColumnsForTest(std::unordered_map & columns_for_test_map_) +{ + columns_for_test_map = columns_for_test_map_; +} + +std::unordered_map & Context::getColumnsForTestMap() +{ + return columns_for_test_map; +} + +ColumnsWithTypeAndName Context::columnsForTest(String executor_id) +{ + auto it = columns_for_test_map.find(executor_id); + if (unlikely(it == columns_for_test_map.end())) + { + throw DB::Exception("Don't have columns for mock source executors"); + } + return it->second; +} + +bool Context::columnsForTestEmpty() +{ + return columns_for_test_map.empty(); +} + SessionCleaner::~SessionCleaner() { try diff --git a/dbms/src/Interpreters/Context.h b/dbms/src/Interpreters/Context.h index 7663b40f612..af49c37c041 100644 --- a/dbms/src/Interpreters/Context.h +++ b/dbms/src/Interpreters/Context.h @@ -14,6 +14,7 @@ #pragma once +#include #include #include #include @@ -154,10 +155,19 @@ class Context bool use_l0_opt = true; + enum TestMode + { + non_test, + mpp_test, + executor_test + }; + TestMode test_mode = non_test; + TimezoneInfo timezone_info; DAGContext * dag_context = nullptr; - + // TODO: add MockStorage. + std::unordered_map columns_for_test_map; /// , for multiple sources using DatabasePtr = std::shared_ptr; using Databases = std::map>; @@ -463,6 +473,17 @@ class Context size_t getMaxStreams() const; + /// For executor tests and MPPTask tests. + bool isMPPTest() const; + void setMPPTest(); + bool isExecutorTest() const; + void setExecutorTest(); + bool isTest() const; + void setColumnsForTest(std::unordered_map & columns_for_test_map_); + std::unordered_map & getColumnsForTestMap(); + ColumnsWithTypeAndName columnsForTest(String executor_id); + bool columnsForTestEmpty(); + private: /** Check if the current client has access to the specified database. * If access is denied, throw an exception. diff --git a/dbms/src/Server/CMakeLists.txt b/dbms/src/Server/CMakeLists.txt index 77ab5e69838..7a4443d1dc9 100644 --- a/dbms/src/Server/CMakeLists.txt +++ b/dbms/src/Server/CMakeLists.txt @@ -27,6 +27,12 @@ option(ENABLE_TIFLASH_PAGECTL "Enable pagectl: tools to debug page storage" ${EN configure_file (config_tools.h.in ${CMAKE_CURRENT_BINARY_DIR}/config_tools.h) +add_library(server_for_test + FlashGrpcServerHolder.cpp +) + +target_link_libraries (server_for_test PUBLIC clickhouse_common_io clickhouse_storages_system) + add_library (clickhouse-server-lib HTTPHandler.cpp FlashGrpcServerHolder.cpp diff --git a/dbms/src/Server/IServer.h b/dbms/src/Server/IServer.h index d8112954d6e..72934f60a80 100644 --- a/dbms/src/Server/IServer.h +++ b/dbms/src/Server/IServer.h @@ -14,11 +14,10 @@ #pragma once -#include -#include - #include #include +#include +#include namespace DB @@ -44,4 +43,4 @@ class IServer virtual ~IServer() {} }; -} +} // namespace DB diff --git a/dbms/src/Server/MockComputeClient.h b/dbms/src/Server/MockComputeClient.h new file mode 100644 index 00000000000..b4b2e05a7fb --- /dev/null +++ b/dbms/src/Server/MockComputeClient.h @@ -0,0 +1,50 @@ +// Copyright 2022 PingCAP, Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include +using grpc::Status; +using grpc_impl::Channel; + +namespace DB +{ + +/// Send RPC Requests to FlashService +/// TODO: Support more methods that FlashService serve. +/// TODO: Support more config of RPC client. +class MockComputeClient +{ +public: + explicit MockComputeClient(std::shared_ptr channel) + : stub(tikvpb::Tikv::NewStub(channel)) + {} + + void runDispatchMPPTask(std::shared_ptr request) + { + mpp::DispatchTaskResponse response; + grpc::ClientContext context; + Status status = stub->DispatchMPPTask(&context, *request, &response); + if (!status.ok()) + { + throw Exception(fmt::format("Meet error while dispatch mpp task, error code = {}, message = {}", status.error_code(), status.error_message())); + } + } + +private: + std::unique_ptr stub{}; +}; +} // namespace DB \ No newline at end of file diff --git a/dbms/src/Server/Server.cpp b/dbms/src/Server/Server.cpp index 607b7e3e6c8..5226089c9f7 100644 --- a/dbms/src/Server/Server.cpp +++ b/dbms/src/Server/Server.cpp @@ -1398,4 +1398,4 @@ int mainEntryClickHouseServer(int argc, char ** argv) auto code = DB::getCurrentExceptionCode(); return code ? code : 1; } -} +} \ No newline at end of file diff --git a/dbms/src/TestUtils/CMakeLists.txt b/dbms/src/TestUtils/CMakeLists.txt index 2adee4f9859..38f0ddfeb59 100644 --- a/dbms/src/TestUtils/CMakeLists.txt +++ b/dbms/src/TestUtils/CMakeLists.txt @@ -18,7 +18,7 @@ add_headers_and_sources(test_util .) list(REMOVE_ITEM test_util_sources "bench_dbms_main.cpp" "gtests_dbms_main.cpp") add_library(test_util_gtest_main ${test_util_headers} ${test_util_sources} gtests_dbms_main.cpp) -target_link_libraries(test_util_gtest_main dbms gtest_main clickhouse_aggregate_functions) +target_link_libraries(test_util_gtest_main dbms gtest_main clickhouse_aggregate_functions server_for_test) add_library(test_util_bench_main ${test_util_headers} ${test_util_sources} bench_dbms_main.cpp) target_link_libraries(test_util_bench_main dbms gtest_main benchmark) diff --git a/dbms/src/TestUtils/ExecutorTestUtils.cpp b/dbms/src/TestUtils/ExecutorTestUtils.cpp index 634e483abd2..0f1b8f0128b 100644 --- a/dbms/src/TestUtils/ExecutorTestUtils.cpp +++ b/dbms/src/TestUtils/ExecutorTestUtils.cpp @@ -34,6 +34,7 @@ void ExecutorTest::initializeContext() dag_context_ptr = std::make_unique(1024); context = MockDAGRequestContext(TiFlashTestEnv::getContext()); dag_context_ptr->log = Logger::get("executorTest"); + TiFlashTestEnv::getGlobalContext().setExecutorTest(); } void ExecutorTest::SetUpTestCase() @@ -66,6 +67,7 @@ void ExecutorTest::executeInterpreter(const String & expected_string, const std: { DAGContext dag_context(*request, "interpreter_test", concurrency); context.context.setDAGContext(&dag_context); + context.context.setExecutorTest(); // Currently, don't care about regions information in interpreter tests. DAGQuerySource dag(context.context); auto res = executeQuery(dag, context.context, false, QueryProcessingStage::Complete); @@ -120,7 +122,8 @@ DB::ColumnsWithTypeAndName readBlock(BlockInputStreamPtr stream) DB::ColumnsWithTypeAndName ExecutorTest::executeStreams(const std::shared_ptr & request, std::unordered_map & source_columns_map, size_t concurrency) { DAGContext dag_context(*request, "executor_test", concurrency); - dag_context.setColumnsForTest(source_columns_map); + context.context.setExecutorTest(); + context.context.setColumnsForTest(source_columns_map); context.context.setDAGContext(&dag_context); // Currently, don't care about regions information in tests. DAGQuerySource dag(context.context); @@ -139,6 +142,16 @@ DB::ColumnsWithTypeAndName ExecutorTest::executeStreamsWithSingleSource(const st return executeStreams(request, source_columns_map, concurrency); } +DB::ColumnsWithTypeAndName ExecutorTest::executeMPPTasks(QueryTasks & tasks) +{ + DAGProperties properties; + // enable mpp + properties.is_mpp_query = true; + context.context.setMPPTest(); + auto res = executeMPPQuery(context.context, properties, tasks); + return readBlock(res); +} + void ExecutorTest::dagRequestEqual(const String & expected_string, const std::shared_ptr & actual) { ASSERT_EQ(Poco::trim(expected_string), Poco::trim(ExecutorSerializer().serialize(actual.get()))); diff --git a/dbms/src/TestUtils/ExecutorTestUtils.h b/dbms/src/TestUtils/ExecutorTestUtils.h index 59b829e04b5..ee1743fcd78 100644 --- a/dbms/src/TestUtils/ExecutorTestUtils.h +++ b/dbms/src/TestUtils/ExecutorTestUtils.h @@ -15,6 +15,7 @@ #pragma once #include +#include #include #include #include @@ -79,6 +80,7 @@ class ExecutorTest : public ::testing::Test const std::shared_ptr & request, std::unordered_map & source_columns_map, size_t concurrency = 1); + ColumnsWithTypeAndName executeStreams( const std::shared_ptr & request, size_t concurrency = 1); @@ -89,6 +91,8 @@ class ExecutorTest : public ::testing::Test SourceType type = TableScan, size_t concurrency = 1); + ColumnsWithTypeAndName executeMPPTasks(QueryTasks & tasks); + protected: MockDAGRequestContext context; std::unique_ptr dag_context_ptr; @@ -96,4 +100,9 @@ class ExecutorTest : public ::testing::Test #define ASSERT_DAGREQUEST_EQAUL(str, request) dagRequestEqual((str), (request)); #define ASSERT_BLOCKINPUTSTREAM_EQAUL(str, request, concurrency) executeInterpreter((str), (request), (concurrency)) +#define ASSERT_MPPTASK_EQUAL(tasks, expect_cols) \ + TiFlashTestEnv::getGlobalContext().setColumnsForTest(context.executorIdColumnsMap()); \ + TiFlashTestEnv::getGlobalContext().setMPPTest(); \ + ASSERT_COLUMNS_EQ_UR(executeMPPTasks(tasks), expected_cols); + } // namespace DB::tests diff --git a/dbms/src/TestUtils/MPPTaskTestUtils.h b/dbms/src/TestUtils/MPPTaskTestUtils.h new file mode 100644 index 00000000000..9e710c6d00f --- /dev/null +++ b/dbms/src/TestUtils/MPPTaskTestUtils.h @@ -0,0 +1,63 @@ +// Copyright 2022 PingCAP, Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include + +namespace DB::tests +{ + +class MPPTaskTestUtils : public ExecutorTest +{ +public: + static void SetUpTestCase() + { + ExecutorTest::SetUpTestCase(); + TiFlashSecurityConfig security_config; + TiFlashRaftConfig raft_config; + raft_config.flash_server_addr = "0.0.0.0:3930"; // TODO:: each FlashGrpcServer should have unique addr. + Poco::AutoPtr config = new Poco::Util::LayeredConfiguration; + log_ptr = Logger::get("compute_test"); + compute_server_ptr = std::make_unique(TiFlashTestEnv::getGlobalContext(), *config, security_config, raft_config, log_ptr); + } + + static void TearDownTestCase() + { + compute_server_ptr.reset(); + } + +protected: + // TODO: Mock a simple storage layer to store test input. + // Currently the lifetime of a server is held in this scope. + // TODO: Add ComputeServerManager to maintain the lifetime of a bunch of servers. + // Note: May go through GRPC fail number 14 --> socket closed, + // if you start a server, send a request to the server using pingcap::kv::RpcClient, + // then close the server and start the server using the same addr, + // then send a request to the new server using pingcap::kv::RpcClient. + static std::unique_ptr compute_server_ptr; + static LoggerPtr log_ptr; +}; + +std::unique_ptr MPPTaskTestUtils::compute_server_ptr = nullptr; +LoggerPtr MPPTaskTestUtils::log_ptr = nullptr; + + +#define ASSERT_MPPTASK_EQUAL(tasks, expect_cols) \ + TiFlashTestEnv::getGlobalContext().setColumnsForTest(context.executorIdColumnsMap()); \ + TiFlashTestEnv::getGlobalContext().setMPPTest(); \ + ASSERT_COLUMNS_EQ_UR(executeMPPTasks(tasks), expected_cols); + +} // namespace DB::tests diff --git a/dbms/src/TestUtils/mockExecutor.cpp b/dbms/src/TestUtils/mockExecutor.cpp index de65ab692c8..2267aba6115 100644 --- a/dbms/src/TestUtils/mockExecutor.cpp +++ b/dbms/src/TestUtils/mockExecutor.cpp @@ -107,12 +107,14 @@ void columnPrune(ExecutorPtr executor) // Split a DAGRequest into multiple QueryTasks which can be dispatched to multiple Compute nodes. -// Currently we don't support window functions. +// Currently we don't support window functions +// and MPPTask with multiple partitions. QueryTasks DAGRequestBuilder::buildMPPTasks(MockDAGRequestContext & mock_context) { columnPrune(root); // enable mpp properties.is_mpp_query = true; + // TODO find a way to record service info. auto query_tasks = queryPlanToQueryTasks(properties, root, executor_index, mock_context.context); root.reset(); executor_index = 0; From 0bafed9b2be3da85a9ec8cb3e160250ee67446ec Mon Sep 17 00:00:00 2001 From: Zhigao Tong Date: Wed, 3 Aug 2022 11:50:05 +0800 Subject: [PATCH 07/11] gtests: fix tiflash-sanitizer-daily (ASan) (#5523) ref pingcap/tiflash#5294 --- dbms/src/Storages/Transaction/tests/gtest_tidb_collator.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dbms/src/Storages/Transaction/tests/gtest_tidb_collator.cpp b/dbms/src/Storages/Transaction/tests/gtest_tidb_collator.cpp index 904cefb26ef..0ba71d6bbbb 100644 --- a/dbms/src/Storages/Transaction/tests/gtest_tidb_collator.cpp +++ b/dbms/src/Storages/Transaction/tests/gtest_tidb_collator.cpp @@ -237,7 +237,7 @@ void testCollator() } { strs.resize(strs.size() + s.size()); - memcpySmallAllowReadWriteOverflow15( + std::memcpy( &strs[strs.size() - s.size()], s.data(), s.size()); From e17ef39aeb3ea7aeb8149cfa9c2aab2d916c272e Mon Sep 17 00:00:00 2001 From: Annie of the Stars <105339527+AnnieoftheStars@users.noreply.github.com> Date: Thu, 4 Aug 2022 15:50:06 +0800 Subject: [PATCH 08/11] feat: implement shiftLeft function push down (#5495) close pingcap/tiflash#5099 --- .../DAGExpressionAnalyzerHelper.cpp | 1 + dbms/src/Flash/Coprocessor/DAGUtils.cpp | 2 +- dbms/src/Functions/bitShiftLeft.cpp | 13 +- .../Functions/tests/gtest_bitshiftleft.cpp | 273 ++++++++++++++++++ .../expr/bitshift_operator.test | 26 ++ 5 files changed, 313 insertions(+), 2 deletions(-) create mode 100644 dbms/src/Functions/tests/gtest_bitshiftleft.cpp diff --git a/dbms/src/Flash/Coprocessor/DAGExpressionAnalyzerHelper.cpp b/dbms/src/Flash/Coprocessor/DAGExpressionAnalyzerHelper.cpp index 23bbb4586b3..b45ade0f7d2 100644 --- a/dbms/src/Flash/Coprocessor/DAGExpressionAnalyzerHelper.cpp +++ b/dbms/src/Flash/Coprocessor/DAGExpressionAnalyzerHelper.cpp @@ -450,6 +450,7 @@ DAGExpressionAnalyzerHelper::FunctionBuilderMap DAGExpressionAnalyzerHelper::fun {"bitOr", DAGExpressionAnalyzerHelper::buildBitwiseFunction}, {"bitXor", DAGExpressionAnalyzerHelper::buildBitwiseFunction}, {"bitNot", DAGExpressionAnalyzerHelper::buildBitwiseFunction}, + {"bitShiftLeft", DAGExpressionAnalyzerHelper::buildBitwiseFunction}, {"bitShiftRight", DAGExpressionAnalyzerHelper::buildBitwiseFunction}, {"leftUTF8", DAGExpressionAnalyzerHelper::buildLeftUTF8Function}, {"date_add", DAGExpressionAnalyzerHelper::buildDateAddOrSubFunction}, diff --git a/dbms/src/Flash/Coprocessor/DAGUtils.cpp b/dbms/src/Flash/Coprocessor/DAGUtils.cpp index b1d7b99e356..c83f029c7f9 100644 --- a/dbms/src/Flash/Coprocessor/DAGUtils.cpp +++ b/dbms/src/Flash/Coprocessor/DAGUtils.cpp @@ -331,7 +331,7 @@ const std::unordered_map scalar_func_map({ {tipb::ScalarFuncSig::DecimalIsFalse, "isFalse"}, {tipb::ScalarFuncSig::DecimalIsFalseWithNull, "isFalseWithNull"}, - //{tipb::ScalarFuncSig::LeftShift, "cast"}, + {tipb::ScalarFuncSig::LeftShift, "bitShiftLeft"}, {tipb::ScalarFuncSig::RightShift, "bitShiftRight"}, //{tipb::ScalarFuncSig::BitCount, "cast"}, diff --git a/dbms/src/Functions/bitShiftLeft.cpp b/dbms/src/Functions/bitShiftLeft.cpp index 1ada5ca748c..ea254665cf3 100644 --- a/dbms/src/Functions/bitShiftLeft.cpp +++ b/dbms/src/Functions/bitShiftLeft.cpp @@ -29,7 +29,18 @@ struct BitShiftLeftImpl template static Result apply(A a, B b) { - return static_cast(a) << static_cast(b); + // It is an undefined behavior for shift operation in c++ that the right operand is negative or greater than + // or equal to the number of digits of the bits in the (promoted) left operand. + // See https://en.cppreference.com/w/cpp/language/operator_arithmetic for details. + if (static_cast(b) >= std::numeric_limits(a))>::digits) + { + return static_cast(0); + } + // Note that we do not consider the case that the right operand is negative, + // since other types will all be cast to uint64 before shift operation + // according to DAGExpressionAnalyzerHelper::buildBitwiseFunction. + // Therefore, we simply suppress clang-tidy checking here. + return static_cast(a) << static_cast(b); // NOLINT(clang-analyzer-core.UndefinedBinaryOperatorResult) } template static Result apply(A, B, UInt8 &) diff --git a/dbms/src/Functions/tests/gtest_bitshiftleft.cpp b/dbms/src/Functions/tests/gtest_bitshiftleft.cpp new file mode 100644 index 00000000000..dc1f854b736 --- /dev/null +++ b/dbms/src/Functions/tests/gtest_bitshiftleft.cpp @@ -0,0 +1,273 @@ +// Copyright 2022 PingCAP, Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +namespace DB +{ +namespace tests +{ +class TestFunctionBitShiftLeft : public DB::tests::FunctionTest +{ +}; + +#define ASSERT_BITSHIFTLEFT(t1, t2, result) \ + ASSERT_COLUMN_EQ(result, executeFunction("bitShiftLeft", {t1, t2})) + +TEST_F(TestFunctionBitShiftLeft, Simple) +try +{ + ASSERT_BITSHIFTLEFT(createColumn>({11}), + createColumn>({3}), + createColumn>({88})); +} +CATCH + +/// Note: Only IntX and UIntX will be received by BitShiftLeft, others will be casted by TiDB planner. +/// Note: BitShiftLeft will further cast other types to UInt64 before doing shift. +TEST_F(TestFunctionBitShiftLeft, TypePromotion) +try +{ + // Type Promotion + ASSERT_BITSHIFTLEFT(createColumn>({-1}), createColumn>({1}), createColumn>({18446744073709551614ull})); + ASSERT_BITSHIFTLEFT(createColumn>({-1}), createColumn>({1}), createColumn>({18446744073709551614ull})); + ASSERT_BITSHIFTLEFT(createColumn>({-1}), createColumn>({1}), createColumn>({18446744073709551614ull})); + ASSERT_BITSHIFTLEFT(createColumn>({-1}), createColumn>({1}), createColumn>({18446744073709551614ull})); + + ASSERT_BITSHIFTLEFT(createColumn>({1}), createColumn>({0}), createColumn>({1})); + ASSERT_BITSHIFTLEFT(createColumn>({1}), createColumn>({0}), createColumn>({1})); + ASSERT_BITSHIFTLEFT(createColumn>({1}), createColumn>({0}), createColumn>({1})); + ASSERT_BITSHIFTLEFT(createColumn>({1}), createColumn>({0}), createColumn>({1})); + + // Type Promotion across signed/unsigned + ASSERT_BITSHIFTLEFT(createColumn>({-1}), createColumn>({0}), createColumn>({18446744073709551615ull})); + ASSERT_BITSHIFTLEFT(createColumn>({-1}), createColumn>({0}), createColumn>({18446744073709551615ull})); + ASSERT_BITSHIFTLEFT(createColumn>({1}), createColumn>({0}), createColumn>({1})); + ASSERT_BITSHIFTLEFT(createColumn>({1}), createColumn>({0}), createColumn>({1})); +} +CATCH + +TEST_F(TestFunctionBitShiftLeft, Nullable) +try +{ + // Non Nullable + ASSERT_BITSHIFTLEFT(createColumn({1}), createColumn({0}), createColumn({1})); + ASSERT_BITSHIFTLEFT(createColumn({1}), createColumn({0}), createColumn({1})); + ASSERT_BITSHIFTLEFT(createColumn({1}), createColumn({0}), createColumn({1})); + ASSERT_BITSHIFTLEFT(createColumn({1}), createColumn({0}), createColumn({1})); + + ASSERT_BITSHIFTLEFT(createColumn({1}), createColumn({0}), createColumn({1})); + ASSERT_BITSHIFTLEFT(createColumn({1}), createColumn({0}), createColumn({1})); + ASSERT_BITSHIFTLEFT(createColumn({1}), createColumn({0}), createColumn({1})); + ASSERT_BITSHIFTLEFT(createColumn({1}), createColumn({0}), createColumn({1})); + + ASSERT_BITSHIFTLEFT(createColumn({1}), createColumn({0}), createColumn({1})); + ASSERT_BITSHIFTLEFT(createColumn({1}), createColumn({0}), createColumn({1})); + ASSERT_BITSHIFTLEFT(createColumn({1}), createColumn({0}), createColumn({1})); + ASSERT_BITSHIFTLEFT(createColumn({1}), createColumn({0}), createColumn({1})); + + // Across Nullable and non-Nullable + ASSERT_BITSHIFTLEFT(createColumn({1}), createColumn>({0}), createColumn>({1})); + ASSERT_BITSHIFTLEFT(createColumn({1}), createColumn>({0}), createColumn>({1})); + ASSERT_BITSHIFTLEFT(createColumn({1}), createColumn>({0}), createColumn>({1})); + ASSERT_BITSHIFTLEFT(createColumn({1}), createColumn>({0}), createColumn>({1})); + + ASSERT_BITSHIFTLEFT(createColumn({1}), createColumn>({0}), createColumn>({1})); + ASSERT_BITSHIFTLEFT(createColumn({1}), createColumn>({0}), createColumn>({1})); + ASSERT_BITSHIFTLEFT(createColumn({1}), createColumn>({0}), createColumn>({1})); + ASSERT_BITSHIFTLEFT(createColumn({1}), createColumn>({0}), createColumn>({1})); + + ASSERT_BITSHIFTLEFT(createColumn({1}), createColumn>({0}), createColumn>({1})); + ASSERT_BITSHIFTLEFT(createColumn({1}), createColumn>({0}), createColumn>({1})); + ASSERT_BITSHIFTLEFT(createColumn({1}), createColumn>({0}), createColumn>({1})); + ASSERT_BITSHIFTLEFT(createColumn({1}), createColumn>({0}), createColumn>({1})); + + ASSERT_BITSHIFTLEFT(createColumn>({1}), createColumn({0}), createColumn>({1})); + ASSERT_BITSHIFTLEFT(createColumn>({1}), createColumn({0}), createColumn>({1})); + ASSERT_BITSHIFTLEFT(createColumn>({1}), createColumn({0}), createColumn>({1})); + ASSERT_BITSHIFTLEFT(createColumn>({1}), createColumn({0}), createColumn>({1})); + + ASSERT_BITSHIFTLEFT(createColumn>({1}), createColumn({0}), createColumn>({1})); + ASSERT_BITSHIFTLEFT(createColumn>({1}), createColumn({0}), createColumn>({1})); + ASSERT_BITSHIFTLEFT(createColumn>({1}), createColumn({0}), createColumn>({1})); + ASSERT_BITSHIFTLEFT(createColumn>({1}), createColumn({0}), createColumn>({1})); + + ASSERT_BITSHIFTLEFT(createColumn>({1}), createColumn({0}), createColumn>({1})); + ASSERT_BITSHIFTLEFT(createColumn>({1}), createColumn({0}), createColumn>({1})); + ASSERT_BITSHIFTLEFT(createColumn>({1}), createColumn({0}), createColumn>({1})); + ASSERT_BITSHIFTLEFT(createColumn>({1}), createColumn({0}), createColumn>({1})); +} +CATCH + +TEST_F(TestFunctionBitShiftLeft, TypeCastWithConst) +try +{ + /// need test these kinds of columns: + /// 1. ColumnVector + /// 2. ColumnVector + /// 3. ColumnConst + /// 4. ColumnConst, value != null + /// 5. ColumnConst, value = null + + ASSERT_BITSHIFTLEFT(createColumn({0, 0, 1, 1}), createColumn({0, 1, 0, 1}), createColumn({0, 0, 1, 2})); + ASSERT_BITSHIFTLEFT(createColumn({0, 0, 1, 1}), createColumn>({0, 1, std::nullopt, std::nullopt}), createColumn>({0, 0, std::nullopt, std::nullopt})); + ASSERT_BITSHIFTLEFT(createColumn({0, 0, 1, 1}), createConstColumn(4, 0), createColumn({0, 0, 1, 1})); + ASSERT_BITSHIFTLEFT(createColumn({0, 0, 1, 1}), createConstColumn>(4, 0), createColumn({0, 0, 1, 1})); + ASSERT_BITSHIFTLEFT(createColumn({0, 0, 1, 1}), createConstColumn>(4, std::nullopt), createConstColumn>(4, std::nullopt)); // become const in wrapInNullable + + ASSERT_BITSHIFTLEFT(createColumn>({0, 1, std::nullopt, std::nullopt}), createColumn({0, 1, 0, 1}), createColumn>({0, 2, std::nullopt, std::nullopt})); + ASSERT_BITSHIFTLEFT(createColumn>({0, 1, std::nullopt, std::nullopt}), createColumn>({0, 1, std::nullopt, std::nullopt}), createColumn>({0, 2, std::nullopt, std::nullopt})); + ASSERT_BITSHIFTLEFT(createColumn>({0, 1, std::nullopt, std::nullopt}), createConstColumn(4, 0), createColumn>({0, 1, std::nullopt, std::nullopt})); + ASSERT_BITSHIFTLEFT(createColumn>({0, 1, std::nullopt, std::nullopt}), createConstColumn(4, 0), createColumn>({0, 1, std::nullopt, std::nullopt})); + ASSERT_BITSHIFTLEFT(createColumn>({0, 1, std::nullopt, std::nullopt}), createConstColumn>(4, std::nullopt), createConstColumn>(4, std::nullopt)); + + ASSERT_BITSHIFTLEFT(createConstColumn(4, 1), createColumn({0, 1, 0, 1}), createColumn({1, 2, 1, 2})); + ASSERT_BITSHIFTLEFT(createConstColumn(4, 1), createColumn>({0, 1, std::nullopt, std::nullopt}), createColumn>({1, 2, std::nullopt, std::nullopt})); + ASSERT_BITSHIFTLEFT(createConstColumn(4, 1), createConstColumn(4, 0), createConstColumn(4, 1)); + ASSERT_BITSHIFTLEFT(createConstColumn(4, 1), createConstColumn>(4, 0), createConstColumn(4, 1)); + ASSERT_BITSHIFTLEFT(createConstColumn(4, 1), createConstColumn>(4, std::nullopt), createConstColumn>(4, std::nullopt)); + + ASSERT_BITSHIFTLEFT(createConstColumn>(4, 1), createColumn({0, 1, 0, 1}), createColumn({1, 2, 1, 2})); + ASSERT_BITSHIFTLEFT(createConstColumn>(4, 1), createColumn>({0, 1, std::nullopt, std::nullopt}), createColumn>({1, 2, std::nullopt, std::nullopt})); + ASSERT_BITSHIFTLEFT(createConstColumn>(4, 1), createConstColumn(4, 0), createConstColumn(4, 1)); + ASSERT_BITSHIFTLEFT(createConstColumn>(4, 1), createConstColumn>(4, 0), createConstColumn(4, 1)); + ASSERT_BITSHIFTLEFT(createConstColumn>(4, 1), createConstColumn>(4, std::nullopt), createConstColumn>(4, std::nullopt)); + + ASSERT_BITSHIFTLEFT(createConstColumn>(4, std::nullopt), createColumn({0, 1, 0, 1}), createConstColumn>(4, std::nullopt)); + ASSERT_BITSHIFTLEFT(createConstColumn>(4, std::nullopt), createColumn>({0, 1, std::nullopt, std::nullopt}), createConstColumn>(4, std::nullopt)); + ASSERT_BITSHIFTLEFT(createConstColumn>(4, std::nullopt), createConstColumn(4, 0), createConstColumn>(4, std::nullopt)); + ASSERT_BITSHIFTLEFT(createConstColumn>(4, std::nullopt), createConstColumn(4, 0), createConstColumn>(4, std::nullopt)); + ASSERT_BITSHIFTLEFT(createConstColumn>(4, std::nullopt), createConstColumn>(4, std::nullopt), createConstColumn>(4, std::nullopt)); +} +CATCH + +TEST_F(TestFunctionBitShiftLeft, Boundary) +try +{ + ASSERT_BITSHIFTLEFT(createColumn({127, 127, -128, -128}), createColumn({0, 7, 0, 7}), createColumn({127, 16256, 18446744073709551488ull, 18446744073709535232ull})); + ASSERT_BITSHIFTLEFT(createColumn({32767, 32767, -32768, -32768}), createColumn({0, 15, 0, 15}), createColumn({32767, 1073709056, 18446744073709518848ull, 18446744072635809792ull})); + + ASSERT_BITSHIFTLEFT(createColumn({0, 0, 1, 1, -1, -1, INT64_MAX, INT64_MAX, INT64_MIN, INT64_MIN}), + createColumn({0, 63, 0, 63, 0, 63, 0, 63, 0, 63}), + createColumn({0, 0, 1, 9223372036854775808ull, 18446744073709551615ull, 9223372036854775808ull, 9223372036854775807ull, 9223372036854775808ull, 9223372036854775808ull, 0})); +} +CATCH + +TEST_F(TestFunctionBitShiftLeft, UINT64) +try +{ + ASSERT_BITSHIFTLEFT(createColumn({0, UINT64_MAX}), + createColumn({63, 63}), + createColumn({0, 9223372036854775808ull})); + + ASSERT_BITSHIFTLEFT(createColumn>({0, UINT64_MAX, std::nullopt}), + createColumn>({63, 63, 63}), + createColumn>({0, 9223372036854775808ull, std::nullopt})); + + ASSERT_BITSHIFTLEFT(createColumn>({0, UINT64_MAX, std::nullopt}), + createColumn({63, 63, 63}), + createColumn>({0, 9223372036854775808ull, std::nullopt})); + + ASSERT_BITSHIFTLEFT(createColumn({0, UINT64_MAX}), + createColumn>({63, 63}), + createColumn>({0, 9223372036854775808ull})); + + ASSERT_BITSHIFTLEFT(createColumn({0, 0, 1, 1, -1, -1, INT64_MAX, INT64_MAX, INT64_MIN, INT64_MIN}), + createColumn({0, UINT64_MAX, 0, UINT64_MAX, 0, UINT64_MAX, 0, UINT64_MAX, 0, UINT64_MAX}), + createColumn({0, 0, 1, 0, 18446744073709551615ull, 0, INT64_MAX, 0, 9223372036854775808ull, 0})); + + + ASSERT_BITSHIFTLEFT(createColumn({0, 0, UINT64_MAX, UINT64_MAX}), + createColumn({0, UINT64_MAX, 0, UINT64_MAX}), + createColumn({0, 0, UINT64_MAX, 0})); + + ASSERT_BITSHIFTLEFT(createColumn>({0, 0, UINT64_MAX, UINT64_MAX, 0, std::nullopt}), + createColumn>({0, UINT64_MAX, 0, UINT64_MAX, std::nullopt, 0}), + createColumn>({0, 0, UINT64_MAX, 0, std::nullopt, std::nullopt})); + + ASSERT_BITSHIFTLEFT(createColumn>({0, 0, UINT64_MAX, UINT64_MAX, std::nullopt}), + createColumn({0, UINT64_MAX, 0, UINT64_MAX, 0}), + createColumn>({0, 0, UINT64_MAX, 0, std::nullopt})); + + ASSERT_BITSHIFTLEFT(createColumn({0, UINT64_MAX, 0, UINT64_MAX, 0}), + createColumn>({0, 0, UINT64_MAX, UINT64_MAX, std::nullopt}), + createColumn>({0, UINT64_MAX, 0, 0, std::nullopt})); + + /* + std::mt19937 gen(std::random_device{}()); + std::uniform_int_distribution dis( + std::numeric_limits::min(), + std::numeric_limits::max() + ); + size_t count = 100; + std::vector v1(count), v2(count), res(count); + for (size_t i=0; i({9590146258092595307ull,10707019645721939841ull,5881591201215258519ull,5632031715512392749ull,14458961951991665960ull,5128508001433333971ull,9332472261576879778ull,3894254697792674671ull,3686723867265895770ull,5228369181925377045ull,15257679458410288384ull,11771509639552105900ull,18418679795092108085ull,9855659417571021370ull,4077869334292632100ull,5391460931979150774ull,9467385864486395009ull,7668968327203373028ull,3000081540418664818ull,2526651933183909841ull,10848278664633919825ull,12728734861754908523ull,13066785977192899361ull,18052975499614771372ull,17129542673283010420ull,11328216476757777123ull,3472066935339592657ull,2176492510111902732ull,12990949065266798350ull,15009988423859086182ull,16923734737646560129ull,10145931322192369702ull,15411802497347660213ull,2061619294213261472ull,17580364414153312329ull,12125264059890046972ull,14105236895418755204ull,18398730584471114130ull,299685962966216102ull,16198499948406097599ull,2471494867787664413ull,4570374989605017140ull,6824988039821274928ull,9677834555477405663ull,4412087854764460966ull,4057574216689522200ull,4469278412271442169ull,2195483625075429663ull,11881108446786181585ull,5924332126027566084ull,3250541536484026531ull,14841325794537432635ull,8914756506150941362ull,2386255466980056883ull,13760701820714978003ull,8972642741869197165ull,16342958971153533843ull,6661808439105638065ull,13203982733978217950ull,12178802970326924901ull,12334199694527701475ull,17567161341866858140ull,12414652975269657715ull,2934467071243815081ull,6925619887984457268ull,16075344167916704125ull,9771482348696948711ull,7565184903186541823ull,1052089026349370803ull,1166166123827253256ull,14765270110884348174ull,1006637556365116527ull,4456932086646133707ull,14569355292824480150ull,644090626559619706ull,6739333562626065749ull,13046051594097911684ull,1889670651667918016ull,809698145438197192ull,1008561797702510136ull,15735380769470049759ull,13969312851421751236ull,5652663870952895520ull,6073389077452560954ull,17203692740201015991ull,12627416461571881794ull,17002372509101573427ull,12750188156795333967ull,4802871164855562574ull,5798598266310893099ull,11109002639116226418ull,11683464915217468028ull,2720046121015539944ull,12315139438585481078ull,12493077138920487022ull,2413460630123408726ull,2090136026867714615ull,9216644480497846345ull,16828372569894418887ull,2320332815156167319ull,}), + createColumn({24,54,57,48,35,51,39,15,49,0,52,6,60,11,21,15,7,31,42,34,53,23,25,7,11,46,27,42,48,25,23,22,19,46,61,49,17,35,24,36,21,32,56,49,46,19,16,12,5,26,28,58,56,45,25,2,47,25,24,36,19,43,3,30,20,26,56,63,59,61,24,24,50,33,8,60,31,1,37,51,43,54,51,22,30,12,9,12,51,31,60,27,32,17,4,54,57,0,35,38}), + createColumn({3891574726585221120ull,16158915463005339648ull,3314649325744685056ull,11253651043868737536ull,14614366883160260608ull,9698501797542363136ull,13616441058819833856ull,10809179421395091456ull,1924162940794044416ull,5228369181925377045ull,5764607523034234880ull,15506853982952712960ull,5764607523034234880ull,3652470547202297856ull,1277578713864601600ull,2923825176436736000ull,12787025863137706112ull,16453584519119765504ull,8207185798278152192ull,10514919160037769216ull,3035426148847714304ull,8458186846902419456ull,4259638639887122432ull,4937854736996783616ull,14042910761747718144ull,17778170635277565952ull,15774523305408593920ull,1346629065141911552ull,11461098101704491008ull,16423978333718446080ull,9943717430914187264ull,12830106927325249536ull,14247866488892948480ull,1308295691751129088ull,2305843009213693952ull,3456512714006855680ull,13579056934690488320ull,3179695887026749440ull,14674633785438371840ull,7988378517584740352ull,8042116021024194560ull,10501433581203619840ull,3458764513820540928ull,6898951679178178560ull,4821525613565181952ull,3604107309594181632ull,1227624260973428736ull,9136564412408262656ull,11260588822966778400ull,13293846752680476672ull,16329074277007491072ull,17005592192950992896ull,12826251738751172608ull,14116075635155664896ull,14706384829390258176ull,17443826893767237044ull,16557906242492694528ull,17693849764497457152ull,6250760634544160768ull,4390981392981295104ull,733708992845971456ull,10859550892101730304ull,7083503433609503640ull,4513840660133969920ull,16826445582531821568ull,7371066879428788224ull,16645304222761353216ull,9223372036854775808ull,10952754293765046272ull,0ull,17267913956513546240ull,9668302355338100736ull,5704934827971575808ull,16610774249069608960ull,17313247809586231808ull,5764607523034234880ull,16559722369413808128ull,3779341303335836032ull,14149387538942394368ull,3584865303386914816ull,3706172222256185344ull,17365880163140632576ull,11601272640106397696ull,14700896002525626368ull,8874759427379429376ull,15674187990554648576ull,16798265942806783488ull,2038217561947303936ull,9975473174625648640ull,5860325600132071424ull,2305843009213693952ull,9406390505584984064ull,2292183673182617600ull,6063068395571249152ull,15421793485632276192ull,6160924290242838528ull,7926335344172072960ull,9216644480497846345ull,14114226497115914240ull,301923419086127104ull, +})); + // clang-format on +} +CATCH + +TEST_F(TestFunctionBitShiftLeft, UB) +try +{ + ASSERT_BITSHIFTLEFT(createColumn({127, -128}), createColumn({64, 64}), createColumn({0, 0})); + ASSERT_BITSHIFTLEFT(createColumn({127, -128}), createColumn({64, 64}), createColumn({0, 0})); + ASSERT_BITSHIFTLEFT(createColumn({32767, -32768}), createColumn({64, 64}), createColumn({0, 0})); + ASSERT_BITSHIFTLEFT(createColumn({INT32_MAX, INT32_MIN}), createColumn({64, 64}), createColumn({0, 0})); + ASSERT_BITSHIFTLEFT(createColumn({INT64_MAX, INT64_MIN}), createColumn({64, 64}), createColumn({0, 0})); + + ASSERT_BITSHIFTLEFT(createColumn({255}), createColumn({64}), createColumn({0})); + ASSERT_BITSHIFTLEFT(createColumn({255}), createColumn({64}), createColumn({0})); + ASSERT_BITSHIFTLEFT(createColumn({65535}), createColumn({64}), createColumn({0})); + ASSERT_BITSHIFTLEFT(createColumn({UINT32_MAX}), createColumn({64}), createColumn({0})); + ASSERT_BITSHIFTLEFT(createColumn({UINT64_MAX}), createColumn({64}), createColumn({0})); + + /* + std::mt19937 gen(std::random_device{}()); + std::uniform_int_distribution dis1( + std::numeric_limits::min(), + std::numeric_limits::max() + ); + std::uniform_int_distribution dis2( + 64, + std::numeric_limits::max() + ); + size_t count = 100; + std::vector v1(count), v2(count), res(count); + for (size_t i=0; i({17563387625296433369ull,5842891814427459261ull,15074502074821508463ull,386435802999553003ull,5487893274931198395ull,8125923807366590570ull,13340330062727071249ull,14908193031091561411ull,296805448857369387ull,8684453485792353774ull,13117933444495098288ull,3225762988982100714ull,11290506757949810556ull,14617912756126856962ull,9479575714707174581ull,11720728318194739598ull,14410575429605211363ull,12068356718035872518ull,80682389916710599ull,11003236134534292734ull,4412447398096224810ull,5331184707993902906ull,13827083432789678788ull,958142831027309576ull,16716461997317184701ull,17128750834581527743ull,11590434571174666313ull,10204342520615148287ull,11067791415848657283ull,17583875436196878829ull,186304014359496415ull,9381729025189804702ull,11502205568225715300ull,16472133582690439104ull,3743303387826342067ull,12860029445868505658ull,2244056593742923769ull,3275687468466891223ull,1545828456957460699ull,14187252460708728077ull,7551907967738536187ull,9754400233340010491ull,16293183350230169116ull,6298812696728711031ull,5915538565572009956ull,2284684518775825662ull,1130711226902262476ull,17158957721471765323ull,4220824385439711070ull,16559772875254313109ull,15397179690017513678ull,6300413832999049491ull,13787530251307637715ull,10132349060092695582ull,10446586881482901699ull,15759779838283537085ull,14402587207027333363ull,5546051719872960161ull,6545031029710296628ull,17407295406267098658ull,4259019625544816073ull,791895457880289787ull,8549227257401578066ull,15246278171168501125ull,1674668228908076954ull,849762797502000057ull,13302651500925764574ull,12438174880334092333ull,17701249772557033303ull,10742459186038873636ull,15671491258945407856ull,9352557101631889001ull,8914093883925002585ull,17935292744735591949ull,606989231583658922ull,6528503454270721815ull,14980539549624989095ull,13765196438235456668ull,3058323869228644592ull,14346577759191739044ull,1543206286382906519ull,1025562312317433790ull,17052896445025268012ull,18349597294988935754ull,17174604730104962524ull,11924965352621110201ull,502032511104181724ull,13845633389643139332ull,15436039204445155412ull,17809579006694175565ull,15166364145138562881ull,14062748599121933798ull,1777457178576774356ull,4985224560472716170ull,3881603168175384251ull,11555031280550342082ull,1252677486917153396ull,8744807353133366467ull,2048964426549800495ull,11945831330508218140ull}), + createColumn({10170168382087373376ull,13942906057510103279ull,14306855544784320750ull,14490595809993733085ull,13894994773265190130ull,4539744309919296456ull,5960451177866905037ull,8254484395144502863ull,15947787413795976654ull,11758211605027987396ull,1969488148300388656ull,3690019359734755802ull,2761764580773985855ull,8214602624640213083ull,10104065697094616841ull,12719638633135031120ull,1579356563729340757ull,9849275185119527902ull,13259386372230333703ull,11525029819550436195ull,9336811033678812919ull,18024494470684769699ull,16262431544264631528ull,747508930926384196ull,6852408932132380275ull,16193626994293308072ull,3668439249087146688ull,2815114051955836547ull,11623595120239653162ull,14051110314333912266ull,3470211856726399766ull,24326123962172569ull,8671257608652446190ull,4449812597097007527ull,420576121975887020ull,5947488386148607688ull,7703644645840795485ull,6034247569925525574ull,7207924845921854255ull,1628903707897208595ull,17386978449329440960ull,11483226896151393705ull,5716613009851345329ull,16637909040452729752ull,8923037827908078416ull,9873656643203662744ull,10728065007959141271ull,9289790999990424278ull,4984880433949807755ull,4081441263589415253ull,14141469082911534070ull,13537500414106209761ull,3771459446515323468ull,3803220448332221997ull,5872361935348309433ull,16931084214526286271ull,229523845509607962ull,6137782246193543039ull,9416772409169847829ull,1599362648203361205ull,16506466156139756326ull,5594658406006721749ull,6075086402035041023ull,8193007286917228777ull,10417397014019249797ull,8027438655725876281ull,3744210550662242051ull,4486285497361262297ull,5233531540019046791ull,8754524731116108320ull,11747423804646205366ull,10823676266667258725ull,2377832589519205884ull,10149462695348926053ull,9938241660263331809ull,1691565986446695935ull,12741697332550437523ull,14377762655871009378ull,2710283096015101134ull,1991666937410026062ull,16045620270500586077ull,3635648749144116339ull,3398892026331619397ull,13943407331484180936ull,6636826897898447964ull,2810976231716911209ull,12715335843259733155ull,9407059307990209078ull,4918637361023593506ull,13248043003208654795ull,8205307620677927795ull,2432590649498729202ull,12261496797882837416ull,12870696446667604684ull,13094194364612141901ull,16877489047893851037ull,3133779556474902761ull,10042552922284313547ull,2121324263442996583ull,15840313846181544148ull}), + createColumn({0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0})); + // clang-format on +} +CATCH + +} // namespace tests +} // namespace DB diff --git a/tests/fullstack-test/expr/bitshift_operator.test b/tests/fullstack-test/expr/bitshift_operator.test index 0d55a1b56a9..f88ce7ee490 100644 --- a/tests/fullstack-test/expr/bitshift_operator.test +++ b/tests/fullstack-test/expr/bitshift_operator.test @@ -40,4 +40,30 @@ mysql> set tidb_enforce_mpp=1; set @@session.tidb_isolation_read_engines = "tifl | -1 | +------+ +mysql> insert into test.t values(1), (NULL); + +mysql> set tidb_enforce_mpp=1; set @@session.tidb_isolation_read_engines = "tiflash"; select a<<0 as v1, a <<64 as v2, a << 10 as v3 from test.t; ++----------------------+------+----------------------+ +| v1 | v2 | v3 | ++----------------------+------+----------------------+ +| 18446744073709551615 | 0 | 18446744073709550592 | +| 1 | 0 | 1024 | +| NULL | NULL | NULL | ++----------------------+------+----------------------+ + +mysql> set tidb_enforce_mpp=1; set @@session.tidb_isolation_read_engines = "tiflash"; select a from test.t where a<<100000=0; ++------+ +| a | ++------+ +| -1 | +| 1 | ++------+ + +mysql> set tidb_enforce_mpp=1; set @@session.tidb_isolation_read_engines = "tiflash"; select a from test.t where a<<6=64; ++------+ +| a | ++------+ +| 1 | ++------+ + mysql> drop table if exists test.t From 127a29709ed0bae9d7645510780e5388440b6b2b Mon Sep 17 00:00:00 2001 From: lidezhu <47731263+lidezhu@users.noreply.github.com> Date: Thu, 4 Aug 2022 17:10:07 +0800 Subject: [PATCH 09/11] fix potential dead lock when remove blob file (#5536) close pingcap/tiflash#5532 --- dbms/src/Storages/Page/V3/BlobStore.cpp | 6 ++++-- dbms/src/Storages/Page/V3/tests/gtest_blob_store.cpp | 6 ++++++ 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/dbms/src/Storages/Page/V3/BlobStore.cpp b/dbms/src/Storages/Page/V3/BlobStore.cpp index c13d1288d82..fdcf737fe51 100644 --- a/dbms/src/Storages/Page/V3/BlobStore.cpp +++ b/dbms/src/Storages/Page/V3/BlobStore.cpp @@ -508,9 +508,11 @@ void BlobStore::removePosFromStats(BlobFileId blob_id, BlobFileOffset offset, si if (need_remove_stat) { LOG_FMT_INFO(log, "Removing BlobFile [blob_id={}]", blob_id); - auto lock_stats = blob_stats.lock(); - // need get blob file before remove its stat otherwise we cannot find the blob file + + // Need get blob file before remove its stat otherwise we cannot find the blob file + // And getBlobFile may get lock on blob_stats inside, so call it before acquire the lock. auto blob_file = getBlobFile(blob_id); + auto lock_stats = blob_stats.lock(); blob_stats.eraseStat(std::move(stat), lock_stats); blob_file->remove(); cached_files.remove(blob_id); diff --git a/dbms/src/Storages/Page/V3/tests/gtest_blob_store.cpp b/dbms/src/Storages/Page/V3/tests/gtest_blob_store.cpp index 1839c5cc5c0..a377461d13e 100644 --- a/dbms/src/Storages/Page/V3/tests/gtest_blob_store.cpp +++ b/dbms/src/Storages/Page/V3/tests/gtest_blob_store.cpp @@ -301,6 +301,10 @@ class BlobStoreTest : public DB::base::TiFlashStorageTestBasic paths.emplace_back(fmt::format("{}/{}", path, i)); } delegator = std::make_shared(paths); + + // Note although set config.cached_fd_size to 0, the cache fd size in blobstore still have capacity 1. + // Decrease cache size to make problems more easily be exposed. + config.cached_fd_size = 0; } protected: @@ -1143,6 +1147,8 @@ TEST_F(BlobStoreTest, testBlobStoreGcStats) Poco::File blob_file_in_disk(path); ASSERT_EQ(blob_file_in_disk.getSize(), stat->sm_total_size); + // Clear cache to reproduce https://github.com/pingcap/tiflash/issues/5532 + blob_store.cached_files.reset(); // Check whether the stat can be totally removed stat->changeToReadOnly(); blob_store.remove(remain_entries); From 0fb8f41a1c00027a3eb7043dfc77529938710081 Mon Sep 17 00:00:00 2001 From: Lloyd-Pottiger <60744015+Lloyd-Pottiger@users.noreply.github.com> Date: Thu, 4 Aug 2022 17:58:07 +0800 Subject: [PATCH 10/11] feat: calculate the io throughput in background in ReadLimiter (#5415) close pingcap/tiflash#5091, close pingcap/tiflash#5401 --- dbms/src/Encryption/RateLimiter.cpp | 119 ++++++++---------- dbms/src/Encryption/RateLimiter.h | 85 ++++++------- .../Encryption/tests/gtest_rate_limiter.cpp | 62 +++++---- dbms/src/Server/StorageConfigParser.cpp | 24 ++-- dbms/src/TestUtils/MockReadLimiter.h | 5 +- 5 files changed, 131 insertions(+), 164 deletions(-) diff --git a/dbms/src/Encryption/RateLimiter.cpp b/dbms/src/Encryption/RateLimiter.cpp index 8c986f93f71..053dc5a816b 100644 --- a/dbms/src/Encryption/RateLimiter.cpp +++ b/dbms/src/Encryption/RateLimiter.cpp @@ -289,47 +289,31 @@ void WriteLimiter::updateMaxBytesPerSec(Int64 max_bytes_per_sec) } ReadLimiter::ReadLimiter( - std::function getIOStatistic_, + std::function get_read_bytes_, Int64 rate_limit_per_sec_, LimiterType type_, - Int64 get_io_stat_period_us, UInt64 refill_period_ms_) : WriteLimiter(rate_limit_per_sec_, type_, refill_period_ms_) - , getIOStatistic(std::move(getIOStatistic_)) - , last_stat_bytes(getIOStatistic()) - , last_stat_time(now()) + , get_read_bytes(std::move(get_read_bytes_)) + , last_stat_bytes(get_read_bytes()) , log(Logger::get("ReadLimiter")) - , get_io_statistic_period_us(get_io_stat_period_us) {} Int64 ReadLimiter::getAvailableBalance() { - TimePoint us = now(); - // Not call getIOStatisctics() every time for performance. - // If the clock back, elapsed_us could be negative. - Int64 elapsed_us = std::chrono::duration_cast(us - last_stat_time).count(); - if (get_io_statistic_period_us != 0 && elapsed_us < get_io_statistic_period_us) - { - return available_balance; - } - - return refreshAvailableBalance(); -} - -Int64 ReadLimiter::refreshAvailableBalance() -{ - TimePoint us = now(); - Int64 bytes = getIOStatistic(); - if (bytes < last_stat_bytes) + Int64 bytes = get_read_bytes(); + if (unlikely(bytes < last_stat_bytes)) { LOG_FMT_WARNING( log, - "last_stat {}:{} current_stat {}:{}", - last_stat_time.time_since_epoch().count(), + "last_stat: {} current_stat: {}", last_stat_bytes, - us.time_since_epoch().count(), bytes); } + else if (likely(bytes == last_stat_bytes)) + { + return available_balance; + } else { Int64 real_alloc_bytes = bytes - last_stat_bytes; @@ -338,7 +322,6 @@ Int64 ReadLimiter::refreshAvailableBalance() alloc_bytes += real_alloc_bytes; } last_stat_bytes = bytes; - last_stat_time = us; return available_balance; } @@ -381,17 +364,18 @@ void ReadLimiter::refillAndAlloc() } } -IORateLimiter::IORateLimiter() +IORateLimiter::IORateLimiter(UInt64 update_read_info_period_ms_) : log(Logger::get("IORateLimiter")) , stop(false) + , update_read_info_period_ms(update_read_info_period_ms_) {} IORateLimiter::~IORateLimiter() { stop.store(true, std::memory_order_relaxed); - if (auto_tune_thread.joinable()) + if (auto_tune_and_get_read_info_thread.joinable()) { - auto_tune_thread.join(); + auto_tune_and_get_read_info_thread.join(); } } @@ -409,13 +393,13 @@ extern thread_local bool is_background_thread; WriteLimiterPtr IORateLimiter::getWriteLimiter() { - std::lock_guard lock(mtx_); + std::lock_guard lock(mtx); return is_background_thread ? bg_write_limiter : fg_write_limiter; } ReadLimiterPtr IORateLimiter::getReadLimiter() { - std::lock_guard lock(mtx_); + std::lock_guard lock(mtx); return is_background_thread ? bg_read_limiter : fg_read_limiter; } @@ -426,7 +410,7 @@ void IORateLimiter::updateConfig(Poco::Util::AbstractConfiguration & config_) { return; } - std::lock_guard lock(mtx_); + std::lock_guard lock(mtx); updateReadLimiter(io_config.getBgReadMaxBytesPerSec(), io_config.getFgReadMaxBytesPerSec()); updateWriteLimiter(io_config.getBgWriteMaxBytesPerSec(), io_config.getFgWriteMaxBytesPerSec()); } @@ -455,11 +439,10 @@ void IORateLimiter::updateReadLimiter(Int64 bg_bytes, Int64 fg_bytes) { LOG_FMT_INFO(log, "updateReadLimiter: bg_bytes {} fg_bytes {}", bg_bytes, fg_bytes); auto get_bg_read_io_statistic = [&]() { - return getCurrentIOInfo().bg_read_bytes; + return read_info.bg_read_bytes.load(std::memory_order_relaxed); }; auto get_fg_read_io_statistic = [&]() { - auto io_info = getCurrentIOInfo(); - return std::max(0, io_info.total_read_bytes - io_info.bg_read_bytes); + return read_info.fg_read_bytes.load(std::memory_order_relaxed); }; if (bg_bytes == 0) @@ -526,7 +509,7 @@ void IORateLimiter::setBackgroundThreadIds(std::vector thread_ids) LOG_FMT_INFO(log, "bg_thread_ids {} => {}", bg_thread_ids.size(), bg_thread_ids); } -std::pair IORateLimiter::getReadWriteBytes(const std::string & fname [[maybe_unused]]) +Int64 IORateLimiter::getReadBytes(const std::string & fname [[maybe_unused]]) { #if __linux__ std::ifstream ifs(fname); @@ -538,7 +521,6 @@ std::pair IORateLimiter::getReadWriteBytes(const std::string & fna } std::string s; Int64 read_bytes = -1; - Int64 write_bytes = -1; while (std::getline(ifs, s)) { if (s.empty()) @@ -557,49 +539,43 @@ std::pair IORateLimiter::getReadWriteBytes(const std::string & fna boost::algorithm::trim(values[1]); read_bytes = std::stoll(values[1]); } - else if (values[0] == "write_bytes") - { - boost::algorithm::trim(values[1]); - write_bytes = std::stoll(values[1]); - } } - if (read_bytes == -1 || write_bytes == -1) + if (read_bytes == -1) { - auto msg = fmt::format("read_bytes: {} write_bytes: {} Invalid result.", read_bytes, write_bytes); + auto msg = fmt::format("read_bytes: {}. Invalid result.", read_bytes); LOG_ERROR(log, msg); throw Exception(msg, ErrorCodes::UNKNOWN_EXCEPTION); } - return {read_bytes, write_bytes}; + return read_bytes; #else - return {0, 0}; + return 0; #endif } -IORateLimiter::IOInfo IORateLimiter::getCurrentIOInfo() +void IORateLimiter::getCurrentIOInfo() { static const pid_t pid = getpid(); - IOInfo io_info; - // Read I/O info of each background threads. + // Read read info of each background threads. + Int64 bg_read_bytes_tmp{0}; for (pid_t tid : bg_thread_ids) { const std::string thread_io_fname = fmt::format("/proc/{}/task/{}/io", pid, tid); - Int64 read_bytes, write_bytes; - std::tie(read_bytes, write_bytes) = getReadWriteBytes(thread_io_fname); - io_info.bg_read_bytes += read_bytes; - io_info.bg_write_bytes += write_bytes; + Int64 read_bytes; + read_bytes = getReadBytes(thread_io_fname); + bg_read_bytes_tmp += read_bytes; } + read_info.bg_read_bytes.store(bg_read_bytes_tmp, std::memory_order_relaxed); - // Read I/O info of this process. + // Read read info of this process. static const std::string proc_io_fname = fmt::format("/proc/{}/io", pid); - std::tie(io_info.total_read_bytes, io_info.total_write_bytes) = getReadWriteBytes(proc_io_fname); - io_info.update_time = std::chrono::system_clock::now(); - return io_info; + Int64 fg_read_bytes_tmp{getReadBytes(proc_io_fname) - bg_read_bytes_tmp}; + read_info.fg_read_bytes.store(std::max(0, fg_read_bytes_tmp), std::memory_order_relaxed); } void IORateLimiter::setStop() { - std::lock_guard lock(mtx_); + std::lock_guard lock(mtx); if (bg_write_limiter != nullptr) { auto sz = bg_write_limiter->setStop(); @@ -624,17 +600,28 @@ void IORateLimiter::setStop() void IORateLimiter::runAutoTune() { - auto auto_tune_worker = [&]() { + auto auto_tune_and_get_read_info_worker = [&]() { + using time_point = std::chrono::time_point; + using clock = std::chrono::system_clock; + time_point auto_tune_time = clock::now(); + time_point update_read_info_time = auto_tune_time; while (!stop.load(std::memory_order_relaxed)) { - ::sleep(io_config.auto_tune_sec > 0 ? io_config.auto_tune_sec : 1); - if (io_config.auto_tune_sec > 0) + std::this_thread::sleep_for(std::chrono::milliseconds(update_read_info_period_ms)); + auto now_time_point = clock::now(); + if ((io_config.auto_tune_sec > 0) && (now_time_point - auto_tune_time >= std::chrono::seconds(io_config.auto_tune_sec))) { autoTune(); + auto_tune_time = now_time_point; + } + if ((bg_read_limiter || fg_read_limiter) && likely(now_time_point - update_read_info_time >= std::chrono::milliseconds(update_read_info_period_ms))) + { + getCurrentIOInfo(); + update_read_info_time = now_time_point; } } }; - auto_tune_thread = std::thread(auto_tune_worker); + auto_tune_and_get_read_info_thread = std::thread(auto_tune_and_get_read_info_worker); } std::unique_ptr IORateLimiter::createIOLimitTuner() @@ -643,7 +630,7 @@ std::unique_ptr IORateLimiter::createIOLimitTuner() ReadLimiterPtr bg_read, fg_read; StorageIORateLimitConfig t_io_config; { - std::lock_guard lock(mtx_); + std::lock_guard lock(mtx); bg_write = bg_write_limiter; fg_write = fg_write_limiter; bg_read = bg_read_limiter; @@ -666,12 +653,12 @@ void IORateLimiter::autoTune() auto tune_result = tuner->tune(); if (tune_result.read_tuned) { - std::lock_guard lock(mtx_); + std::lock_guard lock(mtx); updateReadLimiter(tune_result.max_bg_read_bytes_per_sec, tune_result.max_fg_read_bytes_per_sec); } if (tune_result.write_tuned) { - std::lock_guard lock(mtx_); + std::lock_guard lock(mtx); updateWriteLimiter(tune_result.max_bg_write_bytes_per_sec, tune_result.max_fg_write_bytes_per_sec); } } diff --git a/dbms/src/Encryption/RateLimiter.h b/dbms/src/Encryption/RateLimiter.h index 9d29ea10e0a..4e7846b9a64 100644 --- a/dbms/src/Encryption/RateLimiter.h +++ b/dbms/src/Encryption/RateLimiter.h @@ -46,6 +46,25 @@ enum class LimiterType BG_READ = 4, }; +/// ReadInfo is used to store IO information. +/// bg_read_bytes is the bytes of the background read. +/// fg_read_bytes is the bytes of the foreground read. +struct ReadInfo +{ + std::atomic bg_read_bytes; + std::atomic fg_read_bytes; + + ReadInfo() + : bg_read_bytes(0) + , fg_read_bytes(0) + {} + + std::string toString() const + { + return fmt::format("fg_read_bytes {} bg_read_bytes {}", fg_read_bytes, bg_read_bytes); + } +}; + // WriteLimiter is to control write rate (bytes per second). // Because of the storage engine is append-only, the amount of data written by the storage engine // is equal to the amount of data written to the disk by the operating system. So, WriteLimiter @@ -147,57 +166,49 @@ using WriteLimiterPtr = std::shared_ptr; // // Constructor parameters: // -// `getIOStatistic_` is the function that obtain the amount of data read from /proc. -// -// `get_io_stat_period_us` is the interval between calling getIOStatistic_. +// `get_read_bytes_` is the function that obtain the amount of data from `getCurrentIOInfo()` which read from /proc filesystem. // // Other parameters are the same as WriteLimiter. class ReadLimiter : public WriteLimiter { public: ReadLimiter( - std::function getIOStatistic_, + std::function get_read_bytes_, Int64 rate_limit_per_sec_, LimiterType type_, - Int64 get_io_stat_period_us = 2000, UInt64 refill_period_ms_ = 100); #ifndef DBMS_PUBLIC_GTEST protected: #endif - virtual void refillAndAlloc() override; - virtual void consumeBytes(Int64 bytes) override; - virtual bool canGrant(Int64 bytes) override; + void refillAndAlloc() override; + void consumeBytes(Int64 bytes) override; + bool canGrant(Int64 bytes) override; #ifndef DBMS_PUBLIC_GTEST private: #endif Int64 getAvailableBalance(); - Int64 refreshAvailableBalance(); - std::function getIOStatistic; + std::function get_read_bytes; Int64 last_stat_bytes; - using TimePoint = std::chrono::time_point; - static TimePoint now() - { - return std::chrono::time_point_cast(std::chrono::system_clock::now()); - } - TimePoint last_stat_time; LoggerPtr log; - - Int64 get_io_statistic_period_us; }; using ReadLimiterPtr = std::shared_ptr; // IORateLimiter is the wrapper of WriteLimiter and ReadLimiter. // Currently, It supports four limiter type: background write, foreground write, background read and foreground read. +// +// Constructor parameters: +// +// `update_read_info_period_ms` is the interval between calling getCurrentIOInfo. Default is 30ms. class IORateLimiter { public: - IORateLimiter(); + explicit IORateLimiter(UInt64 update_read_info_period_ms_ = 30); ~IORateLimiter(); WriteLimiterPtr getWriteLimiter(); @@ -209,37 +220,12 @@ class IORateLimiter void setStop(); - struct IOInfo - { - Int64 total_write_bytes; - Int64 total_read_bytes; - Int64 bg_write_bytes; - Int64 bg_read_bytes; - std::chrono::time_point update_time; - - IOInfo() - : total_write_bytes(0) - , total_read_bytes(0) - , bg_write_bytes(0) - , bg_read_bytes(0) - {} - - std::string toString() const - { - return fmt::format("total_write_bytes {} total_read_bytes {} bg_write_bytes {} bg_read_bytes {}", - total_write_bytes, - total_read_bytes, - bg_write_bytes, - bg_read_bytes); - } - }; - #ifndef DBMS_PUBLIC_GTEST private: #endif - std::pair getReadWriteBytes(const std::string & fname); - IOInfo getCurrentIOInfo(); + Int64 getReadBytes(const std::string & fname); + void getCurrentIOInfo(); std::unique_ptr createIOLimitTuner(); void autoTune(); @@ -254,16 +240,17 @@ class IORateLimiter WriteLimiterPtr fg_write_limiter; ReadLimiterPtr bg_read_limiter; ReadLimiterPtr fg_read_limiter; - std::mutex mtx_; + std::mutex mtx; std::mutex bg_thread_ids_mtx; std::vector bg_thread_ids; - IOInfo last_io_info; LoggerPtr log; std::atomic stop; - std::thread auto_tune_thread; + std::thread auto_tune_and_get_read_info_thread; + ReadInfo read_info; + const UInt64 update_read_info_period_ms; // Noncopyable and nonmovable. DISALLOW_COPY_AND_MOVE(IORateLimiter); diff --git a/dbms/src/Encryption/tests/gtest_rate_limiter.cpp b/dbms/src/Encryption/tests/gtest_rate_limiter.cpp index 51984b86460..854473bae44 100644 --- a/dbms/src/Encryption/tests/gtest_rate_limiter.cpp +++ b/dbms/src/Encryption/tests/gtest_rate_limiter.cpp @@ -22,6 +22,8 @@ #include #include +#include "common/types.h" + #ifdef __linux__ #include #endif @@ -177,7 +179,7 @@ TEST(WriteLimiterTest, LimiterStat) ASSERT_EQ(stat.pct(), static_cast(alloc_bytes * 1000 / stat.elapsed_ms) * 100 / stat.maxBytesPerSec()) << stat.toString(); } -TEST(ReadLimiterTest, GetIOStatPeroid2000us) +TEST(ReadLimiterTest, GetIOStatPeroid200ms) { Int64 consumed = 0; auto get_stat = [&consumed]() { @@ -187,41 +189,31 @@ TEST(ReadLimiterTest, GetIOStatPeroid2000us) limiter.request(bytes); consumed += bytes; }; - Int64 get_io_stat_period_us = 2000; - auto wait_refresh = [&]() { - std::chrono::microseconds sleep_time(get_io_stat_period_us + 1); - std::this_thread::sleep_for(sleep_time); - }; using TimePointMS = std::chrono::time_point; Int64 bytes_per_sec = 1000; UInt64 refill_period_ms = 20; - ReadLimiter limiter(get_stat, bytes_per_sec, LimiterType::UNKNOW, get_io_stat_period_us, refill_period_ms); + ReadLimiter limiter(get_stat, bytes_per_sec, LimiterType::UNKNOW, refill_period_ms); TimePointMS t0 = std::chrono::time_point_cast(std::chrono::system_clock::now()); // Refill 20 every 20ms. ASSERT_EQ(limiter.getAvailableBalance(), 20); request(limiter, 1); - ASSERT_EQ(limiter.getAvailableBalance(), 20); - ASSERT_EQ(limiter.refreshAvailableBalance(), 19); - request(limiter, 9); ASSERT_EQ(limiter.getAvailableBalance(), 19); - wait_refresh(); - ASSERT_EQ(limiter.getAvailableBalance(), 10); - request(limiter, 11); - wait_refresh(); + request(limiter, 20); ASSERT_EQ(limiter.getAvailableBalance(), -1); request(limiter, 50); TimePointMS t1 = std::chrono::time_point_cast(std::chrono::system_clock::now()); UInt64 elasped = std::chrono::duration_cast(t1 - t0).count(); ASSERT_GE(elasped, refill_period_ms); - ASSERT_EQ(limiter.getAvailableBalance(), 19); - wait_refresh(); ASSERT_EQ(limiter.getAvailableBalance(), -31); request(limiter, 1); TimePointMS t2 = std::chrono::time_point_cast(std::chrono::system_clock::now()); elasped = std::chrono::duration_cast(t2 - t1).count(); ASSERT_GE(elasped, 2 * refill_period_ms); + ASSERT_EQ(limiter.getAvailableBalance(), 8); + request(limiter, 9); + ASSERT_EQ(limiter.getAvailableBalance(), -1); } void testSetStop(int blocked_thread_cnt) @@ -278,8 +270,12 @@ TEST(ReadLimiterTest, LimiterStat) limiter.request(bytes); consumed += bytes; }; - Int64 get_io_stat_period_us = 2000; - ReadLimiter read_limiter(get_stat, 1000, LimiterType::UNKNOW, get_io_stat_period_us, 100); + + Int64 bytes_per_sec = 1000; + UInt64 refill_period_ms = 100; + ReadLimiter read_limiter(get_stat, bytes_per_sec, LimiterType::UNKNOW, refill_period_ms); + ASSERT_EQ(read_limiter.getAvailableBalance(), 100); + try { read_limiter.getStat(); @@ -314,7 +310,7 @@ TEST(ReadLimiterTest, LimiterStat) request(read_limiter, 100); std::this_thread::sleep_for(100ms); - read_limiter.refreshAvailableBalance(); + ASSERT_EQ(read_limiter.getAvailableBalance(), 0); stat = read_limiter.getStat(); ASSERT_EQ(stat.alloc_bytes, 100ul); @@ -344,7 +340,7 @@ TEST(ReadLimiterTest, LimiterStat) } std::this_thread::sleep_for(100ms); - read_limiter.refreshAvailableBalance(); + ASSERT_EQ(read_limiter.getAvailableBalance(), -947); stat = read_limiter.getStat(); ASSERT_EQ(stat.alloc_bytes, alloc_bytes); @@ -376,7 +372,7 @@ TEST(IORateLimiterTest, IOStat) int buf_size = 4096; int ret = ::posix_memalign(&buf, buf_size, buf_size); ASSERT_EQ(ret, 0) << strerror(errno); - std::unique_ptr> defer_free(buf, [](void * p) { ::free(p); }); + std::unique_ptr> defer_free(buf, [](void * p) { ::free(p); }); // NOLINT(cppcoreguidelines-no-malloc) ssize_t n = ::pwrite(fd, buf, buf_size, 0); ASSERT_EQ(n, buf_size) << strerror(errno); @@ -384,12 +380,10 @@ TEST(IORateLimiterTest, IOStat) n = ::pread(fd, buf, buf_size, 0); ASSERT_EQ(n, buf_size) << strerror(errno); - //int ret = ::fsync(fd); - //ASSERT_EQ(ret, 0) << strerror(errno); - - auto io_info = io_rate_limiter.getCurrentIOInfo(); - ASSERT_GE(io_info.total_write_bytes, buf_size); - ASSERT_GE(io_info.total_read_bytes, buf_size); + io_rate_limiter.getCurrentIOInfo(); + Int64 bg_read_bytes = io_rate_limiter.read_info.bg_read_bytes.load(std::memory_order_relaxed); + Int64 fg_read_bytes = io_rate_limiter.read_info.fg_read_bytes.load(std::memory_order_relaxed); + ASSERT_GE(bg_read_bytes + fg_read_bytes, buf_size); } TEST(IORateLimiterTest, IOStatMultiThread) @@ -418,7 +412,7 @@ TEST(IORateLimiterTest, IOStatMultiThread) void * buf = nullptr; int ret = ::posix_memalign(&buf, buf_size, buf_size); - std::unique_ptr> auto_free(buf, [](void * p) { free(p); }); + std::unique_ptr> auto_free(buf, [](void * p) { free(p); }); // NOLINT(cppcoreguidelines-no-malloc) ASSERT_EQ(ret, 0) << strerror(errno); ssize_t n = ::pwrite(fd, buf, buf_size, 0); @@ -453,12 +447,12 @@ TEST(IORateLimiterTest, IOStatMultiThread) IORateLimiter io_rate_limiter; io_rate_limiter.setBackgroundThreadIds(bg_pids); - auto io_info = io_rate_limiter.getCurrentIOInfo(); - std::cout << io_info.toString() << std::endl; - ASSERT_GE(io_info.total_read_bytes, buf_size * (bg_thread_count + fg_thread_count)); - ASSERT_GE(io_info.total_write_bytes, buf_size * (bg_thread_count + fg_thread_count)); - ASSERT_GE(io_info.bg_read_bytes, buf_size * bg_thread_count); - ASSERT_GE(io_info.bg_write_bytes, buf_size * bg_thread_count); + + io_rate_limiter.getCurrentIOInfo(); + Int64 bg_read_bytes = io_rate_limiter.read_info.bg_read_bytes.load(std::memory_order_relaxed); + Int64 fg_read_bytes = io_rate_limiter.read_info.fg_read_bytes.load(std::memory_order_relaxed); + ASSERT_GE(fg_read_bytes, buf_size * fg_thread_count); + ASSERT_GE(bg_read_bytes, buf_size * bg_thread_count); stop.store(true); for (auto & t : threads) diff --git a/dbms/src/Server/StorageConfigParser.cpp b/dbms/src/Server/StorageConfigParser.cpp index 09b7807c397..1079935a318 100644 --- a/dbms/src/Server/StorageConfigParser.cpp +++ b/dbms/src/Server/StorageConfigParser.cpp @@ -455,42 +455,42 @@ UInt64 StorageIORateLimitConfig::totalWeight() const UInt64 StorageIORateLimitConfig::getFgWriteMaxBytesPerSec() const { - if (totalWeight() <= 0 || writeWeight() <= 0) + if (writeWeight() <= 0 || totalWeight() <= 0) { return 0; } - return use_max_bytes_per_sec ? max_bytes_per_sec / totalWeight() * fg_write_weight - : max_write_bytes_per_sec / writeWeight() * fg_write_weight; + return use_max_bytes_per_sec ? static_cast(1.0 * max_bytes_per_sec / totalWeight() * fg_write_weight) + : static_cast(1.0 * max_write_bytes_per_sec / writeWeight() * fg_write_weight); } UInt64 StorageIORateLimitConfig::getBgWriteMaxBytesPerSec() const { - if (totalWeight() <= 0 || writeWeight() <= 0) + if (writeWeight() <= 0 || totalWeight() <= 0) { return 0; } - return use_max_bytes_per_sec ? max_bytes_per_sec / totalWeight() * bg_write_weight - : max_write_bytes_per_sec / writeWeight() * bg_write_weight; + return use_max_bytes_per_sec ? static_cast(1.0 * max_bytes_per_sec / totalWeight() * bg_write_weight) + : static_cast(1.0 * max_write_bytes_per_sec / writeWeight() * bg_write_weight); } UInt64 StorageIORateLimitConfig::getFgReadMaxBytesPerSec() const { - if (totalWeight() <= 0 || readWeight() <= 0) + if (readWeight() <= 0 || totalWeight() <= 0) { return 0; } - return use_max_bytes_per_sec ? max_bytes_per_sec / totalWeight() * fg_read_weight - : max_read_bytes_per_sec / readWeight() * fg_read_weight; + return use_max_bytes_per_sec ? static_cast(1.0 * max_bytes_per_sec / totalWeight() * fg_read_weight) + : static_cast(1.0 * max_read_bytes_per_sec / readWeight() * fg_read_weight); } UInt64 StorageIORateLimitConfig::getBgReadMaxBytesPerSec() const { - if (totalWeight() <= 0 || readWeight() <= 0) + if (readWeight() <= 0 || totalWeight() <= 0) { return 0; } - return use_max_bytes_per_sec ? max_bytes_per_sec / totalWeight() * bg_read_weight - : max_read_bytes_per_sec / readWeight() * bg_read_weight; + return use_max_bytes_per_sec ? static_cast(1.0 * max_bytes_per_sec / totalWeight() * bg_read_weight) + : static_cast(1.0 * max_read_bytes_per_sec / readWeight() * bg_read_weight); } UInt64 StorageIORateLimitConfig::getWriteMaxBytesPerSec() const diff --git a/dbms/src/TestUtils/MockReadLimiter.h b/dbms/src/TestUtils/MockReadLimiter.h index 8acc96371e3..0bb69145a67 100644 --- a/dbms/src/TestUtils/MockReadLimiter.h +++ b/dbms/src/TestUtils/MockReadLimiter.h @@ -23,9 +23,8 @@ class MockReadLimiter final : public ReadLimiter std::function getIOStatistic_, Int64 rate_limit_per_sec_, LimiterType type_ = LimiterType::UNKNOW, - Int64 get_io_stat_period_us = 2000, UInt64 refill_period_ms_ = 100) - : ReadLimiter(getIOStatistic_, rate_limit_per_sec_, type_, get_io_stat_period_us, refill_period_ms_) + : ReadLimiter(getIOStatistic_, rate_limit_per_sec_, type_, refill_period_ms_) { } @@ -33,7 +32,7 @@ class MockReadLimiter final : public ReadLimiter void consumeBytes(Int64 bytes) override { // Need soft limit here. - WriteLimiter::consumeBytes(bytes); + WriteLimiter::consumeBytes(bytes); // NOLINT(bugprone-parent-virtual-call) } }; From 97fa91ff93b07ef2c23f36f93419599047514e10 Mon Sep 17 00:00:00 2001 From: Calvin Neo Date: Fri, 5 Aug 2022 10:38:06 +0800 Subject: [PATCH 11/11] Add more information when read index fail (#5526) close pingcap/tiflash#5525 --- dbms/src/Storages/Transaction/LearnerRead.cpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/dbms/src/Storages/Transaction/LearnerRead.cpp b/dbms/src/Storages/Transaction/LearnerRead.cpp index e4f85c01d10..6fdec3e495b 100644 --- a/dbms/src/Storages/Transaction/LearnerRead.cpp +++ b/dbms/src/Storages/Transaction/LearnerRead.cpp @@ -311,7 +311,7 @@ LearnerReadSnapshot doLearnerRead( } } - auto handle_wait_timeout_region = [&unavailable_regions, for_batch_cop](const DB::RegionID region_id) { + auto handle_wait_timeout_region = [&unavailable_regions, for_batch_cop](const DB::RegionID region_id, UInt64 index) { if (!for_batch_cop) { // If server is being terminated / time-out, add the region_id into `unavailable_regions` to other store. @@ -319,7 +319,7 @@ LearnerReadSnapshot doLearnerRead( return; } // TODO: Maybe collect all the Regions that happen wait index timeout instead of just throwing one Region id - throw TiFlashException(fmt::format("Region {} is unavailable", region_id), Errors::Coprocessor::RegionError); + throw TiFlashException(fmt::format("Region {} is unavailable at {}", region_id, index), Errors::Coprocessor::RegionError); }; const auto wait_index_timeout_ms = tmt.waitIndexTimeout(); for (size_t region_idx = region_begin_idx, read_index_res_idx = 0; region_idx < region_end_idx; ++region_idx, ++read_index_res_idx) @@ -342,7 +342,7 @@ LearnerReadSnapshot doLearnerRead( auto [wait_res, time_cost] = region->waitIndex(index_to_wait, tmt.waitIndexTimeout(), [&tmt]() { return tmt.checkRunning(); }); if (wait_res != WaitIndexResult::Finished) { - handle_wait_timeout_region(region_to_query.region_id); + handle_wait_timeout_region(region_to_query.region_id, index_to_wait); continue; } if (time_cost > 0) @@ -357,7 +357,7 @@ LearnerReadSnapshot doLearnerRead( // for Regions one by one. if (!region->checkIndex(index_to_wait)) { - handle_wait_timeout_region(region_to_query.region_id); + handle_wait_timeout_region(region_to_query.region_id, index_to_wait); continue; } }