From f1060c74b6ac0cc55034bdf8c72d3028079c401d Mon Sep 17 00:00:00 2001 From: Fu Zhe Date: Mon, 13 Dec 2021 16:40:36 +0800 Subject: [PATCH 1/4] *: Pass more request-level parameters through DAGContext. (#3616) --- dbms/src/Debug/dbgFuncCoprocessor.cpp | 9 +- dbms/src/Flash/BatchCommandsHandler.cpp | 2 +- dbms/src/Flash/BatchCommandsHandler.h | 2 +- dbms/src/Flash/BatchCoprocessorHandler.cpp | 9 +- dbms/src/Flash/Coprocessor/DAGContext.h | 50 ++++--- dbms/src/Flash/Coprocessor/DAGDriver.cpp | 42 +++--- dbms/src/Flash/Coprocessor/DAGDriver.h | 14 +- .../Coprocessor/DAGQueryBlockInterpreter.cpp | 132 +++++++----------- .../Coprocessor/DAGQueryBlockInterpreter.h | 26 ++-- dbms/src/Flash/Coprocessor/DAGQuerySource.cpp | 46 +++--- dbms/src/Flash/Coprocessor/DAGQuerySource.h | 34 +---- .../Flash/Coprocessor/DAGResponseWriter.cpp | 10 +- .../src/Flash/Coprocessor/DAGResponseWriter.h | 6 +- .../Coprocessor/DAGStorageInterpreter.cpp | 36 ++--- .../Flash/Coprocessor/DAGStorageInterpreter.h | 5 +- dbms/src/Flash/Coprocessor/InterpreterDAG.cpp | 32 ++--- dbms/src/Flash/Coprocessor/InterpreterDAG.h | 12 +- .../Flash/Coprocessor/InterpreterUtils.cpp | 28 ++++ dbms/src/Flash/Coprocessor/InterpreterUtils.h | 2 + .../StreamingDAGResponseWriter.cpp | 43 +++--- .../Coprocessor/StreamingDAGResponseWriter.h | 6 +- .../Coprocessor/UnaryDAGResponseWriter.cpp | 20 ++- .../Coprocessor/UnaryDAGResponseWriter.h | 2 - dbms/src/Flash/CoprocessorHandler.cpp | 9 +- dbms/src/Flash/FlashService.cpp | 32 ++--- dbms/src/Flash/FlashService.h | 2 +- dbms/src/Flash/Mpp/MPPHandler.cpp | 6 +- dbms/src/Flash/Mpp/MPPHandler.h | 2 +- dbms/src/Flash/Mpp/MPPTask.cpp | 43 +++--- dbms/src/Flash/Mpp/MPPTask.h | 9 +- dbms/src/Flash/Mpp/getMPPTaskLog.cpp | 24 ++++ dbms/src/Flash/Mpp/getMPPTaskLog.h | 16 +-- dbms/src/Interpreters/Context.h | 2 + dbms/src/Storages/StorageMerge.cpp | 2 +- .../Storages/tests/gtest_filter_parser.cpp | 2 +- 35 files changed, 342 insertions(+), 375 deletions(-) create mode 100644 dbms/src/Flash/Mpp/getMPPTaskLog.cpp diff --git a/dbms/src/Debug/dbgFuncCoprocessor.cpp b/dbms/src/Debug/dbgFuncCoprocessor.cpp index 6c38d4b2aee..e62302228a3 100644 --- a/dbms/src/Debug/dbgFuncCoprocessor.cpp +++ b/dbms/src/Debug/dbgFuncCoprocessor.cpp @@ -2411,10 +2411,15 @@ tipb::SelectResponse executeDAGRequest(Context & context, const tipb::DAGRequest LOG_DEBUG(log, __PRETTY_FUNCTION__ << ": Handling DAG request: " << dag_request.DebugString()); tipb::SelectResponse dag_response; RegionInfoMap regions; - RegionInfoList retry_regions; regions.emplace(region_id, RegionInfo(region_id, region_version, region_conf_version, std::move(key_ranges), nullptr)); - DAGDriver driver(context, dag_request, regions, retry_regions, start_ts, DEFAULT_UNSPECIFIED_SCHEMA_VERSION, &dag_response, true); + + DAGContext dag_context(dag_request); + dag_context.regions_for_local_read = regions; + dag_context.log = std::make_shared(log, ""); + context.setDAGContext(&dag_context); + + DAGDriver driver(context, start_ts, DEFAULT_UNSPECIFIED_SCHEMA_VERSION, &dag_response, true); driver.execute(); LOG_DEBUG(log, __PRETTY_FUNCTION__ << ": Handle DAG request done"); return dag_response; diff --git a/dbms/src/Flash/BatchCommandsHandler.cpp b/dbms/src/Flash/BatchCommandsHandler.cpp index 96e6b156ff6..441172344f2 100644 --- a/dbms/src/Flash/BatchCommandsHandler.cpp +++ b/dbms/src/Flash/BatchCommandsHandler.cpp @@ -55,7 +55,7 @@ ThreadPool::Job BatchCommandsHandler::handleCommandJob( return; } - CoprocessorContext cop_context(context, cop_req.context(), batch_commands_context.grpc_server_context); + CoprocessorContext cop_context(*context, cop_req.context(), batch_commands_context.grpc_server_context); CoprocessorHandler cop_handler(cop_context, &cop_req, cop_resp); ret = cop_handler.execute(); diff --git a/dbms/src/Flash/BatchCommandsHandler.h b/dbms/src/Flash/BatchCommandsHandler.h index d0e1a5a7423..b0836bd570e 100644 --- a/dbms/src/Flash/BatchCommandsHandler.h +++ b/dbms/src/Flash/BatchCommandsHandler.h @@ -18,7 +18,7 @@ struct BatchCommandsContext /// Context creation function for each individual command - they should be handled isolated, /// given that context is being used to pass arguments regarding queries. - using DBContextCreationFunc = std::function(const grpc::ServerContext *)>; + using DBContextCreationFunc = std::function(const grpc::ServerContext *)>; DBContextCreationFunc db_context_creation_func; const grpc::ServerContext & grpc_server_context; diff --git a/dbms/src/Flash/BatchCoprocessorHandler.cpp b/dbms/src/Flash/BatchCoprocessorHandler.cpp index 9200220d320..2582f0eb46b 100644 --- a/dbms/src/Flash/BatchCoprocessorHandler.cpp +++ b/dbms/src/Flash/BatchCoprocessorHandler.cpp @@ -62,7 +62,14 @@ grpc::Status BatchCoprocessorHandler::execute() LOG_DEBUG(log, __PRETTY_FUNCTION__ << ": Handling " << regions.size() << " regions in DAG request: " << dag_request.DebugString()); - DAGDriver driver(cop_context.db_context, dag_request, regions, retry_regions, cop_request->start_ts() > 0 ? cop_request->start_ts() : dag_request.start_ts_fallback(), cop_request->schema_ver(), writer); + DAGContext dag_context(dag_request); + dag_context.is_batch_cop = true; + dag_context.regions_for_local_read = std::move(regions); + dag_context.regions_for_remote_read = std::move(retry_regions); + dag_context.log = std::make_shared(log, ""); + cop_context.db_context.setDAGContext(&dag_context); + + DAGDriver driver(cop_context.db_context, cop_request->start_ts() > 0 ? cop_request->start_ts() : dag_request.start_ts_fallback(), cop_request->schema_ver(), writer); // batch execution; driver.execute(); LOG_DEBUG(log, __PRETTY_FUNCTION__ << ": Handle DAG request done"); diff --git a/dbms/src/Flash/Coprocessor/DAGContext.h b/dbms/src/Flash/Coprocessor/DAGContext.h index 8113c795525..0c1ed00aae7 100644 --- a/dbms/src/Flash/Coprocessor/DAGContext.h +++ b/dbms/src/Flash/Coprocessor/DAGContext.h @@ -35,40 +35,43 @@ UInt64 inline getMaxErrorCount(const tipb::DAGRequest &) class DAGContext { public: - explicit DAGContext(const tipb::DAGRequest & dag_request) - : collect_execution_summaries(dag_request.has_collect_execution_summaries() && dag_request.collect_execution_summaries()) + explicit DAGContext(const tipb::DAGRequest & dag_request_) + : dag_request(&dag_request_) + , collect_execution_summaries(dag_request->has_collect_execution_summaries() && dag_request->collect_execution_summaries()) , is_mpp_task(false) , is_root_mpp_task(false) , tunnel_set(nullptr) - , flags(dag_request.flags()) - , sql_mode(dag_request.sql_mode()) - , max_recorded_error_count(getMaxErrorCount(dag_request)) + , flags(dag_request->flags()) + , sql_mode(dag_request->sql_mode()) + , max_recorded_error_count(getMaxErrorCount(*dag_request)) , warnings(max_recorded_error_count) , warning_count(0) { - assert(dag_request.has_root_executor() || dag_request.executors_size() > 0); - return_executor_id = dag_request.root_executor().has_executor_id() || dag_request.executors(0).has_executor_id(); + assert(dag_request->has_root_executor() || dag_request->executors_size() > 0); + return_executor_id = dag_request->root_executor().has_executor_id() || dag_request->executors(0).has_executor_id(); } - DAGContext(const tipb::DAGRequest & dag_request, const mpp::TaskMeta & meta_, bool is_root_mpp_task_) - : collect_execution_summaries(dag_request.has_collect_execution_summaries() && dag_request.collect_execution_summaries()) + DAGContext(const tipb::DAGRequest & dag_request_, const mpp::TaskMeta & meta_, bool is_root_mpp_task_) + : dag_request(&dag_request_) + , collect_execution_summaries(dag_request->has_collect_execution_summaries() && dag_request->collect_execution_summaries()) , return_executor_id(true) , is_mpp_task(true) , is_root_mpp_task(is_root_mpp_task_) , tunnel_set(nullptr) - , flags(dag_request.flags()) - , sql_mode(dag_request.sql_mode()) + , flags(dag_request->flags()) + , sql_mode(dag_request->sql_mode()) , mpp_task_meta(meta_) , mpp_task_id(mpp_task_meta.start_ts(), mpp_task_meta.task_id()) - , max_recorded_error_count(getMaxErrorCount(dag_request)) + , max_recorded_error_count(getMaxErrorCount(*dag_request)) , warnings(max_recorded_error_count) , warning_count(0) { - assert(dag_request.has_root_executor() && dag_request.root_executor().has_executor_id()); + assert(dag_request->has_root_executor() && dag_request->root_executor().has_executor_id()); } explicit DAGContext(UInt64 max_error_count_) - : collect_execution_summaries(false) + : dag_request(nullptr) + , collect_execution_summaries(false) , is_mpp_task(false) , is_root_mpp_task(false) , tunnel_set(nullptr) @@ -114,6 +117,7 @@ class DAGContext void clearWarnings() { warnings.clear(); } UInt64 getWarningCount() { return warning_count; } const mpp::TaskMeta & getMPPTaskMeta() const { return mpp_task_meta; } + bool isBatchCop() const { return is_batch_cop; } bool isMPPTask() const { return is_mpp_task; } /// root mpp task means mpp task that send data back to TiDB bool isRootMPPTask() const { return is_root_mpp_task; } @@ -126,17 +130,29 @@ class DAGContext std::pair getTableScanThroughput(); + const RegionInfoMap & getRegionsForLocalRead() const { return regions_for_local_read; } + const RegionInfoList & getRegionsForRemoteRead() const { return regions_for_remote_read; } + + const tipb::DAGRequest * dag_request; size_t final_concurrency = 1; Int64 compile_time_ns; String table_scan_executor_id = ""; bool collect_execution_summaries; bool return_executor_id; - bool is_mpp_task; - bool is_root_mpp_task; + bool is_mpp_task = false; + bool is_root_mpp_task = false; + bool is_batch_cop = false; MPPTunnelSetPtr tunnel_set; + RegionInfoMap regions_for_local_read; + RegionInfoList regions_for_remote_read; + // part of regions_for_local_read + regions_for_remote_read, only used for batch-cop RegionInfoList retry_regions; - LogWithPrefixPtr mpp_task_log; + LogWithPrefixPtr log; + + bool keep_session_timezone_info = false; + std::vector result_field_types; + tipb::EncodeType encode_type = tipb::EncodeType::TypeDefault; private: /// profile_streams_map is a map that maps from executor_id to ProfileStreamsInfo diff --git a/dbms/src/Flash/Coprocessor/DAGDriver.cpp b/dbms/src/Flash/Coprocessor/DAGDriver.cpp index ab58a5be611..252cc8273bc 100644 --- a/dbms/src/Flash/Coprocessor/DAGDriver.cpp +++ b/dbms/src/Flash/Coprocessor/DAGDriver.cpp @@ -25,12 +25,20 @@ extern const int LOGICAL_ERROR; extern const int UNKNOWN_EXCEPTION; } // namespace ErrorCodes +template +const tipb::DAGRequest & DAGDriver::dagRequest() const +{ + return *context.getDAGContext()->dag_request; +} + template <> -DAGDriver::DAGDriver(Context & context_, const tipb::DAGRequest & dag_request_, const RegionInfoMap & regions_, const RegionInfoList & retry_regions_, UInt64 start_ts, UInt64 schema_ver, tipb::SelectResponse * dag_response_, bool internal_) +DAGDriver::DAGDriver( + Context & context_, + UInt64 start_ts, + UInt64 schema_ver, + tipb::SelectResponse * dag_response_, + bool internal_) : context(context_) - , dag_request(dag_request_) - , regions(regions_) - , retry_regions(retry_regions_) , dag_response(dag_response_) , writer(nullptr) , internal(internal_) @@ -40,15 +48,17 @@ DAGDriver::DAGDriver(Context & context_, const tipb::DAGRequest & dag_req if (schema_ver) // schema_ver being 0 means TiDB/TiSpark hasn't specified schema version. context.setSetting("schema_version", schema_ver); - context.getTimezoneInfo().resetByDAGRequest(dag_request); + context.getTimezoneInfo().resetByDAGRequest(dagRequest()); } template <> -DAGDriver::DAGDriver(Context & context_, const tipb::DAGRequest & dag_request_, const RegionInfoMap & regions_, const RegionInfoList & retry_regions_, UInt64 start_ts, UInt64 schema_ver, ::grpc::ServerWriter<::coprocessor::BatchResponse> * writer_, bool internal_) +DAGDriver::DAGDriver( + Context & context_, + UInt64 start_ts, + UInt64 schema_ver, + ::grpc::ServerWriter<::coprocessor::BatchResponse> * writer_, + bool internal_) : context(context_) - , dag_request(dag_request_) - , regions(regions_) - , retry_regions(retry_regions_) , writer(writer_) , internal(internal_) , log(&Poco::Logger::get("DAGDriver")) @@ -57,7 +67,7 @@ DAGDriver::DAGDriver(Context & context_, const tipb::DAGRequest & dag_requ if (schema_ver) // schema_ver being 0 means TiDB/TiSpark hasn't specified schema version. context.setSetting("schema_version", schema_ver); - context.getTimezoneInfo().resetByDAGRequest(dag_request); + context.getTimezoneInfo().resetByDAGRequest(dagRequest()); } template @@ -65,9 +75,8 @@ void DAGDriver::execute() try { auto start_time = Clock::now(); - DAGContext dag_context(dag_request); - context.setDAGContext(&dag_context); - DAGQuerySource dag(context, regions, retry_regions, dag_request, std::make_shared(&Poco::Logger::get("CoprocessorHandler"), ""), batch); + DAGQuerySource dag(context); + DAGContext & dag_context = *context.getDAGContext(); BlockIO streams = executeQuery(dag, context, internal, QueryProcessingStage::Complete); if (!streams.in || streams.out) @@ -85,8 +94,6 @@ try std::unique_ptr response_writer = std::make_unique( dag_response, context.getSettings().dag_records_per_chunk, - dag.getEncodeType(), - dag.getResultFieldTypes(), dag_context); dag_output_stream = std::make_shared(streams.in->getHeader(), std::move(response_writer)); copyData(*streams.in, *dag_output_stream); @@ -117,10 +124,7 @@ try context.getSettings().dag_records_per_chunk, context.getSettings().batch_send_min_limit, true, - dag.getEncodeType(), - dag.getResultFieldTypes(), - dag_context, - nullptr); + dag_context); dag_output_stream = std::make_shared(streams.in->getHeader(), std::move(response_writer)); copyData(*streams.in, *dag_output_stream); } diff --git a/dbms/src/Flash/Coprocessor/DAGDriver.h b/dbms/src/Flash/Coprocessor/DAGDriver.h index b04b6c3c658..07ee510c551 100644 --- a/dbms/src/Flash/Coprocessor/DAGDriver.h +++ b/dbms/src/Flash/Coprocessor/DAGDriver.h @@ -23,9 +23,6 @@ class DAGDriver public: DAGDriver( Context & context_, - const tipb::DAGRequest & dag_request_, - const RegionInfoMap & regions_, - const RegionInfoList & retry_regions_, UInt64 start_ts, UInt64 schema_ver, tipb::SelectResponse * dag_response_, @@ -33,9 +30,6 @@ class DAGDriver DAGDriver( Context & context_, - const tipb::DAGRequest & dag_request_, - const RegionInfoMap & regions_, - const RegionInfoList & retry_regions_, UInt64 start_ts, UInt64 schema_ver, ::grpc::ServerWriter<::coprocessor::BatchResponse> * writer, @@ -46,13 +40,9 @@ class DAGDriver private: void recordError(Int32 err_code, const String & err_msg); -private: - Context & context; + const tipb::DAGRequest & dagRequest() const; - const tipb::DAGRequest & dag_request; - - const RegionInfoMap & regions; - const RegionInfoList & retry_regions; + Context & context; tipb::SelectResponse * dag_response; diff --git a/dbms/src/Flash/Coprocessor/DAGQueryBlockInterpreter.cpp b/dbms/src/Flash/Coprocessor/DAGQueryBlockInterpreter.cpp index a284dfa14c2..43695f1ad1e 100644 --- a/dbms/src/Flash/Coprocessor/DAGQueryBlockInterpreter.cpp +++ b/dbms/src/Flash/Coprocessor/DAGQueryBlockInterpreter.cpp @@ -44,20 +44,16 @@ DAGQueryBlockInterpreter::DAGQueryBlockInterpreter( const DAGQueryBlock & query_block_, size_t max_streams_, bool keep_session_timezone_info_, - const DAGQuerySource & dag_, std::vector & subqueries_for_sets_, - const std::unordered_map> & exchange_receiver_map_, - const LogWithPrefixPtr & log_) + const std::unordered_map> & exchange_receiver_map_) : context(context_) , input_streams_vec(input_streams_vec_) , query_block(query_block_) , keep_session_timezone_info(keep_session_timezone_info_) - , rqst(dag_.getDAGRequest()) , max_streams(max_streams_) - , dag(dag_) , subqueries_for_sets(subqueries_for_sets_) , exchange_receiver_map(exchange_receiver_map_) - , log(log_) + , log(getMPPTaskLog(dagContext(), "DAGQueryBlockInterpreter")) { if (query_block.selection != nullptr) { @@ -66,17 +62,6 @@ DAGQueryBlockInterpreter::DAGQueryBlockInterpreter( } } -BlockInputStreamPtr combinedNonJoinedDataStream(DAGPipeline & pipeline, size_t max_threads, const LogWithPrefixPtr & log) -{ - BlockInputStreamPtr ret = nullptr; - if (pipeline.streams_with_non_joined_data.size() == 1) - ret = pipeline.streams_with_non_joined_data.at(0); - else if (pipeline.streams_with_non_joined_data.size() > 1) - ret = std::make_shared>(pipeline.streams_with_non_joined_data, nullptr, max_threads, log); - pipeline.streams_with_non_joined_data.clear(); - return ret; -} - namespace { struct AnalysisResult @@ -275,12 +260,12 @@ void DAGQueryBlockInterpreter::executeTS(const tipb::TableScan & ts, DAGPipeline // do not have table id throw TiFlashException("Table id not specified in table scan executor", Errors::Coprocessor::BadRequest); } - if (dag.getRegions().empty() && dag.getRegionsForRemoteRead().empty()) + if (dagContext().getRegionsForLocalRead().empty() && dagContext().getRegionsForRemoteRead().empty()) { throw TiFlashException("Dag Request does not have region to read. ", Errors::Coprocessor::BadRequest); } - DAGStorageInterpreter storage_interpreter(context, dag, query_block, ts, conditions, max_streams, log); + DAGStorageInterpreter storage_interpreter(context, query_block, ts, conditions, max_streams); storage_interpreter.execute(pipeline); analyzer = std::move(storage_interpreter.analyzer); @@ -344,7 +329,7 @@ void DAGQueryBlockInterpreter::prepareJoin( ExpressionActionsChain chain; if (dag_analyzer.appendJoinKeyAndJoinFilters(chain, keys, key_types, key_names, left, is_right_out_join, filters, filter_column_name)) { - pipeline.transform([&](auto & stream) { stream = std::make_shared(stream, chain.getLastActions(), log); }); + pipeline.transform([&](auto & stream) { stream = std::make_shared(stream, chain.getLastActions(), taskLogger()); }); } } @@ -587,13 +572,13 @@ void DAGQueryBlockInterpreter::executeJoin(const tipb::Join & join, DAGPipeline // add a HashJoinBuildBlockInputStream to build a shared hash table size_t stream_index = 0; right_pipeline.transform( - [&](auto & stream) { stream = std::make_shared(stream, join_ptr, stream_index++, log); }); - executeUnion(right_pipeline, max_streams, log); + [&](auto & stream) { stream = std::make_shared(stream, join_ptr, stream_index++, taskLogger()); }); + executeUnion(right_pipeline, max_streams, taskLogger()); right_query.source = right_pipeline.firstStream(); right_query.join = join_ptr; right_query.join->setSampleBlock(right_query.source->getHeader()); - dag.getDAGContext().getProfileStreamsMapForJoinBuildSide()[query_block.qb_join_subquery_alias].push_back(right_query.source); + dagContext().getProfileStreamsMapForJoinBuildSide()[query_block.qb_join_subquery_alias].push_back(right_query.source); std::vector source_columns; for (const auto & p : left_pipeline.streams[0]->getHeader().getNamesAndTypesList()) @@ -613,7 +598,7 @@ void DAGQueryBlockInterpreter::executeJoin(const tipb::Join & join, DAGPipeline settings.max_block_size)); } for (auto & stream : pipeline.streams) - stream = std::make_shared(stream, chain.getLastActions(), log); + stream = std::make_shared(stream, chain.getLastActions(), taskLogger()); /// add a project to remove all the useless column NamesWithAliases project_cols; @@ -629,7 +614,7 @@ void DAGQueryBlockInterpreter::executeJoin(const tipb::Join & join, DAGPipeline void DAGQueryBlockInterpreter::executeWhere(DAGPipeline & pipeline, const ExpressionActionsPtr & expr, String & filter_column) { - pipeline.transform([&](auto & stream) { stream = std::make_shared(stream, expr, filter_column, log); }); + pipeline.transform([&](auto & stream) { stream = std::make_shared(stream, expr, filter_column, taskLogger()); }); } void DAGQueryBlockInterpreter::executeAggregation( @@ -639,7 +624,7 @@ void DAGQueryBlockInterpreter::executeAggregation( TiDB::TiDBCollators & collators, AggregateDescriptions & aggregate_descriptions) { - pipeline.transform([&](auto & stream) { stream = std::make_shared(stream, expression_actions_ptr, log); }); + pipeline.transform([&](auto & stream) { stream = std::make_shared(stream, expression_actions_ptr, taskLogger()); }); Block header = pipeline.firstStream()->getHeader(); ColumnNumbers keys; @@ -685,7 +670,7 @@ void DAGQueryBlockInterpreter::executeAggregation( if (pipeline.streams.size() > 1) { before_agg_streams = pipeline.streams.size(); - BlockInputStreamPtr stream_with_non_joined_data = combinedNonJoinedDataStream(pipeline, max_streams, log); + BlockInputStreamPtr stream_with_non_joined_data = combinedNonJoinedDataStream(pipeline, max_streams, taskLogger()); pipeline.firstStream() = std::make_shared( pipeline.streams, stream_with_non_joined_data, @@ -694,12 +679,12 @@ void DAGQueryBlockInterpreter::executeAggregation( true, max_streams, settings.aggregation_memory_efficient_merge_threads ? static_cast(settings.aggregation_memory_efficient_merge_threads) : static_cast(settings.max_threads), - log); + taskLogger()); pipeline.streams.resize(1); } else { - BlockInputStreamPtr stream_with_non_joined_data = combinedNonJoinedDataStream(pipeline, max_streams, log); + BlockInputStreamPtr stream_with_non_joined_data = combinedNonJoinedDataStream(pipeline, max_streams, taskLogger()); BlockInputStreams inputs; if (!pipeline.streams.empty()) inputs.push_back(pipeline.firstStream()); @@ -708,11 +693,11 @@ void DAGQueryBlockInterpreter::executeAggregation( if (stream_with_non_joined_data) inputs.push_back(stream_with_non_joined_data); pipeline.firstStream() = std::make_shared( - std::make_shared(inputs, log), + std::make_shared(inputs, taskLogger()), params, context.getFileProvider(), true, - log); + taskLogger()); } // add cast } @@ -721,23 +706,7 @@ void DAGQueryBlockInterpreter::executeExpression(DAGPipeline & pipeline, const E { if (!expressionActionsPtr->getActions().empty()) { - pipeline.transform([&](auto & stream) { stream = std::make_shared(stream, expressionActionsPtr, log); }); - } -} - -void DAGQueryBlockInterpreter::executeUnion(DAGPipeline & pipeline, size_t max_streams, const LogWithPrefixPtr & log) -{ - if (pipeline.streams.size() == 1 && pipeline.streams_with_non_joined_data.empty()) - return; - auto non_joined_data_stream = combinedNonJoinedDataStream(pipeline, max_streams, log); - if (!pipeline.streams.empty()) - { - pipeline.firstStream() = std::make_shared>(pipeline.streams, non_joined_data_stream, max_streams, log); - pipeline.streams.resize(1); - } - else if (non_joined_data_stream != nullptr) - { - pipeline.streams.push_back(non_joined_data_stream); + pipeline.transform([&](auto & stream) { stream = std::make_shared(stream, expressionActionsPtr, taskLogger()); }); } } @@ -748,7 +717,7 @@ void DAGQueryBlockInterpreter::executeOrder(DAGPipeline & pipeline, const std::v Int64 limit = query_block.limitOrTopN->topn().limit(); pipeline.transform([&](auto & stream) { - auto sorting_stream = std::make_shared(stream, order_descr, log, limit); + auto sorting_stream = std::make_shared(stream, order_descr, taskLogger(), limit); /// Limits on sorting IProfilingBlockInputStream::LocalLimits limits; @@ -760,7 +729,7 @@ void DAGQueryBlockInterpreter::executeOrder(DAGPipeline & pipeline, const std::v }); /// If there are several streams, we merge them into one - executeUnion(pipeline, max_streams, log); + executeUnion(pipeline, max_streams, taskLogger()); /// Merge the sorted blocks. pipeline.firstStream() = std::make_shared( @@ -770,18 +739,18 @@ void DAGQueryBlockInterpreter::executeOrder(DAGPipeline & pipeline, const std::v limit, settings.max_bytes_before_external_sort, context.getTemporaryPath(), - log); + taskLogger()); } void DAGQueryBlockInterpreter::recordProfileStreams(DAGPipeline & pipeline, const String & key) { - dag.getDAGContext().getProfileStreamsMap()[key].qb_id = query_block.id; + dagContext().getProfileStreamsMap()[key].qb_id = query_block.id; for (auto & stream : pipeline.streams) { - dag.getDAGContext().getProfileStreamsMap()[key].input_streams.push_back(stream); + dagContext().getProfileStreamsMap()[key].input_streams.push_back(stream); } for (auto & stream : pipeline.streams_with_non_joined_data) - dag.getDAGContext().getProfileStreamsMap()[key].input_streams.push_back(stream); + dagContext().getProfileStreamsMap()[key].input_streams.push_back(stream); } void copyExecutorTreeWithLocalTableScan( @@ -883,7 +852,7 @@ void DAGQueryBlockInterpreter::executeRemoteQuery(DAGPipeline & pipeline) ::tipb::DAGRequest dag_req; - copyExecutorTreeWithLocalTableScan(dag_req, query_block.root, rqst); + copyExecutorTreeWithLocalTableScan(dag_req, query_block.root, *dagContext().dag_request); DAGSchema schema; ColumnsWithTypeAndName columns; BoolVec is_ts_column; @@ -899,7 +868,7 @@ void DAGQueryBlockInterpreter::executeRemoteQuery(DAGPipeline & pipeline) final_project.emplace_back(col_name, ""); } - dag_req.set_collect_execution_summaries(dag.getDAGContext().collect_execution_summaries); + dag_req.set_collect_execution_summaries(dagContext().collect_execution_summaries); executeRemoteQueryImpl(pipeline, cop_key_ranges, dag_req, schema); analyzer = std::make_unique(std::move(source_columns), context); @@ -935,9 +904,9 @@ void DAGQueryBlockInterpreter::executeRemoteQueryImpl( std::vector tasks(all_tasks.begin() + task_start, all_tasks.begin() + task_end); auto coprocessor_reader = std::make_shared(schema, cluster, tasks, has_enforce_encode_type, 1); - BlockInputStreamPtr input = std::make_shared(coprocessor_reader, log); + BlockInputStreamPtr input = std::make_shared(coprocessor_reader, taskLogger()); pipeline.streams.push_back(input); - dag.getDAGContext().getRemoteInputStreams().push_back(input); + dagContext().getRemoteInputStreams().push_back(input); task_start = task_end; } } @@ -977,9 +946,9 @@ void DAGQueryBlockInterpreter::executeImpl(DAGPipeline & pipeline) // todo choose a more reasonable stream number for (size_t i = 0; i < max_streams; i++) { - BlockInputStreamPtr stream = std::make_shared(it->second, log); - dag.getDAGContext().getRemoteInputStreams().push_back(stream); - stream = std::make_shared(stream, 8192, 0, log); + BlockInputStreamPtr stream = std::make_shared(it->second, taskLogger()); + dagContext().getRemoteInputStreams().push_back(stream); + stream = std::make_shared(stream, 8192, 0, taskLogger()); pipeline.streams.push_back(stream); } std::vector source_columns; @@ -1013,7 +982,7 @@ void DAGQueryBlockInterpreter::executeImpl(DAGPipeline & pipeline) output_columns.emplace_back(alias, col.type); project_cols.emplace_back(col.name, alias); } - pipeline.transform([&](auto & stream) { stream = std::make_shared(stream, chain.getLastActions(), log); }); + pipeline.transform([&](auto & stream) { stream = std::make_shared(stream, chain.getLastActions(), taskLogger()); }); executeProject(pipeline, project_cols); analyzer = std::make_unique(std::move(output_columns), context); recordProfileStreams(pipeline, query_block.source_name); @@ -1022,7 +991,7 @@ void DAGQueryBlockInterpreter::executeImpl(DAGPipeline & pipeline) { executeTS(query_block.source->tbl_scan(), pipeline); recordProfileStreams(pipeline, query_block.source_name); - dag.getDAGContext().table_scan_executor_id = query_block.source_name; + dagContext().table_scan_executor_id = query_block.source_name; } auto res = analyzeExpressions( @@ -1049,20 +1018,20 @@ void DAGQueryBlockInterpreter::executeImpl(DAGPipeline & pipeline) { project_for_cop_read = generateProjectExpressionActions(stream, context, res.project_after_ts_and_filter_for_remote_read); } - stream = std::make_shared(stream, project_for_cop_read, log); + stream = std::make_shared(stream, project_for_cop_read, taskLogger()); } } else { /// execute timezone cast or duration cast if needed if (res.extra_cast) - stream = std::make_shared(stream, res.extra_cast, log); + stream = std::make_shared(stream, res.extra_cast, taskLogger()); /// execute selection if needed if (res.before_where) { - stream = std::make_shared(stream, res.before_where, res.filter_column_name, log); + stream = std::make_shared(stream, res.before_where, res.filter_column_name, taskLogger()); if (res.project_after_where) - stream = std::make_shared(stream, res.project_after_where, log); + stream = std::make_shared(stream, res.project_after_where, taskLogger()); } } } @@ -1071,9 +1040,9 @@ void DAGQueryBlockInterpreter::executeImpl(DAGPipeline & pipeline) /// execute selection if needed if (res.before_where) { - stream = std::make_shared(stream, res.before_where, res.filter_column_name, log); + stream = std::make_shared(stream, res.before_where, res.filter_column_name, taskLogger()); if (res.project_after_where) - stream = std::make_shared(stream, res.project_after_where, log); + stream = std::make_shared(stream, res.project_after_where, taskLogger()); } } } @@ -1086,7 +1055,7 @@ void DAGQueryBlockInterpreter::executeImpl(DAGPipeline & pipeline) LOG_INFO(log, "execution stream size for query block(before aggregation) " << query_block.qb_column_prefix << " is " << pipeline.streams.size()); - dag.getDAGContext().final_concurrency = std::max(dag.getDAGContext().final_concurrency, pipeline.streams.size()); + dagContext().final_concurrency = std::max(dagContext().final_concurrency, pipeline.streams.size()); if (res.before_aggregation) { @@ -1137,7 +1106,7 @@ void DAGQueryBlockInterpreter::executeProject(DAGPipeline & pipeline, NamesWithA if (project_cols.empty()) return; ExpressionActionsPtr project = generateProjectExpressionActions(pipeline.firstStream(), context, project_cols); - pipeline.transform([&](auto & stream) { stream = std::make_shared(stream, project, log); }); + pipeline.transform([&](auto & stream) { stream = std::make_shared(stream, project, taskLogger()); }); } void DAGQueryBlockInterpreter::executeLimit(DAGPipeline & pipeline) @@ -1147,18 +1116,18 @@ void DAGQueryBlockInterpreter::executeLimit(DAGPipeline & pipeline) limit = query_block.limitOrTopN->limit().limit(); else limit = query_block.limitOrTopN->topn().limit(); - pipeline.transform([&](auto & stream) { stream = std::make_shared(stream, limit, 0, log, false); }); + pipeline.transform([&](auto & stream) { stream = std::make_shared(stream, limit, 0, taskLogger(), false); }); if (pipeline.hasMoreThanOneStream()) { - executeUnion(pipeline, max_streams, log); - pipeline.transform([&](auto & stream) { stream = std::make_shared(stream, limit, 0, log, false); }); + executeUnion(pipeline, max_streams, taskLogger()); + pipeline.transform([&](auto & stream) { stream = std::make_shared(stream, limit, 0, taskLogger(), false); }); } } void DAGQueryBlockInterpreter::executeExchangeSender(DAGPipeline & pipeline) { /// only run in MPP - assert(dag.getDAGContext().isMPPTask() && context.getDAGContext()->tunnel_set != nullptr); + assert(dagContext().isMPPTask() && dagContext().tunnel_set != nullptr); /// exchange sender should be at the top of operators const auto & exchange_sender = query_block.exchangeSender->exchange_sender(); /// get partition column ids @@ -1188,7 +1157,7 @@ void DAGQueryBlockInterpreter::executeExchangeSender(DAGPipeline & pipeline) collators.emplace_back(nullptr); } } - restoreConcurrency(pipeline, dag.getDAGContext().final_concurrency, log); + restoreConcurrency(pipeline, dagContext().final_concurrency, log); int stream_id = 0; pipeline.transform([&](auto & stream) { // construct writer @@ -1200,10 +1169,7 @@ void DAGQueryBlockInterpreter::executeExchangeSender(DAGPipeline & pipeline) context.getSettings().dag_records_per_chunk, context.getSettings().batch_send_min_limit, stream_id++ == 0, /// only one stream needs to sending execution summaries for the last response - dag.getEncodeType(), - dag.getResultFieldTypes(), - dag.getDAGContext(), - log); + dagContext()); stream = std::make_shared(stream, std::move(response_writer), log); }); } @@ -1215,14 +1181,14 @@ BlockInputStreams DAGQueryBlockInterpreter::execute() if (!pipeline.streams_with_non_joined_data.empty()) { size_t concurrency = pipeline.streams.size(); - executeUnion(pipeline, max_streams, log); + executeUnion(pipeline, max_streams, taskLogger()); if (!query_block.isRootQueryBlock()) - restoreConcurrency(pipeline, concurrency, log); + restoreConcurrency(pipeline, concurrency, taskLogger()); } /// expand concurrency after agg if (!query_block.isRootQueryBlock()) - restoreConcurrency(pipeline, before_agg_streams, log); + restoreConcurrency(pipeline, before_agg_streams, taskLogger()); return pipeline.streams; } diff --git a/dbms/src/Flash/Coprocessor/DAGQueryBlockInterpreter.h b/dbms/src/Flash/Coprocessor/DAGQueryBlockInterpreter.h index e0fe3d46871..c1ebcaebd59 100644 --- a/dbms/src/Flash/Coprocessor/DAGQueryBlockInterpreter.h +++ b/dbms/src/Flash/Coprocessor/DAGQueryBlockInterpreter.h @@ -1,29 +1,22 @@ #pragma once -#pragma GCC diagnostic push -#pragma GCC diagnostic ignored "-Wunused-parameter" -#include -#include -#pragma GCC diagnostic pop - #include #include #include #include #include +#include #include #include #include #include +#include #include +#include namespace DB { -class Context; - -class DAGQuerySource; class DAGQueryBlock; -struct RegionInfo; class ExchangeReceiver; class DAGExpressionAnalyzer; @@ -38,17 +31,13 @@ class DAGQueryBlockInterpreter const DAGQueryBlock & query_block_, size_t max_streams_, bool keep_session_timezone_info_, - const DAGQuerySource & dag_, std::vector & subqueries_for_sets_, - const std::unordered_map> & exchange_receiver_map, - const LogWithPrefixPtr & log_); + const std::unordered_map> & exchange_receiver_map); ~DAGQueryBlockInterpreter() = default; BlockInputStreams execute(); - static void executeUnion(DAGPipeline & pipeline, size_t max_streams, const LogWithPrefixPtr & log); - private: void executeRemoteQuery(DAGPipeline & pipeline); void executeImpl(DAGPipeline & pipeline); @@ -89,11 +78,13 @@ class DAGQueryBlockInterpreter ::tipb::DAGRequest & dag_req, const DAGSchema & schema); + DAGContext & dagContext() const { return *context.getDAGContext(); } + const LogWithPrefixPtr & taskLogger() const { return dagContext().log; } + Context & context; std::vector input_streams_vec; const DAGQueryBlock & query_block; const bool keep_session_timezone_info; - const tipb::DAGRequest & rqst; NamesWithAliases final_project; @@ -108,11 +99,10 @@ class DAGQueryBlockInterpreter std::unique_ptr analyzer; std::vector conditions; - const DAGQuerySource & dag; std::vector & subqueries_for_sets; const std::unordered_map> & exchange_receiver_map; std::vector need_add_cast_column_flag_for_tablescan; - const LogWithPrefixPtr log; + LogWithPrefixPtr log; }; } // namespace DB diff --git a/dbms/src/Flash/Coprocessor/DAGQuerySource.cpp b/dbms/src/Flash/Coprocessor/DAGQuerySource.cpp index e06e03c6e3e..9c3375a2f6a 100644 --- a/dbms/src/Flash/Coprocessor/DAGQuerySource.cpp +++ b/dbms/src/Flash/Coprocessor/DAGQuerySource.cpp @@ -9,20 +9,10 @@ namespace ErrorCodes extern const int COP_BAD_DAG_REQUEST; } // namespace ErrorCodes -DAGQuerySource::DAGQuerySource( - Context & context_, - const RegionInfoMap & regions_, - const RegionInfoList & regions_for_remote_read_, - const tipb::DAGRequest & dag_request_, - const LogWithPrefixPtr & log_, - const bool is_batch_cop_or_mpp_) +DAGQuerySource::DAGQuerySource(Context & context_) : context(context_) - , regions(regions_) - , regions_for_remote_read(regions_for_remote_read_) - , dag_request(dag_request_) - , is_batch_cop_or_mpp(is_batch_cop_or_mpp_) - , log(log_) { + const tipb::DAGRequest & dag_request = *getDAGContext().dag_request; if (dag_request.has_root_executor()) { QueryBlockIDGenerator id_generator; @@ -32,7 +22,7 @@ DAGQuerySource::DAGQuerySource( { root_query_block = std::make_shared(1, dag_request.executors()); } - root_query_block->collectAllPossibleChildrenJoinSubqueryAlias(context.getDAGContext()->getQBIdToJoinAliasMap()); + root_query_block->collectAllPossibleChildrenJoinSubqueryAlias(getDAGContext().getQBIdToJoinAliasMap()); for (Int32 i : dag_request.output_offsets()) root_query_block->output_offsets.push_back(i); for (UInt32 i : dag_request.output_offsets()) @@ -41,32 +31,34 @@ DAGQuerySource::DAGQuerySource( throw TiFlashException(std::string(__PRETTY_FUNCTION__) + ": Invalid output offset(schema has " + std::to_string(root_query_block->output_field_types.size()) + " columns, access index " + std::to_string(i), Errors::Coprocessor::BadRequest); - result_field_types.push_back(root_query_block->output_field_types[i]); + getDAGContext().result_field_types.push_back(root_query_block->output_field_types[i]); } - analyzeDAGEncodeType(); + auto encode_type = analyzeDAGEncodeType(); + getDAGContext().encode_type = encode_type; + getDAGContext().keep_session_timezone_info = encode_type == tipb::EncodeType::TypeChunk || encode_type == tipb::EncodeType::TypeCHBlock; } -void DAGQuerySource::analyzeDAGEncodeType() +tipb::EncodeType DAGQuerySource::analyzeDAGEncodeType() { + const tipb::DAGRequest & dag_request = *getDAGContext().dag_request; + const tipb::EncodeType encode_type = dag_request.encode_type(); if (getDAGContext().isMPPTask() && !getDAGContext().isRootMPPTask()) { /// always use CHBlock encode type for data exchange between TiFlash nodes - encode_type = tipb::EncodeType::TypeCHBlock; - return; + return tipb::EncodeType::TypeCHBlock; } if (dag_request.has_force_encode_type() && dag_request.force_encode_type()) { - encode_type = dag_request.encode_type(); assert(encode_type == tipb::EncodeType::TypeCHBlock); - return; + return encode_type; } - encode_type = dag_request.encode_type(); - if (isUnsupportedEncodeType(getResultFieldTypes(), encode_type)) - encode_type = tipb::EncodeType::TypeDefault; + if (isUnsupportedEncodeType(getDAGContext().result_field_types, encode_type)) + return tipb::EncodeType::TypeDefault; if (encode_type == tipb::EncodeType::TypeChunk && dag_request.has_chunk_memory_layout() && dag_request.chunk_memory_layout().has_endian() && dag_request.chunk_memory_layout().endian() == tipb::Endian::BigEndian) // todo support BigEndian encode for chunk encode type - encode_type = tipb::EncodeType::TypeDefault; + return tipb::EncodeType::TypeDefault; + return encode_type; } std::tuple DAGQuerySource::parse(size_t) @@ -74,17 +66,17 @@ std::tuple DAGQuerySource::parse(size_t) // this is a WAR to avoid NPE when the MergeTreeDataSelectExecutor trying // to extract key range of the query. // todo find a way to enable key range extraction for dag query - return {dag_request.DebugString(), makeDummyQuery()}; + return {getDAGContext().dag_request->DebugString(), makeDummyQuery()}; } String DAGQuerySource::str(size_t) { - return dag_request.DebugString(); + return getDAGContext().dag_request->DebugString(); } std::unique_ptr DAGQuerySource::interpreter(Context &, QueryProcessingStage::Enum) { - return std::make_unique(context, *this, log); + return std::make_unique(context, *this); } } // namespace DB diff --git a/dbms/src/Flash/Coprocessor/DAGQuerySource.h b/dbms/src/Flash/Coprocessor/DAGQuerySource.h index 668ff1bc34a..fec159586a1 100644 --- a/dbms/src/Flash/Coprocessor/DAGQuerySource.h +++ b/dbms/src/Flash/Coprocessor/DAGQuerySource.h @@ -2,7 +2,6 @@ #include #include -#include #include #include #include @@ -18,54 +17,27 @@ namespace DB class DAGQuerySource : public IQuerySource { public: - DAGQuerySource( - Context & context_, - const RegionInfoMap & regions_, - const RegionInfoList & regions_needs_remote_read_, - const tipb::DAGRequest & dag_request_, - const LogWithPrefixPtr & log_, - const bool is_batch_cop_or_mpp_ = false); + explicit DAGQuerySource(Context & context_); std::tuple parse(size_t) override; String str(size_t max_query_size) override; std::unique_ptr interpreter(Context & context, QueryProcessingStage::Enum stage) override; - const tipb::DAGRequest & getDAGRequest() const { return dag_request; }; - - const std::vector & getResultFieldTypes() const { return result_field_types; } - ASTPtr getAST() const { return ast; }; - tipb::EncodeType getEncodeType() const { return encode_type; } - std::shared_ptr getRootQueryBlock() const { return root_query_block; } - const RegionInfoMap & getRegions() const { return regions; } - const RegionInfoList & getRegionsForRemoteRead() const { return regions_for_remote_read; } - - bool isBatchCopOrMpp() const { return is_batch_cop_or_mpp; } DAGContext & getDAGContext() const { return *context.getDAGContext(); } std::string getExecutorNames() const; -protected: - void analyzeDAGEncodeType(); +private: + tipb::EncodeType analyzeDAGEncodeType(); -protected: Context & context; - - const RegionInfoMap & regions; - const RegionInfoList & regions_for_remote_read; - - const tipb::DAGRequest & dag_request; - - std::vector result_field_types; - tipb::EncodeType encode_type; std::shared_ptr root_query_block; ASTPtr ast; - const bool is_batch_cop_or_mpp; - LogWithPrefixPtr log; }; diff --git a/dbms/src/Flash/Coprocessor/DAGResponseWriter.cpp b/dbms/src/Flash/Coprocessor/DAGResponseWriter.cpp index a1b5a16cf3a..ce82caf7698 100644 --- a/dbms/src/Flash/Coprocessor/DAGResponseWriter.cpp +++ b/dbms/src/Flash/Coprocessor/DAGResponseWriter.cpp @@ -118,24 +118,20 @@ void DAGResponseWriter::addExecuteSummaries(tipb::SelectResponse & response, boo DAGResponseWriter::DAGResponseWriter( Int64 records_per_chunk_, - tipb::EncodeType encode_type_, - std::vector result_field_types_, DAGContext & dag_context_) : records_per_chunk(records_per_chunk_) - , encode_type(encode_type_) - , result_field_types(std::move(result_field_types_)) , dag_context(dag_context_) { for (auto & p : dag_context.getProfileStreamsMap()) { local_executors.insert(p.first); } - if (encode_type == tipb::EncodeType::TypeCHBlock) + if (dag_context.encode_type == tipb::EncodeType::TypeCHBlock) { records_per_chunk = -1; } - if (encode_type != tipb::EncodeType::TypeCHBlock && encode_type != tipb::EncodeType::TypeChunk - && encode_type != tipb::EncodeType::TypeDefault) + if (dag_context.encode_type != tipb::EncodeType::TypeCHBlock && dag_context.encode_type != tipb::EncodeType::TypeChunk + && dag_context.encode_type != tipb::EncodeType::TypeDefault) { throw TiFlashException( "Only Default/Arrow/CHBlock encode type is supported in DAGBlockOutputStream.", diff --git a/dbms/src/Flash/Coprocessor/DAGResponseWriter.h b/dbms/src/Flash/Coprocessor/DAGResponseWriter.h index 1ce0c9ba8a8..3f612b9a87e 100644 --- a/dbms/src/Flash/Coprocessor/DAGResponseWriter.h +++ b/dbms/src/Flash/Coprocessor/DAGResponseWriter.h @@ -48,7 +48,9 @@ struct ExecutionSummary class DAGResponseWriter { public: - DAGResponseWriter(Int64 records_per_chunk_, tipb::EncodeType encode_type_, std::vector result_field_types_, DAGContext & dag_context_); + DAGResponseWriter( + Int64 records_per_chunk_, + DAGContext & dag_context_); void fillTiExecutionSummary( tipb::ExecutorExecutionSummary * execution_summary, ExecutionSummary & current, @@ -61,8 +63,6 @@ class DAGResponseWriter protected: Int64 records_per_chunk; - tipb::EncodeType encode_type; - std::vector result_field_types; DAGContext & dag_context; std::unordered_map previous_execution_stats; std::unordered_set local_executors; diff --git a/dbms/src/Flash/Coprocessor/DAGStorageInterpreter.cpp b/dbms/src/Flash/Coprocessor/DAGStorageInterpreter.cpp index 590f8fb651b..cd2a64e3e22 100644 --- a/dbms/src/Flash/Coprocessor/DAGStorageInterpreter.cpp +++ b/dbms/src/Flash/Coprocessor/DAGStorageInterpreter.cpp @@ -5,6 +5,7 @@ #include #include #include +#include #include #include #include @@ -119,30 +120,27 @@ MakeRegionQueryInfos( DAGStorageInterpreter::DAGStorageInterpreter( Context & context_, - const DAGQuerySource & dag_, const DAGQueryBlock & query_block_, const tipb::TableScan & ts, const std::vector & conditions_, - size_t max_streams_, - const LogWithPrefixPtr & log_) + size_t max_streams_) : context(context_) - , dag(dag_) , query_block(query_block_) , table_scan(ts) , conditions(conditions_) , max_streams(max_streams_) - , log(log_) + , log(getMPPTaskLog(*context.getDAGContext(), "DAGStorageInterpreter")) , table_id(ts.table_id()) , settings(context.getSettingsRef()) , tmt(context.getTMTContext()) , mvcc_query_info(new MvccQueryInfo(true, settings.read_tso)) { - log = log_ != nullptr ? log_ : std::make_shared(&Poco::Logger::get("DAGStorageInterpreter"), ""); } void DAGStorageInterpreter::execute(DAGPipeline & pipeline) { - if (dag.isBatchCopOrMpp()) + const DAGContext & dag_context = *context.getDAGContext(); + if (dag_context.isBatchCop() || dag_context.isMPPTask()) learner_read_snapshot = doBatchCopLearnerRead(); else learner_read_snapshot = doCopLearnerRead(); @@ -158,7 +156,7 @@ void DAGStorageInterpreter::execute(DAGPipeline & pipeline) if (!mvcc_query_info->regions_query_info.empty()) doLocalRead(pipeline, settings.max_block_size); - for (auto & region_info : dag.getRegionsForRemoteRead()) + for (auto & region_info : dag_context.getRegionsForRemoteRead()) region_retry.emplace_back(region_info); null_stream_if_empty = std::make_shared(storage->getSampleBlockForColumns(required_columns)); @@ -170,7 +168,7 @@ void DAGStorageInterpreter::execute(DAGPipeline & pipeline) LearnerReadSnapshot DAGStorageInterpreter::doCopLearnerRead() { auto [info_retry, status] = MakeRegionQueryInfos( - dag.getRegions(), + context.getDAGContext()->getRegionsForLocalRead(), {}, tmt, *mvcc_query_info, @@ -186,7 +184,8 @@ LearnerReadSnapshot DAGStorageInterpreter::doCopLearnerRead() /// Will assign region_retry LearnerReadSnapshot DAGStorageInterpreter::doBatchCopLearnerRead() { - if (dag.getRegions().empty()) + const auto & regions_for_local_read = context.getDAGContext()->getRegionsForLocalRead(); + if (regions_for_local_read.empty()) return {}; std::unordered_set force_retry; for (;;) @@ -195,7 +194,7 @@ LearnerReadSnapshot DAGStorageInterpreter::doBatchCopLearnerRead() { region_retry.clear(); auto [retry, status] = MakeRegionQueryInfos( - dag.getRegions(), + regions_for_local_read, force_retry, tmt, *mvcc_query_info, @@ -237,6 +236,7 @@ LearnerReadSnapshot DAGStorageInterpreter::doBatchCopLearnerRead() void DAGStorageInterpreter::doLocalRead(DAGPipeline & pipeline, size_t max_block_size) { + const DAGContext & dag_context = *context.getDAGContext(); SelectQueryInfo query_info; /// to avoid null point exception query_info.query = makeDummyQuery(); @@ -286,11 +286,11 @@ void DAGStorageInterpreter::doLocalRead(DAGPipeline & pipeline, size_t max_block catch (RegionException & e) { /// Recover from region exception when super batch is enable - if (dag.isBatchCopOrMpp()) + if (dag_context.isBatchCop() || dag_context.isMPPTask()) { // clean all streams from local because we are not sure the correctness of those streams pipeline.streams.clear(); - const auto & dag_regions = dag.getRegions(); + const auto & dag_regions = dag_context.getRegionsForLocalRead(); FmtBuffer buffer; // Normally there is only few regions need to retry when super batch is enabled. Retry to read // from local first. However, too many retry in different places may make the whole process @@ -506,6 +506,7 @@ std::tuple, String> DAGS std::tuple, std::optional> DAGStorageInterpreter::buildRemoteTS() { + const DAGContext & dag_context = *context.getDAGContext(); if (region_retry.empty()) return std::make_tuple(std::nullopt, std::nullopt); @@ -574,10 +575,11 @@ std::tuple, std::optional> DAGStorage /// do not collect execution summaries because in this case because the execution summaries /// will be collected by CoprocessorBlockInputStream dag_req.set_collect_execution_summaries(false); - if (dag.getDAGRequest().has_time_zone_name() && !dag.getDAGRequest().time_zone_name().empty()) - dag_req.set_time_zone_name(dag.getDAGRequest().time_zone_name()); - if (dag.getDAGRequest().has_time_zone_offset()) - dag_req.set_time_zone_offset(dag.getDAGRequest().time_zone_offset()); + const auto & original_dag_req = *dag_context.dag_request; + if (original_dag_req.has_time_zone_name() && !original_dag_req.time_zone_name().empty()) + dag_req.set_time_zone_name(original_dag_req.time_zone_name()); + if (original_dag_req.has_time_zone_offset()) + dag_req.set_time_zone_offset(original_dag_req.time_zone_offset()); return std::make_tuple(dag_req, schema); } diff --git a/dbms/src/Flash/Coprocessor/DAGStorageInterpreter.h b/dbms/src/Flash/Coprocessor/DAGStorageInterpreter.h index 45a22bc3141..41d9874ecaa 100644 --- a/dbms/src/Flash/Coprocessor/DAGStorageInterpreter.h +++ b/dbms/src/Flash/Coprocessor/DAGStorageInterpreter.h @@ -33,12 +33,10 @@ class DAGStorageInterpreter public: DAGStorageInterpreter( Context & context_, - const DAGQuerySource & dag_, const DAGQueryBlock & query_block_, const tipb::TableScan & ts, const std::vector & conditions_, - size_t max_streams_, - const LogWithPrefixPtr & log_); + size_t max_streams_); DAGStorageInterpreter(DAGStorageInterpreter &&) = delete; DAGStorageInterpreter & operator=(DAGStorageInterpreter &&) = delete; @@ -76,7 +74,6 @@ class DAGStorageInterpreter /// passed from caller, doesn't change during DAGStorageInterpreter's lifetime Context & context; - const DAGQuerySource & dag; const DAGQueryBlock & query_block; const tipb::TableScan & table_scan; const std::vector & conditions; diff --git a/dbms/src/Flash/Coprocessor/InterpreterDAG.cpp b/dbms/src/Flash/Coprocessor/InterpreterDAG.cpp index f509ead9a2f..c2229dfc0c4 100644 --- a/dbms/src/Flash/Coprocessor/InterpreterDAG.cpp +++ b/dbms/src/Flash/Coprocessor/InterpreterDAG.cpp @@ -11,15 +11,12 @@ namespace DB { -InterpreterDAG::InterpreterDAG(Context & context_, const DAGQuerySource & dag_, const LogWithPrefixPtr & log_) +InterpreterDAG::InterpreterDAG(Context & context_, const DAGQuerySource & dag_) : context(context_) , dag(dag_) - , keep_session_timezone_info( - dag.getEncodeType() == tipb::EncodeType::TypeChunk || dag.getEncodeType() == tipb::EncodeType::TypeCHBlock) - , log(log_) { const Settings & settings = context.getSettingsRef(); - if (dag.isBatchCopOrMpp()) + if (dagContext().isBatchCop() || dagContext().isMPPTask()) max_streams = settings.max_threads; else max_streams = 1; @@ -45,11 +42,9 @@ BlockInputStreams InterpreterDAG::executeQueryBlock(DAGQueryBlock & query_block, input_streams_vec, query_block, max_streams, - keep_session_timezone_info || !query_block.isRootQueryBlock(), - dag, + dagContext().keep_session_timezone_info || !query_block.isRootQueryBlock(), subqueries_for_sets, - mpp_exchange_receiver_maps, - log); + mpp_exchange_receiver_maps); return query_block_interpreter.execute(); } @@ -62,19 +57,20 @@ void InterpreterDAG::initMPPExchangeReceiver(const DAGQueryBlock & dag_query_blo if (dag_query_block.source->tp() == tipb::ExecType::TypeExchangeReceiver) { mpp_exchange_receiver_maps[dag_query_block.source_name] = std::make_shared( - std::make_shared(context.getTMTContext().getKVCluster(), - context.getTMTContext().getMPPTaskManager(), - context.getSettings().enable_local_tunnel), + std::make_shared( + context.getTMTContext().getKVCluster(), + context.getTMTContext().getMPPTaskManager(), + context.getSettings().enable_local_tunnel), dag_query_block.source->exchange_receiver(), - dag.getDAGContext().getMPPTaskMeta(), + dagContext().getMPPTaskMeta(), max_streams, - log); + dagContext().log); } } BlockIO InterpreterDAG::execute() { - if (dag.getDAGContext().isMPPTask()) + if (dagContext().isMPPTask()) /// Due to learner read, DAGQueryBlockInterpreter may take a long time to build /// the query plan, so we init mpp exchange receiver before executeQueryBlock initMPPExchangeReceiver(*dag.getRootQueryBlock()); @@ -87,7 +83,7 @@ BlockIO InterpreterDAG::execute() pipeline.streams = streams; /// add union to run in parallel if needed - DAGQueryBlockInterpreter::executeUnion(pipeline, max_streams, log); + executeUnion(pipeline, max_streams, dagContext().log); if (!subqueries_for_sets.empty()) { const Settings & settings = context.getSettingsRef(); @@ -95,8 +91,8 @@ BlockIO InterpreterDAG::execute() pipeline.firstStream(), std::move(subqueries_for_sets), SizeLimits(settings.max_rows_to_transfer, settings.max_bytes_to_transfer, settings.transfer_overflow_mode), - dag.getDAGContext().getMPPTaskId(), - log); + dagContext().getMPPTaskId(), + dagContext().log); } BlockIO res; diff --git a/dbms/src/Flash/Coprocessor/InterpreterDAG.h b/dbms/src/Flash/Coprocessor/InterpreterDAG.h index 53926716a88..fa5baa75f6e 100644 --- a/dbms/src/Flash/Coprocessor/InterpreterDAG.h +++ b/dbms/src/Flash/Coprocessor/InterpreterDAG.h @@ -29,7 +29,7 @@ using RegionPtr = std::shared_ptr; class InterpreterDAG : public IInterpreter { public: - InterpreterDAG(Context & context_, const DAGQuerySource & dag_, const LogWithPrefixPtr & log_); + InterpreterDAG(Context & context_, const DAGQuerySource & dag_); ~InterpreterDAG() = default; @@ -39,19 +39,13 @@ class InterpreterDAG : public IInterpreter BlockInputStreams executeQueryBlock(DAGQueryBlock & query_block, std::vector & subqueries_for_sets); void initMPPExchangeReceiver(const DAGQueryBlock & dag_query_block); -private: - Context & context; + DAGContext & dagContext() const { return *context.getDAGContext(); } + Context & context; const DAGQuerySource & dag; - /// How many streams we ask for storage to produce, and in how many threads we will do further processing. size_t max_streams = 1; - // key: source_name of ExchangeReceiver nodes in dag. std::unordered_map> mpp_exchange_receiver_maps; - - const bool keep_session_timezone_info; - - LogWithPrefixPtr log; }; } // namespace DB diff --git a/dbms/src/Flash/Coprocessor/InterpreterUtils.cpp b/dbms/src/Flash/Coprocessor/InterpreterUtils.cpp index 281f339f193..0cb7d30349b 100644 --- a/dbms/src/Flash/Coprocessor/InterpreterUtils.cpp +++ b/dbms/src/Flash/Coprocessor/InterpreterUtils.cpp @@ -1,4 +1,5 @@ #include +#include #include namespace DB @@ -12,4 +13,31 @@ void restoreConcurrency(DAGPipeline & pipeline, size_t concurrency, const LogWit pipeline.streams.assign(concurrency, shared_query_block_input_stream); } } + +BlockInputStreamPtr combinedNonJoinedDataStream(DAGPipeline & pipeline, size_t max_threads, const LogWithPrefixPtr & log) +{ + BlockInputStreamPtr ret = nullptr; + if (pipeline.streams_with_non_joined_data.size() == 1) + ret = pipeline.streams_with_non_joined_data.at(0); + else if (pipeline.streams_with_non_joined_data.size() > 1) + ret = std::make_shared>(pipeline.streams_with_non_joined_data, nullptr, max_threads, log); + pipeline.streams_with_non_joined_data.clear(); + return ret; +} + +void executeUnion(DAGPipeline & pipeline, size_t max_streams, const LogWithPrefixPtr & log) +{ + if (pipeline.streams.size() == 1 && pipeline.streams_with_non_joined_data.empty()) + return; + auto non_joined_data_stream = combinedNonJoinedDataStream(pipeline, max_streams, log); + if (!pipeline.streams.empty()) + { + pipeline.firstStream() = std::make_shared>(pipeline.streams, non_joined_data_stream, max_streams, log); + pipeline.streams.resize(1); + } + else if (non_joined_data_stream != nullptr) + { + pipeline.streams.push_back(non_joined_data_stream); + } +} } // namespace DB diff --git a/dbms/src/Flash/Coprocessor/InterpreterUtils.h b/dbms/src/Flash/Coprocessor/InterpreterUtils.h index 221377b332a..1f991865515 100644 --- a/dbms/src/Flash/Coprocessor/InterpreterUtils.h +++ b/dbms/src/Flash/Coprocessor/InterpreterUtils.h @@ -6,4 +6,6 @@ namespace DB { void restoreConcurrency(DAGPipeline & pipeline, size_t concurrency, const LogWithPrefixPtr & log); +BlockInputStreamPtr combinedNonJoinedDataStream(DAGPipeline & pipeline, size_t max_threads, const LogWithPrefixPtr & log); +void executeUnion(DAGPipeline & pipeline, size_t max_streams, const LogWithPrefixPtr & log); } // namespace DB diff --git a/dbms/src/Flash/Coprocessor/StreamingDAGResponseWriter.cpp b/dbms/src/Flash/Coprocessor/StreamingDAGResponseWriter.cpp index 242bb29c8e9..cacad61e0b0 100644 --- a/dbms/src/Flash/Coprocessor/StreamingDAGResponseWriter.cpp +++ b/dbms/src/Flash/Coprocessor/StreamingDAGResponseWriter.cpp @@ -32,11 +32,8 @@ StreamingDAGResponseWriter::StreamingDAGResponseWriter( Int64 records_per_chunk_, Int64 batch_send_min_limit_, bool should_send_exec_summary_at_last_, - tipb::EncodeType encode_type_, - std::vector result_field_types_, - DAGContext & dag_context_, - const LogWithPrefixPtr & log_) - : DAGResponseWriter(records_per_chunk_, encode_type_, result_field_types_, dag_context_) + DAGContext & dag_context_) + : DAGResponseWriter(records_per_chunk_, dag_context_) , batch_send_min_limit(batch_send_min_limit_) , should_send_exec_summary_at_last(should_send_exec_summary_at_last_) , exchange_type(exchange_type_) @@ -44,8 +41,6 @@ StreamingDAGResponseWriter::StreamingDAGResponseWriter( , partition_col_ids(std::move(partition_col_ids_)) , collators(std::move(collators_)) { - log = log_ != nullptr ? log_ : std::make_shared(&Poco::Logger::get("StreamingDAGResponseWriter"), ""); - rows_in_blocks = 0; partition_num = writer_->getPartitionNum(); } @@ -62,7 +57,7 @@ void StreamingDAGResponseWriter::finishWrite() template void StreamingDAGResponseWriter::write(const Block & block) { - if (block.columns() != result_field_types.size()) + if (block.columns() != dag_context.result_field_types.size()) throw TiFlashException("Output column size mismatch with field type size", Errors::Coprocessor::Internal); size_t rows = block.rows(); rows_in_blocks += rows; @@ -70,7 +65,7 @@ void StreamingDAGResponseWriter::write(const Block & block) { blocks.push_back(block); } - if ((Int64)rows_in_blocks > (encode_type == tipb::EncodeType::TypeCHBlock ? batch_send_min_limit : records_per_chunk - 1)) + if ((Int64)rows_in_blocks > (dag_context.encode_type == tipb::EncodeType::TypeCHBlock ? batch_send_min_limit : records_per_chunk - 1)) { batchWrite(); } @@ -83,20 +78,20 @@ void StreamingDAGResponseWriter::encodeThenWriteBlocks( tipb::SelectResponse & response) const { std::unique_ptr chunk_codec_stream = nullptr; - if (encode_type == tipb::EncodeType::TypeDefault) + if (dag_context.encode_type == tipb::EncodeType::TypeDefault) { - chunk_codec_stream = std::make_unique()->newCodecStream(result_field_types); + chunk_codec_stream = std::make_unique()->newCodecStream(dag_context.result_field_types); } - else if (encode_type == tipb::EncodeType::TypeChunk) + else if (dag_context.encode_type == tipb::EncodeType::TypeChunk) { - chunk_codec_stream = std::make_unique()->newCodecStream(result_field_types); + chunk_codec_stream = std::make_unique()->newCodecStream(dag_context.result_field_types); } - else if (encode_type == tipb::EncodeType::TypeCHBlock) + else if (dag_context.encode_type == tipb::EncodeType::TypeCHBlock) { - chunk_codec_stream = std::make_unique()->newCodecStream(result_field_types); + chunk_codec_stream = std::make_unique()->newCodecStream(dag_context.result_field_types); } - if (encode_type == tipb::EncodeType::TypeCHBlock) + if (dag_context.encode_type == tipb::EncodeType::TypeCHBlock) { if (dag_context.isMPPTask()) /// broadcast data among TiFlash nodes in MPP { @@ -123,7 +118,7 @@ void StreamingDAGResponseWriter::encodeThenWriteBlocks( } else /// passthrough data to a non-TiFlash node, like sending data to TiSpark { - response.set_encode_type(encode_type); + response.set_encode_type(dag_context.encode_type); if (input_blocks.empty()) { if constexpr (send_exec_summary_at_last) @@ -144,7 +139,7 @@ void StreamingDAGResponseWriter::encodeThenWriteBlocks( } else /// passthrough data to a TiDB node { - response.set_encode_type(encode_type); + response.set_encode_type(dag_context.encode_type); if (input_blocks.empty()) { if constexpr (send_exec_summary_at_last) @@ -197,17 +192,17 @@ void StreamingDAGResponseWriter::partitionAndEncodeThenWriteBlo std::vector responses_row_count(partition_num); for (auto i = 0; i < partition_num; ++i) { - if (encode_type == tipb::EncodeType::TypeDefault) + if (dag_context.encode_type == tipb::EncodeType::TypeDefault) { - chunk_codec_stream[i] = DefaultChunkCodec().newCodecStream(result_field_types); + chunk_codec_stream[i] = DefaultChunkCodec().newCodecStream(dag_context.result_field_types); } - else if (encode_type == tipb::EncodeType::TypeChunk) + else if (dag_context.encode_type == tipb::EncodeType::TypeChunk) { - chunk_codec_stream[i] = ArrowChunkCodec().newCodecStream(result_field_types); + chunk_codec_stream[i] = ArrowChunkCodec().newCodecStream(dag_context.result_field_types); } - else if (encode_type == tipb::EncodeType::TypeCHBlock) + else if (dag_context.encode_type == tipb::EncodeType::TypeCHBlock) { - chunk_codec_stream[i] = CHBlockChunkCodec().newCodecStream(result_field_types); + chunk_codec_stream[i] = CHBlockChunkCodec().newCodecStream(dag_context.result_field_types); } if constexpr (send_exec_summary_at_last) { diff --git a/dbms/src/Flash/Coprocessor/StreamingDAGResponseWriter.h b/dbms/src/Flash/Coprocessor/StreamingDAGResponseWriter.h index a22a88dfb80..496d5f01ed2 100644 --- a/dbms/src/Flash/Coprocessor/StreamingDAGResponseWriter.h +++ b/dbms/src/Flash/Coprocessor/StreamingDAGResponseWriter.h @@ -31,10 +31,7 @@ class StreamingDAGResponseWriter : public DAGResponseWriter Int64 records_per_chunk_, Int64 batch_send_min_limit_, bool should_send_exec_summary_at_last, - tipb::EncodeType encodeType_, - std::vector result_field_types, - DAGContext & dag_context_, - const LogWithPrefixPtr & log_); + DAGContext & dag_context_); void write(const Block & block) override; void finishWrite() override; @@ -55,7 +52,6 @@ class StreamingDAGResponseWriter : public DAGResponseWriter TiDB::TiDBCollators collators; size_t rows_in_blocks; uint16_t partition_num; - LogWithPrefixPtr log; }; } // namespace DB diff --git a/dbms/src/Flash/Coprocessor/UnaryDAGResponseWriter.cpp b/dbms/src/Flash/Coprocessor/UnaryDAGResponseWriter.cpp index 5cc80589054..8176ae1a766 100644 --- a/dbms/src/Flash/Coprocessor/UnaryDAGResponseWriter.cpp +++ b/dbms/src/Flash/Coprocessor/UnaryDAGResponseWriter.cpp @@ -15,25 +15,23 @@ extern const int LOGICAL_ERROR; UnaryDAGResponseWriter::UnaryDAGResponseWriter( tipb::SelectResponse * dag_response_, Int64 records_per_chunk_, - tipb::EncodeType encode_type_, - std::vector result_field_types_, DAGContext & dag_context_) - : DAGResponseWriter(records_per_chunk_, encode_type_, result_field_types_, dag_context_) + : DAGResponseWriter(records_per_chunk_, dag_context_) , dag_response(dag_response_) { - if (encode_type == tipb::EncodeType::TypeDefault) + if (dag_context.encode_type == tipb::EncodeType::TypeDefault) { - chunk_codec_stream = std::make_unique()->newCodecStream(result_field_types); + chunk_codec_stream = std::make_unique()->newCodecStream(dag_context.result_field_types); } - else if (encode_type == tipb::EncodeType::TypeChunk) + else if (dag_context.encode_type == tipb::EncodeType::TypeChunk) { - chunk_codec_stream = std::make_unique()->newCodecStream(result_field_types); + chunk_codec_stream = std::make_unique()->newCodecStream(dag_context.result_field_types); } - else if (encode_type == tipb::EncodeType::TypeCHBlock) + else if (dag_context.encode_type == tipb::EncodeType::TypeCHBlock) { - chunk_codec_stream = std::make_unique()->newCodecStream(result_field_types); + chunk_codec_stream = std::make_unique()->newCodecStream(dag_context.result_field_types); } - dag_response->set_encode_type(encode_type); + dag_response->set_encode_type(dag_context.encode_type); current_records_num = 0; } @@ -70,7 +68,7 @@ void UnaryDAGResponseWriter::finishWrite() void UnaryDAGResponseWriter::write(const Block & block) { - if (block.columns() != result_field_types.size()) + if (block.columns() != dag_context.result_field_types.size()) throw TiFlashException("Output column size mismatch with field type size", Errors::Coprocessor::Internal); if (records_per_chunk == -1) { diff --git a/dbms/src/Flash/Coprocessor/UnaryDAGResponseWriter.h b/dbms/src/Flash/Coprocessor/UnaryDAGResponseWriter.h index 063df0c8814..e30f7ff67f3 100644 --- a/dbms/src/Flash/Coprocessor/UnaryDAGResponseWriter.h +++ b/dbms/src/Flash/Coprocessor/UnaryDAGResponseWriter.h @@ -21,8 +21,6 @@ class UnaryDAGResponseWriter : public DAGResponseWriter UnaryDAGResponseWriter( tipb::SelectResponse * response_, Int64 records_per_chunk_, - tipb::EncodeType encodeType_, - std::vector result_field_types, DAGContext & dag_context_); void write(const Block & block) override; diff --git a/dbms/src/Flash/CoprocessorHandler.cpp b/dbms/src/Flash/CoprocessorHandler.cpp index 241a9281d1d..778ebdf12da 100644 --- a/dbms/src/Flash/CoprocessorHandler.cpp +++ b/dbms/src/Flash/CoprocessorHandler.cpp @@ -74,14 +74,19 @@ grpc::Status CoprocessorHandler::execute() Errors::Coprocessor::Unimplemented); tipb::SelectResponse dag_response; RegionInfoMap regions; - RegionInfoList retry_regions; const std::unordered_set bypass_lock_ts( cop_context.kv_context.resolved_locks().begin(), cop_context.kv_context.resolved_locks().end()); regions.emplace(cop_context.kv_context.region_id(), RegionInfo(cop_context.kv_context.region_id(), cop_context.kv_context.region_epoch().version(), cop_context.kv_context.region_epoch().conf_ver(), GenCopKeyRange(cop_request->ranges()), &bypass_lock_ts)); - DAGDriver driver(cop_context.db_context, dag_request, regions, retry_regions, cop_request->start_ts() > 0 ? cop_request->start_ts() : dag_request.start_ts_fallback(), cop_request->schema_ver(), &dag_response); + + DAGContext dag_context(dag_request); + dag_context.regions_for_local_read = std::move(regions); + dag_context.log = std::make_shared(log, ""); + cop_context.db_context.setDAGContext(&dag_context); + + DAGDriver driver(cop_context.db_context, cop_request->start_ts() > 0 ? cop_request->start_ts() : dag_request.start_ts_fallback(), cop_request->schema_ver(), &dag_response); driver.execute(); cop_response->set_data(dag_response.SerializeAsString()); LOG_DEBUG(log, __PRETTY_FUNCTION__ << ": Handle DAG request done"); diff --git a/dbms/src/Flash/FlashService.cpp b/dbms/src/Flash/FlashService.cpp index 13def9be67d..8495dbd3944 100644 --- a/dbms/src/Flash/FlashService.cpp +++ b/dbms/src/Flash/FlashService.cpp @@ -74,7 +74,7 @@ grpc::Status FlashService::Coprocessor( { return status; } - CoprocessorContext cop_context(context, request->context(), *grpc_context); + CoprocessorContext cop_context(*context, request->context(), *grpc_context); CoprocessorHandler cop_handler(cop_context, request, response); return cop_handler.execute(); }); @@ -108,7 +108,7 @@ ::grpc::Status FlashService::BatchCoprocessor(::grpc::ServerContext * grpc_conte { return status; } - CoprocessorContext cop_context(context, request->context(), *grpc_context); + CoprocessorContext cop_context(*context, request->context(), *grpc_context); BatchCoprocessorHandler cop_handler(cop_context, request, writer); return cop_handler.execute(); }); @@ -164,7 +164,7 @@ ::grpc::Status FlashService::IsAlive(::grpc::ServerContext * grpc_context [[mayb return status; } - auto & tmt_context = context.getTMTContext(); + auto & tmt_context = context->getTMTContext(); response->set_available(tmt_context.checkRunning()); return ::grpc::Status::OK; } @@ -197,7 +197,7 @@ ::grpc::Status FlashService::EstablishMPPConnection(::grpc::ServerContext * grpc return status; } - auto & tmt_context = context.getTMTContext(); + auto & tmt_context = context->getTMTContext(); auto task_manager = tmt_context.getMPPTaskManager(); std::chrono::seconds timeout(10); std::string err_msg; @@ -262,7 +262,7 @@ ::grpc::Status FlashService::CancelMPPTask( response->set_allocated_error(err.release()); return status; } - auto & tmt_context = context.getTMTContext(); + auto & tmt_context = context->getTMTContext(); auto task_manager = tmt_context.getMPPTaskManager(); task_manager->cancelMPPQuery(request->meta().start_ts(), "Receive cancel request from TiDB"); return grpc::Status::OK; @@ -302,7 +302,7 @@ grpc::Status FlashService::BatchCommands( LOG_DEBUG(log, __PRETTY_FUNCTION__ << ": Handling batch commands: " << request.DebugString()); BatchCommandsContext batch_commands_context( - context, + *context, [this](const grpc::ServerContext * grpc_server_context) { return createDBContext(grpc_server_context); }, *grpc_context); BatchCommandsHandler batch_commands_handler(batch_commands_context, request, response); @@ -342,13 +342,13 @@ grpc::Status FlashService::executeInThreadPool(const std::unique_ptr return future.get(); } -std::tuple FlashService::createDBContext(const grpc::ServerContext * grpc_context) const +std::tuple FlashService::createDBContext(const grpc::ServerContext * grpc_context) const { try { /// Create DB context. - Context context = server.context(); - context.setGlobalContext(server.context()); + auto context = std::make_shared(server.context()); + context->setGlobalContext(server.context()); /// Set a bunch of client information. std::string user = getClientMetaVarWithDefault(grpc_context, "user", "default"); @@ -363,12 +363,12 @@ std::tuple FlashService::createDBContext(const grpc::Serv std::string client_ip = peer.substr(pos + 1); Poco::Net::SocketAddress client_address(client_ip); - context.setUser(user, password, client_address, quota_key); + context->setUser(user, password, client_address, quota_key); String query_id = getClientMetaVarWithDefault(grpc_context, "query_id", ""); - context.setCurrentQueryId(query_id); + context->setCurrentQueryId(query_id); - ClientInfo & client_info = context.getClientInfo(); + ClientInfo & client_info = context->getClientInfo(); client_info.query_kind = ClientInfo::QueryKind::INITIAL_QUERY; client_info.interface = ClientInfo::Interface::GRPC; @@ -376,7 +376,7 @@ std::tuple FlashService::createDBContext(const grpc::Serv std::string dag_records_per_chunk_str = getClientMetaVarWithDefault(grpc_context, "dag_records_per_chunk", ""); if (!dag_records_per_chunk_str.empty()) { - context.setSetting("dag_records_per_chunk", dag_records_per_chunk_str); + context->setSetting("dag_records_per_chunk", dag_records_per_chunk_str); } return std::make_tuple(context, grpc::Status::OK); @@ -384,17 +384,17 @@ std::tuple FlashService::createDBContext(const grpc::Serv catch (Exception & e) { LOG_ERROR(log, __PRETTY_FUNCTION__ << ": DB Exception: " << e.message()); - return std::make_tuple(server.context(), grpc::Status(tiflashErrorCodeToGrpcStatusCode(e.code()), e.message())); + return std::make_tuple(std::make_shared(server.context()), grpc::Status(tiflashErrorCodeToGrpcStatusCode(e.code()), e.message())); } catch (const std::exception & e) { LOG_ERROR(log, __PRETTY_FUNCTION__ << ": std exception: " << e.what()); - return std::make_tuple(server.context(), grpc::Status(grpc::StatusCode::INTERNAL, e.what())); + return std::make_tuple(std::make_shared(server.context()), grpc::Status(grpc::StatusCode::INTERNAL, e.what())); } catch (...) { LOG_ERROR(log, __PRETTY_FUNCTION__ << ": other exception"); - return std::make_tuple(server.context(), grpc::Status(grpc::StatusCode::INTERNAL, "other exception")); + return std::make_tuple(std::make_shared(server.context()), grpc::Status(grpc::StatusCode::INTERNAL, "other exception")); } } diff --git a/dbms/src/Flash/FlashService.h b/dbms/src/Flash/FlashService.h index 3f511a43220..a0ba4d9eeab 100644 --- a/dbms/src/Flash/FlashService.h +++ b/dbms/src/Flash/FlashService.h @@ -51,7 +51,7 @@ class FlashService final : public tikvpb::Tikv::Service ::grpc::Status CancelMPPTask(::grpc::ServerContext * context, const ::mpp::CancelTaskRequest * request, ::mpp::CancelTaskResponse * response) override; private: - std::tuple createDBContext(const grpc::ServerContext * grpc_context) const; + std::tuple createDBContext(const grpc::ServerContext * grpc_context) const; // Use executeInThreadPool to submit job to thread pool which return grpc::Status. grpc::Status executeInThreadPool(const std::unique_ptr & pool, std::function); diff --git a/dbms/src/Flash/Mpp/MPPHandler.cpp b/dbms/src/Flash/Mpp/MPPHandler.cpp index 4c8be4b0501..20851cd9932 100644 --- a/dbms/src/Flash/Mpp/MPPHandler.cpp +++ b/dbms/src/Flash/Mpp/MPPHandler.cpp @@ -25,7 +25,7 @@ void MPPHandler::handleError(const MPPTaskPtr & task, String error) } } // execute is responsible for making plan , register tasks and tunnels and start the running thread. -grpc::Status MPPHandler::execute(Context & context, mpp::DispatchTaskResponse * response) +grpc::Status MPPHandler::execute(const ContextPtr & context, mpp::DispatchTaskResponse * response) { MPPTaskPtr task = nullptr; current_memory_tracker = nullptr; /// to avoid reusing threads in gRPC @@ -34,8 +34,8 @@ grpc::Status MPPHandler::execute(Context & context, mpp::DispatchTaskResponse * Stopwatch stopwatch; task = MPPTask::newTask(task_request.meta(), context); - auto remote_regions = task->prepare(task_request); - for (const auto & region : remote_regions) + task->prepare(task_request); + for (const auto & region : context->getDAGContext()->getRegionsForRemoteRead()) { auto * retry_region = response->add_retry_regions(); retry_region->set_id(region.region_id); diff --git a/dbms/src/Flash/Mpp/MPPHandler.h b/dbms/src/Flash/Mpp/MPPHandler.h index 6a32cbd5f4d..1578cab785c 100644 --- a/dbms/src/Flash/Mpp/MPPHandler.h +++ b/dbms/src/Flash/Mpp/MPPHandler.h @@ -17,7 +17,7 @@ class MPPHandler : task_request(task_request_) , log(&Poco::Logger::get("MPPHandler")) {} - grpc::Status execute(Context & context, mpp::DispatchTaskResponse * response); + grpc::Status execute(const ContextPtr & context, mpp::DispatchTaskResponse * response); void handleError(const MPPTaskPtr & task, String error); }; diff --git a/dbms/src/Flash/Mpp/MPPTask.cpp b/dbms/src/Flash/Mpp/MPPTask.cpp index 9d8afd0f34e..d2b81da031d 100644 --- a/dbms/src/Flash/Mpp/MPPTask.cpp +++ b/dbms/src/Flash/Mpp/MPPTask.cpp @@ -38,7 +38,7 @@ extern const char exception_during_mpp_write_err_to_tunnel[]; extern const char force_no_local_region_for_mpp_task[]; } // namespace FailPoints -MPPTask::MPPTask(const mpp::TaskMeta & meta_, const Context & context_) +MPPTask::MPPTask(const mpp::TaskMeta & meta_, const ContextPtr & context_) : context(context_) , meta(meta_) , id(meta.start_ts(), meta.task_id()) @@ -135,10 +135,13 @@ bool needRemoteRead(const RegionInfo & region_info, const TMTContext & tmt_conte return meta_snap.ver != region_info.region_version; } -std::vector MPPTask::prepare(const mpp::DispatchTaskRequest & task_request) +void MPPTask::prepare(const mpp::DispatchTaskRequest & task_request) { + RegionInfoMap local_regions; + RegionInfoList remote_regions; + dag_req = getDAGRequestFromStringWithRetry(task_request.encoded_plan()); - TMTContext & tmt_context = context.getTMTContext(); + TMTContext & tmt_context = context->getTMTContext(); /// MPP task will only use key ranges in mpp::DispatchTaskRequest::regions. The ones defined in tipb::TableScan /// will never be used and can be removed later. /// Each MPP task will contain at most one TableScan operator belonging to one table. For those tasks without @@ -172,25 +175,25 @@ std::vector MPPTask::prepare(const mpp::DispatchTaskRequest & task_r auto schema_ver = task_request.schema_ver(); auto start_ts = task_request.meta().start_ts(); - context.setSetting("read_tso", start_ts); - context.setSetting("schema_version", schema_ver); + context->setSetting("read_tso", start_ts); + context->setSetting("schema_version", schema_ver); if (unlikely(task_request.timeout() < 0)) { /// this is only for test - context.setSetting("mpp_task_timeout", static_cast(5)); - context.setSetting("mpp_task_running_timeout", static_cast(10)); + context->setSetting("mpp_task_timeout", static_cast(5)); + context->setSetting("mpp_task_running_timeout", static_cast(10)); } else { - context.setSetting("mpp_task_timeout", task_request.timeout()); + context->setSetting("mpp_task_timeout", task_request.timeout()); if (task_request.timeout() > 0) { /// in the implementation, mpp_task_timeout is actually the task writing tunnel timeout /// so make the mpp_task_running_timeout a little bigger than mpp_task_timeout - context.setSetting("mpp_task_running_timeout", task_request.timeout() + 30); + context->setSetting("mpp_task_running_timeout", task_request.timeout() + 30); } } - context.getTimezoneInfo().resetByDAGRequest(dag_req); + context->getTimezoneInfo().resetByDAGRequest(dag_req); bool is_root_mpp_task = false; const auto & exchange_sender = dag_req.root_executor().exchange_sender(); @@ -206,8 +209,10 @@ std::vector MPPTask::prepare(const mpp::DispatchTaskRequest & task_r is_root_mpp_task = task_meta.task_id() == -1; } dag_context = std::make_unique(dag_req, task_request.meta(), is_root_mpp_task); - dag_context->mpp_task_log = log; - context.setDAGContext(dag_context.get()); + dag_context->log = log; + dag_context->regions_for_local_read = std::move(local_regions); + dag_context->regions_for_remote_read = std::move(remote_regions); + context->setDAGContext(dag_context.get()); if (dag_context->isRootMPPTask()) { @@ -233,8 +238,8 @@ std::vector MPPTask::prepare(const mpp::DispatchTaskRequest & task_r mpp::TaskMeta task_meta; if (!task_meta.ParseFromString(exchange_sender.encoded_task_meta(i))) throw TiFlashException("Failed to decode task meta info in ExchangeSender", Errors::Coprocessor::BadRequest); - bool is_local = context.getSettings().enable_local_tunnel && meta.address() == task_meta.address(); - MPPTunnelPtr tunnel = std::make_shared(task_meta, task_request.meta(), timeout, task_cancelled_callback, context.getSettings().max_threads, is_local, log); + bool is_local = context->getSettings().enable_local_tunnel && meta.address() == task_meta.address(); + MPPTunnelPtr tunnel = std::make_shared(task_meta, task_request.meta(), timeout, task_cancelled_callback, context->getSettings().max_threads, is_local, log); LOG_DEBUG(log, "begin to register the tunnel " << tunnel->id()); registerTunnel(MPPTaskId{task_meta.start_ts(), task_meta.task_id()}, tunnel); tunnel_set->addTunnel(tunnel); @@ -260,15 +265,13 @@ std::vector MPPTask::prepare(const mpp::DispatchTaskRequest & task_r { throw TiFlashException(std::string(__PRETTY_FUNCTION__) + ": Failed to register MPP Task", Errors::Coprocessor::BadRequest); } - - return remote_regions; } void MPPTask::preprocess() { auto start_time = Clock::now(); - DAGQuerySource dag(context, local_regions, remote_regions, dag_req, log, true); - io = executeQuery(dag, context, false, QueryProcessingStage::Complete); + DAGQuerySource dag(*context); + io = executeQuery(dag, *context, false, QueryProcessingStage::Complete); auto end_time = Clock::now(); dag_context->compile_time_ns = std::chrono::duration_cast(end_time - start_time).count(); } @@ -347,7 +350,7 @@ void MPPTask::runImpl() auto throughput = dag_context->getTableScanThroughput(); if (throughput.first) GET_METRIC(tiflash_storage_logical_throughput_bytes).Observe(throughput.second); - auto process_info = context.getProcessListElement()->getInfo(); + auto process_info = context->getProcessListElement()->getInfo(); auto peak_memory = process_info.peak_memory_usage > 0 ? process_info.peak_memory_usage : 0; GET_METRIC(tiflash_coprocessor_request_memory_usage, type_run_mpp_task).Observe(peak_memory); } @@ -402,7 +405,7 @@ void MPPTask::cancel(const String & reason) } else if (previous_status == RUNNING && switchStatus(RUNNING, CANCELLED)) { - context.getProcessList().sendCancelToQuery(context.getCurrentQueryId(), context.getClientInfo().current_user, true); + context->getProcessList().sendCancelToQuery(context->getCurrentQueryId(), context->getClientInfo().current_user, true); closeAllTunnels(reason); /// runImpl is running, leave remaining work to runImpl LOG_WARNING(log, "Finish cancel task from running"); diff --git a/dbms/src/Flash/Mpp/MPPTask.h b/dbms/src/Flash/Mpp/MPPTask.h index 4a7f70c9e93..8ed7629bbb1 100644 --- a/dbms/src/Flash/Mpp/MPPTask.h +++ b/dbms/src/Flash/Mpp/MPPTask.h @@ -43,7 +43,7 @@ class MPPTask : public std::enable_shared_from_this void cancel(const String & reason); - std::vector prepare(const mpp::DispatchTaskRequest & task_request); + void prepare(const mpp::DispatchTaskRequest & task_request); void preprocess(); @@ -57,7 +57,7 @@ class MPPTask : public std::enable_shared_from_this ~MPPTask(); private: - MPPTask(const mpp::TaskMeta & meta_, const Context & context_); + MPPTask(const mpp::TaskMeta & meta_, const ContextPtr & context_); void runImpl(); @@ -73,12 +73,9 @@ class MPPTask : public std::enable_shared_from_this bool switchStatus(TaskStatus from, TaskStatus to); - RegionInfoMap local_regions; - RegionInfoList remote_regions; - tipb::DAGRequest dag_req; - Context context; + ContextPtr context; /// store io in MPPTask to keep the life cycle of memory_tracker for the current query /// BlockIO contains some information stored in Context, so need deconstruct it before Context BlockIO io; diff --git a/dbms/src/Flash/Mpp/getMPPTaskLog.cpp b/dbms/src/Flash/Mpp/getMPPTaskLog.cpp new file mode 100644 index 00000000000..c4d42e31476 --- /dev/null +++ b/dbms/src/Flash/Mpp/getMPPTaskLog.cpp @@ -0,0 +1,24 @@ +#include + +namespace DB +{ +LogWithPrefixPtr getMPPTaskLog(const String & name, const MPPTaskId & mpp_task_id_) +{ + return std::make_shared(&Poco::Logger::get(name), mpp_task_id_.toString()); +} + +LogWithPrefixPtr getMPPTaskLog(const DAGContext & dag_context, const String & name) +{ + return getMPPTaskLog(dag_context.log, name, dag_context.getMPPTaskId()); +} + +LogWithPrefixPtr getMPPTaskLog(const LogWithPrefixPtr & log, const String & name, const MPPTaskId & mpp_task_id_) +{ + if (log == nullptr) + { + return getMPPTaskLog(name, mpp_task_id_); + } + + return log->append(name); +} +} // namespace DB diff --git a/dbms/src/Flash/Mpp/getMPPTaskLog.h b/dbms/src/Flash/Mpp/getMPPTaskLog.h index 783b6530dec..40cbb03c17d 100644 --- a/dbms/src/Flash/Mpp/getMPPTaskLog.h +++ b/dbms/src/Flash/Mpp/getMPPTaskLog.h @@ -1,24 +1,16 @@ #pragma once #include +#include #include #include namespace DB { -inline LogWithPrefixPtr getMPPTaskLog(const String & name, const MPPTaskId & mpp_task_id_) -{ - return std::make_shared(&Poco::Logger::get(name), mpp_task_id_.toString()); -} +LogWithPrefixPtr getMPPTaskLog(const String & name, const MPPTaskId & mpp_task_id_); -inline LogWithPrefixPtr getMPPTaskLog(const LogWithPrefixPtr & log, const String & name, const MPPTaskId & mpp_task_id_ = MPPTaskId::unknown_mpp_task_id) -{ - if (log == nullptr) - { - return getMPPTaskLog(name, mpp_task_id_); - } +LogWithPrefixPtr getMPPTaskLog(const DAGContext & dag_context, const String & name); - return log->append(name); -} +LogWithPrefixPtr getMPPTaskLog(const LogWithPrefixPtr & log, const String & name, const MPPTaskId & mpp_task_id_ = MPPTaskId::unknown_mpp_task_id); } // namespace DB diff --git a/dbms/src/Interpreters/Context.h b/dbms/src/Interpreters/Context.h index 167325d694e..b3c0f22efa9 100644 --- a/dbms/src/Interpreters/Context.h +++ b/dbms/src/Interpreters/Context.h @@ -467,6 +467,8 @@ class Context void scheduleCloseSession(const SessionKey & key, std::chrono::steady_clock::duration timeout); }; +using ContextPtr = std::shared_ptr; + /// Puts an element into the map, erases it in the destructor. /// If the element already exists in the map, throws an exception containing provided message. diff --git a/dbms/src/Storages/StorageMerge.cpp b/dbms/src/Storages/StorageMerge.cpp index 8798d48d059..9b46899125f 100644 --- a/dbms/src/Storages/StorageMerge.cpp +++ b/dbms/src/Storages/StorageMerge.cpp @@ -293,7 +293,7 @@ BlockInputStreams StorageMerge::read( BlockInputStreamPtr stream = streams.size() > 1 ? std::make_shared( streams, - context.getDAGContext() ? context.getDAGContext()->mpp_task_log : nullptr) + context.getDAGContext() ? context.getDAGContext()->log : nullptr) : streams[0]; if (has_table_virtual_column) diff --git a/dbms/src/Storages/tests/gtest_filter_parser.cpp b/dbms/src/Storages/tests/gtest_filter_parser.cpp index a2ace785a9c..40fceb883fe 100644 --- a/dbms/src/Storages/tests/gtest_filter_parser.cpp +++ b/dbms/src/Storages/tests/gtest_filter_parser.cpp @@ -71,7 +71,7 @@ DM::RSOperatorPtr FilterParserTest::generateRsOperator(const String table_info_j DAGContext dag_context(dag_request); ctx.setDAGContext(&dag_context); // Don't care about regions information in this test - DAGQuerySource dag(ctx, /*regions*/ RegionInfoMap{}, /*retry_regions*/ RegionInfoList{}, dag_request, std::make_shared(log, ""), false); + DAGQuerySource dag(ctx); auto query_block = *dag.getRootQueryBlock(); std::vector conditions; if (query_block.children[0]->selection != nullptr) From 560184c04751524e620a4d25b7fbcb503257f8e8 Mon Sep 17 00:00:00 2001 From: Schrodinger ZHU Yifan Date: Tue, 14 Dec 2021 00:44:35 +0800 Subject: [PATCH 2/4] apply fixes for llvm (#3640) --- dbms/src/DataStreams/PKColumnIterator.hpp | 53 ------------ .../RangesFilterBlockInputStream.cpp | 81 ------------------- .../RangesFilterBlockInputStream.h | 41 ---------- dbms/src/Debug/dbgFuncMockRaftSnapshot.cpp | 3 + dbms/src/Functions/FunctionsConversion.h | 11 +-- dbms/src/Functions/FunctionsTiDBConversion.h | 11 ++- dbms/src/Interpreters/Context.h | 8 ++ dbms/src/Interpreters/Quota.h | 67 ++++++++------- dbms/src/Storages/Transaction/TiDB.h | 2 + format-diff.py | 1 + 10 files changed, 60 insertions(+), 218 deletions(-) delete mode 100644 dbms/src/DataStreams/PKColumnIterator.hpp delete mode 100644 dbms/src/DataStreams/RangesFilterBlockInputStream.cpp delete mode 100644 dbms/src/DataStreams/RangesFilterBlockInputStream.h diff --git a/dbms/src/DataStreams/PKColumnIterator.hpp b/dbms/src/DataStreams/PKColumnIterator.hpp deleted file mode 100644 index ee4f5aa05a0..00000000000 --- a/dbms/src/DataStreams/PKColumnIterator.hpp +++ /dev/null @@ -1,53 +0,0 @@ -#pragma once - -#include - -namespace DB -{ -struct PKColumnIterator : public std::iterator -{ - PKColumnIterator & operator++() - { - ++pos; - return *this; - } - - PKColumnIterator & operator--() - { - --pos; - return *this; - } - - PKColumnIterator & operator=(const PKColumnIterator & itr) - { - copy(itr); - return *this; - } - - UInt64 operator*() const { return column->getUInt(pos); } - - size_t operator-(const PKColumnIterator & itr) const { return pos - itr.pos; } - - PKColumnIterator(const int pos_, const IColumn * column_) : pos(pos_), column(column_) {} - - PKColumnIterator(const PKColumnIterator & itr) { copy(itr); } - - void operator+=(size_t n) { pos += n; } - - size_t pos; - const IColumn * column; - -private: - inline void copy(const PKColumnIterator & itr) - { - pos = itr.pos; - column = itr.column; - } -}; - -template -inline bool PkCmp(const UInt64 & a, const TiKVHandle::Handle & b) -{ - return static_cast(a) < b; -} -} diff --git a/dbms/src/DataStreams/RangesFilterBlockInputStream.cpp b/dbms/src/DataStreams/RangesFilterBlockInputStream.cpp deleted file mode 100644 index fe7c15eda15..00000000000 --- a/dbms/src/DataStreams/RangesFilterBlockInputStream.cpp +++ /dev/null @@ -1,81 +0,0 @@ -#include -#include -#include -#include - -namespace DB -{ - -namespace ErrorCodes -{ -extern const int LOGICAL_ERROR; -} - -template -Block RangesFilterBlockInputStream::readImpl() -{ - static const auto func_cmp = PkCmp; - - while (true) - { - Block block = input->read(); - if (!block) - return block; - - const ColumnWithTypeAndName & handle_column = block.getByPosition(handle_column_index); - const auto * column = handle_column.column.get(); - - size_t rows = block.rows(); - - auto handle_begin = static_cast(column->getUInt(0)); - auto handle_end = static_cast(column->getUInt(rows - 1)); - - if (handle_begin >= ranges.second || ranges.first > handle_end) - continue; - - if (handle_begin >= ranges.first) - { - if (handle_end < ranges.second) - { - return block; - } - else - { - size_t pos = std::lower_bound(PKColumnIterator(0, column), PKColumnIterator(rows, column), ranges.second, func_cmp).pos; - size_t pop_num = rows - pos; - for (size_t i = 0; i < block.columns(); i++) - { - ColumnWithTypeAndName & ori_column = block.getByPosition(i); - MutableColumnPtr mutable_holder = (*std::move(ori_column.column)).mutate(); - mutable_holder->popBack(pop_num); - ori_column.column = std::move(mutable_holder); - } - } - } - else - { - size_t pos_begin = std::lower_bound(PKColumnIterator(0, column), PKColumnIterator(rows, column), ranges.first, func_cmp).pos; - size_t pos_end = rows; - if (handle_end >= ranges.second) - pos_end = std::lower_bound(PKColumnIterator(0, column), PKColumnIterator(rows, column), ranges.second, func_cmp).pos; - - size_t len = pos_end - pos_begin; - if (!len) - continue; - for (size_t i = 0; i < block.columns(); i++) - { - ColumnWithTypeAndName & ori_column = block.getByPosition(i); - auto new_column = ori_column.column->cloneEmpty(); - new_column->insertRangeFrom(*ori_column.column, pos_begin, len); - ori_column.column = std::move(new_column); - } - } - - return block; - } -} - -template class RangesFilterBlockInputStream; -template class RangesFilterBlockInputStream; - -} // namespace DB diff --git a/dbms/src/DataStreams/RangesFilterBlockInputStream.h b/dbms/src/DataStreams/RangesFilterBlockInputStream.h deleted file mode 100644 index 8b4d1c60b4e..00000000000 --- a/dbms/src/DataStreams/RangesFilterBlockInputStream.h +++ /dev/null @@ -1,41 +0,0 @@ -#pragma once - -#include -#include -#include - -namespace DB -{ - -template -class RangesFilterBlockInputStream : public IProfilingBlockInputStream -{ - using Handle = TiKVHandle::Handle; - -public: - RangesFilterBlockInputStream( - const BlockInputStreamPtr & input_, const HandleRange & ranges_, const size_t handle_column_index_) - : input(input_), ranges(ranges_), handle_column_index(handle_column_index_) - {} - -protected: - Block getHeader() const override { return input->getHeader(); } - - bool isGroupedOutput() const override { return input->isGroupedOutput(); } - - bool isSortedOutput() const override { return input->isSortedOutput(); } - - const SortDescription & getSortDescription() const override { return input->getSortDescription(); } - - String getName() const override { return "RangesFilter"; } - - Block readImpl() override; - -private: - BlockInputStreamPtr input; - const HandleRange ranges; - const size_t handle_column_index; - Poco::Logger * log = &Poco::Logger::get("RangesFilterBlockInputStream"); -}; - -} // namespace DB diff --git a/dbms/src/Debug/dbgFuncMockRaftSnapshot.cpp b/dbms/src/Debug/dbgFuncMockRaftSnapshot.cpp index 21c3f0bf496..45e0e0dbcb6 100644 --- a/dbms/src/Debug/dbgFuncMockRaftSnapshot.cpp +++ b/dbms/src/Debug/dbgFuncMockRaftSnapshot.cpp @@ -236,6 +236,9 @@ struct MockSSTReader struct Data : std::vector> { Data(const Data &) = delete; + Data & operator=(const Data &) = delete; + Data(Data &&) = default; + Data & operator=(Data &&) = default; Data() = default; }; diff --git a/dbms/src/Functions/FunctionsConversion.h b/dbms/src/Functions/FunctionsConversion.h index 3b2baf1276c..6f056d51a7d 100644 --- a/dbms/src/Functions/FunctionsConversion.h +++ b/dbms/src/Functions/FunctionsConversion.h @@ -1926,11 +1926,12 @@ struct ToIntMonotonicity if (checkDataType(&type) || checkDataType(&type)) { - Float64 left_float = left.get(); - Float64 right_float = right.get(); - - if (left_float >= std::numeric_limits::min() && left_float <= std::numeric_limits::max() - && right_float >= std::numeric_limits::min() && right_float <= std::numeric_limits::max()) + auto left_float = left.get(); + auto right_float = right.get(); + auto float_min = static_cast(std::numeric_limits::min()); + auto float_max = static_cast(std::numeric_limits::max()); + if (left_float >= float_min && left_float <= float_max + && right_float >= float_min && right_float <= float_max) return {true}; return {}; diff --git a/dbms/src/Functions/FunctionsTiDBConversion.h b/dbms/src/Functions/FunctionsTiDBConversion.h index cbf6563935d..d4c0ec6742c 100644 --- a/dbms/src/Functions/FunctionsTiDBConversion.h +++ b/dbms/src/Functions/FunctionsTiDBConversion.h @@ -258,12 +258,13 @@ struct TiDBConvertToInteger return static_cast(0); return static_cast(rounded_value); } - if (rounded_value > std::numeric_limits::max()) + auto field_max = static_cast(std::numeric_limits::max()); + if (rounded_value > field_max) { context.getDAGContext()->handleOverflowError("Cast real as integer", Errors::Types::Truncated); return std::numeric_limits::max(); } - else if (rounded_value == std::numeric_limits::max()) + else if (rounded_value == field_max) { context.getDAGContext()->handleOverflowError("cast real as int", Errors::Types::Truncated); return std::numeric_limits::max(); @@ -276,12 +277,14 @@ struct TiDBConvertToInteger static std::enable_if_t, ToFieldType> toInt(const T & value, const Context & context) { T rounded_value = std::round(value); - if (rounded_value < std::numeric_limits::min()) + auto field_min = static_cast(std::numeric_limits::min()); + auto field_max = static_cast(std::numeric_limits::max()); + if (rounded_value < field_min) { context.getDAGContext()->handleOverflowError("cast real as int", Errors::Types::Truncated); return std::numeric_limits::min(); } - if (rounded_value >= std::numeric_limits::max()) + if (rounded_value >= field_max) { context.getDAGContext()->handleOverflowError("cast real as int", Errors::Types::Truncated); return std::numeric_limits::max(); diff --git a/dbms/src/Interpreters/Context.h b/dbms/src/Interpreters/Context.h index b3c0f22efa9..2d2ce0e7548 100644 --- a/dbms/src/Interpreters/Context.h +++ b/dbms/src/Interpreters/Context.h @@ -7,7 +7,15 @@ #include #include #include + +#ifdef __clang__ +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wdeprecated-declarations" +#include +#pragma clang diagnostic pop +#else #include +#endif #include #include diff --git a/dbms/src/Interpreters/Quota.h b/dbms/src/Interpreters/Quota.h index 5fb69fc123a..51b97531aac 100644 --- a/dbms/src/Interpreters/Quota.h +++ b/dbms/src/Interpreters/Quota.h @@ -1,25 +1,21 @@ #pragma once -#include -#include -#include -#include - +#include +#include +#include +#include #include - -#include #include +#include -#include - -#include -#include -#include +#include +#include +#include +#include namespace DB { - /** Quota for resources consumption for specific interval. * Used to limit resource usage by user. * Quota is applied "softly" - could be slightly exceed, because it is checked usually only on each block of processed data. @@ -30,17 +26,17 @@ namespace DB */ /// Used both for maximum allowed values and for counters of current accumulated values. -template /// either size_t or std::atomic +template /// either size_t or std::atomic struct QuotaValues { /// Zero values (for maximums) means no limit. - Counter queries; /// Number of queries. - Counter errors; /// Number of queries with exceptions. - Counter result_rows; /// Number of rows returned as result. - Counter result_bytes; /// Number of bytes returned as result. - Counter read_rows; /// Number of rows read from tables. - Counter read_bytes; /// Number of bytes read from tables. - Counter execution_time_usec; /// Total amount of query execution time in microseconds. + Counter queries; /// Number of queries. + Counter errors; /// Number of queries with exceptions. + Counter result_rows; /// Number of rows returned as result. + Counter result_bytes; /// Number of bytes returned as result. + Counter read_rows; /// Number of rows read from tables. + Counter read_bytes; /// Number of bytes read from tables. + Counter execution_time_usec; /// Total amount of query execution time in microseconds. QuotaValues() { @@ -65,7 +61,7 @@ struct QuotaValues void initFromConfig(const String & config_elem, Poco::Util::AbstractConfiguration & config); - bool operator== (const QuotaValues & rhs) const + bool operator==(const QuotaValues & rhs) const { return tuple() == rhs.tuple(); } @@ -101,15 +97,17 @@ struct QuotaForInterval { constexpr static const char * DEFAULT_QUOTA_NAME = "default"; - std::atomic rounded_time {0}; + std::atomic rounded_time{0}; size_t duration = 0; bool randomize = false; - time_t offset = 0; /// Offset of interval for randomization (to avoid DoS if intervals for many users end at one time). + time_t offset = 0; /// Offset of interval for randomization (to avoid DoS if intervals for many users end at one time). QuotaValues max; QuotaValues> used; QuotaForInterval() = default; - QuotaForInterval(time_t duration_) : duration(duration_) {} + QuotaForInterval(time_t duration_) + : duration(duration_) + {} void initFromConfig(const String & config_elem, time_t duration_, bool randomize_, time_t offset_, Poco::Util::AbstractConfiguration & config); @@ -129,14 +127,14 @@ struct QuotaForInterval String toString() const; /// Only compare configuration, not accumulated (used) values or random offsets. - bool operator== (const QuotaForInterval & rhs) const + bool operator==(const QuotaForInterval & rhs) const { return randomize == rhs.randomize && duration == rhs.duration && max == rhs.max; } - QuotaForInterval & operator= (const QuotaForInterval & rhs) + QuotaForInterval & operator=(const QuotaForInterval & rhs) { rounded_time.store(rhs.rounded_time.load(std::memory_order_relaxed)); duration = rhs.duration; @@ -155,8 +153,7 @@ struct QuotaForInterval private: /// Reset counters of used resources, if interval for quota is expired. void updateTime(time_t current_time); - void check(size_t max_amount, size_t used_amount, - const String & quota_name, const String & user_name, const char * resource_name); + void check(size_t max_amount, size_t used_amount, const String & quota_name, const String & user_name, const char * resource_name); }; @@ -172,11 +169,13 @@ class QuotaForIntervals Container cont; std::string quota_name; - std::string user_name; /// user name is set only for current counters for user, not for object that contain maximum values (limits). + std::string user_name; /// user name is set only for current counters for user, not for object that contain maximum values (limits). public: QuotaForIntervals(const std::string & quota_name_, const std::string & user_name_) - : quota_name(quota_name_), user_name(user_name_) {} + : quota_name(quota_name_) + , user_name(user_name_) + {} QuotaForIntervals(const QuotaForIntervals & other, const std::string & user_name_) : QuotaForIntervals(other) @@ -186,6 +185,7 @@ class QuotaForIntervals QuotaForIntervals() = default; QuotaForIntervals(const QuotaForIntervals & other) = default; + QuotaForIntervals & operator=(const QuotaForIntervals & other) = default; /// Is there at least one interval for counting quota? bool empty() const @@ -257,8 +257,7 @@ class Quotas public: void loadFromConfig(Poco::Util::AbstractConfiguration & config); - QuotaForIntervalsPtr get(const String & name, const String & quota_key, - const String & user_name, const Poco::Net::IPAddress & ip); + QuotaForIntervalsPtr get(const String & name, const String & quota_key, const String & user_name, const Poco::Net::IPAddress & ip); }; -} +} // namespace DB diff --git a/dbms/src/Storages/Transaction/TiDB.h b/dbms/src/Storages/Transaction/TiDB.h index 486155f6603..09a7d69c75f 100644 --- a/dbms/src/Storages/Transaction/TiDB.h +++ b/dbms/src/Storages/Transaction/TiDB.h @@ -315,6 +315,8 @@ struct TableInfo TableInfo(const TableInfo &) = default; + TableInfo & operator=(const TableInfo &) = default; + explicit TableInfo(const String & table_info_json); String serialize() const; diff --git a/format-diff.py b/format-diff.py index b7090713cdd..1571d24744d 100755 --- a/format-diff.py +++ b/format-diff.py @@ -69,6 +69,7 @@ def main(): if diff_res: print('Error: found files NOT formatted') print(''.join(diff_res)) + print(''.join(run_cmd('git diff'))) exit(-1) else: print("Format check passed") From 4b310e002f774ab1a76fbf43f35945557383a537 Mon Sep 17 00:00:00 2001 From: Yilong Li Date: Tue, 14 Dec 2021 01:32:35 +0800 Subject: [PATCH 3/4] add unit test cases for casting integers (#3494) --- dbms/src/DataTypes/DataTypeDecimal.h | 8 +- dbms/src/Functions/FunctionsTiDBConversion.h | 7 +- dbms/src/Functions/tests/gtest_coalesce.cpp | 2 +- dbms/src/Functions/tests/gtest_logical.cpp | 12 +- .../Functions/tests/gtest_tidb_conversion.cpp | 758 +++++++++++++++++- dbms/src/TestUtils/FunctionTestUtils.cpp | 26 +- dbms/src/TestUtils/FunctionTestUtils.h | 4 + 7 files changed, 767 insertions(+), 50 deletions(-) diff --git a/dbms/src/DataTypes/DataTypeDecimal.h b/dbms/src/DataTypes/DataTypeDecimal.h index 469c81efd4c..26bd6bf4dac 100644 --- a/dbms/src/DataTypes/DataTypeDecimal.h +++ b/dbms/src/DataTypes/DataTypeDecimal.h @@ -15,7 +15,8 @@ namespace ErrorCodes extern const int ARGUMENT_OUT_OF_BOUND; } -// Implements Decimal(P, S), where P is precision, S is scale. +// Implements Decimal(P, S), where P is precision (significant digits), and S is scale (digits following the decimal point). +// For example, Decimal(5, 2) can represent numbers from -999.99 to 999.99 // Maximum precisions for underlying types are: // Int32 9 // Int64 18 @@ -42,9 +43,10 @@ class DataTypeDecimal final : public IDataType static constexpr size_t maxPrecision() { return maxDecimalPrecision(); } - // If scale is omitted, the default is 0. If precision is omitted, the default is 10. + // default values DataTypeDecimal() - : DataTypeDecimal(10, 0) + : precision(10) + , scale(0) {} DataTypeDecimal(size_t precision_, size_t scale_) diff --git a/dbms/src/Functions/FunctionsTiDBConversion.h b/dbms/src/Functions/FunctionsTiDBConversion.h index d4c0ec6742c..c206446e764 100644 --- a/dbms/src/Functions/FunctionsTiDBConversion.h +++ b/dbms/src/Functions/FunctionsTiDBConversion.h @@ -806,7 +806,7 @@ struct TiDBConvertToFloat } }; -/// cast int/real/decimal/time/string as decimal +/// cast int/real/decimal/enum/string/time/string as decimal // todo TiKV does not check unsigned flag but TiDB checks, currently follow TiKV's code, maybe changed latter template struct TiDBConvertToDecimal @@ -1118,7 +1118,7 @@ struct TiDBConvertToDecimal return static_cast(is_negative ? -v : v); } - /// cast int/real/time/decimal as decimal + /// cast int/real/enum/string/time/decimal as decimal static void execute(Block & block, const ColumnNumbers & arguments, size_t result, PrecType prec [[maybe_unused]], ScaleType scale, bool, const tipb::FieldType &, const Context & context) { size_t size = block.getByPosition(arguments[0]).column->size(); @@ -1927,7 +1927,7 @@ class FunctionTiDBCast final : public IFunctionBase return createWrapper(to_type); if (checkAndGetDataType(from_type.get())) return createWrapper(to_type); - if (const auto from_actual_type = checkAndGetDataType(from_type.get())) + if (checkAndGetDataType(from_type.get())) return createWrapper(to_type); // todo support convert to duration/json type @@ -2090,6 +2090,7 @@ class FunctionBuilderTiDBCast : public IFunctionBuilder return std::make_shared(context, name, std::move(monotonicity), data_types, return_type, in_union, tidb_tp); } + // use the last const string column's value as the return type name, in string representation like "Float64" DataTypePtr getReturnTypeImpl(const ColumnsWithTypeAndName & arguments) const override { const auto * type_col = checkAndGetColumnConst(arguments.back().column.get()); diff --git a/dbms/src/Functions/tests/gtest_coalesce.cpp b/dbms/src/Functions/tests/gtest_coalesce.cpp index 5d338e65c1f..ea7463ea0df 100644 --- a/dbms/src/Functions/tests/gtest_coalesce.cpp +++ b/dbms/src/Functions/tests/gtest_coalesce.cpp @@ -28,7 +28,7 @@ try executeFunction( func_name, createColumn>({"a"}), - createOnlyNullColumn(1))); + createOnlyNullColumnConst(1))); } CATCH } // namespace DB::tests diff --git a/dbms/src/Functions/tests/gtest_logical.cpp b/dbms/src/Functions/tests/gtest_logical.cpp index d03c6c4b491..5d82d2c8b0b 100644 --- a/dbms/src/Functions/tests/gtest_logical.cpp +++ b/dbms/src/Functions/tests/gtest_logical.cpp @@ -42,7 +42,7 @@ try createColumn>({{}, 0}), executeFunction( func_name, - createOnlyNullColumn(2), + createOnlyNullColumnConst(2), createColumn>({1, 0}))); } CATCH @@ -78,7 +78,7 @@ try createColumn>({1, {}}), executeFunction( func_name, - createOnlyNullColumn(2), + createOnlyNullColumnConst(2), createColumn>({1, 0}))); } CATCH @@ -111,10 +111,10 @@ try createConstColumn>(1, 1))); // only null ASSERT_COLUMN_EQ( - createOnlyNullColumn(2), + createOnlyNullColumnConst(2), executeFunction( func_name, - createOnlyNullColumn(2), + createOnlyNullColumnConst(2), createColumn>({1, 0}))); } CATCH @@ -138,10 +138,10 @@ try createConstColumn>(1, 1))); // only null ASSERT_COLUMN_EQ( - createOnlyNullColumn(1), + createOnlyNullColumnConst(1), executeFunction( func_name, - createOnlyNullColumn(1))); + createOnlyNullColumnConst(1))); } CATCH diff --git a/dbms/src/Functions/tests/gtest_tidb_conversion.cpp b/dbms/src/Functions/tests/gtest_tidb_conversion.cpp index 315fdaf5191..f17cc5811e8 100644 --- a/dbms/src/Functions/tests/gtest_tidb_conversion.cpp +++ b/dbms/src/Functions/tests/gtest_tidb_conversion.cpp @@ -1,3 +1,5 @@ +#include + #include "Columns/ColumnsNumber.h" #include "Core/ColumnWithTypeAndName.h" #include "DataTypes/DataTypeMyDateTime.h" @@ -14,29 +16,713 @@ namespace DB { namespace tests { +namespace +{ +auto getDatetimeColumn(bool single_field = false) +{ + MyDateTime datetime(2021, 10, 26, 16, 8, 59, 0); + MyDateTime datetime_frac(2021, 10, 26, 16, 8, 59, 123456); + + auto col_datetime = ColumnUInt64::create(); + col_datetime->insert(Field(datetime.toPackedUInt())); + if (!single_field) + col_datetime->insert(Field(datetime_frac.toPackedUInt())); + return col_datetime; +} + +auto createCastTypeConstColumn(String str) +{ + return createConstColumn(1, str); +} + +const std::string func_name = "tidb_cast"; + +const Int8 MAX_INT8 = std::numeric_limits::max(); +const Int8 MIN_INT8 = std::numeric_limits::min(); +const Int16 MAX_INT16 = std::numeric_limits::max(); +const Int16 MIN_INT16 = std::numeric_limits::min(); +const Int32 MAX_INT32 = std::numeric_limits::max(); +const Int32 MIN_INT32 = std::numeric_limits::min(); +const Int64 MAX_INT64 = std::numeric_limits::max(); +const Int64 MIN_INT64 = std::numeric_limits::min(); +const UInt8 MAX_UINT8 = std::numeric_limits::max(); +const UInt16 MAX_UINT16 = std::numeric_limits::max(); +const UInt32 MAX_UINT32 = std::numeric_limits::max(); +const UInt64 MAX_UINT64 = std::numeric_limits::max(); + class TestTidbConversion : public DB::tests::FunctionTest { -public: - static auto getDatetimeColumn(bool single_field = false) - { - MyDateTime datetime(2021, 10, 26, 16, 8, 59, 0); - MyDateTime datetime_frac(2021, 10, 26, 16, 8, 59, 123456); - - auto col_datetime = ColumnUInt64::create(); - col_datetime->insert(Field(datetime.toPackedUInt())); - if (!single_field) - col_datetime->insert(Field(datetime_frac.toPackedUInt())); - return col_datetime; - } }; -TEST_F(TestTidbConversion, castTimestampAsReal) +using DecimalField32 = DecimalField; +using DecimalField64 = DecimalField; +using DecimalField128 = DecimalField; +using DecimalField256 = DecimalField; + +TEST_F(TestTidbConversion, castIntAsInt) +try +{ + /// null only cases + ASSERT_COLUMN_EQ( + createColumn>({{}}), + executeFunction(func_name, + {createOnlyNullColumn(1), + createCastTypeConstColumn("Nullable(UInt64)")})); + ASSERT_COLUMN_EQ( + createColumn>({{}}), + executeFunction(func_name, + {createOnlyNullColumn(1), + createCastTypeConstColumn("Nullable(Int64)")})); + + /// const cases + // uint8/16/32/64 -> uint64, no overflow + ASSERT_COLUMN_EQ( + createConstColumn(1, MAX_UINT8), + executeFunction(func_name, + {createConstColumn(1, MAX_UINT8), + createCastTypeConstColumn("UInt64")})); + ASSERT_COLUMN_EQ( + createConstColumn(1, MAX_UINT16), + executeFunction(func_name, + {createConstColumn(1, MAX_UINT16), + createCastTypeConstColumn("UInt64")})); + ASSERT_COLUMN_EQ( + createConstColumn(1, MAX_UINT32), + executeFunction(func_name, + {createConstColumn(1, MAX_UINT32), + createCastTypeConstColumn("UInt64")})); + ASSERT_COLUMN_EQ( + createConstColumn(1, MAX_UINT64), + executeFunction(func_name, + {createConstColumn(1, MAX_UINT64), + createCastTypeConstColumn("UInt64")})); + // int8/16/32/64 -> uint64, no overflow + ASSERT_COLUMN_EQ( + createConstColumn(1, MAX_INT8), + executeFunction(func_name, + {createConstColumn(1, MAX_INT8), + createCastTypeConstColumn("UInt64")})); + ASSERT_COLUMN_EQ( + createConstColumn(1, MAX_INT16), + executeFunction(func_name, + {createConstColumn(1, MAX_INT16), + createCastTypeConstColumn("UInt64")})); + ASSERT_COLUMN_EQ( + createConstColumn(1, MAX_INT32), + executeFunction(func_name, + {createConstColumn(1, MAX_INT32), + createCastTypeConstColumn("UInt64")})); + ASSERT_COLUMN_EQ( + createConstColumn(1, MAX_INT64), + executeFunction(func_name, + {createConstColumn(1, MAX_INT64), + createCastTypeConstColumn("UInt64")})); + // uint8/16/32 -> int64, no overflow + ASSERT_COLUMN_EQ( + createConstColumn(1, MAX_UINT8), + executeFunction(func_name, + {createConstColumn(1, MAX_UINT8), + createCastTypeConstColumn("Int64")})); + ASSERT_COLUMN_EQ( + createConstColumn(1, MAX_UINT16), + executeFunction(func_name, + {createConstColumn(1, MAX_UINT16), + createCastTypeConstColumn("Int64")})); + ASSERT_COLUMN_EQ( + createConstColumn(1, MAX_UINT32), + executeFunction(func_name, + {createConstColumn(1, MAX_UINT32), + createCastTypeConstColumn("Int64")})); + // uint64 -> int64, will overflow + ASSERT_COLUMN_EQ( + createConstColumn(1, -1), + executeFunction(func_name, + {createConstColumn(1, MAX_UINT64), + createCastTypeConstColumn("Int64")})); + // int8/16/32/64 -> int64, no overflow + ASSERT_COLUMN_EQ( + createConstColumn(1, MAX_INT8), + executeFunction(func_name, + {createConstColumn(1, MAX_INT8), + createCastTypeConstColumn("Int64")})); + ASSERT_COLUMN_EQ( + createConstColumn(1, MAX_INT16), + executeFunction(func_name, + {createConstColumn(1, MAX_INT16), + createCastTypeConstColumn("Int64")})); + ASSERT_COLUMN_EQ( + createConstColumn(1, MAX_INT32), + executeFunction(func_name, + {createConstColumn(1, MAX_INT32), + createCastTypeConstColumn("Int64")})); + ASSERT_COLUMN_EQ( + createConstColumn(1, MAX_INT64), + executeFunction(func_name, + {createConstColumn(1, MAX_INT64), + createCastTypeConstColumn("Int64")})); + + /// normal cases + // uint8/16/32/64 -> uint64, no overflow + ASSERT_COLUMN_EQ( + createColumn>({0, 1, MAX_UINT8, {}}), + executeFunction(func_name, + {createColumn>({0, 1, MAX_UINT8, {}}), + createCastTypeConstColumn("Nullable(UInt64)")})); + ASSERT_COLUMN_EQ( + createColumn>({0, 1, MAX_UINT16, {}}), + executeFunction(func_name, + {createColumn>({0, 1, MAX_UINT16, {}}), + createCastTypeConstColumn("Nullable(UInt64)")})); + ASSERT_COLUMN_EQ( + createColumn>({0, 1, MAX_UINT32, {}}), + executeFunction(func_name, + {createColumn>({0, 1, MAX_UINT32, {}}), + createCastTypeConstColumn("Nullable(UInt64)")})); + ASSERT_COLUMN_EQ( + createColumn>({0, 1, MAX_UINT64, {}}), + executeFunction(func_name, + {createColumn>({0, 1, MAX_UINT64, {}}), + createCastTypeConstColumn("Nullable(UInt64)")})); + // int8/16/32/64 -> uint64, no overflow + ASSERT_COLUMN_EQ( + createColumn>({0, MAX_INT8, MAX_UINT64, MAX_UINT64 - MAX_INT8, {}}), + executeFunction(func_name, + {createColumn>({0, MAX_INT8, -1, MIN_INT8, {}}), + createCastTypeConstColumn("Nullable(UInt64)")})); + ASSERT_COLUMN_EQ( + createColumn>({0, MAX_INT16, MAX_UINT64, MAX_UINT64 - MAX_INT16, {}}), + executeFunction(func_name, + {createColumn>({0, MAX_INT16, -1, MIN_INT16, {}}), + createCastTypeConstColumn("Nullable(UInt64)")})); + ASSERT_COLUMN_EQ( + createColumn>({0, MAX_INT32, MAX_UINT64, MAX_UINT64 - MAX_INT32, {}}), + executeFunction(func_name, + {createColumn>({0, MAX_INT32, -1, MIN_INT32, {}}), + createCastTypeConstColumn("Nullable(UInt64)")})); + ASSERT_COLUMN_EQ( + createColumn>({0, MAX_INT64, MAX_UINT64, MAX_UINT64 - MAX_INT64, {}}), + executeFunction(func_name, + {createColumn>({0, MAX_INT64, -1, MIN_INT64, {}}), + createCastTypeConstColumn("Nullable(UInt64)")})); + // uint8/16/32 -> int64, no overflow + ASSERT_COLUMN_EQ( + createColumn>({0, MAX_INT8, MAX_UINT8, {}}), + executeFunction(func_name, + {createColumn>({0, MAX_INT8, MAX_UINT8, {}}), + createCastTypeConstColumn("Nullable(Int64)")})); + ASSERT_COLUMN_EQ( + createColumn>({0, MAX_INT16, MAX_UINT16, {}}), + executeFunction(func_name, + {createColumn>({0, MAX_INT16, MAX_UINT16, {}}), + createCastTypeConstColumn("Nullable(Int64)")})); + ASSERT_COLUMN_EQ( + createColumn>({0, MAX_INT32, MAX_UINT32, {}}), + executeFunction(func_name, + {createColumn>({0, MAX_INT32, MAX_UINT32, {}}), + createCastTypeConstColumn("Nullable(Int64)")})); + // uint64 -> int64, overflow may happen + ASSERT_COLUMN_EQ( + createColumn>({0, MAX_INT64, -1, {}}), + executeFunction(func_name, + {createColumn>({0, MAX_INT64, MAX_UINT64, {}}), + createCastTypeConstColumn("Nullable(Int64)")})); + // int8/16/32/64 -> int64, no overflow + ASSERT_COLUMN_EQ( + createColumn>({0, MAX_INT8, -1, MIN_INT8, {}}), + executeFunction(func_name, + {createColumn>({0, MAX_INT8, -1, MIN_INT8, {}}), + createCastTypeConstColumn("Nullable(Int64)")})); + ASSERT_COLUMN_EQ( + createColumn>({0, MAX_INT16, -1, MIN_INT16, {}}), + executeFunction(func_name, + {createColumn>({0, MAX_INT16, -1, MIN_INT16, {}}), + createCastTypeConstColumn("Nullable(Int64)")})); + ASSERT_COLUMN_EQ( + createColumn>({0, MAX_INT32, -1, MIN_INT32, {}}), + executeFunction(func_name, + {createColumn>({0, MAX_INT32, -1, MIN_INT32, {}}), + createCastTypeConstColumn("Nullable(Int64)")})); + ASSERT_COLUMN_EQ( + createColumn>({0, MAX_INT64, -1, MIN_INT64, {}}), + executeFunction(func_name, + {createColumn>({0, MAX_INT64, -1, MIN_INT64, {}}), + createCastTypeConstColumn("Nullable(Int64)")})); +} +CATCH + +TEST_F(TestTidbConversion, castIntAsReal) +try +{ + // uint64/int64 -> float64, may be not precise + ASSERT_COLUMN_EQ( + createColumn>( + {1234567890.0, + 123456789012345680.0, + 0.0, + {}}), + executeFunction(func_name, + {createColumn>( + {1234567890, // this is fine + 123456789012345678, // but this cannot be represented precisely in the IEEE 754 64-bit float format + 0, + {}}), + createCastTypeConstColumn("Nullable(Float64)")})); + ASSERT_COLUMN_EQ( + createColumn>( + {1234567890.0, + 123456789012345680.0, + 0.0, + {}}), + executeFunction(func_name, + {createColumn>( + {1234567890, // this is fine + 123456789012345678, // but this cannot be represented precisely in the IEEE 754 64-bit float format + 0, + {}}), + createCastTypeConstColumn("Nullable(Float64)")})); + // uint32/16/8 and int32/16/8 -> float64, precise + ASSERT_COLUMN_EQ( + createColumn>({MAX_UINT32, 0, {}}), + executeFunction(func_name, + {createColumn>({MAX_UINT32, 0, {}}), + createCastTypeConstColumn("Nullable(Float64)")})); + ASSERT_COLUMN_EQ( + createColumn>({MAX_UINT16, 0, {}}), + executeFunction(func_name, + {createColumn>({MAX_UINT16, 0, {}}), + createCastTypeConstColumn("Nullable(Float64)")})); + ASSERT_COLUMN_EQ( + createColumn>({MAX_UINT8, 0, {}}), + executeFunction(func_name, + {createColumn>({MAX_UINT8, 0, {}}), + createCastTypeConstColumn("Nullable(Float64)")})); + ASSERT_COLUMN_EQ( + createColumn>({MAX_INT32, MIN_INT32, 0, {}}), + executeFunction(func_name, + {createColumn>({MAX_INT32, MIN_INT32, 0, {}}), + createCastTypeConstColumn("Nullable(Float64)")})); + ASSERT_COLUMN_EQ( + createColumn>({MAX_INT16, MIN_INT16, 0, {}}), + executeFunction(func_name, + {createColumn>({MAX_INT16, MIN_INT16, 0, {}}), + createCastTypeConstColumn("Nullable(Float64)")})); + ASSERT_COLUMN_EQ( + createColumn>({MAX_INT8, MIN_INT8, 0, {}}), + executeFunction(func_name, + {createColumn>({MAX_INT8, MIN_INT8, 0, {}}), + createCastTypeConstColumn("Nullable(Float64)")})); +} +CATCH + +TEST_F(TestTidbConversion, castIntAsString) +try +{ + /// null only cases + ASSERT_COLUMN_EQ( + createColumn>({{}}), + executeFunction(func_name, + {createOnlyNullColumn(1), + createCastTypeConstColumn("Nullable(String)")})); + + /// const cases + // uint64/32/16/8 -> string + ASSERT_COLUMN_EQ( + createConstColumn(1, "18446744073709551615"), + executeFunction(func_name, + {createConstColumn(1, MAX_UINT64), + createCastTypeConstColumn("String")})); + ASSERT_COLUMN_EQ( + createConstColumn(1, "4294967295"), + executeFunction(func_name, + {createConstColumn(1, MAX_UINT32), + createCastTypeConstColumn("String")})); + ASSERT_COLUMN_EQ( + createConstColumn(1, "65535"), + executeFunction(func_name, + {createConstColumn(1, MAX_UINT16), + createCastTypeConstColumn("String")})); + ASSERT_COLUMN_EQ( + createConstColumn(1, "255"), + executeFunction(func_name, + {createConstColumn(1, MAX_UINT8), + createCastTypeConstColumn("String")})); + // int64/32/16/8 -> string + ASSERT_COLUMN_EQ( + createConstColumn(1, "9223372036854775807"), + executeFunction(func_name, + {createConstColumn(1, MAX_INT64), + createCastTypeConstColumn("String")})); + ASSERT_COLUMN_EQ( + createConstColumn(1, "2147483647"), + executeFunction(func_name, + {createConstColumn(1, MAX_INT32), + createCastTypeConstColumn("String")})); + ASSERT_COLUMN_EQ( + createConstColumn(1, "32767"), + executeFunction(func_name, + {createConstColumn(1, MAX_INT16), + createCastTypeConstColumn("String")})); + ASSERT_COLUMN_EQ( + createConstColumn(1, "127"), + executeFunction(func_name, + {createConstColumn(1, MAX_INT8), + createCastTypeConstColumn("String")})); + + /// normal cases + // uint64/32/16/8 -> string + ASSERT_COLUMN_EQ( + createColumn>({"18446744073709551615", "0", {}}), + executeFunction(func_name, + {createColumn>({MAX_UINT64, 0, {}}), + createCastTypeConstColumn("Nullable(String)")})); + ASSERT_COLUMN_EQ( + createColumn>({"4294967295", "0", {}}), + executeFunction(func_name, + {createColumn>({MAX_UINT32, 0, {}}), + createCastTypeConstColumn("Nullable(String)")})); + ASSERT_COLUMN_EQ( + createColumn>({"65535", "0", {}}), + executeFunction(func_name, + {createColumn>({MAX_UINT16, 0, {}}), + createCastTypeConstColumn("Nullable(String)")})); + ASSERT_COLUMN_EQ( + createColumn>({"255", "0", {}}), + executeFunction(func_name, + {createColumn>({MAX_UINT8, 0, {}}), + createCastTypeConstColumn("Nullable(String)")})); + // int64/32/16/8 -> string + ASSERT_COLUMN_EQ( + createColumn>({"9223372036854775807", "-9223372036854775808", "0", {}}), + executeFunction(func_name, + {createColumn>({MAX_INT64, MIN_INT64, 0, {}}), + createCastTypeConstColumn("Nullable(String)")})); + ASSERT_COLUMN_EQ( + createColumn>({"2147483647", "-2147483648", "0", {}}), + executeFunction(func_name, + {createColumn>({MAX_INT32, MIN_INT32, 0, {}}), + createCastTypeConstColumn("Nullable(String)")})); + ASSERT_COLUMN_EQ( + createColumn>({"32767", "-32768", "0", {}}), + executeFunction(func_name, + {createColumn>({MAX_INT16, MIN_INT16, 0, {}}), + createCastTypeConstColumn("Nullable(String)")})); + ASSERT_COLUMN_EQ( + createColumn>({"127", "-128", "0", {}}), + executeFunction(func_name, + {createColumn>({MAX_INT8, MIN_INT8, 0, {}}), + createCastTypeConstColumn("Nullable(String)")})); +} +CATCH + +TEST_F(TestTidbConversion, castIntAsDecimal) +try +{ + // int8 -> decimal32/64/128/256 + ASSERT_COLUMN_EQ( + createColumn>( + std::make_tuple(9, 0), + {DecimalField32(MAX_INT8, 0), DecimalField32(MIN_INT8, 0), {}}), + executeFunction(func_name, + {createColumn>({MAX_INT8, MIN_INT8, {}}), + createCastTypeConstColumn("Nullable(Decimal(9,0))")})); + ASSERT_COLUMN_EQ( + createColumn>( + std::make_tuple(18, 0), + {DecimalField64(MAX_INT8, 0), DecimalField64(MIN_INT8, 0), {}}), + executeFunction(func_name, + {createColumn>({MAX_INT8, MIN_INT8, {}}), + createCastTypeConstColumn("Nullable(Decimal(18,0))")})); + ASSERT_COLUMN_EQ( + createColumn>( + std::make_tuple(38, 0), + {DecimalField128(MAX_INT8, 0), DecimalField128(MIN_INT8, 0), {}}), + executeFunction(func_name, + {createColumn>({MAX_INT8, MIN_INT8, {}}), + createCastTypeConstColumn("Nullable(Decimal(38,0))")})); + ASSERT_COLUMN_EQ( + createColumn>( + std::make_tuple(65, 0), + {DecimalField256(static_cast(MAX_INT8), 0), DecimalField256(static_cast(MIN_INT8), 0), {}}), + executeFunction(func_name, + {createColumn>({MAX_INT8, MIN_INT8, {}}), + createCastTypeConstColumn("Nullable(Decimal(65,0))")})); + // int16 -> decimal32/64/128/256 + ASSERT_COLUMN_EQ( + createColumn>( + std::make_tuple(9, 0), + {DecimalField32(MAX_INT16, 0), DecimalField32(MIN_INT16, 0), {}}), + executeFunction(func_name, + {createColumn>({MAX_INT16, MIN_INT16, {}}), + createCastTypeConstColumn("Nullable(Decimal(9,0))")})); + ASSERT_COLUMN_EQ( + createColumn>( + std::make_tuple(18, 0), + {DecimalField64(MAX_INT16, 0), DecimalField64(MIN_INT16, 0), {}}), + executeFunction(func_name, + {createColumn>({MAX_INT16, MIN_INT16, {}}), + createCastTypeConstColumn("Nullable(Decimal(18,0))")})); + ASSERT_COLUMN_EQ( + createColumn>( + std::make_tuple(38, 0), + {DecimalField128(MAX_INT16, 0), DecimalField128(MIN_INT16, 0), {}}), + executeFunction(func_name, + {createColumn>({MAX_INT16, MIN_INT16, {}}), + createCastTypeConstColumn("Nullable(Decimal(38,0))")})); + ASSERT_COLUMN_EQ( + createColumn>( + std::make_tuple(65, 0), + {DecimalField256(static_cast(MAX_INT16), 0), DecimalField256(static_cast(MIN_INT16), 0), {}}), + executeFunction(func_name, + {createColumn>({MAX_INT16, MIN_INT16, {}}), + createCastTypeConstColumn("Nullable(Decimal(65,0))")})); + // int32 -> decimal32/64/128/256 + ASSERT_COLUMN_EQ( + createColumn>( + std::make_tuple(9, 0), + {DecimalField32(999999999, 0), DecimalField32(-999999999, 0), {}}), + executeFunction(func_name, + {createColumn>({999999999, -999999999, {}}), + createCastTypeConstColumn("Nullable(Decimal(9,0))")})); + ASSERT_THROW(executeFunction(func_name, + {createColumn>({1000000000, -1000000000, {}}), + createCastTypeConstColumn("Nullable(Decimal(9,0))")}), + TiFlashException); + ASSERT_COLUMN_EQ( + createColumn>( + std::make_tuple(18, 0), + {DecimalField64(MAX_INT32, 0), DecimalField64(MIN_INT32, 0), {}}), + executeFunction(func_name, + {createColumn>({MAX_INT32, MIN_INT32, {}}), + createCastTypeConstColumn("Nullable(Decimal(18,0))")})); + ASSERT_COLUMN_EQ( + createColumn>( + std::make_tuple(38, 0), + {DecimalField128(MAX_INT32, 0), DecimalField128(MIN_INT32, 0), {}}), + executeFunction(func_name, + {createColumn>({MAX_INT32, MIN_INT32, {}}), + createCastTypeConstColumn("Nullable(Decimal(38,0))")})); + ASSERT_COLUMN_EQ( + createColumn>( + std::make_tuple(65, 0), + {DecimalField256(static_cast(MAX_INT32), 0), DecimalField256(static_cast(MIN_INT32), 0), {}}), + executeFunction(func_name, + {createColumn>({MAX_INT32, MIN_INT32, {}}), + createCastTypeConstColumn("Nullable(Decimal(65,0))")})); + // int64 -> decimal32/64/128/256 + ASSERT_COLUMN_EQ( + createColumn>( + std::make_tuple(9, 0), + {DecimalField32(999999999, 0), DecimalField32(-999999999, 0), {}}), + executeFunction(func_name, + {createColumn>({999999999, -999999999, {}}), + createCastTypeConstColumn("Nullable(Decimal(9,0))")})); + ASSERT_THROW(executeFunction(func_name, + {createColumn>({1000000000, -1000000000, {}}), + createCastTypeConstColumn("Nullable(Decimal(9,0))")}), + TiFlashException); + ASSERT_COLUMN_EQ( + createColumn>( + std::make_tuple(18, 0), + {DecimalField64(999999999999999999, 0), DecimalField64(-999999999999999999, 0), {}}), + executeFunction(func_name, + {createColumn>({999999999999999999, -999999999999999999, {}}), + createCastTypeConstColumn("Nullable(Decimal(18,0))")})); + ASSERT_THROW(executeFunction(func_name, + {createColumn>({1000000000000000000, -1000000000000000000, {}}), + createCastTypeConstColumn("Nullable(Decimal(18,0))")}), + TiFlashException); + ASSERT_COLUMN_EQ( + createColumn>( + std::make_tuple(38, 0), + {DecimalField128(MAX_INT64, 0), DecimalField128(MIN_INT64, 0), {}}), + executeFunction(func_name, + {createColumn>({MAX_INT64, MIN_INT64, {}}), + createCastTypeConstColumn("Nullable(Decimal(38,0))")})); + ASSERT_COLUMN_EQ( + createColumn>( + std::make_tuple(65, 0), + {DecimalField256(static_cast(MAX_INT64), 0), DecimalField256(static_cast(MIN_INT64), 0), {}}), + executeFunction(func_name, + {createColumn>({MAX_INT64, MIN_INT64, {}}), + createCastTypeConstColumn("Nullable(Decimal(65,0))")})); + // uint8 -> decimal32/64/128/256 + ASSERT_COLUMN_EQ( + createColumn>( + std::make_tuple(9, 0), + {DecimalField32(MAX_UINT8, 0), {}}), + executeFunction(func_name, + {createColumn>({MAX_UINT8, {}}), + createCastTypeConstColumn("Nullable(Decimal(9,0))")})); + ASSERT_COLUMN_EQ( + createColumn>( + std::make_tuple(18, 0), + {DecimalField64(MAX_UINT8, 0), {}}), + executeFunction(func_name, + {createColumn>({MAX_UINT8, {}}), + createCastTypeConstColumn("Nullable(Decimal(18,0))")})); + ASSERT_COLUMN_EQ( + createColumn>( + std::make_tuple(38, 0), + {DecimalField128(MAX_UINT8, 0), {}}), + executeFunction(func_name, + {createColumn>({MAX_UINT8, {}}), + createCastTypeConstColumn("Nullable(Decimal(38,0))")})); + ASSERT_COLUMN_EQ( + createColumn>( + std::make_tuple(65, 0), + {DecimalField256(static_cast(MAX_UINT8), 0), {}}), + executeFunction(func_name, + {createColumn>({MAX_UINT8, {}}), + createCastTypeConstColumn("Nullable(Decimal(65,0))")})); + // uint16 -> decimal32/64/128/256 + ASSERT_COLUMN_EQ( + createColumn>( + std::make_tuple(9, 0), + {DecimalField32(MAX_UINT16, 0), {}}), + executeFunction(func_name, + {createColumn>({MAX_UINT16, {}}), + createCastTypeConstColumn("Nullable(Decimal(9,0))")})); + ASSERT_COLUMN_EQ( + createColumn>( + std::make_tuple(18, 0), + {DecimalField64(MAX_UINT16, 0), {}}), + executeFunction(func_name, + {createColumn>({MAX_UINT16, {}}), + createCastTypeConstColumn("Nullable(Decimal(18,0))")})); + ASSERT_COLUMN_EQ( + createColumn>( + std::make_tuple(38, 0), + {DecimalField128(MAX_UINT16, 0), {}}), + executeFunction(func_name, + {createColumn>({MAX_UINT16, {}}), + createCastTypeConstColumn("Nullable(Decimal(38,0))")})); + ASSERT_COLUMN_EQ( + createColumn>( + std::make_tuple(65, 0), + {DecimalField256(static_cast(MAX_UINT16), 0), {}}), + executeFunction(func_name, + {createColumn>({MAX_UINT16, {}}), + createCastTypeConstColumn("Nullable(Decimal(65,0))")})); + // uint32 -> decimal32/64/128/256 + ASSERT_COLUMN_EQ( + createColumn>( + std::make_tuple(9, 0), + {DecimalField32(999999999, 0), {}}), + executeFunction(func_name, + {createColumn>({999999999, {}}), + createCastTypeConstColumn("Nullable(Decimal(9,0))")})); + ASSERT_THROW(executeFunction(func_name, + {createColumn>({1000000000, {}}), + createCastTypeConstColumn("Nullable(Decimal(9,0))")}), + TiFlashException); + ASSERT_COLUMN_EQ( + createColumn>( + std::make_tuple(18, 0), + {DecimalField64(MAX_UINT32, 0), {}}), + executeFunction(func_name, + {createColumn>({MAX_UINT32, {}}), + createCastTypeConstColumn("Nullable(Decimal(18,0))")})); + ASSERT_COLUMN_EQ( + createColumn>( + std::make_tuple(38, 0), + {DecimalField128(MAX_UINT32, 0), {}}), + executeFunction(func_name, + {createColumn>({MAX_UINT32, {}}), + createCastTypeConstColumn("Nullable(Decimal(38,0))")})); + ASSERT_COLUMN_EQ( + createColumn>( + std::make_tuple(65, 0), + {DecimalField256(static_cast(MAX_UINT32), 0), {}}), + executeFunction(func_name, + {createColumn>({MAX_UINT32, {}}), + createCastTypeConstColumn("Nullable(Decimal(65,0))")})); + // uint64 -> decimal32/64/128/256 + ASSERT_COLUMN_EQ( + createColumn>( + std::make_tuple(9, 0), + {DecimalField32(999999999, 0), {}}), + executeFunction(func_name, + {createColumn>({999999999, {}}), + createCastTypeConstColumn("Nullable(Decimal(9,0))")})); + ASSERT_THROW(executeFunction(func_name, + {createColumn>({1000000000, {}}), + createCastTypeConstColumn("Nullable(Decimal(9,0))")}), + TiFlashException); + ASSERT_COLUMN_EQ( + createColumn>( + std::make_tuple(18, 0), + {DecimalField64(999999999999999999, 0), {}}), + executeFunction(func_name, + {createColumn>({999999999999999999, {}}), + createCastTypeConstColumn("Nullable(Decimal(18,0))")})); + ASSERT_THROW(executeFunction(func_name, + {createColumn>({1000000000000000000, {}}), + createCastTypeConstColumn("Nullable(Decimal(18,0))")}), + TiFlashException); + ASSERT_COLUMN_EQ( + createColumn>( + std::make_tuple(38, 0), + {DecimalField128(MAX_INT64, 0), {}}), + executeFunction(func_name, + {createColumn>({MAX_INT64, {}}), + createCastTypeConstColumn("Nullable(Decimal(38,0))")})); + ASSERT_COLUMN_EQ( + createColumn>( + std::make_tuple(65, 0), + {DecimalField256(static_cast(MAX_INT64), 0), {}}), + executeFunction(func_name, + {createColumn>({MAX_INT64, {}}), + createCastTypeConstColumn("Nullable(Decimal(65,0))")})); +} +CATCH + +TEST_F(TestTidbConversion, castIntAsTime) +try +{ + ASSERT_COLUMN_EQ( + createNullableDateTimeColumn({{}, {{2021, 10, 26, 16, 8, 59, 0}}}, 6), + executeFunction(func_name, + {createColumn>({{}, 20211026160859}), + createCastTypeConstColumn("Nullable(MyDateTime(6))")})); + ASSERT_COLUMN_EQ( + createNullableDateTimeColumn({{}, {{2021, 10, 26, 16, 8, 59, 0}}}, 6), + executeFunction(func_name, + {createColumn>({{}, 20211026160859}), + createCastTypeConstColumn("Nullable(MyDateTime(6))")})); + ASSERT_THROW( + executeFunction(func_name, + {createColumn>({MAX_UINT8}), + createCastTypeConstColumn("Nullable(MyDateTime(6))")}), + TiFlashException); + ASSERT_THROW( + executeFunction(func_name, + {createColumn>({MAX_UINT16}), + createCastTypeConstColumn("Nullable(MyDateTime(6))")}), + TiFlashException); + ASSERT_THROW( + executeFunction(func_name, + {createColumn>({MAX_UINT32}), + createCastTypeConstColumn("Nullable(MyDateTime(6))")}), + TiFlashException); + ASSERT_COLUMN_EQ( + createNullableDateTimeColumn({{}}, 6), + executeFunction(func_name, + {createColumn>({0}), + createCastTypeConstColumn("Nullable(MyDateTime(6))")})); + ASSERT_THROW( + executeFunction(func_name, + {createColumn>({{}, -20211026160859}), + createCastTypeConstColumn("Nullable(MyDateTime(6))")}), + TiFlashException); +} +CATCH + +TEST_F(TestTidbConversion, castTimeAsReal) try { - static const std::string func_name = "tidb_cast"; - static const auto data_type_ptr = std::make_shared(6); - static const Float64 datetime_float = 20211026160859; - static const Float64 datetime_frac_float = 20211026160859.125; + const auto data_type_ptr = std::make_shared(6); + const Float64 datetime_float = 20211026160859; + const Float64 datetime_frac_float = 20211026160859.125; // cast datetime to float auto col_datetime1 = getDatetimeColumn(); @@ -45,7 +731,7 @@ try createColumn({datetime_float, datetime_frac_float}), executeFunction(func_name, {ctn_datetime1, - createConstColumn(1, "Float64")})); + createCastTypeConstColumn("Float64")})); // cast datetime to nullable float auto col_datetime2 = getDatetimeColumn(); @@ -54,7 +740,7 @@ try createColumn>({datetime_float, datetime_frac_float}), executeFunction(func_name, {ctn_datetime2, - createConstColumn(1, "Nullable(Float64)")})); + createCastTypeConstColumn("Nullable(Float64)")})); // cast nullable datetime to nullable float auto col_datetime3 = getDatetimeColumn(); @@ -66,7 +752,7 @@ try createColumn>({datetime_float, {}}), executeFunction(func_name, {ctn_datetime3_nullable, - createConstColumn(1, "Nullable(Float64)")})); + createCastTypeConstColumn("Nullable(Float64)")})); // cast const datetime to float auto col_datetime4_const = ColumnConst::create(getDatetimeColumn(true), 1); @@ -75,7 +761,7 @@ try createConstColumn(1, datetime_float), executeFunction(func_name, {ctn_datetime4_const, - createConstColumn(1, "Float64")})); + createCastTypeConstColumn("Float64")})); // cast nullable const datetime to float auto col_datetime5 = getDatetimeColumn(true); @@ -87,18 +773,17 @@ try createConstColumn>(1, datetime_float), executeFunction(func_name, {ctn_datetime5_nullable_const, - createConstColumn(1, "Nullable(Float64)")})); + createCastTypeConstColumn("Nullable(Float64)")})); } CATCH TEST_F(TestTidbConversion, castDurationAsDuration) try { - static const std::string func_name = "tidb_cast"; - static const auto from_type = std::make_shared(3); - static const auto to_type_1 = std::make_shared(5); // from_fsp < to_fsp - static const auto to_type_2 = std::make_shared(3); // from_fsp == to_fsp - static const auto to_type_3 = std::make_shared(2); // from_fsp < to_fsp + const auto from_type = std::make_shared(3); + const auto to_type_1 = std::make_shared(5); // from_fsp < to_fsp + const auto to_type_2 = std::make_shared(3); // from_fsp == to_fsp + const auto to_type_3 = std::make_shared(2); // from_fsp < to_fsp ColumnWithTypeAndName input( createColumn({(20 * 3600 + 20 * 60 + 20) * 1000000000L + 555000000L, @@ -124,9 +809,9 @@ try to_type_3, "output3"); - ASSERT_COLUMN_EQ(output1, executeFunction(func_name, {input, createConstColumn(1, to_type_1->getName())})); - ASSERT_COLUMN_EQ(output2, executeFunction(func_name, {input, createConstColumn(1, to_type_2->getName())})); - ASSERT_COLUMN_EQ(output3, executeFunction(func_name, {input, createConstColumn(1, to_type_3->getName())})); + ASSERT_COLUMN_EQ(output1, executeFunction(func_name, {input, createCastTypeConstColumn(to_type_1->getName())})); + ASSERT_COLUMN_EQ(output2, executeFunction(func_name, {input, createCastTypeConstColumn(to_type_2->getName())})); + ASSERT_COLUMN_EQ(output3, executeFunction(func_name, {input, createCastTypeConstColumn(to_type_3->getName())})); // Test Nullable ColumnWithTypeAndName input_nullable( @@ -156,9 +841,9 @@ try makeNullable(to_type_3), "output3_output"); - ASSERT_COLUMN_EQ(output1_nullable, executeFunction(func_name, {input_nullable, createConstColumn(1, makeNullable(to_type_1)->getName())})); - ASSERT_COLUMN_EQ(output2_nullable, executeFunction(func_name, {input_nullable, createConstColumn(1, makeNullable(to_type_2)->getName())})); - ASSERT_COLUMN_EQ(output3_nullable, executeFunction(func_name, {input_nullable, createConstColumn(1, makeNullable(to_type_3)->getName())})); + ASSERT_COLUMN_EQ(output1_nullable, executeFunction(func_name, {input_nullable, createCastTypeConstColumn(makeNullable(to_type_1)->getName())})); + ASSERT_COLUMN_EQ(output2_nullable, executeFunction(func_name, {input_nullable, createCastTypeConstColumn(makeNullable(to_type_2)->getName())})); + ASSERT_COLUMN_EQ(output3_nullable, executeFunction(func_name, {input_nullable, createCastTypeConstColumn(makeNullable(to_type_3)->getName())})); // Test Const ColumnWithTypeAndName input_const(createConstColumn(1, (20 * 3600 + 20 * 60 + 20) * 1000000000L + 999000000L).column, from_type, "input_const"); @@ -166,11 +851,12 @@ try ColumnWithTypeAndName output2_const(input_const.column, to_type_2, "output2_const"); ColumnWithTypeAndName output3_const(createConstColumn(1, (20 * 3600 + 20 * 60 + 21) * 1000000000L + 000000000L).column, to_type_3, "output3_const"); - ASSERT_COLUMN_EQ(output1_const, executeFunction(func_name, {input_const, createConstColumn(1, to_type_1->getName())})); - ASSERT_COLUMN_EQ(output2_const, executeFunction(func_name, {input_const, createConstColumn(1, to_type_2->getName())})); - ASSERT_COLUMN_EQ(output3_const, executeFunction(func_name, {input_const, createConstColumn(1, to_type_3->getName())})); + ASSERT_COLUMN_EQ(output1_const, executeFunction(func_name, {input_const, createCastTypeConstColumn(to_type_1->getName())})); + ASSERT_COLUMN_EQ(output2_const, executeFunction(func_name, {input_const, createCastTypeConstColumn(to_type_2->getName())})); + ASSERT_COLUMN_EQ(output3_const, executeFunction(func_name, {input_const, createCastTypeConstColumn(to_type_3->getName())})); } CATCH +} // namespace } // namespace tests } // namespace DB diff --git a/dbms/src/TestUtils/FunctionTestUtils.cpp b/dbms/src/TestUtils/FunctionTestUtils.cpp index 4d835cbe187..6bbd91f78c5 100644 --- a/dbms/src/TestUtils/FunctionTestUtils.cpp +++ b/dbms/src/TestUtils/FunctionTestUtils.cpp @@ -118,10 +118,34 @@ ColumnWithTypeAndName FunctionTest::executeFunction(const String & func_name, co return block.getByPosition(columns.size()); } -ColumnWithTypeAndName createOnlyNullColumn(size_t size, const String & name) +ColumnWithTypeAndName createOnlyNullColumnConst(size_t size, const String & name) { DataTypePtr data_type = std::make_shared(std::make_shared()); return {data_type->createColumnConst(size, Null()), data_type, name}; } + +ColumnWithTypeAndName createOnlyNullColumn(size_t size, const String & name) +{ + DataTypePtr data_type = std::make_shared(std::make_shared()); + auto col = data_type->createColumn(); + for (size_t i = 0; i < size; i++) + col->insert(Null()); + return {std::move(col), data_type, name}; +} + +ColumnWithTypeAndName createNullableDateTimeColumn(std::initializer_list> init, int fraction) +{ + auto data_type_ptr = makeNullable(std::make_shared(fraction)); + auto col = data_type_ptr->createColumn(); + for (const auto dt : init) + { + if (dt.has_value()) + col->insert(Field(dt->toPackedUInt())); + else + col->insert(Null()); + } + return ColumnWithTypeAndName(std::move(col), data_type_ptr, "datetime"); +} + } // namespace tests } // namespace DB diff --git a/dbms/src/TestUtils/FunctionTestUtils.h b/dbms/src/TestUtils/FunctionTestUtils.h index 6cc9d821386..deaf4ca5097 100644 --- a/dbms/src/TestUtils/FunctionTestUtils.h +++ b/dbms/src/TestUtils/FunctionTestUtils.h @@ -7,6 +7,7 @@ #include #include #include +#include #include #include #include @@ -193,6 +194,7 @@ ColumnPtr makeConstColumn(const DataTypePtr & data_type, size_t size, const Infe return data_type->createColumnConst(size, makeField(value)); } +ColumnWithTypeAndName createOnlyNullColumnConst(size_t size, const String & name = ""); ColumnWithTypeAndName createOnlyNullColumn(size_t size, const String & name = ""); template @@ -247,6 +249,8 @@ ColumnWithTypeAndName createConstColumn( return {makeConstColumn(data_type, size, value), data_type, name}; } +ColumnWithTypeAndName createNullableDateTimeColumn(std::initializer_list> init, int fraction); + // parse a string into decimal field. template typename TypeTraits::FieldType parseDecimal( From b2d3ba17e7c5772015e86d32bf78b28058018a01 Mon Sep 17 00:00:00 2001 From: SeaRise Date: Tue, 14 Dec 2021 14:54:35 +0800 Subject: [PATCH 4/4] refactor DAGQueryBlockInterpreter (#3639) --- .../Coprocessor/DAGQueryBlockInterpreter.cpp | 192 ++++++++++-------- .../Coprocessor/DAGQueryBlockInterpreter.h | 9 + 2 files changed, 118 insertions(+), 83 deletions(-) diff --git a/dbms/src/Flash/Coprocessor/DAGQueryBlockInterpreter.cpp b/dbms/src/Flash/Coprocessor/DAGQueryBlockInterpreter.cpp index 43695f1ad1e..c12281dc21a 100644 --- a/dbms/src/Flash/Coprocessor/DAGQueryBlockInterpreter.cpp +++ b/dbms/src/Flash/Coprocessor/DAGQueryBlockInterpreter.cpp @@ -911,6 +911,105 @@ void DAGQueryBlockInterpreter::executeRemoteQueryImpl( } } +void DAGQueryBlockInterpreter::executeExchangeReceiver(DAGPipeline & pipeline) +{ + auto it = exchange_receiver_map.find(query_block.source_name); + if (unlikely(it == exchange_receiver_map.end())) + throw Exception("Can not find exchange receiver for " + query_block.source_name, ErrorCodes::LOGICAL_ERROR); + // todo choose a more reasonable stream number + for (size_t i = 0; i < max_streams; ++i) + { + BlockInputStreamPtr stream = std::make_shared(it->second, taskLogger()); + dagContext().getRemoteInputStreams().push_back(stream); + stream = std::make_shared(stream, 8192, 0, taskLogger()); + pipeline.streams.push_back(stream); + } + std::vector source_columns; + Block block = pipeline.firstStream()->getHeader(); + for (const auto & col : block.getColumnsWithTypeAndName()) + { + source_columns.emplace_back(NameAndTypePair(col.name, col.type)); + } + analyzer = std::make_unique(std::move(source_columns), context); +} + +void DAGQueryBlockInterpreter::executeSourceProjection(DAGPipeline & pipeline, const tipb::Projection & projection) +{ + std::vector input_columns; + pipeline.streams = input_streams_vec[0]; + for (auto const & p : pipeline.firstStream()->getHeader().getNamesAndTypesList()) + input_columns.emplace_back(p.name, p.type); + DAGExpressionAnalyzer dag_analyzer(std::move(input_columns), context); + ExpressionActionsChain chain; + dag_analyzer.initChain(chain, dag_analyzer.getCurrentInputColumns()); + ExpressionActionsChain::Step & last_step = chain.steps.back(); + std::vector output_columns; + NamesWithAliases project_cols; + UniqueNameGenerator unique_name_generator; + for (const auto & expr : projection.exprs()) + { + auto expr_name = dag_analyzer.getActions(expr, last_step.actions); + last_step.required_output.emplace_back(expr_name); + const auto & col = last_step.actions->getSampleBlock().getByName(expr_name); + String alias = unique_name_generator.toUniqueName(col.name); + output_columns.emplace_back(alias, col.type); + project_cols.emplace_back(col.name, alias); + } + pipeline.transform([&](auto & stream) { stream = std::make_shared(stream, chain.getLastActions(), taskLogger()); }); + executeProject(pipeline, project_cols); + analyzer = std::make_unique(std::move(output_columns), context); +} + +void DAGQueryBlockInterpreter::executeExtraCastAndSelection( + DAGPipeline & pipeline, + const ExpressionActionsPtr & extra_cast, + const NamesWithAliases & project_after_ts_and_filter_for_remote_read, + const ExpressionActionsPtr & before_where, + const ExpressionActionsPtr & project_after_where, + const String & filter_column_name) +{ + /// execute timezone cast and the selection + ExpressionActionsPtr project_for_cop_read; + for (auto & stream : pipeline.streams) + { + if (dynamic_cast(stream.get()) != nullptr) + { + /// for cop read, just execute the project is enough, because timezone cast and the selection are already done in remote TiFlash + if (!project_after_ts_and_filter_for_remote_read.empty()) + { + if (project_for_cop_read == nullptr) + { + project_for_cop_read = generateProjectExpressionActions(stream, context, project_after_ts_and_filter_for_remote_read); + } + stream = std::make_shared(stream, project_for_cop_read, taskLogger()); + } + } + else + { + /// execute timezone cast or duration cast if needed + if (extra_cast) + stream = std::make_shared(stream, extra_cast, taskLogger()); + /// execute selection if needed + if (before_where) + { + stream = std::make_shared(stream, before_where, filter_column_name, taskLogger()); + if (project_after_where) + stream = std::make_shared(stream, project_after_where, taskLogger()); + } + } + } + for (auto & stream : pipeline.streams_with_non_joined_data) + { + /// execute selection if needed + if (before_where) + { + stream = std::make_shared(stream, before_where, filter_column_name, taskLogger()); + if (project_after_where) + stream = std::make_shared(stream, project_after_where, taskLogger()); + } + } +} + // To execute a query block, you have to: // 1. generate the date stream and push it to pipeline. // 2. assign the analyzer @@ -940,51 +1039,12 @@ void DAGQueryBlockInterpreter::executeImpl(DAGPipeline & pipeline) } else if (query_block.source->tp() == tipb::ExecType::TypeExchangeReceiver) { - auto it = exchange_receiver_map.find(query_block.source_name); - if (unlikely(it == exchange_receiver_map.end())) - throw Exception("Can not find exchange receiver for " + query_block.source_name, ErrorCodes::LOGICAL_ERROR); - // todo choose a more reasonable stream number - for (size_t i = 0; i < max_streams; i++) - { - BlockInputStreamPtr stream = std::make_shared(it->second, taskLogger()); - dagContext().getRemoteInputStreams().push_back(stream); - stream = std::make_shared(stream, 8192, 0, taskLogger()); - pipeline.streams.push_back(stream); - } - std::vector source_columns; - Block block = pipeline.firstStream()->getHeader(); - for (const auto & col : block.getColumnsWithTypeAndName()) - { - source_columns.emplace_back(NameAndTypePair(col.name, col.type)); - } - analyzer = std::make_unique(std::move(source_columns), context); + executeExchangeReceiver(pipeline); recordProfileStreams(pipeline, query_block.source_name); } else if (query_block.source->tp() == tipb::ExecType::TypeProjection) { - std::vector input_columns; - pipeline.streams = input_streams_vec[0]; - for (auto const & p : pipeline.firstStream()->getHeader().getNamesAndTypesList()) - input_columns.emplace_back(p.name, p.type); - DAGExpressionAnalyzer dag_analyzer(std::move(input_columns), context); - ExpressionActionsChain chain; - dag_analyzer.initChain(chain, dag_analyzer.getCurrentInputColumns()); - ExpressionActionsChain::Step & last_step = chain.steps.back(); - std::vector output_columns; - NamesWithAliases project_cols; - UniqueNameGenerator unique_name_generator; - for (const auto & expr : query_block.source->projection().exprs()) - { - auto expr_name = dag_analyzer.getActions(expr, last_step.actions); - last_step.required_output.emplace_back(expr_name); - const auto & col = last_step.actions->getSampleBlock().getByName(expr_name); - String alias = unique_name_generator.toUniqueName(col.name); - output_columns.emplace_back(alias, col.type); - project_cols.emplace_back(col.name, alias); - } - pipeline.transform([&](auto & stream) { stream = std::make_shared(stream, chain.getLastActions(), taskLogger()); }); - executeProject(pipeline, project_cols); - analyzer = std::make_unique(std::move(output_columns), context); + executeSourceProjection(pipeline, query_block.source->projection()); recordProfileStreams(pipeline, query_block.source_name); } else @@ -1005,46 +1065,13 @@ void DAGQueryBlockInterpreter::executeImpl(DAGPipeline & pipeline) if (res.extra_cast || res.before_where) { - /// execute timezone cast and the selection - ExpressionActionsPtr project_for_cop_read; - for (auto & stream : pipeline.streams) - { - if (dynamic_cast(stream.get()) != nullptr) - { - /// for cop read, just execute the project is enough, because timezone cast and the selection are already done in remote TiFlash - if (!res.project_after_ts_and_filter_for_remote_read.empty()) - { - if (project_for_cop_read == nullptr) - { - project_for_cop_read = generateProjectExpressionActions(stream, context, res.project_after_ts_and_filter_for_remote_read); - } - stream = std::make_shared(stream, project_for_cop_read, taskLogger()); - } - } - else - { - /// execute timezone cast or duration cast if needed - if (res.extra_cast) - stream = std::make_shared(stream, res.extra_cast, taskLogger()); - /// execute selection if needed - if (res.before_where) - { - stream = std::make_shared(stream, res.before_where, res.filter_column_name, taskLogger()); - if (res.project_after_where) - stream = std::make_shared(stream, res.project_after_where, taskLogger()); - } - } - } - for (auto & stream : pipeline.streams_with_non_joined_data) - { - /// execute selection if needed - if (res.before_where) - { - stream = std::make_shared(stream, res.before_where, res.filter_column_name, taskLogger()); - if (res.project_after_where) - stream = std::make_shared(stream, res.project_after_where, taskLogger()); - } - } + executeExtraCastAndSelection( + pipeline, + res.extra_cast, + res.project_after_ts_and_filter_for_remote_read, + res.before_where, + res.project_after_where, + res.filter_column_name); } if (res.before_where) { @@ -1052,8 +1079,7 @@ void DAGQueryBlockInterpreter::executeImpl(DAGPipeline & pipeline) } // this log measures the concurrent degree in this mpp task - LOG_INFO(log, - "execution stream size for query block(before aggregation) " << query_block.qb_column_prefix << " is " << pipeline.streams.size()); + LOG_INFO(log, "execution stream size for query block(before aggregation) " << query_block.qb_column_prefix << " is " << pipeline.streams.size()); dagContext().final_concurrency = std::max(dagContext().final_concurrency, pipeline.streams.size()); diff --git a/dbms/src/Flash/Coprocessor/DAGQueryBlockInterpreter.h b/dbms/src/Flash/Coprocessor/DAGQueryBlockInterpreter.h index c1ebcaebd59..834a2bd49ad 100644 --- a/dbms/src/Flash/Coprocessor/DAGQueryBlockInterpreter.h +++ b/dbms/src/Flash/Coprocessor/DAGQueryBlockInterpreter.h @@ -52,6 +52,15 @@ class DAGQueryBlockInterpreter bool is_right_out_join, const google::protobuf::RepeatedPtrField & filters, String & filter_column_name); + void executeExchangeReceiver(DAGPipeline & pipeline); + void executeSourceProjection(DAGPipeline & pipeline, const tipb::Projection & projection); + void executeExtraCastAndSelection( + DAGPipeline & pipeline, + const ExpressionActionsPtr & extra_cast, + const NamesWithAliases & project_after_ts_and_filter_for_remote_read, + const ExpressionActionsPtr & before_where, + const ExpressionActionsPtr & project_after_where, + const String & filter_column_name); ExpressionActionsPtr genJoinOtherConditionAction( const tipb::Join & join, std::vector & source_columns,