From 17c95861ca5fa91fff4b79cdc8ae19c9bf6552e7 Mon Sep 17 00:00:00 2001 From: Zhichao Zhang Date: Thu, 8 Feb 2024 21:07:01 +0800 Subject: [PATCH] fix ut issue --- .../Parser/AggregateFunctionParser.cpp | 14 ++++++++-- .../Parser/AggregateFunctionParser.h | 2 +- .../Parser/AggregateRelParser.cpp | 28 +++++++++++++++---- cpp-ch/local-engine/Parser/TypeParser.cpp | 6 ++++ cpp-ch/local-engine/Parser/TypeParser.h | 1 + .../local-engine/Parser/WindowRelParser.cpp | 2 +- .../CollectListParser.h | 2 +- 7 files changed, 44 insertions(+), 11 deletions(-) diff --git a/cpp-ch/local-engine/Parser/AggregateFunctionParser.cpp b/cpp-ch/local-engine/Parser/AggregateFunctionParser.cpp index 126c31b12cd3e..af5a4309dd3bd 100644 --- a/cpp-ch/local-engine/Parser/AggregateFunctionParser.cpp +++ b/cpp-ch/local-engine/Parser/AggregateFunctionParser.cpp @@ -147,10 +147,20 @@ std::pair AggregateFunctionParser::tryApplyCHCombinator( const DB::ActionsDAG::Node * AggregateFunctionParser::convertNodeTypeIfNeeded( const CommonFunctionInfo & func_info, const DB::ActionsDAG::Node * func_node, - DB::ActionsDAGPtr & actions_dag) const + DB::ActionsDAGPtr & actions_dag, + bool withNullability) const { const auto & output_type = func_info.output_type; - if (!TypeParser::isTypeMatched(output_type, func_node->result_type)) + bool notToConvertNodeType = true; + if (withNullability) + { + notToConvertNodeType = TypeParser::isTypeMatchedWithNullability(output_type, func_node->result_type); + } + else + { + notToConvertNodeType = TypeParser::isTypeMatched(output_type, func_node->result_type); + } + if (!notToConvertNodeType) { func_node = ActionsDAGUtil::convertNodeType( actions_dag, func_node, TypeParser::parseType(output_type)->getName(), func_node->result_name); diff --git a/cpp-ch/local-engine/Parser/AggregateFunctionParser.h b/cpp-ch/local-engine/Parser/AggregateFunctionParser.h index bfa932b819f8a..a9840eeef8253 100644 --- a/cpp-ch/local-engine/Parser/AggregateFunctionParser.h +++ b/cpp-ch/local-engine/Parser/AggregateFunctionParser.h @@ -106,7 +106,7 @@ class AggregateFunctionParser /// Make a postprojection for the function result. virtual const DB::ActionsDAG::Node * convertNodeTypeIfNeeded( - const CommonFunctionInfo & func_info, const DB::ActionsDAG::Node * func_node, DB::ActionsDAGPtr & actions_dag) const; + const CommonFunctionInfo & func_info, const DB::ActionsDAG::Node * func_node, DB::ActionsDAGPtr & actions_dag, bool withNullability) const; /// Parameters are only used in aggregate functions at present. e.g. percentiles(0.5)(x). /// 0.5 is the parameter of percentiles function. diff --git a/cpp-ch/local-engine/Parser/AggregateRelParser.cpp b/cpp-ch/local-engine/Parser/AggregateRelParser.cpp index de74b341c4f52..396137752d897 100644 --- a/cpp-ch/local-engine/Parser/AggregateRelParser.cpp +++ b/cpp-ch/local-engine/Parser/AggregateRelParser.cpp @@ -463,15 +463,31 @@ void AggregateRelParser::addPostProjection() auto input_header = plan->getCurrentDataStream().header; ActionsDAGPtr project_actions_dag = std::make_shared(input_header.getColumnsWithTypeAndName()); auto dag_footprint = project_actions_dag->dumpDAG(); - for (const auto & agg_info : aggregates) + + if (has_final_stage) + { + for (const auto & agg_info : aggregates) + { + for (const auto * input_node : project_actions_dag->getInputs()) + { + if (input_node->result_name == agg_info.measure_column_name) + { + agg_info.function_parser->convertNodeTypeIfNeeded(agg_info.parser_func_info, input_node, project_actions_dag, false); + } + } + } + } + if (has_complete_stage) { - /// For final stage, the aggregate function's input is only one intermediate result columns. - /// The final result columm's position is the same as the intermediate result column's position. - for (const auto * input_node : project_actions_dag->getInputs()) + // on the complete mode, it must consider the nullability when converting node type + for (const auto & agg_info : aggregates) { - if (input_node->result_name == agg_info.measure_column_name) + for (const auto * output_node : project_actions_dag->getOutputs()) { - agg_info.function_parser->convertNodeTypeIfNeeded(agg_info.parser_func_info, input_node, project_actions_dag); + if (output_node->result_name == agg_info.measure_column_name) + { + agg_info.function_parser->convertNodeTypeIfNeeded(agg_info.parser_func_info, output_node, project_actions_dag, true); + } } } } diff --git a/cpp-ch/local-engine/Parser/TypeParser.cpp b/cpp-ch/local-engine/Parser/TypeParser.cpp index 958a5fb4518f1..2edd8c1c83ec7 100644 --- a/cpp-ch/local-engine/Parser/TypeParser.cpp +++ b/cpp-ch/local-engine/Parser/TypeParser.cpp @@ -315,6 +315,12 @@ bool TypeParser::isTypeMatched(const substrait::Type & substrait_type, const Dat return a->equals(*b); } +bool TypeParser::isTypeMatchedWithNullability(const substrait::Type & substrait_type, const DataTypePtr & ch_type) +{ + const auto parsed_ch_type = TypeParser::parseType(substrait_type); + return parsed_ch_type->equals(*ch_type); +} + DB::DataTypePtr TypeParser::tryWrapNullable(substrait::Type_Nullability nullable, DB::DataTypePtr nested_type) { if (nullable == substrait::Type_Nullability::Type_Nullability_NULLABILITY_NULLABLE && !nested_type->isNullable()) diff --git a/cpp-ch/local-engine/Parser/TypeParser.h b/cpp-ch/local-engine/Parser/TypeParser.h index 7793ae198b860..a25b2f50afe83 100644 --- a/cpp-ch/local-engine/Parser/TypeParser.h +++ b/cpp-ch/local-engine/Parser/TypeParser.h @@ -48,6 +48,7 @@ class TypeParser static DB::Block buildBlockFromNamedStructWithoutDFS(const substrait::NamedStruct & struct_); static bool isTypeMatched(const substrait::Type & substrait_type, const DB::DataTypePtr & ch_type); + static bool isTypeMatchedWithNullability(const substrait::Type & substrait_type, const DB::DataTypePtr & ch_type); private: /// Mapping spark type names to CH type names. static std::unordered_map type_names_mapping; diff --git a/cpp-ch/local-engine/Parser/WindowRelParser.cpp b/cpp-ch/local-engine/Parser/WindowRelParser.cpp index 969959c3aec7a..a1787a2c93c5c 100644 --- a/cpp-ch/local-engine/Parser/WindowRelParser.cpp +++ b/cpp-ch/local-engine/Parser/WindowRelParser.cpp @@ -357,7 +357,7 @@ void WindowRelParser::tryAddProjectionAfterWindow() { auto & win_info = win_infos[i]; const auto * win_result_node = &actions_dag->findInOutputs(win_info.result_column_name); - win_info.function_parser->convertNodeTypeIfNeeded(win_info.parser_func_info, win_result_node, actions_dag); + win_info.function_parser->convertNodeTypeIfNeeded(win_info.parser_func_info, win_result_node, actions_dag, false); } if (actions_dag->dumpDAG() != dag_footprint) diff --git a/cpp-ch/local-engine/Parser/aggregate_function_parser/CollectListParser.h b/cpp-ch/local-engine/Parser/aggregate_function_parser/CollectListParser.h index a75e9ee2ad3c4..d7a9c1a5c1889 100644 --- a/cpp-ch/local-engine/Parser/aggregate_function_parser/CollectListParser.h +++ b/cpp-ch/local-engine/Parser/aggregate_function_parser/CollectListParser.h @@ -52,7 +52,7 @@ class CollectFunctionParser : public AggregateFunctionParser throw DB::Exception(DB::ErrorCodes::NOT_IMPLEMENTED, "Not implement"); } const DB::ActionsDAG::Node * convertNodeTypeIfNeeded( - const CommonFunctionInfo &, const DB::ActionsDAG::Node * func_node, DB::ActionsDAGPtr & actions_dag) const override + const CommonFunctionInfo &, const DB::ActionsDAG::Node * func_node, DB::ActionsDAGPtr & actions_dag, bool /* withNullability */) const override { const DB::ActionsDAG::Node * ret_node = func_node; if (func_node->result_type->isNullable())