Skip to content

Commit

Permalink
fix ut issue
Browse files Browse the repository at this point in the history
  • Loading branch information
zzcclp committed Feb 8, 2024
1 parent 296aaa0 commit 17c9586
Show file tree
Hide file tree
Showing 7 changed files with 44 additions and 11 deletions.
14 changes: 12 additions & 2 deletions cpp-ch/local-engine/Parser/AggregateFunctionParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -147,10 +147,20 @@ std::pair<String, DB::DataTypes> AggregateFunctionParser::tryApplyCHCombinator(
const DB::ActionsDAG::Node * AggregateFunctionParser::convertNodeTypeIfNeeded(
const CommonFunctionInfo & func_info,
const DB::ActionsDAG::Node * func_node,
DB::ActionsDAGPtr & actions_dag) const
DB::ActionsDAGPtr & actions_dag,
bool withNullability) const
{
const auto & output_type = func_info.output_type;
if (!TypeParser::isTypeMatched(output_type, func_node->result_type))
bool notToConvertNodeType = true;
if (withNullability)
{
notToConvertNodeType = TypeParser::isTypeMatchedWithNullability(output_type, func_node->result_type);
}
else
{
notToConvertNodeType = TypeParser::isTypeMatched(output_type, func_node->result_type);
}
if (!notToConvertNodeType)
{
func_node = ActionsDAGUtil::convertNodeType(
actions_dag, func_node, TypeParser::parseType(output_type)->getName(), func_node->result_name);
Expand Down
2 changes: 1 addition & 1 deletion cpp-ch/local-engine/Parser/AggregateFunctionParser.h
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ class AggregateFunctionParser

/// Make a postprojection for the function result.
virtual const DB::ActionsDAG::Node * convertNodeTypeIfNeeded(
const CommonFunctionInfo & func_info, const DB::ActionsDAG::Node * func_node, DB::ActionsDAGPtr & actions_dag) const;
const CommonFunctionInfo & func_info, const DB::ActionsDAG::Node * func_node, DB::ActionsDAGPtr & actions_dag, bool withNullability) const;

/// Parameters are only used in aggregate functions at present. e.g. percentiles(0.5)(x).
/// 0.5 is the parameter of percentiles function.
Expand Down
28 changes: 22 additions & 6 deletions cpp-ch/local-engine/Parser/AggregateRelParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -463,15 +463,31 @@ void AggregateRelParser::addPostProjection()
auto input_header = plan->getCurrentDataStream().header;
ActionsDAGPtr project_actions_dag = std::make_shared<ActionsDAG>(input_header.getColumnsWithTypeAndName());
auto dag_footprint = project_actions_dag->dumpDAG();
for (const auto & agg_info : aggregates)

if (has_final_stage)
{
for (const auto & agg_info : aggregates)
{
for (const auto * input_node : project_actions_dag->getInputs())
{
if (input_node->result_name == agg_info.measure_column_name)
{
agg_info.function_parser->convertNodeTypeIfNeeded(agg_info.parser_func_info, input_node, project_actions_dag, false);
}
}
}
}
if (has_complete_stage)
{
/// For final stage, the aggregate function's input is only one intermediate result columns.
/// The final result columm's position is the same as the intermediate result column's position.
for (const auto * input_node : project_actions_dag->getInputs())
// on the complete mode, it must consider the nullability when converting node type
for (const auto & agg_info : aggregates)
{
if (input_node->result_name == agg_info.measure_column_name)
for (const auto * output_node : project_actions_dag->getOutputs())
{
agg_info.function_parser->convertNodeTypeIfNeeded(agg_info.parser_func_info, input_node, project_actions_dag);
if (output_node->result_name == agg_info.measure_column_name)
{
agg_info.function_parser->convertNodeTypeIfNeeded(agg_info.parser_func_info, output_node, project_actions_dag, true);
}
}
}
}
Expand Down
6 changes: 6 additions & 0 deletions cpp-ch/local-engine/Parser/TypeParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,12 @@ bool TypeParser::isTypeMatched(const substrait::Type & substrait_type, const Dat
return a->equals(*b);
}

bool TypeParser::isTypeMatchedWithNullability(const substrait::Type & substrait_type, const DataTypePtr & ch_type)
{
const auto parsed_ch_type = TypeParser::parseType(substrait_type);
return parsed_ch_type->equals(*ch_type);
}

DB::DataTypePtr TypeParser::tryWrapNullable(substrait::Type_Nullability nullable, DB::DataTypePtr nested_type)
{
if (nullable == substrait::Type_Nullability::Type_Nullability_NULLABILITY_NULLABLE && !nested_type->isNullable())
Expand Down
1 change: 1 addition & 0 deletions cpp-ch/local-engine/Parser/TypeParser.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ class TypeParser
static DB::Block buildBlockFromNamedStructWithoutDFS(const substrait::NamedStruct & struct_);

static bool isTypeMatched(const substrait::Type & substrait_type, const DB::DataTypePtr & ch_type);
static bool isTypeMatchedWithNullability(const substrait::Type & substrait_type, const DB::DataTypePtr & ch_type);
private:
/// Mapping spark type names to CH type names.
static std::unordered_map<String, String> type_names_mapping;
Expand Down
2 changes: 1 addition & 1 deletion cpp-ch/local-engine/Parser/WindowRelParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,7 @@ void WindowRelParser::tryAddProjectionAfterWindow()
{
auto & win_info = win_infos[i];
const auto * win_result_node = &actions_dag->findInOutputs(win_info.result_column_name);
win_info.function_parser->convertNodeTypeIfNeeded(win_info.parser_func_info, win_result_node, actions_dag);
win_info.function_parser->convertNodeTypeIfNeeded(win_info.parser_func_info, win_result_node, actions_dag, false);
}

if (actions_dag->dumpDAG() != dag_footprint)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ class CollectFunctionParser : public AggregateFunctionParser
throw DB::Exception(DB::ErrorCodes::NOT_IMPLEMENTED, "Not implement");
}
const DB::ActionsDAG::Node * convertNodeTypeIfNeeded(
const CommonFunctionInfo &, const DB::ActionsDAG::Node * func_node, DB::ActionsDAGPtr & actions_dag) const override
const CommonFunctionInfo &, const DB::ActionsDAG::Node * func_node, DB::ActionsDAGPtr & actions_dag, bool /* withNullability */) const override
{
const DB::ActionsDAG::Node * ret_node = func_node;
if (func_node->result_type->isNullable())
Expand Down

0 comments on commit 17c9586

Please sign in to comment.