From 8cbe372b55a1fa865327991aa0d83c3155f0a10b Mon Sep 17 00:00:00 2001 From: Chendi Xue Date: Wed, 6 Jan 2021 19:01:54 +0800 Subject: [PATCH] ColumnarWSCG optimization: only GetValue when field is used Signed-off-by: Chendi Xue --- .../arrow_compute/ext/array_item_index.h | 10 +- .../ext/basic_physical_kernels.cc | 78 ++--- .../arrow_compute/ext/codegen_context.h | 4 +- .../ext/conditioned_merge_join_kernel.cc | 200 +++++------ .../ext/conditioned_probe_kernel.cc | 316 ++++++++++-------- .../ext/expression_codegen_visitor.cc | 39 ++- .../ext/expression_codegen_visitor.h | 26 +- .../codegen/arrow_compute/ext/kernels_ext.h | 40 ++- .../ext/whole_stage_codegen_kernel.cc | 11 +- 9 files changed, 395 insertions(+), 329 deletions(-) diff --git a/cpp/src/codegen/arrow_compute/ext/array_item_index.h b/cpp/src/codegen/arrow_compute/ext/array_item_index.h index 34d635224..8d5897a2d 100644 --- a/cpp/src/codegen/arrow_compute/ext/array_item_index.h +++ b/cpp/src/codegen/arrow_compute/ext/array_item_index.h @@ -25,18 +25,14 @@ namespace extra { struct ArrayItemIndex { uint16_t id = 0; uint16_t array_id = 0; - bool valid = true; - ArrayItemIndex() : array_id(0), id(0), valid(true) {} - ArrayItemIndex(bool valid) : array_id(0), id(0), valid(valid) {} - ArrayItemIndex(uint16_t array_id, uint16_t id) - : array_id(array_id), id(id), valid(true) {} + ArrayItemIndex() : array_id(0), id(0) {} + ArrayItemIndex(uint16_t array_id, uint16_t id) : array_id(array_id), id(id) {} }; struct ArrayItemIndexS { uint16_t id = 0; uint16_t array_id = 0; ArrayItemIndexS() : array_id(0), id(0) {} - ArrayItemIndexS(uint16_t array_id, uint16_t id) - : array_id(array_id), id(id) {} + ArrayItemIndexS(uint16_t array_id, uint16_t id) : array_id(array_id), id(id) {} }; } // namespace extra diff --git a/cpp/src/codegen/arrow_compute/ext/basic_physical_kernels.cc b/cpp/src/codegen/arrow_compute/ext/basic_physical_kernels.cc index 1f8d60006..af0c018b7 100644 --- a/cpp/src/codegen/arrow_compute/ext/basic_physical_kernels.cc +++ b/cpp/src/codegen/arrow_compute/ext/basic_physical_kernels.cc @@ -65,16 +65,20 @@ class ProjectKernel::Impl { std::string GetSignature() { return signature_; } - arrow::Status DoCodeGen(int level, const std::vector input, - std::shared_ptr* codegen_ctx_out, int* var_id) { + arrow::Status DoCodeGen( + int level, + std::vector, gandiva::DataTypePtr>> + input, + std::shared_ptr* codegen_ctx_out, int* var_id) { auto codegen_ctx = std::make_shared(); int idx = 0; for (auto project : project_list_) { std::shared_ptr project_node_visitor; std::vector input_list; std::vector indices_list; - RETURN_NOT_OK(MakeExpressionCodegenVisitor(project, input, {input_field_list_}, -1, - var_id, &input_list, + auto is_local = false; + RETURN_NOT_OK(MakeExpressionCodegenVisitor(project, &input, {input_field_list_}, -1, + var_id, is_local, &input_list, &project_node_visitor)); codegen_ctx->process_codes += project_node_visitor->GetPrepare(); auto name = project_node_visitor->GetResult(); @@ -83,26 +87,19 @@ class ProjectKernel::Impl { auto output_name = "project_" + std::to_string(level) + "_output_col_" + std::to_string(idx++); auto output_validity = output_name + "_validity"; - codegen_ctx->output_list.push_back( - std::make_pair(output_name, project->return_type())); + std::stringstream output_get_ss; + output_get_ss << "auto " << output_name << " = " << name << ";" << std::endl; + output_get_ss << "auto " << output_validity << " = " << validity << ";" + << std::endl; + + codegen_ctx->output_list.push_back(std::make_pair( + std::make_pair(output_name, output_get_ss.str()), project->return_type())); for (auto header : project_node_visitor->GetHeaders()) { if (std::find(codegen_ctx->header_codes.begin(), codegen_ctx->header_codes.end(), header) == codegen_ctx->header_codes.end()) { codegen_ctx->header_codes.push_back(header); } } - - std::stringstream process_ss; - std::stringstream define_ss; - - process_ss << output_name << " = " << name << ";" << std::endl; - process_ss << output_validity << " = " << validity << ";" << std::endl; - codegen_ctx->process_codes += process_ss.str(); - - define_ss << GetCTypeString(project->return_type()) << " " << output_name << ";" - << std::endl; - define_ss << "bool " << output_validity << ";" << std::endl; - codegen_ctx->definition_codes += define_ss.str(); } *codegen_ctx_out = codegen_ctx; return arrow::Status::OK(); @@ -139,9 +136,11 @@ arrow::Status ProjectKernel::MakeResultIterator( std::string ProjectKernel::GetSignature() { return impl_->GetSignature(); } -arrow::Status ProjectKernel::DoCodeGen(int level, std::vector input, - std::shared_ptr* codegen_ctx, - int* var_id) { +arrow::Status ProjectKernel::DoCodeGen( + int level, + std::vector, gandiva::DataTypePtr>> + input, + std::shared_ptr* codegen_ctx, int* var_id) { return impl_->DoCodeGen(level, input, codegen_ctx, var_id); } @@ -166,14 +165,18 @@ class FilterKernel::Impl { std::string GetSignature() { return signature_; } - arrow::Status DoCodeGen(int level, const std::vector input, - std::shared_ptr* codegen_ctx_out, int* var_id) { + arrow::Status DoCodeGen( + int level, + std::vector, gandiva::DataTypePtr>> + input, + std::shared_ptr* codegen_ctx_out, int* var_id) { auto codegen_ctx = std::make_shared(); std::shared_ptr condition_node_visitor; std::vector input_list; std::vector indices_list; - RETURN_NOT_OK(MakeExpressionCodegenVisitor(condition_, input, {input_field_list_}, -1, - var_id, &input_list, + auto is_local = false; + RETURN_NOT_OK(MakeExpressionCodegenVisitor(condition_, &input, {input_field_list_}, + -1, var_id, is_local, &input_list, &condition_node_visitor)); codegen_ctx->process_codes += condition_node_visitor->GetPrepare(); for (auto header : condition_node_visitor->GetHeaders()) { @@ -185,27 +188,16 @@ class FilterKernel::Impl { auto condition_codes = condition_node_visitor->GetResult(); std::stringstream process_ss; - std::stringstream define_ss; process_ss << "if (!(" << condition_codes << ")) {" << std::endl; process_ss << "continue;" << std::endl; process_ss << "}" << std::endl; int idx = 0; for (auto field : input_field_list_) { - auto output_name = - "filter_" + std::to_string(level) + "_output_col_" + std::to_string(idx); - auto output_validity = output_name + "_validity"; - codegen_ctx->output_list.push_back(std::make_pair(output_name, field->type())); - - define_ss << GetCTypeString(field->type()) << " " << output_name << ";" - << std::endl; - define_ss << "bool " << output_validity << ";" << std::endl; - - process_ss << output_name << " = " << input[idx] << ";" << std::endl; - process_ss << output_validity << " = " << input[idx] << "_validity" - << ";" << std::endl; + codegen_ctx->output_list.push_back( + std::make_pair(std::make_pair(input[idx].first.first, input[idx].first.second), + field->type())); idx++; } - codegen_ctx->definition_codes += define_ss.str(); codegen_ctx->process_codes += process_ss.str(); *codegen_ctx_out = codegen_ctx; @@ -244,9 +236,11 @@ arrow::Status FilterKernel::MakeResultIterator( std::string FilterKernel::GetSignature() { return impl_->GetSignature(); } -arrow::Status FilterKernel::DoCodeGen(int level, std::vector input, - std::shared_ptr* codegen_ctx, - int* var_id) { +arrow::Status FilterKernel::DoCodeGen( + int level, + std::vector, gandiva::DataTypePtr>> + input, + std::shared_ptr* codegen_ctx, int* var_id) { return impl_->DoCodeGen(level, input, codegen_ctx, var_id); } diff --git a/cpp/src/codegen/arrow_compute/ext/codegen_context.h b/cpp/src/codegen/arrow_compute/ext/codegen_context.h index fc210f475..8d05977d8 100644 --- a/cpp/src/codegen/arrow_compute/ext/codegen_context.h +++ b/cpp/src/codegen/arrow_compute/ext/codegen_context.h @@ -29,5 +29,7 @@ struct CodeGenContext { std::string finish_codes; std::string definition_codes; std::vector function_list; - std::vector>> output_list; + std::vector< + std::pair, std::shared_ptr>> + output_list; }; \ No newline at end of file diff --git a/cpp/src/codegen/arrow_compute/ext/conditioned_merge_join_kernel.cc b/cpp/src/codegen/arrow_compute/ext/conditioned_merge_join_kernel.cc index 13a3435d0..f81da748d 100644 --- a/cpp/src/codegen/arrow_compute/ext/conditioned_merge_join_kernel.cc +++ b/cpp/src/codegen/arrow_compute/ext/conditioned_merge_join_kernel.cc @@ -99,8 +99,11 @@ class ConditionedMergeJoinKernel::Impl { std::string GetSignature() { return ""; } - arrow::Status DoCodeGen(int level, std::vector input, - std::shared_ptr* codegen_ctx_out, int* var_id) { + arrow::Status DoCodeGen( + int level, + std::vector, gandiva::DataTypePtr>> + input, + std::shared_ptr* codegen_ctx_out, int* var_id) { auto codegen_ctx = std::make_shared(); bool use_relation_for_stream = input.empty(); @@ -144,19 +147,6 @@ class ConditionedMergeJoinKernel::Impl { idx++; } codegen_ctx->relation_prepare_codes = sort_prepare_ss.str(); - - // define output list here, which will also be defined in class variables definition - idx = 0; - for (auto field : result_schema_) { - auto output_name = "sort_merge_join_" + std::to_string(relation_id_[0]) + - "_output_col_" + std::to_string(idx++); - auto output_validity = output_name + "_validity"; - codegen_ctx->output_list.push_back(std::make_pair(output_name, field->type())); - sort_define_ss << GetCTypeString(field->type()) << " " << output_name << ";" - << std::endl; - sort_define_ss << "bool " << output_validity << ";" << std::endl; - } - codegen_ctx->definition_codes = sort_define_ss.str(); ///// Prepare compare function ///// @@ -172,9 +162,10 @@ class ConditionedMergeJoinKernel::Impl { for (int i = 0; i < 2; i++) { for (auto expr : project_codegen_list[i]) { std::shared_ptr project_node_visitor; - RETURN_NOT_OK(MakeExpressionCodegenVisitor(expr->root(), input, field_list, - relation_id_[0], var_id, &input_list, - &project_node_visitor, true)); + auto is_local = true; + RETURN_NOT_OK(MakeExpressionCodegenVisitor( + expr->root(), &input, field_list, relation_id_[0], var_id, is_local, + &input_list, &project_node_visitor, true)); prepare_ss << project_node_visitor->GetPrepare(); auto key_name = project_node_visitor->GetResult(); auto validity_name = project_node_visitor->GetPreCheck(); @@ -244,9 +235,10 @@ class ConditionedMergeJoinKernel::Impl { // 3. do continue if not exists if (cond_check) { std::shared_ptr condition_node_visitor; - RETURN_NOT_OK(MakeExpressionCodegenVisitor(condition_, input, field_list, - relation_id_[0], var_id, &prepare_list, - &condition_node_visitor, true)); + auto is_local = true; + RETURN_NOT_OK(MakeExpressionCodegenVisitor( + condition_, &input, field_list, relation_id_[0], var_id, is_local, + &prepare_list, &condition_node_visitor, true)); auto function_name = "ConditionCheck_" + std::to_string(relation_id_[0]); if (use_relation_for_stream) { function_define_ss << "inline bool " << function_name @@ -301,7 +293,6 @@ class ConditionedMergeJoinKernel::Impl { std::vector cached_; arrow::Status GetInnerJoin(bool cond_check, bool use_relation_for_stream, - std::string set_value, std::shared_ptr* output) { std::stringstream shuffle_ss; std::stringstream codes_ss; @@ -313,7 +304,6 @@ class ConditionedMergeJoinKernel::Impl { auto range_id = "range_" + relation_id + "_i"; auto streamed_range_name = "streamed_range_" + relation_id; auto streamed_range_id = "streamed_range_" + relation_id + "_i"; - auto fill_null_name = "fill_null_" + relation_id; auto build_relation = "sort_relation_" + relation_id + "_"; auto streamed_relation = "sort_relation_" + std::to_string(relation_id_[1]) + "_"; auto left_index_name = "left_index_" + relation_id; @@ -321,7 +311,6 @@ class ConditionedMergeJoinKernel::Impl { ///// Get Matched row ///// codes_ss << "int " << range_name << " = 0;" << std::endl; - codes_ss << "bool " << fill_null_name << " = false;" << std::endl; codes_ss << "auto " << function_name << "_res = " << function_name << "();" << std::endl; codes_ss << "while (" << function_name << "_res < 0) {" << std::endl; @@ -348,10 +337,15 @@ class ConditionedMergeJoinKernel::Impl { codes_ss << "for (int " << streamed_range_id << " = 0; " << streamed_range_id << " < " << streamed_range_name << "; " << streamed_range_id << "++) {" << std::endl; - codes_ss << "auto " << right_index_name << " = " << streamed_relation + codes_ss << right_index_name << " = " << streamed_relation << "->GetItemIndexWithShift(" << streamed_range_id << ");" << std::endl; + std::stringstream prepare_ss; + prepare_ss << "ArrayItemIndexS " << right_index_name << ";" << std::endl; + (*output)->definition_codes += prepare_ss.str(); } - codes_ss << "ArrayItemIndexS " << left_index_name << ";" << std::endl; + std::stringstream prepare_ss; + prepare_ss << "ArrayItemIndexS " << left_index_name << ";" << std::endl; + (*output)->definition_codes += prepare_ss.str(); codes_ss << "for (int " << range_id << " = 0; " << range_id << " < " << range_name << "; " << range_id << "++) {" << std::endl; codes_ss << left_index_name << " = " << build_relation << "->GetItemIndexWithShift(" @@ -367,9 +361,6 @@ class ConditionedMergeJoinKernel::Impl { } codes_ss << " continue;" << std::endl; codes_ss << "}" << std::endl; - codes_ss << set_value << std::endl; - } else { - codes_ss << set_value << std::endl; } finish_codes_ss << "} // end of Inner Join" << std::endl; if (use_relation_for_stream) { @@ -383,7 +374,6 @@ class ConditionedMergeJoinKernel::Impl { return arrow::Status::OK(); } arrow::Status GetOuterJoin(bool cond_check, bool use_relation_for_stream, - std::string set_value, std::shared_ptr* output) { std::stringstream shuffle_ss; std::stringstream codes_ss; @@ -392,7 +382,7 @@ class ConditionedMergeJoinKernel::Impl { auto function_name = "JoinCompare_" + relation_id; auto condition_name = "ConditionCheck_" + relation_id; auto range_name = "range_" + relation_id; - auto fill_null_name = "fill_null_" + relation_id; + auto fill_null_name = "is_outer_null_" + relation_id; auto range_id = "range_" + relation_id + "_i"; auto streamed_range_name = "streamed_range_" + relation_id; auto streamed_range_id = "streamed_range_" + relation_id + "_i"; @@ -403,7 +393,7 @@ class ConditionedMergeJoinKernel::Impl { ///// Get Matched row ///// codes_ss << "int " << range_name << " = 0;" << std::endl; - codes_ss << "bool " << fill_null_name << " = false;" << std::endl; + codes_ss << fill_null_name << " = false;" << std::endl; codes_ss << "auto " << function_name << "_res = " << function_name << "();" << std::endl; codes_ss << "while (" << function_name << "_res < 0) {" << std::endl; @@ -427,10 +417,16 @@ class ConditionedMergeJoinKernel::Impl { codes_ss << "for (int " << streamed_range_id << " = 0; " << streamed_range_id << " < " << streamed_range_name << "; " << streamed_range_id << "++) {" << std::endl; - codes_ss << "auto " << right_index_name << " = " << streamed_relation + codes_ss << right_index_name << " = " << streamed_relation << "->GetItemIndexWithShift(" << streamed_range_id << ");" << std::endl; + std::stringstream prepare_ss; + prepare_ss << "ArrayItemIndexS " << right_index_name << ";" << std::endl; + (*output)->definition_codes += prepare_ss.str(); } - codes_ss << "ArrayItemIndexS " << left_index_name << ";" << std::endl; + std::stringstream prepare_ss; + prepare_ss << "ArrayItemIndexS " << left_index_name << ";" << std::endl; + prepare_ss << "bool " << fill_null_name << ";" << std::endl; + (*output)->definition_codes += prepare_ss.str(); codes_ss << "for (int " << range_id << " = 0; " << range_id << " < " << range_name << "; " << range_id << "++) {" << std::endl; codes_ss << "if(!" << fill_null_name << "){" << std::endl; @@ -448,9 +444,6 @@ class ConditionedMergeJoinKernel::Impl { } codes_ss << fill_null_name << " = true;" << std::endl; codes_ss << "}" << std::endl; - codes_ss << set_value << std::endl; - } else { - codes_ss << set_value << std::endl; } finish_codes_ss << "} // end of Outer Join" << std::endl; if (use_relation_for_stream) { @@ -464,7 +457,6 @@ class ConditionedMergeJoinKernel::Impl { return arrow::Status::OK(); } arrow::Status GetAntiJoin(bool cond_check, bool use_relation_for_stream, - std::string set_value, std::shared_ptr* output) { std::stringstream shuffle_ss; std::stringstream codes_ss; @@ -473,7 +465,6 @@ class ConditionedMergeJoinKernel::Impl { auto function_name = "JoinCompare_" + relation_id; auto condition_name = "ConditionCheck_" + relation_id; auto range_name = "range_" + relation_id; - auto fill_null_name = "fill_null_" + relation_id; auto found_match_name = "found_" + relation_id; auto range_id = "range_" + relation_id + "_i"; auto streamed_range_name = "streamed_range_" + relation_id; @@ -485,7 +476,6 @@ class ConditionedMergeJoinKernel::Impl { ///// Get Matched row ///// codes_ss << "int " << range_name << " = 1;" << std::endl; - codes_ss << "bool " << fill_null_name << " = false;" << std::endl; codes_ss << "bool " << found_match_name << " = false;" << std::endl; codes_ss << "auto " << function_name << "_res = " << function_name << "();" << std::endl; @@ -516,9 +506,15 @@ class ConditionedMergeJoinKernel::Impl { codes_ss << "for (int " << streamed_range_id << " = 0; " << streamed_range_id << " < " << streamed_range_name << "; " << streamed_range_id << "++) {" << std::endl; - codes_ss << "auto " << right_index_name << " = " << streamed_relation + codes_ss << right_index_name << " = " << streamed_relation << "->GetItemIndexWithShift(" << streamed_range_id << ");" << std::endl; + std::stringstream prepare_ss; + prepare_ss << "ArrayItemIndexS " << right_index_name << ";" << std::endl; + (*output)->definition_codes += prepare_ss.str(); } + std::stringstream prepare_ss; + prepare_ss << "ArrayItemIndexS " << left_index_name << ";" << std::endl; + (*output)->definition_codes += prepare_ss.str(); codes_ss << "for (int " << range_id << " = 0; " << range_id << " < 1;" << range_id << "++) {" << std::endl; if (cond_check) { @@ -526,7 +522,7 @@ class ConditionedMergeJoinKernel::Impl { codes_ss << "if (" << found_match_name << ") {" << std::endl; codes_ss << found_match_name << " = false;" << std::endl; codes_ss << "for (int j = 0; j < " << range_name << "; j++) {" << std::endl; - codes_ss << "auto " << left_index_name << " = " << build_relation + codes_ss << left_index_name << " = " << build_relation << "->GetItemIndexWithShift(j);" << std::endl; if (use_relation_for_stream) { codes_ss << "if (" << condition_name << "(" << left_index_name << ", " @@ -543,9 +539,6 @@ class ConditionedMergeJoinKernel::Impl { codes_ss << "continue;" << std::endl; codes_ss << "}" << std::endl; codes_ss << "}" << std::endl; - codes_ss << set_value << std::endl; - } else { - codes_ss << set_value << std::endl; } finish_codes_ss << "} // end of anti Join" << std::endl; if (use_relation_for_stream) { @@ -559,7 +552,7 @@ class ConditionedMergeJoinKernel::Impl { return arrow::Status::OK(); } arrow::Status GetSemiJoin(bool cond_check, bool use_relation_for_stream, - std::string set_value, + std::shared_ptr* output) { std::stringstream shuffle_ss; std::stringstream codes_ss; @@ -568,7 +561,6 @@ class ConditionedMergeJoinKernel::Impl { auto function_name = "JoinCompare_" + relation_id; auto condition_name = "ConditionCheck_" + relation_id; auto range_name = "range_" + relation_id; - auto fill_null_name = "fill_null_" + relation_id; auto found_match_name = "found_" + relation_id; auto range_id = "range_" + relation_id + "_i"; auto streamed_range_name = "streamed_range_" + relation_id; @@ -580,7 +572,6 @@ class ConditionedMergeJoinKernel::Impl { ///// Get Matched row ///// codes_ss << "int " << range_name << " = 1;" << std::endl; - codes_ss << "bool " << fill_null_name << " = false;" << std::endl; codes_ss << "bool " << found_match_name << " = true;" << std::endl; codes_ss << "auto " << function_name << "_res = " << function_name << "();" << std::endl; @@ -609,16 +600,22 @@ class ConditionedMergeJoinKernel::Impl { codes_ss << "for (int " << streamed_range_id << " = 0; " << streamed_range_id << " < " << streamed_range_name << "; " << streamed_range_id << "++) {" << std::endl; - codes_ss << "auto " << right_index_name << " = " << streamed_relation + codes_ss << right_index_name << " = " << streamed_relation << "->GetItemIndexWithShift(" << streamed_range_id << ");" << std::endl; + std::stringstream prepare_ss; + prepare_ss << "ArrayItemIndexS " << right_index_name << ";" << std::endl; + (*output)->definition_codes += prepare_ss.str(); } + std::stringstream prepare_ss; + prepare_ss << "ArrayItemIndexS " << left_index_name << ";" << std::endl; + (*output)->definition_codes += prepare_ss.str(); codes_ss << "for (int " << range_id << " = 0; " << range_id << " < 1;" << range_id << "++) {" << std::endl; if (cond_check) { auto condition_name = "ConditionCheck_" + std::to_string(relation_id_[0]); codes_ss << found_match_name << " = false;" << std::endl; codes_ss << "for (int j = 0; j < " << range_name << "; j++) {" << std::endl; - codes_ss << "auto " << left_index_name << " = " << build_relation + codes_ss << left_index_name << " = " << build_relation << "->GetItemIndexWithShift(j);" << std::endl; if (use_relation_for_stream) { codes_ss << "if (" << condition_name << "(" << left_index_name << ", " @@ -634,9 +631,6 @@ class ConditionedMergeJoinKernel::Impl { codes_ss << "if (!" << found_match_name << ") {" << std::endl; codes_ss << "continue;" << std::endl; codes_ss << "}" << std::endl; - codes_ss << set_value << std::endl; - } else { - codes_ss << set_value << std::endl; } finish_codes_ss << "} // end of semi Join" << std::endl; if (use_relation_for_stream) { @@ -651,7 +645,6 @@ class ConditionedMergeJoinKernel::Impl { } arrow::Status GetExistenceJoin(bool cond_check, bool use_relation_for_stream, - std::string set_value, std::shared_ptr* output) { std::stringstream shuffle_ss; std::stringstream codes_ss; @@ -660,7 +653,6 @@ class ConditionedMergeJoinKernel::Impl { auto function_name = "JoinCompare_" + relation_id; auto condition_name = "ConditionCheck_" + relation_id; auto range_name = "range_" + relation_id; - auto fill_null_name = "fill_null_" + relation_id; auto found_match_name = "found_" + relation_id; auto range_id = "range_" + relation_id + "_i"; auto streamed_range_name = "streamed_range_" + relation_id; @@ -674,7 +666,6 @@ class ConditionedMergeJoinKernel::Impl { ///// Get Matched row ///// codes_ss << "int " << range_name << " = 1;" << std::endl; - codes_ss << "bool " << fill_null_name << " = false;" << std::endl; codes_ss << "bool " << found_match_name << " = true;" << std::endl; codes_ss << "auto " << function_name << "_res = " << function_name << "();" << std::endl; @@ -700,9 +691,15 @@ class ConditionedMergeJoinKernel::Impl { codes_ss << "for (int " << streamed_range_id << " = 0; " << streamed_range_id << " < " << streamed_range_name << "; " << streamed_range_id << "++) {" << std::endl; - codes_ss << "auto " << right_index_name << " = " << streamed_relation + codes_ss << right_index_name << " = " << streamed_relation << "->GetItemIndexWithShift(" << streamed_range_id << ");" << std::endl; + std::stringstream prepare_ss; + prepare_ss << "ArrayItemIndexS " << right_index_name << ";" << std::endl; + (*output)->definition_codes += prepare_ss.str(); } + std::stringstream prepare_ss; + prepare_ss << "ArrayItemIndexS " << left_index_name << ";" << std::endl; + (*output)->definition_codes += prepare_ss.str(); codes_ss << "for (int " << range_id << " = 0; " << range_id << " < 1;" << range_id << "++) {" << std::endl; if (cond_check) { @@ -710,7 +707,7 @@ class ConditionedMergeJoinKernel::Impl { codes_ss << "if (" << found_match_name << ") {" << std::endl; codes_ss << found_match_name << " = false;" << std::endl; codes_ss << "for (int j = 0; j < " << range_name << "; j++) {" << std::endl; - codes_ss << "auto " << left_index_name << " = " << build_relation + codes_ss << left_index_name << " = " << build_relation << "->GetItemIndexWithShift(j);" << std::endl; if (use_relation_for_stream) { codes_ss << "if (" << condition_name << "(" << left_index_name << ", " @@ -726,11 +723,9 @@ class ConditionedMergeJoinKernel::Impl { codes_ss << "}" << std::endl; codes_ss << "auto " << exist_name << " = " << found_match_name << ";" << std::endl; codes_ss << "bool " << exist_validity << " = true;" << std::endl; - codes_ss << set_value << std::endl; } else { codes_ss << "auto " << exist_name << " = " << found_match_name << ";" << std::endl; codes_ss << "bool " << exist_validity << " = true;" << std::endl; - codes_ss << set_value << std::endl; } finish_codes_ss << "} // end of Existence Join" << std::endl; if (use_relation_for_stream) { @@ -744,24 +739,29 @@ class ConditionedMergeJoinKernel::Impl { return arrow::Status::OK(); } - arrow::Status GetProcessJoin(int join_type, bool cond_check, - std::vector input, - std::shared_ptr* output) { + arrow::Status GetProcessJoin( + int join_type, bool cond_check, + std::vector, gandiva::DataTypePtr>> + input, + std::shared_ptr* output) { // General codes when found matched rows - std::stringstream valid_ss; auto relation_id = std::to_string(relation_id_[0]); auto left_index_name = "left_index_" + relation_id; auto right_index_name = "right_index_" + relation_id; - auto fill_null_name = "fill_null_" + relation_id; + auto fill_null_name = "is_outer_null_" + relation_id; bool use_relation_for_stream = input.empty(); + // define output list here, which will also be defined in class variables definition + + int right_index_shift = 0; for (int idx = 0; idx < result_schema_index_list_.size(); idx++) { - auto output_name = - "sort_merge_join_" + relation_id + "_output_col_" + std::to_string(idx); - auto output_validity = output_name + "_validity"; std::string name; std::string arguments; std::shared_ptr type; + std::stringstream valid_ss; + auto output_name = "sort_merge_join_" + std::to_string(relation_id_[0]) + + "_output_col_" + std::to_string(idx); + auto output_validity = output_name + "_validity"; auto i = result_schema_index_list_[idx].second; if (result_schema_index_list_[idx].first == 0) { /*left(streamed) table*/ if (join_type != 0 && join_type != 1) continue; @@ -769,10 +769,15 @@ class ConditionedMergeJoinKernel::Impl { "sort_relation_" + std::to_string(relation_id_[0]) + "_" + std::to_string(i); type = left_field_list_[i]->type(); arguments = left_index_name + ".array_id, " + left_index_name + ".id"; - valid_ss << output_validity << " = !" << fill_null_name << " && !" << name - << "->IsNull(" << arguments << ");" << std::endl; - valid_ss << output_name << " = " << name << "->GetValue(" << arguments << ");" - << std::endl; + if (join_type == 1) { + valid_ss << "auto " << output_validity << " = !" << fill_null_name << " && !" + << name << "->IsNull(" << arguments << ");" << std::endl; + } else { + valid_ss << "auto " << output_validity << " = !" << name << "->IsNull(" + << arguments << ");" << std::endl; + } + valid_ss << "auto " << output_name << " = " << name << "->GetValue(" << arguments + << ");" << std::endl; } else { /*right(streamed) table*/ if (use_relation_for_stream) { /* use sort relation in streamed side*/ @@ -780,53 +785,58 @@ class ConditionedMergeJoinKernel::Impl { name = "sort_relation_" + std::to_string(relation_id_[0]) + "_existence_value"; type = arrow::boolean(); - valid_ss << output_validity << " = " << name << "_validity;" << std::endl; - valid_ss << output_name << " = " << name << ";" << std::endl; + valid_ss << "auto " << output_validity << " = " << name << "_validity;" + << std::endl; + valid_ss << "auto " << output_name << " = " << name << ";" << std::endl; + right_index_shift = -1; } else { + i += right_index_shift; name = "sort_relation_" + std::to_string(relation_id_[1]) + "_" + std::to_string(i); type = right_field_list_[i]->type(); arguments = right_index_name + ".array_id, " + right_index_name + ".id"; - valid_ss << output_validity << " = !" << name << "->IsNull(" << arguments - << ");" << std::endl; - valid_ss << output_name << " = " << name << "->GetValue(" << arguments << ");" - << std::endl; + valid_ss << "auto " << output_validity << " = !" << name << "->IsNull(" + << arguments << ");" << std::endl; + valid_ss << "auto " << output_name << " = " << name << "->GetValue(" + << arguments << ");" << std::endl; } } else { /* use previous output in streamed side*/ if (exist_index_ != -1 && exist_index_ == i) { name = "sort_relation_" + std::to_string(relation_id_[0]) + "_existence_value"; - valid_ss << output_validity << " = " << name << "_validity;" << std::endl; - valid_ss << output_name << " = " << name << ";" << std::endl; + valid_ss << "auto " << output_validity << " = " << name << "_validity;" + << std::endl; + valid_ss << "auto " << output_name << " = " << name << ";" << std::endl; + type = arrow::boolean(); + right_index_shift = -1; } else { - type = right_field_list_[i]->type(); - valid_ss << output_validity << " = " << input[i] << "_validity;" << std::endl; - valid_ss << output_name << " = " << input[i] << ";" << std::endl; + i += right_index_shift; + output_name = input[i].first.first; + output_validity = output_name + "_validity"; + valid_ss << input[i].first.second; + type = input[i].second; } } } + (*output)->output_list.push_back( + std::make_pair(std::make_pair(output_name, valid_ss.str()), type)); } switch (join_type) { case 0: { /* inner join */ - RETURN_NOT_OK( - GetInnerJoin(cond_check, use_relation_for_stream, valid_ss.str(), output)); + RETURN_NOT_OK(GetInnerJoin(cond_check, use_relation_for_stream, output)); } break; case 1: { /* Outer join */ - RETURN_NOT_OK( - GetOuterJoin(cond_check, use_relation_for_stream, valid_ss.str(), output)); + RETURN_NOT_OK(GetOuterJoin(cond_check, use_relation_for_stream, output)); } break; case 2: { /* Anti join */ - RETURN_NOT_OK( - GetAntiJoin(cond_check, use_relation_for_stream, valid_ss.str(), output)); + RETURN_NOT_OK(GetAntiJoin(cond_check, use_relation_for_stream, output)); } break; case 3: { /* Semi join */ - RETURN_NOT_OK( - GetSemiJoin(cond_check, use_relation_for_stream, valid_ss.str(), output)); + RETURN_NOT_OK(GetSemiJoin(cond_check, use_relation_for_stream, output)); } break; case 4: { /* Existence join */ - RETURN_NOT_OK(GetExistenceJoin(cond_check, use_relation_for_stream, - valid_ss.str(), output)); + RETURN_NOT_OK(GetExistenceJoin(cond_check, use_relation_for_stream, output)); } break; default: { } break; @@ -871,7 +881,9 @@ arrow::Status ConditionedMergeJoinKernel::MakeResultIterator( std::string ConditionedMergeJoinKernel::GetSignature() { return impl_->GetSignature(); } arrow::Status ConditionedMergeJoinKernel::DoCodeGen( - int level, std::vector input, + int level, + std::vector, gandiva::DataTypePtr>> + input, std::shared_ptr* codegen_ctx_out, int* var_id) { return impl_->DoCodeGen(level, input, codegen_ctx_out, var_id); } diff --git a/cpp/src/codegen/arrow_compute/ext/conditioned_probe_kernel.cc b/cpp/src/codegen/arrow_compute/ext/conditioned_probe_kernel.cc index 94024f2fe..f9ebcce0b 100644 --- a/cpp/src/codegen/arrow_compute/ext/conditioned_probe_kernel.cc +++ b/cpp/src/codegen/arrow_compute/ext/conditioned_probe_kernel.cc @@ -164,8 +164,11 @@ class ConditionedProbeKernel::Impl { std::string GetSignature() { return ""; } - arrow::Status DoCodeGen(int level, std::vector input, - std::shared_ptr* codegen_ctx_out, int* var_id) { + arrow::Status DoCodeGen( + int level, + std::vector, gandiva::DataTypePtr>> + input, + std::shared_ptr* codegen_ctx_out, int* var_id) { auto codegen_ctx = std::make_shared(); codegen_ctx->header_codes.push_back( @@ -205,24 +208,13 @@ class ConditionedProbeKernel::Impl { } codegen_ctx->relation_prepare_codes = hash_prepare_ss.str(); - // define output list here, which will also be defined in class variables definition - int idx = 0; - for (auto field : result_schema_) { - auto output_name = "hash_relation_" + std::to_string(hash_relation_id_) + - "_output_col_" + std::to_string(idx++); - auto output_validity = output_name + "_validity"; - codegen_ctx->output_list.push_back(std::make_pair(output_name, field->type())); - hash_define_ss << GetCTypeString(field->type()) << " " << output_name << ";" - << std::endl; - hash_define_ss << "bool " << output_validity << ";" << std::endl; - } - codegen_ctx->definition_codes = hash_define_ss.str(); // 1.1 prepare probe key column, name is key_0 and key_0_validity std::stringstream prepare_ss; std::vector input_list; - std::vector project_output_list; + std::vector, gandiva::DataTypePtr>> + project_output_list; auto unsafe_row_name = "unsafe_row_" + std::to_string(hash_relation_id_); bool do_unsafe_row = true; if (right_key_project_codegen_.size() == 1) { @@ -239,12 +231,13 @@ class ConditionedProbeKernel::Impl { codegen_ctx->unsafe_row_prepare_codes = unsafe_row_define_ss.str(); prepare_ss << unsafe_row_name << "->reset();" << std::endl; } - idx = 0; + int idx = 0; for (auto expr : right_key_project_codegen_) { std::shared_ptr project_node_visitor; - RETURN_NOT_OK(MakeExpressionCodegenVisitor(expr->root(), input, {right_field_list_}, - -1, var_id, &input_list, - &project_node_visitor)); + auto is_local = false; + RETURN_NOT_OK(MakeExpressionCodegenVisitor( + expr->root(), &input, {right_field_list_}, -1, var_id, is_local, &input_list, + &project_node_visitor)); prepare_ss << project_node_visitor->GetPrepare(); auto key_name = project_node_visitor->GetResult(); auto validity_name = project_node_visitor->GetPreCheck(); @@ -262,7 +255,8 @@ class ConditionedProbeKernel::Impl { prepare_ss << "}" << std::endl; } - project_output_list.push_back(project_node_visitor->GetResult()); + project_output_list.push_back( + std::make_pair(std::make_pair(key_name, ""), nullptr)); for (auto header : project_node_visitor->GetHeaders()) { if (std::find(codegen_ctx->header_codes.begin(), codegen_ctx->header_codes.end(), header) == codegen_ctx->header_codes.end()) { @@ -272,9 +266,10 @@ class ConditionedProbeKernel::Impl { idx++; } std::shared_ptr hash_node_visitor; + auto is_local = false; RETURN_NOT_OK(MakeExpressionCodegenVisitor( - right_key_hash_codegen_->root(), project_output_list, {key_hash_field_list_}, -1, - var_id, &input_list, &hash_node_visitor)); + right_key_hash_codegen_->root(), &project_output_list, {key_hash_field_list_}, -1, + var_id, is_local, &input_list, &hash_node_visitor)); prepare_ss << hash_node_visitor->GetPrepare(); auto key_name = hash_node_visitor->GetResult(); auto validity_name = hash_node_visitor->GetPreCheck(); @@ -295,9 +290,10 @@ class ConditionedProbeKernel::Impl { // 3. do continue if not exists if (cond_check) { std::shared_ptr condition_node_visitor; + auto is_local = true; RETURN_NOT_OK(MakeExpressionCodegenVisitor( - condition_, input, {left_field_list_, right_field_list_}, hash_relation_id_, - var_id, &prepare_list, &condition_node_visitor)); + condition_, &input, {left_field_list_, right_field_list_}, hash_relation_id_, + var_id, is_local, &prepare_list, &condition_node_visitor)); auto function_name = "ConditionCheck_" + std::to_string(hash_relation_id_); std::stringstream function_define_ss; function_define_ss << "bool " << function_name << "(ArrayItemIndex x, int y) {" @@ -1440,78 +1436,92 @@ class ConditionedProbeKernel::Impl { std::shared_ptr probe_func_; }; - arrow::Status GetInnerJoin(const std::vector input, bool cond_check, - std::string set_value, std::string index_name, + arrow::Status GetInnerJoin(bool cond_check, std::string index_name, std::string hash_relation_name, std::shared_ptr* output) { std::stringstream shuffle_ss; std::stringstream codes_ss; std::stringstream finish_codes_ss; + auto tmp_name = "tmp_" + std::to_string(hash_relation_id_); + auto item_index_list_name = index_name + "_item_list"; + auto range_index_name = "range_" + std::to_string(hash_relation_id_) + "_i"; codes_ss << "int32_t " << index_name << ";" << std::endl; codes_ss << index_name << " = " << hash_relation_name << "->Get(key_" << hash_relation_id_ << ", unsafe_row_" << hash_relation_id_ << ");" << std::endl; codes_ss << "if (" << index_name << " == -1) { continue; }" << std::endl; - codes_ss << "for (auto tmp : " << hash_relation_name << "->GetItemListByIndex(" - << index_name << ")) {" << std::endl; + codes_ss << "auto " << item_index_list_name << " = " << hash_relation_name + << "->GetItemListByIndex(" << index_name << ");" << std::endl; + codes_ss << "for (int " << range_index_name << " = 0; " << range_index_name << " < " + << item_index_list_name << ".size(); " << range_index_name << "++) {" + << std::endl; + codes_ss << tmp_name << " = " << item_index_list_name << "[" << range_index_name + << "];" << std::endl; if (cond_check) { auto condition_name = "ConditionCheck_" + std::to_string(hash_relation_id_); - codes_ss << "if (!" << condition_name << "(tmp, i)) {" << std::endl; + codes_ss << "if (!" << condition_name << "(" << tmp_name << ", i)) {" << std::endl; codes_ss << " continue;" << std::endl; codes_ss << "}" << std::endl; - codes_ss << set_value << std::endl; - } else { - codes_ss << set_value << std::endl; } finish_codes_ss << "} // end of Inner Join" << std::endl; (*output)->process_codes += codes_ss.str(); (*output)->finish_codes += finish_codes_ss.str(); return arrow::Status::OK(); } - arrow::Status GetOuterJoin(const std::vector input, bool cond_check, - std::string set_value, std::string index_name, + arrow::Status GetOuterJoin(bool cond_check, std::string index_name, std::string hash_relation_name, std::shared_ptr* output) { std::stringstream codes_ss; std::stringstream finish_codes_ss; + + auto tmp_name = "tmp_" + std::to_string(hash_relation_id_); + auto is_outer_null_name = "is_outer_null_" + std::to_string(hash_relation_id_); auto condition_name = "ConditionCheck_" + std::to_string(hash_relation_id_); - auto matched_index_list_name = - "hash_relation_matched_" + std::to_string(hash_relation_id_); + auto item_index_list_name = index_name + "_item_list"; + auto range_index_name = "range_" + std::to_string(hash_relation_id_) + "_i"; + auto range_size_name = "range_" + std::to_string(hash_relation_id_) + "_size"; + codes_ss << "int32_t " << index_name << ";" << std::endl; + codes_ss << "std::vector " << item_index_list_name << ";" + << std::endl; codes_ss << index_name << " = " << hash_relation_name << "->Get(key_" << hash_relation_id_ << ", unsafe_row_" << hash_relation_id_ << ");" << std::endl; - codes_ss << "std::vector " << matched_index_list_name << ";" + codes_ss << "auto " << range_size_name << " = 1;" << std::endl; + codes_ss << "if (" << index_name << " != -1) {" << std::endl; + codes_ss << item_index_list_name << " = " << hash_relation_name + << "->GetItemListByIndex(" << index_name << ");" << std::endl; + codes_ss << range_size_name << " = " << item_index_list_name << ".size();" << std::endl; - codes_ss << "if (" << index_name << " == -1) {" << std::endl; - codes_ss << matched_index_list_name << " = {ArrayItemIndex(false)};" << std::endl; + codes_ss << "}" << std::endl; + codes_ss << "for (int " << range_index_name << " = 0; " << range_index_name << " < " + << range_size_name << "; " << range_index_name << "++) {" << std::endl; + codes_ss << "if (!" << item_index_list_name << ".empty()) {" << std::endl; + codes_ss << tmp_name << " = " << item_index_list_name << "[" << range_index_name + << "];" << std::endl; + codes_ss << is_outer_null_name << " = false;" << std::endl; codes_ss << "} else {" << std::endl; - codes_ss << matched_index_list_name << " = " << hash_relation_name - << "->GetItemListByIndex(" << index_name << ");" << std::endl; + codes_ss << is_outer_null_name << " = true;" << std::endl; codes_ss << "}" << std::endl; - codes_ss << "for (auto tmp : " << matched_index_list_name << ") {" << std::endl; if (cond_check) { - codes_ss << "if (!" << condition_name << "(tmp, i)) {" << std::endl; + codes_ss << "if (!" << condition_name << "(" << tmp_name << ", i)) {" << std::endl; codes_ss << " continue;" << std::endl; codes_ss << "}" << std::endl; - codes_ss << set_value << std::endl; - } else { - codes_ss << set_value << std::endl; } finish_codes_ss << "} // end of Outer Join" << std::endl; (*output)->process_codes += codes_ss.str(); (*output)->finish_codes += finish_codes_ss.str(); return arrow::Status::OK(); } - arrow::Status GetAntiJoin(const std::vector input, bool cond_check, - std::string set_value, std::string index_name, + arrow::Status GetAntiJoin(bool cond_check, std::string index_name, std::string hash_relation_name, std::shared_ptr* output) { std::stringstream codes_ss; std::stringstream finish_codes_ss; + auto tmp_name = "tmp_" + std::to_string(hash_relation_id_); auto condition_name = "ConditionCheck_" + std::to_string(hash_relation_id_); - auto matched_index_list_name = - "hash_relation_matched_" + std::to_string(hash_relation_id_); + auto item_index_list_name = index_name + "_item_list"; + auto range_index_name = "range_" + std::to_string(hash_relation_id_) + "_i"; codes_ss << "int32_t " << index_name << ";" << std::endl; if (cond_check) { codes_ss << index_name << " = " << hash_relation_name << "->Get(key_" @@ -1522,43 +1532,46 @@ class ConditionedProbeKernel::Impl { << hash_relation_id_ << ", unsafe_row_" << hash_relation_id_ << ");" << std::endl; } - codes_ss << "std::vector " << matched_index_list_name << ";" - << std::endl; - codes_ss << "if (" << index_name << " == -1) {" << std::endl; - codes_ss << matched_index_list_name << " = {ArrayItemIndex(false)};" << std::endl; if (cond_check) { - codes_ss << "} else {" << std::endl; + codes_ss << "if (" << index_name << " != -1) {" << std::endl; codes_ss << " bool found = false;" << std::endl; - codes_ss << " for (auto tmp : " << hash_relation_name << "->GetItemListByIndex(" - << index_name << ")" - << ") {" << std::endl; - codes_ss << " if (" << condition_name << "(tmp, i)) {" << std::endl; + codes_ss << "auto " << item_index_list_name << " = " << hash_relation_name + << "->GetItemListByIndex(" << index_name << ");" << std::endl; + codes_ss << "for (int " << range_index_name << " = 0; " << range_index_name << " < " + << item_index_list_name << ".size(); " << range_index_name << "++) {" + << std::endl; + codes_ss << tmp_name << " = " << item_index_list_name << "[" << range_index_name + << "];" << std::endl; + codes_ss << " if (" << condition_name << "(" << tmp_name << ", i)) {" + << std::endl; codes_ss << " found = true;" << std::endl; codes_ss << " break;" << std::endl; codes_ss << " }" << std::endl; codes_ss << " }" << std::endl; - codes_ss << " if (!found) {" << std::endl; - codes_ss << matched_index_list_name << " = {ArrayItemIndex(false)};" << std::endl; - codes_ss << " }" << std::endl; + codes_ss << "if (found) continue;" << std::endl; + codes_ss << "}" << std::endl; + } else { + codes_ss << "if (" << index_name << " != -1) {" << std::endl; + codes_ss << " continue;" << std::endl; + codes_ss << "}" << std::endl; } - codes_ss << "}" << std::endl; - codes_ss << "for (auto tmp : " << matched_index_list_name << ") {" << std::endl; - codes_ss << set_value << std::endl; + codes_ss << "for (int " << range_index_name << " = 0; " << range_index_name << " < 1;" + << range_index_name << "++) {" << std::endl; finish_codes_ss << "} // end of Anti Join" << std::endl; (*output)->process_codes += codes_ss.str(); (*output)->finish_codes += finish_codes_ss.str(); return arrow::Status::OK(); } - arrow::Status GetSemiJoin(const std::vector input, bool cond_check, - std::string set_value, std::string index_name, + arrow::Status GetSemiJoin(bool cond_check, std::string index_name, std::string hash_relation_name, std::shared_ptr* output) { std::stringstream shuffle_ss; std::stringstream codes_ss; std::stringstream finish_codes_ss; + auto tmp_name = "tmp_" + std::to_string(hash_relation_id_); + auto item_index_list_name = index_name + "_item_list"; + auto range_index_name = "range_" + std::to_string(hash_relation_id_) + "_i"; auto condition_name = "ConditionCheck_" + std::to_string(hash_relation_id_); - auto matched_index_list_name = - "hash_relation_matched_" + std::to_string(hash_relation_id_); codes_ss << "int32_t " << index_name << ";" << std::endl; if (cond_check) { codes_ss << index_name << " = " << hash_relation_name << "->Get(key_" @@ -1569,46 +1582,45 @@ class ConditionedProbeKernel::Impl { << hash_relation_id_ << ", unsafe_row_" << hash_relation_id_ << ");" << std::endl; } - codes_ss << "std::vector " << matched_index_list_name << ";" - << std::endl; codes_ss << "if (" << index_name << " == -1) {" << std::endl; codes_ss << "continue;" << std::endl; if (cond_check) { codes_ss << "} else {" << std::endl; codes_ss << " bool found = false;" << std::endl; - codes_ss << " for (auto tmp : " << hash_relation_name << "->GetItemListByIndex(" - << index_name << ")" - << ") {" << std::endl; - codes_ss << " if (" << condition_name << "(tmp, i)) {" << std::endl; + codes_ss << "auto " << item_index_list_name << " = " << hash_relation_name + << "->GetItemListByIndex(" << index_name << ");" << std::endl; + codes_ss << "for (int " << range_index_name << " = 0; " << range_index_name << " < " + << item_index_list_name << ".size(); " << range_index_name << "++) {" + << std::endl; + codes_ss << tmp_name << " = " << item_index_list_name << "[" << range_index_name + << "];" << std::endl; + codes_ss << " if (" << condition_name << "(" << tmp_name << ", i)) {" + << std::endl; codes_ss << " found = true;" << std::endl; codes_ss << " break;" << std::endl; codes_ss << " }" << std::endl; codes_ss << " }" << std::endl; codes_ss << " if (found) {" << std::endl; - codes_ss << matched_index_list_name << " = {ArrayItemIndex(false)};" << std::endl; codes_ss << " }" << std::endl; - } else { - codes_ss << "} else {" << std::endl; - codes_ss << matched_index_list_name << " = {ArrayItemIndex(false)};" << std::endl; } codes_ss << "}" << std::endl; - codes_ss << "for (auto tmp : " << matched_index_list_name << ") {" << std::endl; - codes_ss << set_value << std::endl; + codes_ss << "for (int " << range_index_name << " = 0; " << range_index_name << " < 1;" + << range_index_name << "++) {" << std::endl; finish_codes_ss << "} // end of Semi Join" << std::endl; (*output)->process_codes += codes_ss.str(); (*output)->finish_codes += finish_codes_ss.str(); return arrow::Status::OK(); } - arrow::Status GetExistenceJoin(const std::vector input, bool cond_check, - std::string set_value, std::string index_name, + arrow::Status GetExistenceJoin(bool cond_check, std::string index_name, std::string hash_relation_name, std::shared_ptr* output) { std::stringstream shuffle_ss; std::stringstream codes_ss; std::stringstream finish_codes_ss; + auto tmp_name = "tmp_" + std::to_string(hash_relation_id_); auto condition_name = "ConditionCheck_" + std::to_string(hash_relation_id_); - auto matched_index_list_name = - "hash_relation_matched_" + std::to_string(hash_relation_id_); + auto item_index_list_name = index_name + "_item_list"; + auto range_index_name = "range_" + std::to_string(hash_relation_id_) + "_i"; auto exist_name = "hash_relation_" + std::to_string(hash_relation_id_) + "_existence_value"; auto exist_validity = exist_name + "_validity"; @@ -1628,10 +1640,15 @@ class ConditionedProbeKernel::Impl { codes_ss << exist_name << " = false;" << std::endl; if (cond_check) { codes_ss << "} else {" << std::endl; - codes_ss << " for (auto tmp : " << hash_relation_name << "->GetItemListByIndex(" - << index_name << ")" - << ") {" << std::endl; - codes_ss << " if (" << condition_name << "(tmp, i)) {" << std::endl; + codes_ss << "auto " << item_index_list_name << " = " << hash_relation_name + << "->GetItemListByIndex(" << index_name << ");" << std::endl; + codes_ss << "for (int " << range_index_name << " = 0; " << range_index_name << " < " + << item_index_list_name << ".size(); " << range_index_name << "++) {" + << std::endl; + codes_ss << tmp_name << " = " << item_index_list_name << "[" << range_index_name + << "];" << std::endl; + codes_ss << " if (" << condition_name << "(" << tmp_name << ", i)) {" + << std::endl; codes_ss << " " << exist_name << " = true;" << std::endl; codes_ss << " break;" << std::endl; codes_ss << " }" << std::endl; @@ -1641,85 +1658,94 @@ class ConditionedProbeKernel::Impl { codes_ss << exist_name << " = true;" << std::endl; } codes_ss << "}" << std::endl; - codes_ss << "std::vector " << matched_index_list_name - << " = {ArrayItemIndex(false)};" << std::endl; - codes_ss << "for (auto tmp : " << matched_index_list_name << ") {" << std::endl; - codes_ss << set_value << std::endl; + codes_ss << "for (int " << range_index_name << " = 0; " << range_index_name << " < 1;" + << range_index_name << "++) {" << std::endl; finish_codes_ss << "} // end of Existence Join" << std::endl; (*output)->process_codes += codes_ss.str(); (*output)->finish_codes += finish_codes_ss.str(); return arrow::Status::OK(); } - arrow::Status GetProcessProbe(const std::vector input, int join_type, - bool cond_check, - std::shared_ptr* output) { + arrow::Status GetProcessProbe( + const std::vector< + std::pair, gandiva::DataTypePtr>> + input, + int join_type, bool cond_check, std::shared_ptr* output) { auto hash_relation_name = "hash_relation_list_" + std::to_string(hash_relation_id_) + "_"; auto index_name = "hash_relation_" + std::to_string(hash_relation_id_) + "_index"; - std::vector> output_name_list = {{}, {}}; - std::stringstream valid_ss; - for (int i = 0; i < left_field_list_.size(); i++) { - auto type = left_field_list_[i]->type(); - auto name = - "hash_relation_" + std::to_string(hash_relation_id_) + "_" + std::to_string(i); - auto output_name = name + "_value"; - auto output_validity = output_name + "_validity"; - valid_ss << "auto " << output_validity << " = tmp.valid ? !" << name - << "->IsNull(tmp.array_id, tmp.id) : false;" << std::endl; - valid_ss << GetCTypeString(type) << " " << output_name << ";" << std::endl; - valid_ss << "if (" << output_validity << ") {" << std::endl; - valid_ss << output_name << " = " << name << "->GetValue(tmp.array_id, tmp.id);" - << std::endl; - valid_ss << "}" << std::endl; - - output_name_list[0].push_back(output_name); - } - for (int i = 0; i < right_field_list_.size(); i++) { - if (exist_index_ != -1 && exist_index_ == i) { - auto exist_name = - "hash_relation_" + std::to_string(hash_relation_id_) + "_existence_value"; - output_name_list[1].push_back(exist_name); - } - output_name_list[1].push_back(input[i]); - } - if (exist_index_ != -1 && exist_index_ == right_field_list_.size()) { - auto exist_name = - "hash_relation_" + std::to_string(hash_relation_id_) + "_existence_value"; - output_name_list[1].push_back(exist_name); - } int output_idx = 0; std::stringstream ss; + auto tmp_name = "tmp_" + std::to_string(hash_relation_id_); + auto is_outer_null_name = "is_outer_null_" + std::to_string(hash_relation_id_); + std::stringstream prepare_ss; + if (join_type == 1) { + prepare_ss << "bool " << is_outer_null_name << ";" << std::endl; + } + prepare_ss << "ArrayItemIndex " << tmp_name << ";" << std::endl; + (*output)->definition_codes += prepare_ss.str(); + + int right_index_shift = 0; for (auto pair : result_schema_index_list_) { // set result to output list - auto name = (*output)->output_list[output_idx++].first; - ss << name << " = " << output_name_list[pair.first][pair.second] << ";" - << std::endl; - ss << name << "_validity = " << output_name_list[pair.first][pair.second] - << "_validity;" << std::endl; + auto output_name = "hash_relation_" + std::to_string(hash_relation_id_) + + "_output_col_" + std::to_string(output_idx++); + auto output_validity = output_name + "_validity"; + + gandiva::DataTypePtr type; + std::stringstream valid_ss; + if (pair.first == 0) { /* left_table */ + auto name = "hash_relation_" + std::to_string(hash_relation_id_) + "_" + + std::to_string(pair.second); + type = left_field_list_[pair.second]->type(); + if (join_type == 1) { + valid_ss << "auto " << output_validity << " = !" << is_outer_null_name + << " && !" << name << "->IsNull(" << tmp_name << ".array_id, " + << tmp_name << ".id);" << std::endl; + + } else { + valid_ss << "auto " << output_validity << " = !" << name << "->IsNull(" + << tmp_name << ".array_id, " << tmp_name << ".id);" << std::endl; + } + valid_ss << "auto " << output_name << " = " << name << "->GetValue(" << tmp_name + << ".array_id, " << tmp_name << ".id);" << std::endl; + + } else { /* right table */ + std::string name; + if (exist_index_ != -1 && exist_index_ == pair.second) { + name = + "hash_relation_" + std::to_string(hash_relation_id_) + "_existence_value"; + valid_ss << "auto " << output_validity << " = true;" << std::endl; + valid_ss << "auto " << output_name << " = " << name << ";" << std::endl; + type = arrow::boolean(); + right_index_shift = -1; + } else { + auto i = pair.second + right_index_shift; + output_name = input[i].first.first; + output_validity = output_name + "_validity"; + valid_ss << input[i].first.second; + type = input[i].second; + } + } + (*output)->output_list.push_back( + std::make_pair(std::make_pair(output_name, valid_ss.str()), type)); } - valid_ss << ss.str(); switch (join_type) { case 0: { /*Inner Join*/ - return GetInnerJoin(input, cond_check, valid_ss.str(), index_name, - hash_relation_name, output); + return GetInnerJoin(cond_check, index_name, hash_relation_name, output); } break; case 1: { /*Outer Join*/ - return GetOuterJoin(input, cond_check, valid_ss.str(), index_name, - hash_relation_name, output); + return GetOuterJoin(cond_check, index_name, hash_relation_name, output); } break; case 2: { /*Anti Join*/ - return GetAntiJoin(input, cond_check, valid_ss.str(), index_name, - hash_relation_name, output); + return GetAntiJoin(cond_check, index_name, hash_relation_name, output); } break; case 3: { /*Semi Join*/ - return GetSemiJoin(input, cond_check, valid_ss.str(), index_name, - hash_relation_name, output); + return GetSemiJoin(cond_check, index_name, hash_relation_name, output); } break; case 4: { /*Existence Join*/ - return GetExistenceJoin(input, cond_check, valid_ss.str(), index_name, - hash_relation_name, output); + return GetExistenceJoin(cond_check, index_name, hash_relation_name, output); } break; default: return arrow::Status::NotImplemented( @@ -1766,7 +1792,9 @@ arrow::Status ConditionedProbeKernel::MakeResultIterator( std::string ConditionedProbeKernel::GetSignature() { return impl_->GetSignature(); } arrow::Status ConditionedProbeKernel::DoCodeGen( - int level, std::vector input, + int level, + std::vector, gandiva::DataTypePtr>> + input, std::shared_ptr* codegen_ctx_out, int* var_id) { return impl_->DoCodeGen(level, input, codegen_ctx_out, var_id); } diff --git a/cpp/src/codegen/arrow_compute/ext/expression_codegen_visitor.cc b/cpp/src/codegen/arrow_compute/ext/expression_codegen_visitor.cc index 01811f88e..1bdc5adcf 100644 --- a/cpp/src/codegen/arrow_compute/ext/expression_codegen_visitor.cc +++ b/cpp/src/codegen/arrow_compute/ext/expression_codegen_visitor.cc @@ -50,7 +50,7 @@ arrow::Status ExpressionCodegenVisitor::Visit(const gandiva::FunctionNode& node) *func_count_ = *func_count_ + 1; RETURN_NOT_OK(MakeExpressionCodegenVisitor(child, input_list, field_list_v_, - hash_relation_id_, func_count_, + hash_relation_id_, func_count_, is_local_, prepared_list_, &child_visitor, is_smj_)); child_visitor_list.push_back(child_visitor); if (field_type_ == unknown || field_type_ == literal) { @@ -610,7 +610,7 @@ arrow::Status ExpressionCodegenVisitor::Visit(const gandiva::FieldNode& node) { auto index_pair = GetFieldIndex(this_field, field_list_v_); auto index = index_pair.first; auto arg_id = index_pair.second; - if (is_smj_ && input_list_.empty()) { + if (is_smj_ && (*input_list_).empty()) { ///// For inputs are SortRelation ///// codes_str_ = "sort_relation_" + std::to_string(hash_relation_id_ + index) + "_" + std::to_string(arg_id) + "_value"; @@ -651,14 +651,22 @@ arrow::Status ExpressionCodegenVisitor::Visit(const gandiva::FieldNode& node) { prepare_ss << " }" << std::endl; field_type_ = sort_relation; } else { - codes_str_ = input_list_[arg_id]; + prepare_ss << (*input_list_)[arg_id].first.second; + if (!is_local_) { + (*input_list_)[arg_id].first.second = ""; + } + codes_str_ = (*input_list_)[arg_id].first.first; codes_validity_str_ = GetValidityName(codes_str_); field_type_ = right; } } else { ///// For Inputs are one side HashRelation and other side regular array ///// if (field_list_v_.size() == 1) { - codes_str_ = input_list_[arg_id]; + prepare_ss << (*input_list_)[arg_id].first.second; + if (!is_local_) { + (*input_list_)[arg_id].first.second = ""; + } + codes_str_ = (*input_list_)[arg_id].first.first; codes_validity_str_ = GetValidityName(codes_str_); } else { if (index == 0) { @@ -680,7 +688,11 @@ arrow::Status ExpressionCodegenVisitor::Visit(const gandiva::FieldNode& node) { field_type_ = left; } else { - codes_str_ = input_list_[arg_id]; + prepare_ss << (*input_list_)[arg_id].first.second; + if (!is_local_) { + (*input_list_)[arg_id].first.second = ""; + } + codes_str_ = (*input_list_)[arg_id].first.first; codes_validity_str_ = GetValidityName(codes_str_); field_type_ = right; } @@ -693,7 +705,7 @@ arrow::Status ExpressionCodegenVisitor::Visit(const gandiva::FieldNode& node) { if (std::find((*prepared_list_).begin(), (*prepared_list_).end(), codes_str_) == (*prepared_list_).end()) { (*prepared_list_).push_back(codes_str_); - prepare_str_ = prepare_ss.str(); + prepare_str_ += prepare_ss.str(); } } return arrow::Status::OK(); @@ -710,7 +722,7 @@ arrow::Status ExpressionCodegenVisitor::Visit(const gandiva::IfNode& node) { std::shared_ptr child_visitor; *func_count_ = *func_count_ + 1; RETURN_NOT_OK(MakeExpressionCodegenVisitor(child, input_list_, field_list_v_, - hash_relation_id_, func_count_, + hash_relation_id_, func_count_, is_local_, prepared_list_, &child_visitor, is_smj_)); child_visitor_list.push_back(child_visitor); if (field_type_ == unknown || field_type_ == literal) { @@ -779,7 +791,7 @@ arrow::Status ExpressionCodegenVisitor::Visit(const gandiva::BooleanNode& node) std::shared_ptr child_visitor; *func_count_ = *func_count_ + 1; RETURN_NOT_OK(MakeExpressionCodegenVisitor(child, input_list_, field_list_v_, - hash_relation_id_, func_count_, + hash_relation_id_, func_count_, is_local_, prepared_list_, &child_visitor, is_smj_)); prepare_str_ += child_visitor->GetPrepare(); @@ -818,7 +830,7 @@ arrow::Status ExpressionCodegenVisitor::Visit( *func_count_ = *func_count_ + 1; RETURN_NOT_OK(MakeExpressionCodegenVisitor(node.eval_expr(), input_list_, field_list_v_, - hash_relation_id_, func_count_, + hash_relation_id_, func_count_, is_local_, prepared_list_, &child_visitor, is_smj_)); std::stringstream prepare_ss; prepare_ss << "std::vector in_list_" << cur_func_id << " = {"; @@ -858,7 +870,7 @@ arrow::Status ExpressionCodegenVisitor::Visit( *func_count_ = *func_count_ + 1; RETURN_NOT_OK(MakeExpressionCodegenVisitor(node.eval_expr(), input_list_, field_list_v_, - hash_relation_id_, func_count_, + hash_relation_id_, func_count_, is_local_, prepared_list_, &child_visitor, is_smj_)); std::stringstream prepare_ss; prepare_ss << "std::vector in_list_" << cur_func_id << " = {"; @@ -898,7 +910,7 @@ arrow::Status ExpressionCodegenVisitor::Visit( *func_count_ = *func_count_ + 1; RETURN_NOT_OK(MakeExpressionCodegenVisitor(node.eval_expr(), input_list_, field_list_v_, - hash_relation_id_, func_count_, + hash_relation_id_, func_count_, is_local_, prepared_list_, &child_visitor, is_smj_)); std::stringstream prepare_ss; prepare_ss << "std::vector in_list_" << cur_func_id << " = {"; @@ -959,11 +971,12 @@ std::string ExpressionCodegenVisitor::GetValidityName(std::string name) { } } -std::string ExpressionCodegenVisitor::GetNaNCheckStr(std::string left, std::string right, +std::string ExpressionCodegenVisitor::GetNaNCheckStr(std::string left, std::string right, std::string func) { std::stringstream ss; func = " " + func + " "; - ss << "((std::isnan(" << left << ") && std::isnan(" << right << ")) ? (1.0 / 0.0" << func << "1.0 / 0.0) : " + ss << "((std::isnan(" << left << ") && std::isnan(" << right << ")) ? (1.0 / 0.0" + << func << "1.0 / 0.0) : " << "(std::isnan(" << left << ")) ? (1.0 / 0.0" << func << right << ") : " << "(std::isnan(" << right << ")) ? (" << left << func << "1.0 / 0.0) : " << "(" << left << func << right << "))"; diff --git a/cpp/src/codegen/arrow_compute/ext/expression_codegen_visitor.h b/cpp/src/codegen/arrow_compute/ext/expression_codegen_visitor.h index 5b7c514ab..51305cef9 100644 --- a/cpp/src/codegen/arrow_compute/ext/expression_codegen_visitor.h +++ b/cpp/src/codegen/arrow_compute/ext/expression_codegen_visitor.h @@ -28,17 +28,20 @@ namespace extra { class ExpressionCodegenVisitor : public VisitorBase { public: ExpressionCodegenVisitor( - std::shared_ptr func, std::vector input_list, + std::shared_ptr func, + std::vector, gandiva::DataTypePtr>>* + input_list, std::vector>> field_list_v, - int hash_relation_id, int* func_count, std::vector* prepared_list, - bool is_smj) + int hash_relation_id, int* func_count, bool is_local, + std::vector* prepared_list, bool is_smj) : func_(func), field_list_v_(field_list_v), func_count_(func_count), input_list_(input_list), prepared_list_(prepared_list), hash_relation_id_(hash_relation_id), - is_smj_(is_smj) {} + is_smj_(is_smj), + is_local_(is_local) {} enum FieldType { left, right, sort_relation, literal, mixed, unknown }; @@ -70,10 +73,12 @@ class ExpressionCodegenVisitor : public VisitorBase { std::shared_ptr func_; std::vector>> field_list_v_; int hash_relation_id_; - std::vector input_list_; + std::vector, gandiva::DataTypePtr>>* + input_list_; int* func_count_; FieldType field_type_ = unknown; bool is_smj_; + bool is_local_; // output std::vector* prepared_list_; std::vector header_list_; @@ -91,13 +96,16 @@ class ExpressionCodegenVisitor : public VisitorBase { }; static arrow::Status MakeExpressionCodegenVisitor( - std::shared_ptr func, std::vector input_list, + std::shared_ptr func, + std::vector, gandiva::DataTypePtr>>* + input_list, std::vector>> field_list_v, - int hash_relation_id, int* func_count, std::vector* prepared_list, + int hash_relation_id, int* func_count, bool is_local, + std::vector* prepared_list, std::shared_ptr* out, bool is_smj = false) { auto visitor = std::make_shared( - func, input_list, field_list_v, hash_relation_id, func_count, prepared_list, - is_smj); + func, input_list, field_list_v, hash_relation_id, func_count, is_local, + prepared_list, is_smj); RETURN_NOT_OK(visitor->Eval()); *out = visitor; return arrow::Status::OK(); diff --git a/cpp/src/codegen/arrow_compute/ext/kernels_ext.h b/cpp/src/codegen/arrow_compute/ext/kernels_ext.h index 7aa1cc778..821b483c9 100644 --- a/cpp/src/codegen/arrow_compute/ext/kernels_ext.h +++ b/cpp/src/codegen/arrow_compute/ext/kernels_ext.h @@ -93,9 +93,11 @@ class KernalBase { return arrow::Status::NotImplemented("MakeResultIterator is abstract interface for ", kernel_name_); } - virtual arrow::Status DoCodeGen(int level, std::vector input, - std::shared_ptr* codegen_ctx, - int* var_id) { + virtual arrow::Status DoCodeGen( + int level, + std::vector, gandiva::DataTypePtr>> + input, + std::shared_ptr* codegen_ctx, int* var_id) { return arrow::Status::NotImplemented("DoCodeGen is abstract interface for ", kernel_name_); } @@ -608,9 +610,11 @@ class ConditionedProbeKernel : public KernalBase { arrow::Status MakeResultIterator( std::shared_ptr schema, std::shared_ptr>* out) override; - arrow::Status DoCodeGen(int level, std::vector input, - std::shared_ptr* codegen_ctx_out, - int* var_id) override; + arrow::Status DoCodeGen( + int level, + std::vector, gandiva::DataTypePtr>> + input, + std::shared_ptr* codegen_ctx_out, int* var_id) override; std::string GetSignature() override; class Impl; @@ -640,9 +644,11 @@ class ConditionedMergeJoinKernel : public KernalBase { arrow::Status MakeResultIterator( std::shared_ptr schema, std::shared_ptr>* out) override; - arrow::Status DoCodeGen(int level, std::vector input, - std::shared_ptr* codegen_ctx_out, - int* var_id) override; + arrow::Status DoCodeGen( + int level, + std::vector, gandiva::DataTypePtr>> + input, + std::shared_ptr* codegen_ctx_out, int* var_id) override; std::string GetSignature() override; class Impl; @@ -662,9 +668,11 @@ class ProjectKernel : public KernalBase { arrow::Status MakeResultIterator( std::shared_ptr schema, std::shared_ptr>* out) override; - arrow::Status DoCodeGen(int level, std::vector input, - std::shared_ptr* codegen_ctx, - int* var_id) override; + arrow::Status DoCodeGen( + int level, + std::vector, gandiva::DataTypePtr>> + input, + std::shared_ptr* codegen_ctx, int* var_id) override; std::string GetSignature() override; class Impl; @@ -684,9 +692,11 @@ class FilterKernel : public KernalBase { arrow::Status MakeResultIterator( std::shared_ptr schema, std::shared_ptr>* out) override; - arrow::Status DoCodeGen(int level, std::vector input, - std::shared_ptr* codegen_ctx, - int* var_id) override; + arrow::Status DoCodeGen( + int level, + std::vector, gandiva::DataTypePtr>> + input, + std::shared_ptr* codegen_ctx, int* var_id) override; std::string GetSignature() override; class Impl; diff --git a/cpp/src/codegen/arrow_compute/ext/whole_stage_codegen_kernel.cc b/cpp/src/codegen/arrow_compute/ext/whole_stage_codegen_kernel.cc index e8d6d87e8..e1ca1cf51 100644 --- a/cpp/src/codegen/arrow_compute/ext/whole_stage_codegen_kernel.cc +++ b/cpp/src/codegen/arrow_compute/ext/whole_stage_codegen_kernel.cc @@ -219,10 +219,12 @@ class WholeStageCodeGenKernel::Impl { int argument_id = 0; int level = 0; std::vector> codegen_ctx_list; - std::vector input_list; + std::vector, gandiva::DataTypePtr>> + input_list; for (int i = 0; i < input_field_list.size(); i++) { auto name = "typed_in_col_" + std::to_string(i); - input_list.push_back(name); + auto type = input_field_list[i]->type(); + input_list.push_back(std::make_pair(std::make_pair(name, ""), type)); } for (auto kernel : kernel_list) { std::shared_ptr child_codegen_ctx; @@ -231,7 +233,7 @@ class WholeStageCodeGenKernel::Impl { codegen_ctx_list.push_back(child_codegen_ctx); input_list.clear(); for (auto pair : child_codegen_ctx->output_list) { - input_list.push_back(pair.first); + input_list.push_back(pair); } } std::string codes; @@ -474,9 +476,10 @@ extern "C" void MakeCodeGen(arrow::compute::FunctionContext *ctx, std::stringstream codes_ss; int i = 0; for (auto pair : codegen_ctx->output_list) { - auto name = pair.first; + auto name = pair.first.first; auto type = pair.second; auto validity = name + "_validity"; + codes_ss << pair.first.second << std::endl; codes_ss << "if (" << validity << ") {" << std::endl; if (type->id() == arrow::Type::STRING) { codes_ss << " RETURN_NOT_OK(builder_" << i << "_->AppendString(" << name << "));"