From 9fca53d14bc91fe70e1da6500a68fe03b9fc2231 Mon Sep 17 00:00:00 2001 From: Fu Zhe Date: Sun, 21 Nov 2021 22:29:48 +0800 Subject: [PATCH] Little refinements on DAG code (#3482) --- dbms/src/Debug/dbgFuncCoprocessor.cpp | 3 +- dbms/src/Flash/BatchCoprocessorHandler.cpp | 6 +- .../Coprocessor/DAGExpressionAnalyzer.cpp | 236 ++- .../Coprocessor/DAGQueryBlockInterpreter.cpp | 13 +- .../Coprocessor/DAGQueryBlockInterpreter.h | 1 + dbms/src/Flash/Coprocessor/DAGQuerySource.cpp | 4 +- dbms/src/Flash/Coprocessor/DAGQuerySource.h | 6 +- .../Coprocessor/DAGStorageInterpreter.cpp | 4 +- dbms/src/Flash/Coprocessor/DAGUtils.cpp | 1578 +++++++++-------- dbms/src/Flash/Coprocessor/DAGUtils.h | 28 +- dbms/src/Flash/Coprocessor/InterpreterDAG.cpp | 14 +- dbms/src/Flash/CoprocessorHandler.cpp | 3 +- dbms/src/Flash/Mpp/MPPTask.cpp | 2 +- 13 files changed, 934 insertions(+), 964 deletions(-) diff --git a/dbms/src/Debug/dbgFuncCoprocessor.cpp b/dbms/src/Debug/dbgFuncCoprocessor.cpp index 64682c34260..e0377db787f 100644 --- a/dbms/src/Debug/dbgFuncCoprocessor.cpp +++ b/dbms/src/Debug/dbgFuncCoprocessor.cpp @@ -752,8 +752,7 @@ void astToPB(const DAGSchema & input, ASTPtr ast, tipb::Expr * expr, uint32_t co astToPB(input, child_ast, child, collator_id, context); } // for like need to add the third argument - tipb::Expr * constant_expr = expr->add_children(); - constructInt64LiteralTiExpr(*constant_expr, 92); + *expr->add_children() = constructInt64LiteralTiExpr(92); return; } case tipb::ScalarFuncSig::FromUnixTime2Arg: diff --git a/dbms/src/Flash/BatchCoprocessorHandler.cpp b/dbms/src/Flash/BatchCoprocessorHandler.cpp index 8430a831af4..5579062c7ef 100644 --- a/dbms/src/Flash/BatchCoprocessorHandler.cpp +++ b/dbms/src/Flash/BatchCoprocessorHandler.cpp @@ -42,11 +42,7 @@ grpc::Status BatchCoprocessorHandler::execute() SCOPE_EXIT( { GET_METRIC(tiflash_coprocessor_handling_request_count, type_super_batch_cop_dag).Decrement(); }); - const auto dag_request = ({ - tipb::DAGRequest dag_req; - getDAGRequestFromStringWithRetry(dag_req, cop_request->data()); - std::move(dag_req); - }); + auto dag_request = getDAGRequestFromStringWithRetry(cop_request->data()); RegionInfoMap regions; RegionInfoList retry_regions; for (auto & r : cop_request->regions()) diff --git a/dbms/src/Flash/Coprocessor/DAGExpressionAnalyzer.cpp b/dbms/src/Flash/Coprocessor/DAGExpressionAnalyzer.cpp index 8bf8f229de8..b70c65b1c84 100644 --- a/dbms/src/Flash/Coprocessor/DAGExpressionAnalyzer.cpp +++ b/dbms/src/Flash/Coprocessor/DAGExpressionAnalyzer.cpp @@ -2,6 +2,7 @@ #include #include #include +#include #include #include #include @@ -31,6 +32,67 @@ extern const int COP_BAD_DAG_REQUEST; extern const int UNSUPPORTED_METHOD; } // namespace ErrorCodes +namespace +{ +String genFuncString(const String & func_name, const Names & argument_names, const TiDB::TiDBCollators & collators) +{ + assert(!collators.empty()); + FmtBuffer buf; + buf.fmtAppend("{}({})_collator", func_name, fmt::join(argument_names.begin(), argument_names.end(), ", ")); + for (const auto & collator : collators) + { + if (collator == nullptr) + buf.append("_0"); + else + buf.fmtAppend("_{}", collator->getCollatorId()); + } + buf.append(" "); + return buf.toString(); +} + +String getUniqueName(const Block & block, const String & prefix) +{ + for (int i = 1;; ++i) + { + auto name = fmt::format("{}{}", prefix, i); + if (!block.has(name)) + return name; + } +} + +struct DateAdd +{ + static constexpr auto name = "date_add"; + static const std::unordered_map unit_to_func_name_map; +}; + +const std::unordered_map DateAdd::unit_to_func_name_map + = { + {"DAY", "addDays"}, + {"WEEK", "addWeeks"}, + {"MONTH", "addMonths"}, + {"YEAR", "addYears"}, + {"HOUR", "addHours"}, + {"MINUTE", "addMinutes"}, + {"SECOND", "addSeconds"}}; + +struct DateSub +{ + static constexpr auto name = "date_sub"; + static const std::unordered_map unit_to_func_name_map; +}; + +const std::unordered_map DateSub::unit_to_func_name_map + = { + {"DAY", "subtractDays"}, + {"WEEK", "subtractWeeks"}, + {"MONTH", "subtractMonths"}, + {"YEAR", "subtractYears"}, + {"HOUR", "subtractHours"}, + {"MINUTE", "subtractMinutes"}, + {"SECOND", "subtractSeconds"}}; +} // namespace + class DAGExpressionAnalyzerHelper { public: @@ -86,45 +148,12 @@ class DAGExpressionAnalyzerHelper DAGExpressionAnalyzer * analyzer, const tipb::Expr & expr, ExpressionActionsPtr & actions); -}; -static String genFuncString(const String & func_name, const Names & argument_names, const TiDB::TiDBCollators & collators) -{ - assert(!collators.empty()); - std::stringstream ss; - ss << func_name << "("; - bool first = true; - for (const String & argument_name : argument_names) - { - if (first) - { - first = false; - } - else - { - ss << ", "; - } - ss << argument_name; - } - ss << ")_collator"; - for (const auto & collator : collators) - { - if (collator == nullptr) - ss << "_0"; - else - ss << "_" << collator->getCollatorId(); - } - ss << " "; - return ss.str(); -} + using FunctionBuilder = std::function; + using FunctionBuilderMap = std::unordered_map; -static String getUniqueName(const Block & block, const String & prefix) -{ - int i = 1; - while (block.has(prefix + toString(i))) - ++i; - return prefix + toString(i); -} + static FunctionBuilderMap function_builder_map; +}; String DAGExpressionAnalyzerHelper::buildMultiIfFunction( DAGExpressionAnalyzer * analyzer, @@ -138,14 +167,13 @@ String DAGExpressionAnalyzerHelper::buildMultiIfFunction( Names argument_names; for (int i = 0; i < expr.children_size(); i++) { - String name = analyzer->getActions(expr.children(i), actions, i != expr.children_size() - 1 && i % 2 == 0); + bool output_as_uint8_type = (i + 1) != expr.children_size() && (i % 2 == 0); + String name = analyzer->getActions(expr.children(i), actions, output_as_uint8_type); argument_names.push_back(name); } if (argument_names.size() % 2 == 0) { - tipb::Expr null_expr; - constructNULLLiteralTiExpr(null_expr); - String name = analyzer->getActions(null_expr, actions); + String name = analyzer->getActions(constructNULLLiteralTiExpr(), actions); argument_names.push_back(name); } return analyzer->applyFunction(func_name, argument_names, actions, getCollatorFromExpr(expr)); @@ -276,8 +304,7 @@ String DAGExpressionAnalyzerHelper::buildLeftUTF8Function( argument_names.push_back(str); // the second parameter: const(1) - auto const_one = tipb::Expr(); - constructInt64LiteralTiExpr(const_one, 1); + auto const_one = constructInt64LiteralTiExpr(1); auto col_const_one = analyzer->getActions(const_one, actions, false); argument_names.push_back(col_const_one); @@ -325,46 +352,13 @@ String DAGExpressionAnalyzerHelper::buildCastFunction( String name = analyzer->getActions(expr.children(0), actions); DataTypePtr expected_type = getDataTypeByFieldTypeForComputingLayer(expr.field_type()); - tipb::Expr type_expr; - constructStringLiteralTiExpr(type_expr, expected_type->getName()); + tipb::Expr type_expr = constructStringLiteralTiExpr(expected_type->getName()); auto type_expr_name = analyzer->getActions(type_expr, actions); // todo extract in_union from tipb::Expr return buildCastFunctionInternal(analyzer, {name, type_expr_name}, false, expr.field_type(), actions); } -struct DateAdd -{ - static constexpr auto name = "date_add"; - static const std::unordered_map unit_to_func_name_map; -}; - -const std::unordered_map DateAdd::unit_to_func_name_map - = { - {"DAY", "addDays"}, - {"WEEK", "addWeeks"}, - {"MONTH", "addMonths"}, - {"YEAR", "addYears"}, - {"HOUR", "addHours"}, - {"MINUTE", "addMinutes"}, - {"SECOND", "addSeconds"}}; - -struct DateSub -{ - static constexpr auto name = "date_sub"; - static const std::unordered_map unit_to_func_name_map; -}; - -const std::unordered_map DateSub::unit_to_func_name_map - = { - {"DAY", "subtractDays"}, - {"WEEK", "subtractWeeks"}, - {"MONTH", "subtractMonths"}, - {"YEAR", "subtractYears"}, - {"HOUR", "subtractHours"}, - {"MINUTE", "subtractMinutes"}, - {"SECOND", "subtractSeconds"}}; - template String DAGExpressionAnalyzerHelper::buildDateAddOrSubFunction( DAGExpressionAnalyzer * analyzer, @@ -464,8 +458,7 @@ String DAGExpressionAnalyzerHelper::buildRoundFunction( auto input_arg_name = analyzer->getActions(expr.children(0), actions); - auto const_zero = tipb::Expr(); - constructInt64LiteralTiExpr(const_zero, 0); + auto const_zero = constructInt64LiteralTiExpr(0); auto const_zero_arg_name = analyzer->getActions(const_zero, actions); Names argument_names; @@ -475,29 +468,28 @@ String DAGExpressionAnalyzerHelper::buildRoundFunction( return analyzer->applyFunction("tidbRoundWithFrac", argument_names, actions, getCollatorFromExpr(expr)); } -static std::unordered_map> - function_builder_map( - {{"in", DAGExpressionAnalyzerHelper::buildInFunction}, - {"notIn", DAGExpressionAnalyzerHelper::buildInFunction}, - {"globalIn", DAGExpressionAnalyzerHelper::buildInFunction}, - {"globalNotIn", DAGExpressionAnalyzerHelper::buildInFunction}, - {"tidbIn", DAGExpressionAnalyzerHelper::buildInFunction}, - {"tidbNotIn", DAGExpressionAnalyzerHelper::buildInFunction}, - {"ifNull", DAGExpressionAnalyzerHelper::buildIfNullFunction}, - {"multiIf", DAGExpressionAnalyzerHelper::buildMultiIfFunction}, - {"tidb_cast", DAGExpressionAnalyzerHelper::buildCastFunction}, - {"and", DAGExpressionAnalyzerHelper::buildLogicalFunction}, - {"or", DAGExpressionAnalyzerHelper::buildLogicalFunction}, - {"xor", DAGExpressionAnalyzerHelper::buildLogicalFunction}, - {"not", DAGExpressionAnalyzerHelper::buildLogicalFunction}, - {"bitAnd", DAGExpressionAnalyzerHelper::buildBitwiseFunction}, - {"bitOr", DAGExpressionAnalyzerHelper::buildBitwiseFunction}, - {"bitXor", DAGExpressionAnalyzerHelper::buildBitwiseFunction}, - {"bitNot", DAGExpressionAnalyzerHelper::buildBitwiseFunction}, - {"leftUTF8", DAGExpressionAnalyzerHelper::buildLeftUTF8Function}, - {"date_add", DAGExpressionAnalyzerHelper::buildDateAddOrSubFunction}, - {"date_sub", DAGExpressionAnalyzerHelper::buildDateAddOrSubFunction}, - {"tidbRound", DAGExpressionAnalyzerHelper::buildRoundFunction}}); +DAGExpressionAnalyzerHelper::FunctionBuilderMap DAGExpressionAnalyzerHelper::function_builder_map( + {{"in", DAGExpressionAnalyzerHelper::buildInFunction}, + {"notIn", DAGExpressionAnalyzerHelper::buildInFunction}, + {"globalIn", DAGExpressionAnalyzerHelper::buildInFunction}, + {"globalNotIn", DAGExpressionAnalyzerHelper::buildInFunction}, + {"tidbIn", DAGExpressionAnalyzerHelper::buildInFunction}, + {"tidbNotIn", DAGExpressionAnalyzerHelper::buildInFunction}, + {"ifNull", DAGExpressionAnalyzerHelper::buildIfNullFunction}, + {"multiIf", DAGExpressionAnalyzerHelper::buildMultiIfFunction}, + {"tidb_cast", DAGExpressionAnalyzerHelper::buildCastFunction}, + {"and", DAGExpressionAnalyzerHelper::buildLogicalFunction}, + {"or", DAGExpressionAnalyzerHelper::buildLogicalFunction}, + {"xor", DAGExpressionAnalyzerHelper::buildLogicalFunction}, + {"not", DAGExpressionAnalyzerHelper::buildLogicalFunction}, + {"bitAnd", DAGExpressionAnalyzerHelper::buildBitwiseFunction}, + {"bitOr", DAGExpressionAnalyzerHelper::buildBitwiseFunction}, + {"bitXor", DAGExpressionAnalyzerHelper::buildBitwiseFunction}, + {"bitNot", DAGExpressionAnalyzerHelper::buildBitwiseFunction}, + {"leftUTF8", DAGExpressionAnalyzerHelper::buildLeftUTF8Function}, + {"date_add", DAGExpressionAnalyzerHelper::buildDateAddOrSubFunction}, + {"date_sub", DAGExpressionAnalyzerHelper::buildDateAddOrSubFunction}, + {"tidbRound", DAGExpressionAnalyzerHelper::buildRoundFunction}}); DAGExpressionAnalyzer::DAGExpressionAnalyzer(std::vector source_columns_, const Context & context_) : source_columns(std::move(source_columns_)) @@ -864,8 +856,7 @@ String DAGExpressionAnalyzer::convertToUInt8(ExpressionActionsPtr & actions, con const auto & org_type = removeNullable(actions->getSampleBlock().getByName(column_name).type); if (org_type->isNumber() || org_type->isDecimal()) { - tipb::Expr const_expr; - constructInt64LiteralTiExpr(const_expr, 0); + tipb::Expr const_expr = constructInt64LiteralTiExpr(0); auto const_expr_name = getActions(const_expr, actions); return applyFunction("notEquals", {column_name, const_expr_name}, actions, nullptr); } @@ -876,8 +867,7 @@ String DAGExpressionAnalyzer::convertToUInt8(ExpressionActionsPtr & actions, con // TODO: Use TypeDouble as return type, to be compatible with TiDB field_type.set_tp(TiDB::TypeDouble); field_type.set_flen(-1); - tipb::Expr type_expr; - constructStringLiteralTiExpr(type_expr, "Nullable(Double)"); + tipb::Expr type_expr = constructStringLiteralTiExpr("Nullable(Double)"); auto type_expr_name = getActions(type_expr, actions); String num_col_name = DAGExpressionAnalyzerHelper::buildCastFunctionInternal( this, @@ -886,15 +876,13 @@ String DAGExpressionAnalyzer::convertToUInt8(ExpressionActionsPtr & actions, con field_type, actions); - tipb::Expr const_expr; - constructInt64LiteralTiExpr(const_expr, 0); + tipb::Expr const_expr = constructInt64LiteralTiExpr(0); auto const_expr_name = getActions(const_expr, actions); return applyFunction("notEquals", {num_col_name, const_expr_name}, actions, nullptr); } if (org_type->isDateOrDateTime()) { - tipb::Expr const_expr; - constructDateTimeLiteralTiExpr(const_expr, 0); + tipb::Expr const_expr = constructDateTimeLiteralTiExpr(0); auto const_expr_name = getActions(const_expr, actions); return applyFunction("notEquals", {column_name, const_expr_name}, actions, nullptr); } @@ -929,15 +917,14 @@ const std::vector & DAGExpressionAnalyzer::getCurrentInputColum return after_agg ? aggregated_columns : source_columns; } -void constructTZExpr( - tipb::Expr & tz_expr, +tipb::Expr constructTZExpr( const TimezoneInfo & dag_timezone_info, bool from_utc) { if (dag_timezone_info.is_name_based) - constructStringLiteralTiExpr(tz_expr, dag_timezone_info.timezone_name); + return constructStringLiteralTiExpr(dag_timezone_info.timezone_name); else - constructInt64LiteralTiExpr(tz_expr, from_utc ? dag_timezone_info.timezone_offset : -dag_timezone_info.timezone_offset); + return constructInt64LiteralTiExpr(from_utc ? dag_timezone_info.timezone_offset : -dag_timezone_info.timezone_offset); } String DAGExpressionAnalyzer::appendTimeZoneCast( @@ -959,8 +946,7 @@ bool DAGExpressionAnalyzer::appendExtraCastsAfterTS( initChain(chain, getCurrentInputColumns()); ExpressionActionsPtr actions = chain.getLastActions(); // For TimeZone - tipb::Expr tz_expr; - constructTZExpr(tz_expr, context.getTimezoneInfo(), true); + tipb::Expr tz_expr = constructTZExpr(context.getTimezoneInfo(), true); String tz_col = getActions(tz_expr, actions); static const String convert_time_zone_form_utc = "ConvertTimeZoneFromUTC"; static const String convert_time_zone_by_offset = "ConvertTimeZoneByOffset"; @@ -981,11 +967,10 @@ bool DAGExpressionAnalyzer::appendExtraCastsAfterTS( if (need_cast_column[i] == ExtraCastAfterTSMode::AppendDurationCast) { - tipb::Expr fsp_expr; if (columns[i].decimal() > 6) throw Exception("fsp must <= 6", ErrorCodes::LOGICAL_ERROR); auto fsp = columns[i].decimal() < 0 ? 0 : columns[i].decimal(); - constructInt64LiteralTiExpr(fsp_expr, fsp); + tipb::Expr fsp_expr = constructInt64LiteralTiExpr(fsp); fsp_col = getActions(fsp_expr, actions); String casted_name = appendDurationCast(fsp_col, source_columns[i].name, dur_func_name, actions); source_columns[i].name = casted_name; @@ -1216,6 +1201,7 @@ NamesWithAliases DAGExpressionAnalyzer::appendFinalProject( { if (!output_offsets.empty()) { + /// root query block for (auto i : output_offsets) { final_project.emplace_back( @@ -1225,6 +1211,7 @@ NamesWithAliases DAGExpressionAnalyzer::appendFinalProject( } else { + /// non-root query block for (const auto & element : current_columns) { final_project.emplace_back(element.name, unique_name_generator.toUniqueName(column_prefix + element.name)); @@ -1238,8 +1225,7 @@ NamesWithAliases DAGExpressionAnalyzer::appendFinalProject( initChain(chain, getCurrentInputColumns()); ExpressionActionsChain::Step & step = chain.steps.back(); - tipb::Expr tz_expr; - constructTZExpr(tz_expr, context.getTimezoneInfo(), false); + tipb::Expr tz_expr = constructTZExpr(context.getTimezoneInfo(), false); String tz_col; String tz_cast_func_name = context.getTimezoneInfo().is_name_based ? "ConvertTimeZoneToUTC" : "ConvertTimeZoneByOffset"; std::vector casted(schema.size(), 0); @@ -1330,8 +1316,7 @@ String DAGExpressionAnalyzer::appendCast(const DataTypePtr & target_type, Expres { // need to add cast function // first construct the second argument - tipb::Expr type_expr; - constructStringLiteralTiExpr(type_expr, target_type->getName()); + tipb::Expr type_expr = constructStringLiteralTiExpr(target_type->getName()); auto type_expr_name = getActions(type_expr, actions); String cast_expr_name = applyFunction("CAST", {expr_name, type_expr_name}, actions, nullptr); return cast_expr_name; @@ -1411,8 +1396,7 @@ String DAGExpressionAnalyzer::getActions(const tipb::Expr & expr, ExpressionActi if (expr.field_type().tp() == TiDB::TypeTimestamp && !context.getTimezoneInfo().is_utc_timezone) { /// append timezone cast for timestamp literal - tipb::Expr tz_expr; - constructTZExpr(tz_expr, context.getTimezoneInfo(), true); + tipb::Expr tz_expr = constructTZExpr(context.getTimezoneInfo(), true); String func_name = context.getTimezoneInfo().is_name_based ? "ConvertTimeZoneFromUTC" : "ConvertTimeZoneByOffset"; String tz_col = getActions(tz_expr, actions); String casted_name = appendTimeZoneCast(tz_col, ret, func_name, actions); @@ -1426,9 +1410,9 @@ String DAGExpressionAnalyzer::getActions(const tipb::Expr & expr, ExpressionActi else if (isScalarFunctionExpr(expr)) { const String & func_name = getFunctionName(expr); - if (function_builder_map.find(func_name) != function_builder_map.end()) + if (DAGExpressionAnalyzerHelper::function_builder_map.count(func_name) != 0) { - ret = function_builder_map[func_name](this, expr, actions); + ret = DAGExpressionAnalyzerHelper::function_builder_map[func_name](this, expr, actions); } else { diff --git a/dbms/src/Flash/Coprocessor/DAGQueryBlockInterpreter.cpp b/dbms/src/Flash/Coprocessor/DAGQueryBlockInterpreter.cpp index a1a0b9be21a..e35444c7b48 100644 --- a/dbms/src/Flash/Coprocessor/DAGQueryBlockInterpreter.cpp +++ b/dbms/src/Flash/Coprocessor/DAGQueryBlockInterpreter.cpp @@ -40,6 +40,7 @@ DAGQueryBlockInterpreter::DAGQueryBlockInterpreter( Context & context_, const std::vector & input_streams_vec_, const DAGQueryBlock & query_block_, + size_t max_streams_, bool keep_session_timezone_info_, const DAGQuerySource & dag_, std::vector & subqueries_for_sets_, @@ -50,6 +51,7 @@ DAGQueryBlockInterpreter::DAGQueryBlockInterpreter( , query_block(query_block_) , keep_session_timezone_info(keep_session_timezone_info_) , rqst(dag_.getDAGRequest()) + , max_streams(max_streams_) , dag(dag_) , subqueries_for_sets(subqueries_for_sets_) , exchange_receiver_map(exchange_receiver_map_) @@ -60,15 +62,6 @@ DAGQueryBlockInterpreter::DAGQueryBlockInterpreter( for (const auto & condition : query_block.selection->selection().conditions()) conditions.push_back(&condition); } - const Settings & settings = context.getSettingsRef(); - if (dag.isBatchCop()) - max_streams = settings.max_threads; - else - max_streams = 1; - if (max_streams > 1) - { - max_streams *= settings.max_streams_to_max_threads_ratio; - } } BlockInputStreamPtr combinedNonJoinedDataStream(DAGPipeline & pipeline, size_t max_threads, const LogWithPrefixPtr & log) @@ -209,7 +202,7 @@ AnalysisResult analyzeExpressions( query_block.output_field_types, query_block.output_offsets, query_block.qb_column_prefix, - keep_session_timezone_info || !query_block.isRootQueryBlock()); + keep_session_timezone_info); res.before_order_and_select = chain.getLastActions(); diff --git a/dbms/src/Flash/Coprocessor/DAGQueryBlockInterpreter.h b/dbms/src/Flash/Coprocessor/DAGQueryBlockInterpreter.h index 49122a452da..a1177463946 100644 --- a/dbms/src/Flash/Coprocessor/DAGQueryBlockInterpreter.h +++ b/dbms/src/Flash/Coprocessor/DAGQueryBlockInterpreter.h @@ -36,6 +36,7 @@ class DAGQueryBlockInterpreter Context & context_, const std::vector & input_streams_vec_, const DAGQueryBlock & query_block_, + size_t max_streams_, bool keep_session_timezone_info_, const DAGQuerySource & dag_, std::vector & subqueries_for_sets_, diff --git a/dbms/src/Flash/Coprocessor/DAGQuerySource.cpp b/dbms/src/Flash/Coprocessor/DAGQuerySource.cpp index 4c3b0439b6b..87d59ba9bab 100644 --- a/dbms/src/Flash/Coprocessor/DAGQuerySource.cpp +++ b/dbms/src/Flash/Coprocessor/DAGQuerySource.cpp @@ -15,12 +15,12 @@ DAGQuerySource::DAGQuerySource( const RegionInfoList & regions_for_remote_read_, const tipb::DAGRequest & dag_request_, const LogWithPrefixPtr & log_, - const bool is_batch_cop_) + const bool is_batch_cop_or_mpp_) : context(context_) , regions(regions_) , regions_for_remote_read(regions_for_remote_read_) , dag_request(dag_request_) - , is_batch_cop(is_batch_cop_) + , is_batch_cop_or_mpp(is_batch_cop_or_mpp_) , log(log_) { if (dag_request.has_root_executor()) diff --git a/dbms/src/Flash/Coprocessor/DAGQuerySource.h b/dbms/src/Flash/Coprocessor/DAGQuerySource.h index 0f6f68bc2c1..668ff1bc34a 100644 --- a/dbms/src/Flash/Coprocessor/DAGQuerySource.h +++ b/dbms/src/Flash/Coprocessor/DAGQuerySource.h @@ -24,7 +24,7 @@ class DAGQuerySource : public IQuerySource const RegionInfoList & regions_needs_remote_read_, const tipb::DAGRequest & dag_request_, const LogWithPrefixPtr & log_, - const bool is_batch_cop_ = false); + const bool is_batch_cop_or_mpp_ = false); std::tuple parse(size_t) override; String str(size_t max_query_size) override; @@ -42,7 +42,7 @@ class DAGQuerySource : public IQuerySource const RegionInfoMap & getRegions() const { return regions; } const RegionInfoList & getRegionsForRemoteRead() const { return regions_for_remote_read; } - bool isBatchCop() const { return is_batch_cop; } + bool isBatchCopOrMpp() const { return is_batch_cop_or_mpp; } DAGContext & getDAGContext() const { return *context.getDAGContext(); } @@ -64,7 +64,7 @@ class DAGQuerySource : public IQuerySource std::shared_ptr root_query_block; ASTPtr ast; - const bool is_batch_cop; + const bool is_batch_cop_or_mpp; LogWithPrefixPtr log; }; diff --git a/dbms/src/Flash/Coprocessor/DAGStorageInterpreter.cpp b/dbms/src/Flash/Coprocessor/DAGStorageInterpreter.cpp index b3b4c25f2a0..c21000d9242 100644 --- a/dbms/src/Flash/Coprocessor/DAGStorageInterpreter.cpp +++ b/dbms/src/Flash/Coprocessor/DAGStorageInterpreter.cpp @@ -141,7 +141,7 @@ DAGStorageInterpreter::DAGStorageInterpreter( void DAGStorageInterpreter::execute(DAGPipeline & pipeline) { - if (dag.isBatchCop()) + if (dag.isBatchCopOrMpp()) learner_read_snapshot = doBatchCopLearnerRead(); else learner_read_snapshot = doCopLearnerRead(); @@ -285,7 +285,7 @@ void DAGStorageInterpreter::doLocalRead(DAGPipeline & pipeline, size_t max_block catch (RegionException & e) { /// Recover from region exception when super batch is enable - if (dag.isBatchCop()) + if (dag.isBatchCopOrMpp()) { // clean all streams from local because we are not sure the correctness of those streams pipeline.streams.clear(); diff --git a/dbms/src/Flash/Coprocessor/DAGUtils.cpp b/dbms/src/Flash/Coprocessor/DAGUtils.cpp index d0b7b84050e..c150d5b1f07 100644 --- a/dbms/src/Flash/Coprocessor/DAGUtils.cpp +++ b/dbms/src/Flash/Coprocessor/DAGUtils.cpp @@ -16,8 +16,6 @@ namespace DB { namespace ErrorCodes { -extern const int UNSUPPORTED_METHOD; -extern const int LOGICAL_ERROR; extern const int NOT_IMPLEMENTED; extern const int UNKNOWN_USER; extern const int WRONG_PASSWORD; @@ -27,672 +25,149 @@ extern const int IP_ADDRESS_NOT_ALLOWED; const Int8 VAR_SIZE = 0; -bool isScalarFunctionExpr(const tipb::Expr & expr) -{ - return expr.tp() == tipb::ExprType::ScalarFunc; -} +extern const String uniq_raw_res_name; -bool isFunctionExpr(const tipb::Expr & expr) +namespace { - return isScalarFunctionExpr(expr) || isAggFunctionExpr(expr); -} +const std::unordered_map agg_func_map({ + {tipb::ExprType::Count, "count"}, + {tipb::ExprType::Sum, "sum"}, + {tipb::ExprType::Min, "min"}, + {tipb::ExprType::Max, "max"}, + {tipb::ExprType::First, "first_row"}, + {tipb::ExprType::ApproxCountDistinct, uniq_raw_res_name}, + {tipb::ExprType::GroupConcat, "groupArray"}, + //{tipb::ExprType::Avg, ""}, + //{tipb::ExprType::Agg_BitAnd, ""}, + //{tipb::ExprType::Agg_BitOr, ""}, + //{tipb::ExprType::Agg_BitXor, ""}, + //{tipb::ExprType::Std, ""}, + //{tipb::ExprType::Stddev, ""}, + //{tipb::ExprType::StddevPop, ""}, + //{tipb::ExprType::StddevSamp, ""}, + //{tipb::ExprType::VarPop, ""}, + //{tipb::ExprType::VarSamp, ""}, + //{tipb::ExprType::Variance, ""}, + //{tipb::ExprType::JsonArrayAgg, ""}, + //{tipb::ExprType::JsonObjectAgg, ""}, +}); -const String & getAggFunctionName(const tipb::Expr & expr) -{ - if (expr.has_distinct()) - { - if (distinct_agg_func_map.find(expr.tp()) != distinct_agg_func_map.end()) - { - return distinct_agg_func_map[expr.tp()]; - } - } - else - { - if (agg_func_map.find(expr.tp()) != agg_func_map.end()) - { - return agg_func_map[expr.tp()]; - } - } +const std::unordered_map distinct_agg_func_map({ + {tipb::ExprType::Count, "countDistinct"}, + {tipb::ExprType::GroupConcat, "groupUniqArray"}, +}); - const auto errmsg - = tipb::ExprType_Name(expr.tp()) + "(distinct=" + (expr.has_distinct() ? "true" : "false") + ")" + " is not supported."; - throw TiFlashException(errmsg, Errors::Coprocessor::Unimplemented); -} +const std::unordered_map scalar_func_map({ + {tipb::ScalarFuncSig::CastIntAsInt, "tidb_cast"}, + {tipb::ScalarFuncSig::CastIntAsReal, "tidb_cast"}, + {tipb::ScalarFuncSig::CastIntAsString, "tidb_cast"}, + {tipb::ScalarFuncSig::CastIntAsDecimal, "tidb_cast"}, + {tipb::ScalarFuncSig::CastIntAsTime, "tidb_cast"}, + //{tipb::ScalarFuncSig::CastIntAsDuration, "cast"}, + //{tipb::ScalarFuncSig::CastIntAsJson, "cast"}, -const String & getFunctionName(const tipb::Expr & expr) -{ - if (isAggFunctionExpr(expr)) - { - return getAggFunctionName(expr); - } - else - { - if (scalar_func_map.find(expr.sig()) == scalar_func_map.end()) - { - throw TiFlashException(tipb::ScalarFuncSig_Name(expr.sig()) + " is not supported.", Errors::Coprocessor::Unimplemented); - } - return scalar_func_map[expr.sig()]; - } -} + {tipb::ScalarFuncSig::CastRealAsInt, "tidb_cast"}, + {tipb::ScalarFuncSig::CastRealAsReal, "tidb_cast"}, + {tipb::ScalarFuncSig::CastRealAsString, "tidb_cast"}, + {tipb::ScalarFuncSig::CastRealAsDecimal, "tidb_cast"}, + {tipb::ScalarFuncSig::CastRealAsTime, "tidb_cast"}, + //{tipb::ScalarFuncSig::CastRealAsDuration, "cast"}, + //{tipb::ScalarFuncSig::CastRealAsJson, "cast"}, -String exprToString(const tipb::Expr & expr, const std::vector & input_col) -{ - std::stringstream ss; - String func_name; - Field f; - switch (expr.tp()) - { - case tipb::ExprType::Null: - return "NULL"; - case tipb::ExprType::Int64: - return std::to_string(decodeDAGInt64(expr.val())); - case tipb::ExprType::Uint64: - return std::to_string(decodeDAGUInt64(expr.val())); - case tipb::ExprType::Float32: - return std::to_string(decodeDAGFloat32(expr.val())); - case tipb::ExprType::Float64: - return std::to_string(decodeDAGFloat64(expr.val())); - case tipb::ExprType::String: - return decodeDAGString(expr.val()); - case tipb::ExprType::Bytes: - return decodeDAGBytes(expr.val()); - case tipb::ExprType::MysqlDecimal: - { - auto field = decodeDAGDecimal(expr.val()); - if (field.getType() == Field::Types::Decimal32) - return field.get>().toString(); - else if (field.getType() == Field::Types::Decimal64) - return field.get>().toString(); - else if (field.getType() == Field::Types::Decimal128) - return field.get>().toString(); - else if (field.getType() == Field::Types::Decimal256) - return field.get>().toString(); - else - throw TiFlashException("Not decimal literal" + expr.DebugString(), Errors::Coprocessor::BadRequest); - } - case tipb::ExprType::MysqlTime: - { - if (!expr.has_field_type()) - throw TiFlashException("MySQL Time literal without field_type" + expr.DebugString(), Errors::Coprocessor::BadRequest); - auto t = decodeDAGUInt64(expr.val()); - auto ret = std::to_string(TiDB::DatumFlat(t, static_cast(expr.field_type().tp())).field().get()); - if (expr.field_type().tp() == TiDB::TypeTimestamp) - ret = ret + "_ts"; - return ret; - } - case tipb::ExprType::MysqlDuration: - { - if (!expr.has_field_type()) - throw TiFlashException("MySQL Duration literal without field_type" + expr.DebugString(), Errors::Coprocessor::BadRequest); - auto t = decodeDAGInt64(expr.val()); - auto ret = std::to_string(TiDB::DatumFlat(t, static_cast(expr.field_type().tp())).field().get()); - return ret; - } - case tipb::ExprType::ColumnRef: - return getColumnNameForColumnExpr(expr, input_col); - case tipb::ExprType::Count: - case tipb::ExprType::Sum: - case tipb::ExprType::Avg: - case tipb::ExprType::Min: - case tipb::ExprType::Max: - case tipb::ExprType::First: - case tipb::ExprType::ApproxCountDistinct: - case tipb::ExprType::GroupConcat: - func_name = getAggFunctionName(expr); - break; - case tipb::ExprType::ScalarFunc: - if (scalar_func_map.find(expr.sig()) == scalar_func_map.end()) - { - throw TiFlashException(tipb::ScalarFuncSig_Name(expr.sig()) + " not supported", Errors::Coprocessor::Unimplemented); - } - func_name = scalar_func_map.find(expr.sig())->second; - break; - default: - throw TiFlashException(tipb::ExprType_Name(expr.tp()) + " not supported", Errors::Coprocessor::Unimplemented); - } - // build function expr - if (functionIsInOrGlobalInOperator(func_name)) - { - // for in, we could not represent the function expr using func_name(param1, param2, ...) - ss << exprToString(expr.children(0), input_col) << " " << func_name << " ("; - bool first = true; - for (int i = 1; i < expr.children_size(); i++) - { - String s = exprToString(expr.children(i), input_col); - if (first) - first = false; - else - ss << ", "; - ss << s; - } - ss << ")"; - } - else - { - ss << func_name << "("; - bool first = true; - for (const tipb::Expr & child : expr.children()) - { - String s = exprToString(child, input_col); - if (first) - first = false; - else - ss << ", "; - ss << s; - } - ss << ")"; - } - return ss.str(); -} + {tipb::ScalarFuncSig::CastDecimalAsInt, "tidb_cast"}, + {tipb::ScalarFuncSig::CastDecimalAsReal, "tidb_cast"}, + {tipb::ScalarFuncSig::CastDecimalAsString, "tidb_cast"}, + {tipb::ScalarFuncSig::CastDecimalAsDecimal, "tidb_cast"}, + {tipb::ScalarFuncSig::CastDecimalAsTime, "tidb_cast"}, + //{tipb::ScalarFuncSig::CastDecimalAsDuration, "cast"}, + //{tipb::ScalarFuncSig::CastDecimalAsJson, "cast"}, -const String & getTypeName(const tipb::Expr & expr) -{ - return tipb::ExprType_Name(expr.tp()); -} + {tipb::ScalarFuncSig::CastStringAsInt, "tidb_cast"}, + {tipb::ScalarFuncSig::CastStringAsReal, "tidb_cast"}, + {tipb::ScalarFuncSig::CastStringAsString, "tidb_cast"}, + {tipb::ScalarFuncSig::CastStringAsDecimal, "tidb_cast"}, + {tipb::ScalarFuncSig::CastStringAsTime, "tidb_cast"}, + //{tipb::ScalarFuncSig::CastStringAsDuration, "cast"}, + //{tipb::ScalarFuncSig::CastStringAsJson, "cast"}, -bool isAggFunctionExpr(const tipb::Expr & expr) -{ - switch (expr.tp()) - { - case tipb::ExprType::Count: - case tipb::ExprType::Sum: - case tipb::ExprType::Avg: - case tipb::ExprType::Min: - case tipb::ExprType::Max: - case tipb::ExprType::First: - case tipb::ExprType::GroupConcat: - case tipb::ExprType::Agg_BitAnd: - case tipb::ExprType::Agg_BitOr: - case tipb::ExprType::Agg_BitXor: - case tipb::ExprType::Std: - case tipb::ExprType::Stddev: - case tipb::ExprType::StddevPop: - case tipb::ExprType::StddevSamp: - case tipb::ExprType::VarPop: - case tipb::ExprType::VarSamp: - case tipb::ExprType::Variance: - case tipb::ExprType::JsonArrayAgg: - case tipb::ExprType::JsonObjectAgg: - case tipb::ExprType::ApproxCountDistinct: - return true; - default: - return false; - } -} + {tipb::ScalarFuncSig::CastTimeAsInt, "tidb_cast"}, + {tipb::ScalarFuncSig::CastTimeAsReal, "tidb_cast"}, + {tipb::ScalarFuncSig::CastTimeAsString, "tidb_cast"}, + {tipb::ScalarFuncSig::CastTimeAsDecimal, "tidb_cast"}, + {tipb::ScalarFuncSig::CastTimeAsTime, "tidb_cast"}, + //{tipb::ScalarFuncSig::CastTimeAsDuration, "cast"}, + //{tipb::ScalarFuncSig::CastTimeAsJson, "cast"}, -bool isLiteralExpr(const tipb::Expr & expr) -{ - switch (expr.tp()) - { - case tipb::ExprType::Null: - case tipb::ExprType::Int64: - case tipb::ExprType::Uint64: - case tipb::ExprType::Float32: - case tipb::ExprType::Float64: - case tipb::ExprType::String: - case tipb::ExprType::Bytes: - case tipb::ExprType::MysqlBit: - case tipb::ExprType::MysqlDecimal: - case tipb::ExprType::MysqlDuration: - case tipb::ExprType::MysqlEnum: - case tipb::ExprType::MysqlHex: - case tipb::ExprType::MysqlSet: - case tipb::ExprType::MysqlTime: - case tipb::ExprType::MysqlJson: - case tipb::ExprType::ValueList: - return true; - default: - return false; - } -} - -bool isColumnExpr(const tipb::Expr & expr) -{ - return expr.tp() == tipb::ExprType::ColumnRef; -} + //{tipb::ScalarFuncSig::CastDurationAsInt, "cast"}, + //{tipb::ScalarFuncSig::CastDurationAsReal, "cast"}, + //{tipb::ScalarFuncSig::CastDurationAsString, "cast"}, + //{tipb::ScalarFuncSig::CastDurationAsDecimal, "cast"}, + //{tipb::ScalarFuncSig::CastDurationAsTime, "cast"}, + {tipb::ScalarFuncSig::CastDurationAsDuration, "tidb_cast"}, + //{tipb::ScalarFuncSig::CastDurationAsJson, "cast"}, -Field decodeLiteral(const tipb::Expr & expr) -{ - switch (expr.tp()) - { - case tipb::ExprType::Null: - return Field(); - case tipb::ExprType::Int64: - return decodeDAGInt64(expr.val()); - case tipb::ExprType::Uint64: - return decodeDAGUInt64(expr.val()); - case tipb::ExprType::Float32: - return Float64(decodeDAGFloat32(expr.val())); - case tipb::ExprType::Float64: - return decodeDAGFloat64(expr.val()); - case tipb::ExprType::String: - return decodeDAGString(expr.val()); - case tipb::ExprType::Bytes: - return decodeDAGBytes(expr.val()); - case tipb::ExprType::MysqlDecimal: - return decodeDAGDecimal(expr.val()); - case tipb::ExprType::MysqlTime: - { - if (!expr.has_field_type()) - throw TiFlashException("MySQL Time literal without field_type" + expr.DebugString(), Errors::Coprocessor::BadRequest); - auto t = decodeDAGUInt64(expr.val()); - return TiDB::DatumFlat(t, static_cast(expr.field_type().tp())).field(); - } - case tipb::ExprType::MysqlDuration: - { - if (!expr.has_field_type()) - throw TiFlashException("MySQL Duration literal without field_type" + expr.DebugString(), Errors::Coprocessor::BadRequest); - auto t = decodeDAGInt64(expr.val()); - return TiDB::DatumFlat(t, static_cast(expr.field_type().tp())).field(); - } - case tipb::ExprType::MysqlBit: - case tipb::ExprType::MysqlEnum: - case tipb::ExprType::MysqlHex: - case tipb::ExprType::MysqlSet: - case tipb::ExprType::MysqlJson: - case tipb::ExprType::ValueList: - throw TiFlashException(tipb::ExprType_Name(expr.tp()) + " is not supported yet", Errors::Coprocessor::Unimplemented); - default: - throw TiFlashException("Should not reach here: not a literal expression", Errors::Coprocessor::Internal); - } -} + //{tipb::ScalarFuncSig::CastJsonAsInt, "cast"}, + //{tipb::ScalarFuncSig::CastJsonAsReal, "cast"}, + //{tipb::ScalarFuncSig::CastJsonAsString, "cast"}, + //{tipb::ScalarFuncSig::CastJsonAsDecimal, "cast"}, + //{tipb::ScalarFuncSig::CastJsonAsTime, "cast"}, + //{tipb::ScalarFuncSig::CastJsonAsDuration, "cast"}, + //{tipb::ScalarFuncSig::CastJsonAsJson, "cast"}, -String getColumnNameForColumnExpr(const tipb::Expr & expr, const std::vector & input_col) -{ - auto column_index = decodeDAGInt64(expr.val()); - if (column_index < 0 || column_index >= static_cast(input_col.size())) - { - throw TiFlashException("Column index out of bound", Errors::Coprocessor::BadRequest); - } - return input_col[column_index].name; -} + {tipb::ScalarFuncSig::CoalesceInt, "coalesce"}, + {tipb::ScalarFuncSig::CoalesceReal, "coalesce"}, + {tipb::ScalarFuncSig::CoalesceString, "coalesce"}, + {tipb::ScalarFuncSig::CoalesceDecimal, "coalesce"}, + {tipb::ScalarFuncSig::CoalesceTime, "coalesce"}, + {tipb::ScalarFuncSig::CoalesceDuration, "coalesce"}, + {tipb::ScalarFuncSig::CoalesceJson, "coalesce"}, -// For some historical or unknown reasons, TiDB might set an invalid -// field type. This function checks if the expr has a valid field type. -// So far the known invalid field types are: -// 1. decimal type with scale == -1 -// 2. decimal type with precision == 0 -bool exprHasValidFieldType(const tipb::Expr & expr) -{ - return expr.has_field_type() - && !(expr.field_type().tp() == TiDB::TP::TypeNewDecimal - && (expr.field_type().decimal() == -1 || expr.field_type().flen() == 0)); -} + {tipb::ScalarFuncSig::LTInt, "less"}, + {tipb::ScalarFuncSig::LTReal, "less"}, + {tipb::ScalarFuncSig::LTString, "less"}, + {tipb::ScalarFuncSig::LTDecimal, "less"}, + {tipb::ScalarFuncSig::LTTime, "less"}, + {tipb::ScalarFuncSig::LTDuration, "less"}, + {tipb::ScalarFuncSig::LTJson, "less"}, -bool isUnsupportedEncodeType(const std::vector & types, tipb::EncodeType encode_type) -{ - const static std::unordered_map> unsupported_types_map({ - {tipb::EncodeType::TypeCHBlock, {TiDB::TypeSet, TiDB::TypeGeometry, TiDB::TypeNull, TiDB::TypeEnum, TiDB::TypeJSON, TiDB::TypeBit}}, - {tipb::EncodeType::TypeChunk, {TiDB::TypeSet, TiDB::TypeGeometry, TiDB::TypeNull}}, - }); + {tipb::ScalarFuncSig::LEInt, "lessOrEquals"}, + {tipb::ScalarFuncSig::LEReal, "lessOrEquals"}, + {tipb::ScalarFuncSig::LEString, "lessOrEquals"}, + {tipb::ScalarFuncSig::LEDecimal, "lessOrEquals"}, + {tipb::ScalarFuncSig::LETime, "lessOrEquals"}, + {tipb::ScalarFuncSig::LEDuration, "lessOrEquals"}, + {tipb::ScalarFuncSig::LEJson, "lessOrEquals"}, - auto unsupported_set = unsupported_types_map.find(encode_type); - if (unsupported_set == unsupported_types_map.end()) - return false; - for (const auto & type : types) - { - if (unsupported_set->second.find(type.tp()) != unsupported_set->second.end()) - return true; - } - return false; -} + {tipb::ScalarFuncSig::GTInt, "greater"}, + {tipb::ScalarFuncSig::GTReal, "greater"}, + {tipb::ScalarFuncSig::GTString, "greater"}, + {tipb::ScalarFuncSig::GTDecimal, "greater"}, + {tipb::ScalarFuncSig::GTTime, "greater"}, + {tipb::ScalarFuncSig::GTDuration, "greater"}, + {tipb::ScalarFuncSig::GTJson, "greater"}, -DataTypePtr inferDataType4Literal(const tipb::Expr & expr) -{ - Field value = decodeLiteral(expr); - DataTypePtr flash_type = applyVisitor(FieldToDataType(), value); - /// need to extract target_type from expr.field_type() because the flash_type derived from - /// value is just a `memory type`, which does not have enough information, for example: - /// for date literal, the flash_type is `UInt64` - DataTypePtr target_type{}; - if (expr.tp() == tipb::ExprType::Null) - { - // todo We should use DataTypeNothing as NULL literal's TiFlash Type, because TiFlash has a lot of - // optimization for DataTypeNothing, but there are still some bugs when using DataTypeNothing: when - // TiFlash try to return data to TiDB or exchange data between TiFlash node, since codec only recognize - // TiDB type, use DataTypeNothing will meet error in the codec, so do not use DataTypeNothing until - // we fix the codec issue. - if (exprHasValidFieldType(expr)) - { - target_type = getDataTypeByFieldTypeForComputingLayer(expr.field_type()); - } - else - { - if (expr.has_field_type() && expr.field_type().tp() == TiDB::TP::TypeNewDecimal) - target_type = createDecimal(1, 0); - else - target_type = flash_type; - } - target_type = makeNullable(target_type); - } - else - { - if (expr.tp() == tipb::ExprType::MysqlDecimal) - { - /// to fix https://github.com/pingcap/tics/issues/1425, when TiDB push down - /// a decimal literal, it contains two types: one is the type that encoded - /// in Decimal value itself(i.e. expr.val()), the other is the type that in - /// expr.field_type(). According to TiDB and Mysql behavior, the computing - /// layer should use the type in expr.val(), which means we should ignore - /// the type in expr.field_type() - target_type = flash_type; - } - else - { - target_type = exprHasValidFieldType(expr) ? getDataTypeByFieldTypeForComputingLayer(expr.field_type()) : flash_type; - } - // We should remove nullable for constant value since TiDB may not set NOT_NULL flag for literal expression. - target_type = removeNullable(target_type); - } - return target_type; -} + {tipb::ScalarFuncSig::GreatestInt, "greatest"}, + {tipb::ScalarFuncSig::GreatestReal, "greatest"}, + {tipb::ScalarFuncSig::GreatestString, "greatest"}, + {tipb::ScalarFuncSig::GreatestDecimal, "greatest"}, + {tipb::ScalarFuncSig::GreatestTime, "greatest"}, -UInt8 getFieldLengthForArrowEncode(Int32 tp) -{ - switch (tp) - { - case TiDB::TypeTiny: - case TiDB::TypeShort: - case TiDB::TypeInt24: - case TiDB::TypeLong: - case TiDB::TypeLongLong: - case TiDB::TypeYear: - case TiDB::TypeDouble: - case TiDB::TypeTime: - case TiDB::TypeDate: - case TiDB::TypeDatetime: - case TiDB::TypeNewDate: - case TiDB::TypeTimestamp: - return 8; - case TiDB::TypeFloat: - return 4; - case TiDB::TypeDecimal: - case TiDB::TypeNewDecimal: - return 40; - case TiDB::TypeVarchar: - case TiDB::TypeVarString: - case TiDB::TypeString: - case TiDB::TypeBlob: - case TiDB::TypeTinyBlob: - case TiDB::TypeMediumBlob: - case TiDB::TypeLongBlob: - case TiDB::TypeBit: - case TiDB::TypeEnum: - case TiDB::TypeJSON: - return VAR_SIZE; - default: - throw TiFlashException("not supported field type in arrow encode: " + std::to_string(tp), Errors::Coprocessor::Internal); - } -} + {tipb::ScalarFuncSig::LeastInt, "least"}, + {tipb::ScalarFuncSig::LeastReal, "least"}, + {tipb::ScalarFuncSig::LeastString, "least"}, + {tipb::ScalarFuncSig::LeastDecimal, "least"}, + {tipb::ScalarFuncSig::LeastTime, "least"}, -void constructStringLiteralTiExpr(tipb::Expr & expr, const String & value) -{ - expr.set_tp(tipb::ExprType::String); - expr.set_val(value); - auto * field_type = expr.mutable_field_type(); - field_type->set_tp(TiDB::TypeString); - field_type->set_flag(TiDB::ColumnFlagNotNull); -} + //{tipb::ScalarFuncSig::IntervalInt, "cast"}, + //{tipb::ScalarFuncSig::IntervalReal, "cast"}, -void constructInt64LiteralTiExpr(tipb::Expr & expr, Int64 value) -{ - expr.set_tp(tipb::ExprType::Int64); - WriteBufferFromOwnString ss; - encodeDAGInt64(value, ss); - expr.set_val(ss.releaseStr()); - auto * field_type = expr.mutable_field_type(); - field_type->set_tp(TiDB::TypeLongLong); - field_type->set_flag(TiDB::ColumnFlagNotNull); -} - -void constructDateTimeLiteralTiExpr(tipb::Expr & expr, UInt64 packed_value) -{ - expr.set_tp(tipb::ExprType::MysqlTime); - WriteBufferFromOwnString ss; - encodeDAGUInt64(packed_value, ss); - expr.set_val(ss.releaseStr()); - auto * field_type = expr.mutable_field_type(); - field_type->set_tp(TiDB::TypeDatetime); - field_type->set_flag(TiDB::ColumnFlagNotNull); -} - -void constructNULLLiteralTiExpr(tipb::Expr & expr) -{ - expr.set_tp(tipb::ExprType::Null); - auto * field_type = expr.mutable_field_type(); - field_type->set_tp(TiDB::TypeNull); -} - -std::shared_ptr getCollatorFromExpr(const tipb::Expr & expr) -{ - if (expr.has_field_type()) - return getCollatorFromFieldType(expr.field_type()); - return nullptr; -} - -SortDescription getSortDescription(const std::vector & order_columns, const google::protobuf::RepeatedPtrField & by_items) -{ - SortDescription order_descr; - order_descr.reserve(by_items.size()); - for (int i = 0; i < by_items.size(); i++) - { - const auto & name = order_columns[i].name; - int direction = by_items[i].desc() ? -1 : 1; - // MySQL/TiDB treats NULL as "minimum". - int nulls_direction = -1; - std::shared_ptr collator = nullptr; - if (removeNullable(order_columns[i].type)->isString()) - collator = getCollatorFromExpr(by_items[i].expr()); - - order_descr.emplace_back(name, direction, nulls_direction, collator); - } - return order_descr; -} - -TiDB::TiDBCollatorPtr getCollatorFromFieldType(const tipb::FieldType & field_type) -{ - if (field_type.collate() < 0) - return TiDB::ITiDBCollator::getCollator(-field_type.collate()); - return nullptr; -} - -bool hasUnsignedFlag(const tipb::FieldType & tp) -{ - return tp.flag() & TiDB::ColumnFlagUnsigned; -} - -grpc::StatusCode tiflashErrorCodeToGrpcStatusCode(int error_code) -{ - /// do not use switch statement because ErrorCodes::XXXX is not a compile time constant - if (error_code == ErrorCodes::NOT_IMPLEMENTED) - return grpc::StatusCode::UNIMPLEMENTED; - if (error_code == ErrorCodes::UNKNOWN_USER || error_code == ErrorCodes::WRONG_PASSWORD || error_code == ErrorCodes::REQUIRED_PASSWORD - || error_code == ErrorCodes::IP_ADDRESS_NOT_ALLOWED) - return grpc::StatusCode::UNAUTHENTICATED; - return grpc::StatusCode::INTERNAL; -} - -void assertBlockSchema(const DataTypes & expected_types, const Block & block, const std::string & context_description) -{ - size_t columns = expected_types.size(); - if (block.columns() != columns) - throw Exception("Block schema mismatch in " + context_description + ": different number of columns: expected " - + std::to_string(columns) + " columns, got " + std::to_string(block.columns()) + " columns"); - - for (size_t i = 0; i < columns; ++i) - { - const auto & actual = block.getByPosition(i).type; - const auto & expected = expected_types[i]; - - if (!expected->equals(*actual)) - { - throw Exception("Block schema mismatch in " + context_description + ": different types: expected " + expected->getName() - + ", got " + actual->getName()); - } - } -} - -void getDAGRequestFromStringWithRetry(tipb::DAGRequest & dag_req, const String & s) -{ - if (!dag_req.ParseFromString(s)) - { - /// ParseFromString will use the default recursion limit, which is 100 to decode the plan, if the plan tree is too deep, - /// it may exceed this limit, so just try again by double the recursion limit - ::google::protobuf::io::CodedInputStream coded_input_stream(reinterpret_cast(s.data()), s.size()); - coded_input_stream.SetRecursionLimit(::google::protobuf::io::CodedInputStream::GetDefaultRecursionLimit() * 2); - if (!dag_req.ParseFromCodedStream(&coded_input_stream)) - { - /// just return error if decode failed this time, because it's really a corner case, and even if we can decode the plan - /// successfully by using a very large value of the recursion limit, it is kinds of meaningless because the runtime - /// performance of this task may be very bad if the plan tree is too deep - throw TiFlashException( - std::string(__PRETTY_FUNCTION__) + ": Invalid encoded plan, the most likely is that the plan/expression tree is too deep", - Errors::Coprocessor::BadRequest); - } - } -} - -extern const String uniq_raw_res_name; - -std::unordered_map agg_func_map({ - {tipb::ExprType::Count, "count"}, - {tipb::ExprType::Sum, "sum"}, - {tipb::ExprType::Min, "min"}, - {tipb::ExprType::Max, "max"}, - {tipb::ExprType::First, "first_row"}, - {tipb::ExprType::ApproxCountDistinct, uniq_raw_res_name}, - {tipb::ExprType::GroupConcat, "groupArray"}, - //{tipb::ExprType::Avg, ""}, - //{tipb::ExprType::Agg_BitAnd, ""}, - //{tipb::ExprType::Agg_BitOr, ""}, - //{tipb::ExprType::Agg_BitXor, ""}, - //{tipb::ExprType::Std, ""}, - //{tipb::ExprType::Stddev, ""}, - //{tipb::ExprType::StddevPop, ""}, - //{tipb::ExprType::StddevSamp, ""}, - //{tipb::ExprType::VarPop, ""}, - //{tipb::ExprType::VarSamp, ""}, - //{tipb::ExprType::Variance, ""}, - //{tipb::ExprType::JsonArrayAgg, ""}, - //{tipb::ExprType::JsonObjectAgg, ""}, -}); - -std::unordered_map distinct_agg_func_map({ - {tipb::ExprType::Count, "countDistinct"}, - {tipb::ExprType::GroupConcat, "groupUniqArray"}, -}); - -std::unordered_map scalar_func_map({ - {tipb::ScalarFuncSig::CastIntAsInt, "tidb_cast"}, - {tipb::ScalarFuncSig::CastIntAsReal, "tidb_cast"}, - {tipb::ScalarFuncSig::CastIntAsString, "tidb_cast"}, - {tipb::ScalarFuncSig::CastIntAsDecimal, "tidb_cast"}, - {tipb::ScalarFuncSig::CastIntAsTime, "tidb_cast"}, - //{tipb::ScalarFuncSig::CastIntAsDuration, "cast"}, - //{tipb::ScalarFuncSig::CastIntAsJson, "cast"}, - - {tipb::ScalarFuncSig::CastRealAsInt, "tidb_cast"}, - {tipb::ScalarFuncSig::CastRealAsReal, "tidb_cast"}, - {tipb::ScalarFuncSig::CastRealAsString, "tidb_cast"}, - {tipb::ScalarFuncSig::CastRealAsDecimal, "tidb_cast"}, - {tipb::ScalarFuncSig::CastRealAsTime, "tidb_cast"}, - //{tipb::ScalarFuncSig::CastRealAsDuration, "cast"}, - //{tipb::ScalarFuncSig::CastRealAsJson, "cast"}, - - {tipb::ScalarFuncSig::CastDecimalAsInt, "tidb_cast"}, - {tipb::ScalarFuncSig::CastDecimalAsReal, "tidb_cast"}, - {tipb::ScalarFuncSig::CastDecimalAsString, "tidb_cast"}, - {tipb::ScalarFuncSig::CastDecimalAsDecimal, "tidb_cast"}, - {tipb::ScalarFuncSig::CastDecimalAsTime, "tidb_cast"}, - //{tipb::ScalarFuncSig::CastDecimalAsDuration, "cast"}, - //{tipb::ScalarFuncSig::CastDecimalAsJson, "cast"}, - - {tipb::ScalarFuncSig::CastStringAsInt, "tidb_cast"}, - {tipb::ScalarFuncSig::CastStringAsReal, "tidb_cast"}, - {tipb::ScalarFuncSig::CastStringAsString, "tidb_cast"}, - {tipb::ScalarFuncSig::CastStringAsDecimal, "tidb_cast"}, - {tipb::ScalarFuncSig::CastStringAsTime, "tidb_cast"}, - //{tipb::ScalarFuncSig::CastStringAsDuration, "cast"}, - //{tipb::ScalarFuncSig::CastStringAsJson, "cast"}, - - {tipb::ScalarFuncSig::CastTimeAsInt, "tidb_cast"}, - {tipb::ScalarFuncSig::CastTimeAsReal, "tidb_cast"}, - {tipb::ScalarFuncSig::CastTimeAsString, "tidb_cast"}, - {tipb::ScalarFuncSig::CastTimeAsDecimal, "tidb_cast"}, - {tipb::ScalarFuncSig::CastTimeAsTime, "tidb_cast"}, - //{tipb::ScalarFuncSig::CastTimeAsDuration, "cast"}, - //{tipb::ScalarFuncSig::CastTimeAsJson, "cast"}, - - //{tipb::ScalarFuncSig::CastDurationAsInt, "cast"}, - //{tipb::ScalarFuncSig::CastDurationAsReal, "cast"}, - //{tipb::ScalarFuncSig::CastDurationAsString, "cast"}, - //{tipb::ScalarFuncSig::CastDurationAsDecimal, "cast"}, - //{tipb::ScalarFuncSig::CastDurationAsTime, "cast"}, - {tipb::ScalarFuncSig::CastDurationAsDuration, "tidb_cast"}, - //{tipb::ScalarFuncSig::CastDurationAsJson, "cast"}, - - //{tipb::ScalarFuncSig::CastJsonAsInt, "cast"}, - //{tipb::ScalarFuncSig::CastJsonAsReal, "cast"}, - //{tipb::ScalarFuncSig::CastJsonAsString, "cast"}, - //{tipb::ScalarFuncSig::CastJsonAsDecimal, "cast"}, - //{tipb::ScalarFuncSig::CastJsonAsTime, "cast"}, - //{tipb::ScalarFuncSig::CastJsonAsDuration, "cast"}, - //{tipb::ScalarFuncSig::CastJsonAsJson, "cast"}, - - {tipb::ScalarFuncSig::CoalesceInt, "coalesce"}, - {tipb::ScalarFuncSig::CoalesceReal, "coalesce"}, - {tipb::ScalarFuncSig::CoalesceString, "coalesce"}, - {tipb::ScalarFuncSig::CoalesceDecimal, "coalesce"}, - {tipb::ScalarFuncSig::CoalesceTime, "coalesce"}, - {tipb::ScalarFuncSig::CoalesceDuration, "coalesce"}, - {tipb::ScalarFuncSig::CoalesceJson, "coalesce"}, - - {tipb::ScalarFuncSig::LTInt, "less"}, - {tipb::ScalarFuncSig::LTReal, "less"}, - {tipb::ScalarFuncSig::LTString, "less"}, - {tipb::ScalarFuncSig::LTDecimal, "less"}, - {tipb::ScalarFuncSig::LTTime, "less"}, - {tipb::ScalarFuncSig::LTDuration, "less"}, - {tipb::ScalarFuncSig::LTJson, "less"}, - - {tipb::ScalarFuncSig::LEInt, "lessOrEquals"}, - {tipb::ScalarFuncSig::LEReal, "lessOrEquals"}, - {tipb::ScalarFuncSig::LEString, "lessOrEquals"}, - {tipb::ScalarFuncSig::LEDecimal, "lessOrEquals"}, - {tipb::ScalarFuncSig::LETime, "lessOrEquals"}, - {tipb::ScalarFuncSig::LEDuration, "lessOrEquals"}, - {tipb::ScalarFuncSig::LEJson, "lessOrEquals"}, - - {tipb::ScalarFuncSig::GTInt, "greater"}, - {tipb::ScalarFuncSig::GTReal, "greater"}, - {tipb::ScalarFuncSig::GTString, "greater"}, - {tipb::ScalarFuncSig::GTDecimal, "greater"}, - {tipb::ScalarFuncSig::GTTime, "greater"}, - {tipb::ScalarFuncSig::GTDuration, "greater"}, - {tipb::ScalarFuncSig::GTJson, "greater"}, - - {tipb::ScalarFuncSig::GreatestInt, "greatest"}, - {tipb::ScalarFuncSig::GreatestReal, "greatest"}, - {tipb::ScalarFuncSig::GreatestString, "greatest"}, - {tipb::ScalarFuncSig::GreatestDecimal, "greatest"}, - {tipb::ScalarFuncSig::GreatestTime, "greatest"}, - - {tipb::ScalarFuncSig::LeastInt, "least"}, - {tipb::ScalarFuncSig::LeastReal, "least"}, - {tipb::ScalarFuncSig::LeastString, "least"}, - {tipb::ScalarFuncSig::LeastDecimal, "least"}, - {tipb::ScalarFuncSig::LeastTime, "least"}, - - //{tipb::ScalarFuncSig::IntervalInt, "cast"}, - //{tipb::ScalarFuncSig::IntervalReal, "cast"}, - - {tipb::ScalarFuncSig::GEInt, "greaterOrEquals"}, - {tipb::ScalarFuncSig::GEReal, "greaterOrEquals"}, - {tipb::ScalarFuncSig::GEString, "greaterOrEquals"}, - {tipb::ScalarFuncSig::GEDecimal, "greaterOrEquals"}, - {tipb::ScalarFuncSig::GETime, "greaterOrEquals"}, - {tipb::ScalarFuncSig::GEDuration, "greaterOrEquals"}, - {tipb::ScalarFuncSig::GEJson, "greaterOrEquals"}, + {tipb::ScalarFuncSig::GEInt, "greaterOrEquals"}, + {tipb::ScalarFuncSig::GEReal, "greaterOrEquals"}, + {tipb::ScalarFuncSig::GEString, "greaterOrEquals"}, + {tipb::ScalarFuncSig::GEDecimal, "greaterOrEquals"}, + {tipb::ScalarFuncSig::GETime, "greaterOrEquals"}, + {tipb::ScalarFuncSig::GEDuration, "greaterOrEquals"}, + {tipb::ScalarFuncSig::GEJson, "greaterOrEquals"}, {tipb::ScalarFuncSig::EQInt, "equals"}, {tipb::ScalarFuncSig::EQReal, "equals"}, @@ -1034,152 +509,687 @@ std::unordered_map scalar_func_map({ //{tipb::ScalarFuncSig::UTCTimestampWithArg, "cast"}, //{tipb::ScalarFuncSig::UTCTimestampWithoutArg, "cast"}, - //{tipb::ScalarFuncSig::AddDatetimeAndDuration, "cast"}, - //{tipb::ScalarFuncSig::AddDatetimeAndString, "cast"}, - //{tipb::ScalarFuncSig::AddTimeDateTimeNull, "cast"}, - //{tipb::ScalarFuncSig::AddStringAndDuration, "cast"}, - //{tipb::ScalarFuncSig::AddStringAndString, "cast"}, - //{tipb::ScalarFuncSig::AddTimeStringNull, "cast"}, - //{tipb::ScalarFuncSig::AddDurationAndDuration, "cast"}, - //{tipb::ScalarFuncSig::AddDurationAndString, "cast"}, - //{tipb::ScalarFuncSig::AddTimeDurationNull, "cast"}, - //{tipb::ScalarFuncSig::AddDateAndDuration, "cast"}, - //{tipb::ScalarFuncSig::AddDateAndString, "cast"}, + //{tipb::ScalarFuncSig::AddDatetimeAndDuration, "cast"}, + //{tipb::ScalarFuncSig::AddDatetimeAndString, "cast"}, + //{tipb::ScalarFuncSig::AddTimeDateTimeNull, "cast"}, + //{tipb::ScalarFuncSig::AddStringAndDuration, "cast"}, + //{tipb::ScalarFuncSig::AddStringAndString, "cast"}, + //{tipb::ScalarFuncSig::AddTimeStringNull, "cast"}, + //{tipb::ScalarFuncSig::AddDurationAndDuration, "cast"}, + //{tipb::ScalarFuncSig::AddDurationAndString, "cast"}, + //{tipb::ScalarFuncSig::AddTimeDurationNull, "cast"}, + //{tipb::ScalarFuncSig::AddDateAndDuration, "cast"}, + //{tipb::ScalarFuncSig::AddDateAndString, "cast"}, + + //{tipb::ScalarFuncSig::SubDateAndDuration, "cast"}, + //{tipb::ScalarFuncSig::SubDateAndString, "cast"}, + //{tipb::ScalarFuncSig::SubTimeDateTimeNull, "cast"}, + //{tipb::ScalarFuncSig::SubStringAndDuration, "cast"}, + //{tipb::ScalarFuncSig::SubStringAndString, "cast"}, + //{tipb::ScalarFuncSig::SubTimeStringNull, "cast"}, + //{tipb::ScalarFuncSig::SubDurationAndDuration, "cast"}, + //{tipb::ScalarFuncSig::SubDurationAndString, "cast"}, + //{tipb::ScalarFuncSig::SubDateAndDuration, "cast"}, + //{tipb::ScalarFuncSig::SubDateAndString, "cast"}, + + //{tipb::ScalarFuncSig::UnixTimestampCurrent, "cast"}, + {tipb::ScalarFuncSig::UnixTimestampInt, "tidbUnixTimeStampInt"}, + {tipb::ScalarFuncSig::UnixTimestampDec, "tidbUnixTimeStampDec"}, + + //{tipb::ScalarFuncSig::ConvertTz, "cast"}, + //{tipb::ScalarFuncSig::MakeDate, "cast"}, + //{tipb::ScalarFuncSig::MakeTime, "cast"}, + //{tipb::ScalarFuncSig::PeriodAdd, "cast"}, + //{tipb::ScalarFuncSig::PeriodDiff, "cast"}, + //{tipb::ScalarFuncSig::Quarter, "cast"}, + + //{tipb::ScalarFuncSig::SecToTime, "cast"}, + //{tipb::ScalarFuncSig::TimeToSec, "cast"}, + //{tipb::ScalarFuncSig::TimestampAdd, "cast"}, + //{tipb::ScalarFuncSig::ToDays, "cast"}, + //{tipb::ScalarFuncSig::ToSeconds, "cast"}, + //{tipb::ScalarFuncSig::UTCTimeWithArg, "cast"}, + //{tipb::ScalarFuncSig::UTCTimestampWithoutArg, "cast"}, + //{tipb::ScalarFuncSig::Timestamp1Arg, "cast"}, + //{tipb::ScalarFuncSig::Timestamp2Args, "cast"}, + //{tipb::ScalarFuncSig::TimestampLiteral, "cast"}, + + //{tipb::ScalarFuncSig::LastDay, "cast"}, + {tipb::ScalarFuncSig::StrToDateDate, "strToDateDate"}, + {tipb::ScalarFuncSig::StrToDateDatetime, "strToDateDatetime"}, + // {tipb::ScalarFuncSig::StrToDateDuration, "cast"}, + {tipb::ScalarFuncSig::FromUnixTime1Arg, "fromUnixTime"}, + {tipb::ScalarFuncSig::FromUnixTime2Arg, "fromUnixTime"}, + {tipb::ScalarFuncSig::ExtractDatetime, "extractMyDateTime"}, + //{tipb::ScalarFuncSig::ExtractDuration, "cast"}, + + //{tipb::ScalarFuncSig::AddDateStringString, "cast"}, + {tipb::ScalarFuncSig::AddDateStringInt, "date_add"}, + //{tipb::ScalarFuncSig::AddDateStringDecimal, "cast"}, + //{tipb::ScalarFuncSig::AddDateIntString, "cast"}, + //{tipb::ScalarFuncSig::AddDateIntInt, "cast"}, + //{tipb::ScalarFuncSig::AddDateDatetimeString, "date_add"}, + {tipb::ScalarFuncSig::AddDateDatetimeInt, "date_add"}, + + //{tipb::ScalarFuncSig::SubDateStringString, "cast"}, + {tipb::ScalarFuncSig::SubDateStringInt, "date_sub"}, + //{tipb::ScalarFuncSig::SubDateStringDecimal, "cast"}, + //{tipb::ScalarFuncSig::SubDateIntString, "cast"}, + //{tipb::ScalarFuncSig::SubDateIntInt, "cast"}, + //{tipb::ScalarFuncSig::SubDateDatetimeString, "cast"}, + {tipb::ScalarFuncSig::SubDateDatetimeInt, "date_sub"}, + + //{tipb::ScalarFuncSig::FromDays, "cast"}, + //{tipb::ScalarFuncSig::TimeFormat, "cast"}, + {tipb::ScalarFuncSig::TimestampDiff, "tidbTimestampDiff"}, + + //{tipb::ScalarFuncSig::BitLength, "cast"}, + //{tipb::ScalarFuncSig::Bin, "cast"}, + {tipb::ScalarFuncSig::ASCII, "ascii"}, + //{tipb::ScalarFuncSig::Char, "cast"}, + {tipb::ScalarFuncSig::CharLengthUTF8, "lengthUTF8"}, + {tipb::ScalarFuncSig::Concat, "tidbConcat"}, + {tipb::ScalarFuncSig::ConcatWS, "tidbConcatWS"}, + //{tipb::ScalarFuncSig::Convert, "cast"}, + //{tipb::ScalarFuncSig::Elt, "cast"}, + //{tipb::ScalarFuncSig::ExportSet3Arg, "cast"}, + //{tipb::ScalarFuncSig::ExportSet4Arg, "cast"}, + //{tipb::ScalarFuncSig::ExportSet5Arg, "cast"}, + //{tipb::ScalarFuncSig::FieldInt, "cast"}, + //{tipb::ScalarFuncSig::FieldReal, "cast"}, + //{tipb::ScalarFuncSig::FieldString, "cast"}, + + //{tipb::ScalarFuncSig::FindInSet, "cast"}, + //{tipb::ScalarFuncSig::Format, "cast"}, + //{tipb::ScalarFuncSig::FormatWithLocale, "cast"}, + //{tipb::ScalarFuncSig::FromBase64, "cast"}, + //{tipb::ScalarFuncSig::HexIntArg, "cast"}, + //{tipb::ScalarFuncSig::HexStrArg, "cast"}, + //{tipb::ScalarFuncSig::InsertUTF8, "cast"}, + //{tipb::ScalarFuncSig::Insert, "cast"}, + //{tipb::ScalarFuncSig::InstrUTF8, "cast"}, + //{tipb::ScalarFuncSig::Instr, "cast"}, + + {tipb::ScalarFuncSig::LeftUTF8, "leftUTF8"}, + //{tipb::ScalarFuncSig::Left, "cast"}, + {tipb::ScalarFuncSig::Length, "length"}, + {tipb::ScalarFuncSig::Locate2ArgsUTF8, "position"}, + //{tipb::ScalarFuncSig::Locate3ArgsUTF8, "cast"}, + {tipb::ScalarFuncSig::Locate2Args, "position"}, + //{tipb::ScalarFuncSig::Locate3Args, "cast"}, + + {tipb::ScalarFuncSig::Lower, "lowerBinary"}, + {tipb::ScalarFuncSig::LowerUTF8, "lowerUTF8"}, + //{tipb::ScalarFuncSig::LpadUTF8, "cast"}, + //{tipb::ScalarFuncSig::Lpad, "cast"}, + //{tipb::ScalarFuncSig::MakeSet, "cast"}, + //{tipb::ScalarFuncSig::OctInt, "cast"}, + //{tipb::ScalarFuncSig::OctString, "cast"}, + //{tipb::ScalarFuncSig::Ord, "cast"}, + //{tipb::ScalarFuncSig::Quote, "cast"}, + //{tipb::ScalarFuncSig::Repeat, "cast"}, + {tipb::ScalarFuncSig::Replace, "replaceAll"}, + //{tipb::ScalarFuncSig::ReverseUTF8, "cast"}, + //{tipb::ScalarFuncSig::Reverse, "cast"}, + {tipb::ScalarFuncSig::RightUTF8, "rightUTF8"}, + //{tipb::ScalarFuncSig::Right, "cast"}, + //{tipb::ScalarFuncSig::RpadUTF8, "cast"}, + //{tipb::ScalarFuncSig::Rpad, "cast"}, + //{tipb::ScalarFuncSig::Space, "cast"}, + //{tipb::ScalarFuncSig::Strcmp, "cast"}, + {tipb::ScalarFuncSig::Substring2ArgsUTF8, "substringUTF8"}, + {tipb::ScalarFuncSig::Substring3ArgsUTF8, "substringUTF8"}, + //{tipb::ScalarFuncSig::Substring2Args, "cast"}, + //{tipb::ScalarFuncSig::Substring3Args, "cast"}, + {tipb::ScalarFuncSig::SubstringIndex, "substringIndex"}, + {tipb::ScalarFuncSig::Format, "format"}, + {tipb::ScalarFuncSig::FormatWithLocale, "formatWithLocale"}, + + //{tipb::ScalarFuncSig::ToBase64, "cast"}, + {tipb::ScalarFuncSig::Trim1Arg, "tidbTrim"}, + {tipb::ScalarFuncSig::Trim2Args, "tidbTrim"}, + {tipb::ScalarFuncSig::Trim3Args, "tidbTrim"}, + {tipb::ScalarFuncSig::LTrim, "tidbLTrim"}, + {tipb::ScalarFuncSig::RTrim, "tidbRTrim"}, + //{tipb::ScalarFuncSig::UnHex, "cast"}, + {tipb::ScalarFuncSig::UpperUTF8, "upperUTF8"}, + {tipb::ScalarFuncSig::Upper, "upperBinary"}, + //{tipb::ScalarFuncSig::CharLength, "upper"}, +}); +} // namespace + +bool isScalarFunctionExpr(const tipb::Expr & expr) +{ + return expr.tp() == tipb::ExprType::ScalarFunc; +} + +bool isFunctionExpr(const tipb::Expr & expr) +{ + return isScalarFunctionExpr(expr) || isAggFunctionExpr(expr); +} + +const String & getAggFunctionName(const tipb::Expr & expr) +{ + if (expr.has_distinct()) + { + auto it = distinct_agg_func_map.find(expr.tp()); + if (it != distinct_agg_func_map.end()) + return it->second; + } + else + { + auto it = agg_func_map.find(expr.tp()); + if (it != agg_func_map.end()) + return it->second; + } + + const auto errmsg = fmt::format( + "{}(distinct={}) is not supported.", + tipb::ExprType_Name(expr.tp()), + expr.has_distinct() ? "true" : "false"); + throw TiFlashException(errmsg, Errors::Coprocessor::Unimplemented); +} + +const String & getFunctionName(const tipb::Expr & expr) +{ + if (isAggFunctionExpr(expr)) + { + return getAggFunctionName(expr); + } + else + { + auto it = scalar_func_map.find(expr.sig()); + if (it == scalar_func_map.end()) + throw TiFlashException(tipb::ScalarFuncSig_Name(expr.sig()) + " is not supported.", Errors::Coprocessor::Unimplemented); + return it->second; + } +} + +String exprToString(const tipb::Expr & expr, const std::vector & input_col) +{ + std::stringstream ss; + String func_name; + Field f; + switch (expr.tp()) + { + case tipb::ExprType::Null: + return "NULL"; + case tipb::ExprType::Int64: + return std::to_string(decodeDAGInt64(expr.val())); + case tipb::ExprType::Uint64: + return std::to_string(decodeDAGUInt64(expr.val())); + case tipb::ExprType::Float32: + return std::to_string(decodeDAGFloat32(expr.val())); + case tipb::ExprType::Float64: + return std::to_string(decodeDAGFloat64(expr.val())); + case tipb::ExprType::String: + return decodeDAGString(expr.val()); + case tipb::ExprType::Bytes: + return decodeDAGBytes(expr.val()); + case tipb::ExprType::MysqlDecimal: + { + auto field = decodeDAGDecimal(expr.val()); + if (field.getType() == Field::Types::Decimal32) + return field.get>().toString(); + else if (field.getType() == Field::Types::Decimal64) + return field.get>().toString(); + else if (field.getType() == Field::Types::Decimal128) + return field.get>().toString(); + else if (field.getType() == Field::Types::Decimal256) + return field.get>().toString(); + else + throw TiFlashException("Not decimal literal" + expr.DebugString(), Errors::Coprocessor::BadRequest); + } + case tipb::ExprType::MysqlTime: + { + if (!expr.has_field_type()) + throw TiFlashException("MySQL Time literal without field_type" + expr.DebugString(), Errors::Coprocessor::BadRequest); + auto t = decodeDAGUInt64(expr.val()); + auto ret = std::to_string(TiDB::DatumFlat(t, static_cast(expr.field_type().tp())).field().get()); + if (expr.field_type().tp() == TiDB::TypeTimestamp) + ret = ret + "_ts"; + return ret; + } + case tipb::ExprType::MysqlDuration: + { + if (!expr.has_field_type()) + throw TiFlashException("MySQL Duration literal without field_type" + expr.DebugString(), Errors::Coprocessor::BadRequest); + auto t = decodeDAGInt64(expr.val()); + auto ret = std::to_string(TiDB::DatumFlat(t, static_cast(expr.field_type().tp())).field().get()); + return ret; + } + case tipb::ExprType::ColumnRef: + return getColumnNameForColumnExpr(expr, input_col); + case tipb::ExprType::Count: + case tipb::ExprType::Sum: + case tipb::ExprType::Avg: + case tipb::ExprType::Min: + case tipb::ExprType::Max: + case tipb::ExprType::First: + case tipb::ExprType::ApproxCountDistinct: + case tipb::ExprType::GroupConcat: + func_name = getAggFunctionName(expr); + break; + case tipb::ExprType::ScalarFunc: + if (scalar_func_map.find(expr.sig()) == scalar_func_map.end()) + { + throw TiFlashException(tipb::ScalarFuncSig_Name(expr.sig()) + " not supported", Errors::Coprocessor::Unimplemented); + } + func_name = scalar_func_map.find(expr.sig())->second; + break; + default: + throw TiFlashException(tipb::ExprType_Name(expr.tp()) + " not supported", Errors::Coprocessor::Unimplemented); + } + // build function expr + if (functionIsInOrGlobalInOperator(func_name)) + { + // for in, we could not represent the function expr using func_name(param1, param2, ...) + ss << exprToString(expr.children(0), input_col) << " " << func_name << " ("; + bool first = true; + for (int i = 1; i < expr.children_size(); i++) + { + String s = exprToString(expr.children(i), input_col); + if (first) + first = false; + else + ss << ", "; + ss << s; + } + ss << ")"; + } + else + { + ss << func_name << "("; + bool first = true; + for (const tipb::Expr & child : expr.children()) + { + String s = exprToString(child, input_col); + if (first) + first = false; + else + ss << ", "; + ss << s; + } + ss << ")"; + } + return ss.str(); +} + +const String & getTypeName(const tipb::Expr & expr) +{ + return tipb::ExprType_Name(expr.tp()); +} + +bool isAggFunctionExpr(const tipb::Expr & expr) +{ + switch (expr.tp()) + { + case tipb::ExprType::Count: + case tipb::ExprType::Sum: + case tipb::ExprType::Avg: + case tipb::ExprType::Min: + case tipb::ExprType::Max: + case tipb::ExprType::First: + case tipb::ExprType::GroupConcat: + case tipb::ExprType::Agg_BitAnd: + case tipb::ExprType::Agg_BitOr: + case tipb::ExprType::Agg_BitXor: + case tipb::ExprType::Std: + case tipb::ExprType::Stddev: + case tipb::ExprType::StddevPop: + case tipb::ExprType::StddevSamp: + case tipb::ExprType::VarPop: + case tipb::ExprType::VarSamp: + case tipb::ExprType::Variance: + case tipb::ExprType::JsonArrayAgg: + case tipb::ExprType::JsonObjectAgg: + case tipb::ExprType::ApproxCountDistinct: + return true; + default: + return false; + } +} + +bool isLiteralExpr(const tipb::Expr & expr) +{ + switch (expr.tp()) + { + case tipb::ExprType::Null: + case tipb::ExprType::Int64: + case tipb::ExprType::Uint64: + case tipb::ExprType::Float32: + case tipb::ExprType::Float64: + case tipb::ExprType::String: + case tipb::ExprType::Bytes: + case tipb::ExprType::MysqlBit: + case tipb::ExprType::MysqlDecimal: + case tipb::ExprType::MysqlDuration: + case tipb::ExprType::MysqlEnum: + case tipb::ExprType::MysqlHex: + case tipb::ExprType::MysqlSet: + case tipb::ExprType::MysqlTime: + case tipb::ExprType::MysqlJson: + case tipb::ExprType::ValueList: + return true; + default: + return false; + } +} + +bool isColumnExpr(const tipb::Expr & expr) +{ + return expr.tp() == tipb::ExprType::ColumnRef; +} + +Field decodeLiteral(const tipb::Expr & expr) +{ + switch (expr.tp()) + { + case tipb::ExprType::Null: + return Field(); + case tipb::ExprType::Int64: + return decodeDAGInt64(expr.val()); + case tipb::ExprType::Uint64: + return decodeDAGUInt64(expr.val()); + case tipb::ExprType::Float32: + return Float64(decodeDAGFloat32(expr.val())); + case tipb::ExprType::Float64: + return decodeDAGFloat64(expr.val()); + case tipb::ExprType::String: + return decodeDAGString(expr.val()); + case tipb::ExprType::Bytes: + return decodeDAGBytes(expr.val()); + case tipb::ExprType::MysqlDecimal: + return decodeDAGDecimal(expr.val()); + case tipb::ExprType::MysqlTime: + { + if (!expr.has_field_type()) + throw TiFlashException("MySQL Time literal without field_type" + expr.DebugString(), Errors::Coprocessor::BadRequest); + auto t = decodeDAGUInt64(expr.val()); + return TiDB::DatumFlat(t, static_cast(expr.field_type().tp())).field(); + } + case tipb::ExprType::MysqlDuration: + { + if (!expr.has_field_type()) + throw TiFlashException("MySQL Duration literal without field_type" + expr.DebugString(), Errors::Coprocessor::BadRequest); + auto t = decodeDAGInt64(expr.val()); + return TiDB::DatumFlat(t, static_cast(expr.field_type().tp())).field(); + } + case tipb::ExprType::MysqlBit: + case tipb::ExprType::MysqlEnum: + case tipb::ExprType::MysqlHex: + case tipb::ExprType::MysqlSet: + case tipb::ExprType::MysqlJson: + case tipb::ExprType::ValueList: + throw TiFlashException(tipb::ExprType_Name(expr.tp()) + " is not supported yet", Errors::Coprocessor::Unimplemented); + default: + throw TiFlashException("Should not reach here: not a literal expression", Errors::Coprocessor::Internal); + } +} + +String getColumnNameForColumnExpr(const tipb::Expr & expr, const std::vector & input_col) +{ + auto column_index = decodeDAGInt64(expr.val()); + if (column_index < 0 || column_index >= static_cast(input_col.size())) + { + throw TiFlashException("Column index out of bound", Errors::Coprocessor::BadRequest); + } + return input_col[column_index].name; +} + +// For some historical or unknown reasons, TiDB might set an invalid +// field type. This function checks if the expr has a valid field type. +// So far the known invalid field types are: +// 1. decimal type with scale == -1 +// 2. decimal type with precision == 0 +bool exprHasValidFieldType(const tipb::Expr & expr) +{ + return expr.has_field_type() + && !(expr.field_type().tp() == TiDB::TP::TypeNewDecimal + && (expr.field_type().decimal() == -1 || expr.field_type().flen() == 0)); +} + +bool isUnsupportedEncodeType(const std::vector & types, tipb::EncodeType encode_type) +{ + const static std::unordered_map> unsupported_types_map({ + {tipb::EncodeType::TypeCHBlock, {TiDB::TypeSet, TiDB::TypeGeometry, TiDB::TypeNull, TiDB::TypeEnum, TiDB::TypeJSON, TiDB::TypeBit}}, + {tipb::EncodeType::TypeChunk, {TiDB::TypeSet, TiDB::TypeGeometry, TiDB::TypeNull}}, + }); + + auto unsupported_set = unsupported_types_map.find(encode_type); + if (unsupported_set == unsupported_types_map.end()) + return false; + for (const auto & type : types) + { + if (unsupported_set->second.find(type.tp()) != unsupported_set->second.end()) + return true; + } + return false; +} + +DataTypePtr inferDataType4Literal(const tipb::Expr & expr) +{ + Field value = decodeLiteral(expr); + DataTypePtr flash_type = applyVisitor(FieldToDataType(), value); + /// need to extract target_type from expr.field_type() because the flash_type derived from + /// value is just a `memory type`, which does not have enough information, for example: + /// for date literal, the flash_type is `UInt64` + DataTypePtr target_type{}; + if (expr.tp() == tipb::ExprType::Null) + { + // todo We should use DataTypeNothing as NULL literal's TiFlash Type, because TiFlash has a lot of + // optimization for DataTypeNothing, but there are still some bugs when using DataTypeNothing: when + // TiFlash try to return data to TiDB or exchange data between TiFlash node, since codec only recognize + // TiDB type, use DataTypeNothing will meet error in the codec, so do not use DataTypeNothing until + // we fix the codec issue. + if (exprHasValidFieldType(expr)) + { + target_type = getDataTypeByFieldTypeForComputingLayer(expr.field_type()); + } + else + { + if (expr.has_field_type() && expr.field_type().tp() == TiDB::TP::TypeNewDecimal) + target_type = createDecimal(1, 0); + else + target_type = flash_type; + } + target_type = makeNullable(target_type); + } + else + { + if (expr.tp() == tipb::ExprType::MysqlDecimal) + { + /// to fix https://github.com/pingcap/tics/issues/1425, when TiDB push down + /// a decimal literal, it contains two types: one is the type that encoded + /// in Decimal value itself(i.e. expr.val()), the other is the type that in + /// expr.field_type(). According to TiDB and Mysql behavior, the computing + /// layer should use the type in expr.val(), which means we should ignore + /// the type in expr.field_type() + target_type = flash_type; + } + else + { + target_type = exprHasValidFieldType(expr) ? getDataTypeByFieldTypeForComputingLayer(expr.field_type()) : flash_type; + } + // We should remove nullable for constant value since TiDB may not set NOT_NULL flag for literal expression. + target_type = removeNullable(target_type); + } + return target_type; +} + +UInt8 getFieldLengthForArrowEncode(Int32 tp) +{ + switch (tp) + { + case TiDB::TypeTiny: + case TiDB::TypeShort: + case TiDB::TypeInt24: + case TiDB::TypeLong: + case TiDB::TypeLongLong: + case TiDB::TypeYear: + case TiDB::TypeDouble: + case TiDB::TypeTime: + case TiDB::TypeDate: + case TiDB::TypeDatetime: + case TiDB::TypeNewDate: + case TiDB::TypeTimestamp: + return 8; + case TiDB::TypeFloat: + return 4; + case TiDB::TypeDecimal: + case TiDB::TypeNewDecimal: + return 40; + case TiDB::TypeVarchar: + case TiDB::TypeVarString: + case TiDB::TypeString: + case TiDB::TypeBlob: + case TiDB::TypeTinyBlob: + case TiDB::TypeMediumBlob: + case TiDB::TypeLongBlob: + case TiDB::TypeBit: + case TiDB::TypeEnum: + case TiDB::TypeJSON: + return VAR_SIZE; + default: + throw TiFlashException("not supported field type in arrow encode: " + std::to_string(tp), Errors::Coprocessor::Internal); + } +} + +tipb::Expr constructStringLiteralTiExpr(const String & value) +{ + tipb::Expr expr; + expr.set_tp(tipb::ExprType::String); + expr.set_val(value); + auto * field_type = expr.mutable_field_type(); + field_type->set_tp(TiDB::TypeString); + field_type->set_flag(TiDB::ColumnFlagNotNull); + return expr; +} - //{tipb::ScalarFuncSig::SubDateAndDuration, "cast"}, - //{tipb::ScalarFuncSig::SubDateAndString, "cast"}, - //{tipb::ScalarFuncSig::SubTimeDateTimeNull, "cast"}, - //{tipb::ScalarFuncSig::SubStringAndDuration, "cast"}, - //{tipb::ScalarFuncSig::SubStringAndString, "cast"}, - //{tipb::ScalarFuncSig::SubTimeStringNull, "cast"}, - //{tipb::ScalarFuncSig::SubDurationAndDuration, "cast"}, - //{tipb::ScalarFuncSig::SubDurationAndString, "cast"}, - //{tipb::ScalarFuncSig::SubDateAndDuration, "cast"}, - //{tipb::ScalarFuncSig::SubDateAndString, "cast"}, +tipb::Expr constructInt64LiteralTiExpr(Int64 value) +{ + tipb::Expr expr; + expr.set_tp(tipb::ExprType::Int64); + WriteBufferFromOwnString ss; + encodeDAGInt64(value, ss); + expr.set_val(ss.releaseStr()); + auto * field_type = expr.mutable_field_type(); + field_type->set_tp(TiDB::TypeLongLong); + field_type->set_flag(TiDB::ColumnFlagNotNull); + return expr; +} - //{tipb::ScalarFuncSig::UnixTimestampCurrent, "cast"}, - {tipb::ScalarFuncSig::UnixTimestampInt, "tidbUnixTimeStampInt"}, - {tipb::ScalarFuncSig::UnixTimestampDec, "tidbUnixTimeStampDec"}, +tipb::Expr constructDateTimeLiteralTiExpr(UInt64 packed_value) +{ + tipb::Expr expr; + expr.set_tp(tipb::ExprType::MysqlTime); + WriteBufferFromOwnString ss; + encodeDAGUInt64(packed_value, ss); + expr.set_val(ss.releaseStr()); + auto * field_type = expr.mutable_field_type(); + field_type->set_tp(TiDB::TypeDatetime); + field_type->set_flag(TiDB::ColumnFlagNotNull); + return expr; +} - //{tipb::ScalarFuncSig::ConvertTz, "cast"}, - //{tipb::ScalarFuncSig::MakeDate, "cast"}, - //{tipb::ScalarFuncSig::MakeTime, "cast"}, - //{tipb::ScalarFuncSig::PeriodAdd, "cast"}, - //{tipb::ScalarFuncSig::PeriodDiff, "cast"}, - //{tipb::ScalarFuncSig::Quarter, "cast"}, +tipb::Expr constructNULLLiteralTiExpr() +{ + tipb::Expr expr; + expr.set_tp(tipb::ExprType::Null); + auto * field_type = expr.mutable_field_type(); + field_type->set_tp(TiDB::TypeNull); + return expr; +} - //{tipb::ScalarFuncSig::SecToTime, "cast"}, - //{tipb::ScalarFuncSig::TimeToSec, "cast"}, - //{tipb::ScalarFuncSig::TimestampAdd, "cast"}, - //{tipb::ScalarFuncSig::ToDays, "cast"}, - //{tipb::ScalarFuncSig::ToSeconds, "cast"}, - //{tipb::ScalarFuncSig::UTCTimeWithArg, "cast"}, - //{tipb::ScalarFuncSig::UTCTimestampWithoutArg, "cast"}, - //{tipb::ScalarFuncSig::Timestamp1Arg, "cast"}, - //{tipb::ScalarFuncSig::Timestamp2Args, "cast"}, - //{tipb::ScalarFuncSig::TimestampLiteral, "cast"}, +std::shared_ptr getCollatorFromExpr(const tipb::Expr & expr) +{ + if (expr.has_field_type()) + return getCollatorFromFieldType(expr.field_type()); + return nullptr; +} - //{tipb::ScalarFuncSig::LastDay, "cast"}, - {tipb::ScalarFuncSig::StrToDateDate, "strToDateDate"}, - {tipb::ScalarFuncSig::StrToDateDatetime, "strToDateDatetime"}, - // {tipb::ScalarFuncSig::StrToDateDuration, "cast"}, - {tipb::ScalarFuncSig::FromUnixTime1Arg, "fromUnixTime"}, - {tipb::ScalarFuncSig::FromUnixTime2Arg, "fromUnixTime"}, - {tipb::ScalarFuncSig::ExtractDatetime, "extractMyDateTime"}, - //{tipb::ScalarFuncSig::ExtractDuration, "cast"}, +SortDescription getSortDescription(const std::vector & order_columns, const google::protobuf::RepeatedPtrField & by_items) +{ + SortDescription order_descr; + order_descr.reserve(by_items.size()); + for (int i = 0; i < by_items.size(); i++) + { + const auto & name = order_columns[i].name; + int direction = by_items[i].desc() ? -1 : 1; + // MySQL/TiDB treats NULL as "minimum". + int nulls_direction = -1; + std::shared_ptr collator = nullptr; + if (removeNullable(order_columns[i].type)->isString()) + collator = getCollatorFromExpr(by_items[i].expr()); - //{tipb::ScalarFuncSig::AddDateStringString, "cast"}, - {tipb::ScalarFuncSig::AddDateStringInt, "date_add"}, - //{tipb::ScalarFuncSig::AddDateStringDecimal, "cast"}, - //{tipb::ScalarFuncSig::AddDateIntString, "cast"}, - //{tipb::ScalarFuncSig::AddDateIntInt, "cast"}, - //{tipb::ScalarFuncSig::AddDateDatetimeString, "date_add"}, - {tipb::ScalarFuncSig::AddDateDatetimeInt, "date_add"}, + order_descr.emplace_back(name, direction, nulls_direction, collator); + } + return order_descr; +} - //{tipb::ScalarFuncSig::SubDateStringString, "cast"}, - {tipb::ScalarFuncSig::SubDateStringInt, "date_sub"}, - //{tipb::ScalarFuncSig::SubDateStringDecimal, "cast"}, - //{tipb::ScalarFuncSig::SubDateIntString, "cast"}, - //{tipb::ScalarFuncSig::SubDateIntInt, "cast"}, - //{tipb::ScalarFuncSig::SubDateDatetimeString, "cast"}, - {tipb::ScalarFuncSig::SubDateDatetimeInt, "date_sub"}, +TiDB::TiDBCollatorPtr getCollatorFromFieldType(const tipb::FieldType & field_type) +{ + if (field_type.collate() < 0) + return TiDB::ITiDBCollator::getCollator(-field_type.collate()); + return nullptr; +} - //{tipb::ScalarFuncSig::FromDays, "cast"}, - //{tipb::ScalarFuncSig::TimeFormat, "cast"}, - {tipb::ScalarFuncSig::TimestampDiff, "tidbTimestampDiff"}, +bool hasUnsignedFlag(const tipb::FieldType & tp) +{ + return tp.flag() & TiDB::ColumnFlagUnsigned; +} - //{tipb::ScalarFuncSig::BitLength, "cast"}, - //{tipb::ScalarFuncSig::Bin, "cast"}, - {tipb::ScalarFuncSig::ASCII, "ascii"}, - //{tipb::ScalarFuncSig::Char, "cast"}, - {tipb::ScalarFuncSig::CharLengthUTF8, "lengthUTF8"}, - {tipb::ScalarFuncSig::Concat, "tidbConcat"}, - {tipb::ScalarFuncSig::ConcatWS, "tidbConcatWS"}, - //{tipb::ScalarFuncSig::Convert, "cast"}, - //{tipb::ScalarFuncSig::Elt, "cast"}, - //{tipb::ScalarFuncSig::ExportSet3Arg, "cast"}, - //{tipb::ScalarFuncSig::ExportSet4Arg, "cast"}, - //{tipb::ScalarFuncSig::ExportSet5Arg, "cast"}, - //{tipb::ScalarFuncSig::FieldInt, "cast"}, - //{tipb::ScalarFuncSig::FieldReal, "cast"}, - //{tipb::ScalarFuncSig::FieldString, "cast"}, +grpc::StatusCode tiflashErrorCodeToGrpcStatusCode(int error_code) +{ + /// do not use switch statement because ErrorCodes::XXXX is not a compile time constant + if (error_code == ErrorCodes::NOT_IMPLEMENTED) + return grpc::StatusCode::UNIMPLEMENTED; + if (error_code == ErrorCodes::UNKNOWN_USER || error_code == ErrorCodes::WRONG_PASSWORD || error_code == ErrorCodes::REQUIRED_PASSWORD + || error_code == ErrorCodes::IP_ADDRESS_NOT_ALLOWED) + return grpc::StatusCode::UNAUTHENTICATED; + return grpc::StatusCode::INTERNAL; +} - //{tipb::ScalarFuncSig::FindInSet, "cast"}, - //{tipb::ScalarFuncSig::Format, "cast"}, - //{tipb::ScalarFuncSig::FormatWithLocale, "cast"}, - //{tipb::ScalarFuncSig::FromBase64, "cast"}, - //{tipb::ScalarFuncSig::HexIntArg, "cast"}, - //{tipb::ScalarFuncSig::HexStrArg, "cast"}, - //{tipb::ScalarFuncSig::InsertUTF8, "cast"}, - //{tipb::ScalarFuncSig::Insert, "cast"}, - //{tipb::ScalarFuncSig::InstrUTF8, "cast"}, - //{tipb::ScalarFuncSig::Instr, "cast"}, +void assertBlockSchema(const DataTypes & expected_types, const Block & block, const std::string & context_description) +{ + size_t columns = expected_types.size(); + if (block.columns() != columns) + throw Exception("Block schema mismatch in " + context_description + ": different number of columns: expected " + + std::to_string(columns) + " columns, got " + std::to_string(block.columns()) + " columns"); - {tipb::ScalarFuncSig::LeftUTF8, "leftUTF8"}, - //{tipb::ScalarFuncSig::Left, "cast"}, - {tipb::ScalarFuncSig::Length, "length"}, - {tipb::ScalarFuncSig::Locate2ArgsUTF8, "position"}, - //{tipb::ScalarFuncSig::Locate3ArgsUTF8, "cast"}, - {tipb::ScalarFuncSig::Locate2Args, "position"}, - //{tipb::ScalarFuncSig::Locate3Args, "cast"}, + for (size_t i = 0; i < columns; ++i) + { + const auto & actual = block.getByPosition(i).type; + const auto & expected = expected_types[i]; - {tipb::ScalarFuncSig::Lower, "lowerBinary"}, - {tipb::ScalarFuncSig::LowerUTF8, "lowerUTF8"}, - //{tipb::ScalarFuncSig::LpadUTF8, "cast"}, - //{tipb::ScalarFuncSig::Lpad, "cast"}, - //{tipb::ScalarFuncSig::MakeSet, "cast"}, - //{tipb::ScalarFuncSig::OctInt, "cast"}, - //{tipb::ScalarFuncSig::OctString, "cast"}, - //{tipb::ScalarFuncSig::Ord, "cast"}, - //{tipb::ScalarFuncSig::Quote, "cast"}, - //{tipb::ScalarFuncSig::Repeat, "cast"}, - {tipb::ScalarFuncSig::Replace, "replaceAll"}, - //{tipb::ScalarFuncSig::ReverseUTF8, "cast"}, - //{tipb::ScalarFuncSig::Reverse, "cast"}, - {tipb::ScalarFuncSig::RightUTF8, "rightUTF8"}, - //{tipb::ScalarFuncSig::Right, "cast"}, - //{tipb::ScalarFuncSig::RpadUTF8, "cast"}, - //{tipb::ScalarFuncSig::Rpad, "cast"}, - //{tipb::ScalarFuncSig::Space, "cast"}, - //{tipb::ScalarFuncSig::Strcmp, "cast"}, - {tipb::ScalarFuncSig::Substring2ArgsUTF8, "substringUTF8"}, - {tipb::ScalarFuncSig::Substring3ArgsUTF8, "substringUTF8"}, - //{tipb::ScalarFuncSig::Substring2Args, "cast"}, - //{tipb::ScalarFuncSig::Substring3Args, "cast"}, - {tipb::ScalarFuncSig::SubstringIndex, "substringIndex"}, - {tipb::ScalarFuncSig::Format, "format"}, - {tipb::ScalarFuncSig::FormatWithLocale, "formatWithLocale"}, + if (!expected->equals(*actual)) + { + throw Exception("Block schema mismatch in " + context_description + ": different types: expected " + expected->getName() + + ", got " + actual->getName()); + } + } +} - //{tipb::ScalarFuncSig::ToBase64, "cast"}, - {tipb::ScalarFuncSig::Trim1Arg, "tidbTrim"}, - {tipb::ScalarFuncSig::Trim2Args, "tidbTrim"}, - {tipb::ScalarFuncSig::Trim3Args, "tidbTrim"}, - {tipb::ScalarFuncSig::LTrim, "tidbLTrim"}, - {tipb::ScalarFuncSig::RTrim, "tidbRTrim"}, - //{tipb::ScalarFuncSig::UnHex, "cast"}, - {tipb::ScalarFuncSig::UpperUTF8, "upperUTF8"}, - {tipb::ScalarFuncSig::Upper, "upperBinary"}, - //{tipb::ScalarFuncSig::CharLength, "upper"}, -}); +tipb::DAGRequest getDAGRequestFromStringWithRetry(const String & s) +{ + tipb::DAGRequest dag_req; + if (!dag_req.ParseFromString(s)) + { + /// ParseFromString will use the default recursion limit, which is 100 to decode the plan, if the plan tree is too deep, + /// it may exceed this limit, so just try again by double the recursion limit + ::google::protobuf::io::CodedInputStream coded_input_stream(reinterpret_cast(s.data()), s.size()); + coded_input_stream.SetRecursionLimit(::google::protobuf::io::CodedInputStream::GetDefaultRecursionLimit() * 2); + if (!dag_req.ParseFromCodedStream(&coded_input_stream)) + { + /// just return error if decode failed this time, because it's really a corner case, and even if we can decode the plan + /// successfully by using a very large value of the recursion limit, it is kinds of meaningless because the runtime + /// performance of this task may be very bad if the plan tree is too deep + throw TiFlashException( + std::string(__PRETTY_FUNCTION__) + ": Invalid encoded plan, the most likely is that the plan/expression tree is too deep", + Errors::Coprocessor::BadRequest); + } + } + return dag_req; +} } // namespace DB diff --git a/dbms/src/Flash/Coprocessor/DAGUtils.h b/dbms/src/Flash/Coprocessor/DAGUtils.h index a208bfd6b7d..c3e286aad8f 100644 --- a/dbms/src/Flash/Coprocessor/DAGUtils.h +++ b/dbms/src/Flash/Coprocessor/DAGUtils.h @@ -1,12 +1,5 @@ #pragma once -#include - -#pragma GCC diagnostic push -#pragma GCC diagnostic ignored "-Wunused-parameter" -#include -#pragma GCC diagnostic pop - #include #include #include @@ -15,6 +8,9 @@ #include #include #include +#include + +#include namespace DB { @@ -30,16 +26,15 @@ String getColumnNameForColumnExpr(const tipb::Expr & expr, const std::vector & input_col); bool exprHasValidFieldType(const tipb::Expr & expr); -void constructStringLiteralTiExpr(tipb::Expr & expr, const String & value); -void constructInt64LiteralTiExpr(tipb::Expr & expr, Int64 value); -void constructDateTimeLiteralTiExpr(tipb::Expr & expr, UInt64 packed_value); -void constructNULLLiteralTiExpr(tipb::Expr & expr); +tipb::Expr constructStringLiteralTiExpr(const String & value); +tipb::Expr constructInt64LiteralTiExpr(Int64 value); +tipb::Expr constructDateTimeLiteralTiExpr(UInt64 packed_value); +tipb::Expr constructNULLLiteralTiExpr(); DataTypePtr inferDataType4Literal(const tipb::Expr & expr); -SortDescription getSortDescription(const std::vector & order_columns, const google::protobuf::RepeatedPtrField & by_items); +SortDescription getSortDescription( + const std::vector & order_columns, + const google::protobuf::RepeatedPtrField & by_items); -extern std::unordered_map agg_func_map; -extern std::unordered_map distinct_agg_func_map; -extern std::unordered_map scalar_func_map; extern const Int8 VAR_SIZE; UInt8 getFieldLengthForArrowEncode(Int32 tp); @@ -69,6 +64,7 @@ class UniqueNameGenerator return ret_name; } }; -void getDAGRequestFromStringWithRetry(tipb::DAGRequest & req, const String & s); + +tipb::DAGRequest getDAGRequestFromStringWithRetry(const String & s); } // namespace DB diff --git a/dbms/src/Flash/Coprocessor/InterpreterDAG.cpp b/dbms/src/Flash/Coprocessor/InterpreterDAG.cpp index 5e9b0f43a8d..ec69837fd7e 100644 --- a/dbms/src/Flash/Coprocessor/InterpreterDAG.cpp +++ b/dbms/src/Flash/Coprocessor/InterpreterDAG.cpp @@ -18,15 +18,6 @@ namespace DB { -namespace ErrorCodes -{ -extern const int UNKNOWN_TABLE; -extern const int TOO_MANY_COLUMNS; -extern const int SCHEMA_VERSION_ERROR; -extern const int UNKNOWN_EXCEPTION; -extern const int COP_BAD_DAG_REQUEST; -} // namespace ErrorCodes - InterpreterDAG::InterpreterDAG(Context & context_, const DAGQuerySource & dag_, const LogWithPrefixPtr & log_) : context(context_) , dag(dag_) @@ -35,7 +26,7 @@ InterpreterDAG::InterpreterDAG(Context & context_, const DAGQuerySource & dag_, , log(log_) { const Settings & settings = context.getSettingsRef(); - if (dag.isBatchCop()) + if (dag.isBatchCopOrMpp()) max_streams = settings.max_threads; else max_streams = 1; @@ -60,7 +51,8 @@ BlockInputStreams InterpreterDAG::executeQueryBlock(DAGQueryBlock & query_block, context, input_streams_vec, query_block, - keep_session_timezone_info, + max_streams, + keep_session_timezone_info || !query_block.isRootQueryBlock(), dag, subqueries_for_sets, mpp_exchange_receiver_maps, diff --git a/dbms/src/Flash/CoprocessorHandler.cpp b/dbms/src/Flash/CoprocessorHandler.cpp index 29d30dd8739..dfadc186f85 100644 --- a/dbms/src/Flash/CoprocessorHandler.cpp +++ b/dbms/src/Flash/CoprocessorHandler.cpp @@ -67,8 +67,7 @@ grpc::Status CoprocessorHandler::execute() GET_METRIC(tiflash_coprocessor_handling_request_count, type_cop_dag).Increment(); SCOPE_EXIT({ GET_METRIC(tiflash_coprocessor_handling_request_count, type_cop_dag).Decrement(); }); - tipb::DAGRequest dag_request; - getDAGRequestFromStringWithRetry(dag_request, cop_request->data()); + tipb::DAGRequest dag_request = getDAGRequestFromStringWithRetry(cop_request->data()); LOG_DEBUG(log, __PRETTY_FUNCTION__ << ": Handling DAG request: " << dag_request.DebugString()); if (dag_request.has_is_rpn_expr() && dag_request.is_rpn_expr()) throw TiFlashException( diff --git a/dbms/src/Flash/Mpp/MPPTask.cpp b/dbms/src/Flash/Mpp/MPPTask.cpp index 040fa70f2da..7716ef338e4 100644 --- a/dbms/src/Flash/Mpp/MPPTask.cpp +++ b/dbms/src/Flash/Mpp/MPPTask.cpp @@ -137,7 +137,7 @@ bool needRemoteRead(const RegionInfo & region_info, const TMTContext & tmt_conte std::vector MPPTask::prepare(const mpp::DispatchTaskRequest & task_request) { - getDAGRequestFromStringWithRetry(dag_req, task_request.encoded_plan()); + dag_req = getDAGRequestFromStringWithRetry(task_request.encoded_plan()); TMTContext & tmt_context = context.getTMTContext(); /// MPP task will only use key ranges in mpp::DispatchTaskRequest::regions. The ones defined in tipb::TableScan /// will never be used and can be removed later.