From 58c5531f043e8134cd213fdb5242073ea934a7c4 Mon Sep 17 00:00:00 2001 From: xufei Date: Thu, 28 Dec 2023 20:51:28 +0800 Subject: [PATCH] Support finalize for join (#8366) close pingcap/tiflash#8296 --- .../ScanHashMapAfterProbeBlockInputStream.cpp | 151 ++-- .../ScanHashMapAfterProbeBlockInputStream.h | 8 +- .../Coprocessor/DAGQueryBlockInterpreter.cpp | 26 +- .../Coprocessor/JoinInterpreterHelper.cpp | 93 ++- .../Flash/Coprocessor/JoinInterpreterHelper.h | 21 +- dbms/src/Flash/Planner/PhysicalPlanNode.cpp | 29 +- dbms/src/Flash/Planner/PhysicalPlanNode.h | 5 +- .../Planner/Plans/PhysicalAggregation.cpp | 2 +- .../Flash/Planner/Plans/PhysicalAggregation.h | 2 +- .../Plans/PhysicalExchangeReceiver.cpp | 2 +- .../Planner/Plans/PhysicalExchangeReceiver.h | 2 +- .../Planner/Plans/PhysicalExchangeSender.cpp | 2 +- .../Planner/Plans/PhysicalExchangeSender.h | 2 +- .../Flash/Planner/Plans/PhysicalExpand.cpp | 2 +- dbms/src/Flash/Planner/Plans/PhysicalExpand.h | 2 +- .../Flash/Planner/Plans/PhysicalExpand2.cpp | 2 +- .../src/Flash/Planner/Plans/PhysicalExpand2.h | 2 +- .../Flash/Planner/Plans/PhysicalFilter.cpp | 2 +- dbms/src/Flash/Planner/Plans/PhysicalFilter.h | 2 +- .../Planner/Plans/PhysicalGetResultSink.h | 2 +- dbms/src/Flash/Planner/Plans/PhysicalJoin.cpp | 66 +- dbms/src/Flash/Planner/Plans/PhysicalJoin.h | 8 +- .../src/Flash/Planner/Plans/PhysicalLimit.cpp | 2 +- dbms/src/Flash/Planner/Plans/PhysicalLimit.h | 2 +- .../Plans/PhysicalMockExchangeReceiver.cpp | 2 +- .../Plans/PhysicalMockExchangeReceiver.h | 2 +- .../Plans/PhysicalMockExchangeSender.cpp | 2 +- .../Plans/PhysicalMockExchangeSender.h | 2 +- .../Planner/Plans/PhysicalMockTableScan.cpp | 2 +- .../Planner/Plans/PhysicalMockTableScan.h | 2 +- .../Planner/Plans/PhysicalProjection.cpp | 2 +- .../Flash/Planner/Plans/PhysicalProjection.h | 2 +- .../Flash/Planner/Plans/PhysicalTableScan.cpp | 2 +- .../Flash/Planner/Plans/PhysicalTableScan.h | 2 +- dbms/src/Flash/Planner/Plans/PhysicalTopN.cpp | 2 +- dbms/src/Flash/Planner/Plans/PhysicalTopN.h | 2 +- .../Flash/Planner/Plans/PhysicalWindow.cpp | 2 +- dbms/src/Flash/Planner/Plans/PhysicalWindow.h | 2 +- .../Planner/Plans/PhysicalWindowSort.cpp | 2 +- .../Flash/Planner/Plans/PhysicalWindowSort.h | 2 +- .../Planner/Plans/PipelineBreakerHelper.h | 15 +- dbms/src/Flash/Planner/optimize.cpp | 2 +- dbms/src/Flash/tests/gtest_join.h | 13 + dbms/src/Flash/tests/gtest_join_executor.cpp | 708 ++++++++++++++++-- .../Flash/tests/gtest_planner_interpreter.out | 22 +- dbms/src/Flash/tests/gtest_spill_join.cpp | 49 +- .../src/Interpreters/CrossJoinProbeHelper.cpp | 90 ++- dbms/src/Interpreters/ExpressionActions.cpp | 28 +- dbms/src/Interpreters/ExpressionActions.h | 5 +- dbms/src/Interpreters/Join.cpp | 477 ++++++++---- dbms/src/Interpreters/Join.h | 30 +- dbms/src/Interpreters/JoinPartition.cpp | 13 +- .../Interpreters/NullAwareSemiJoinHelper.cpp | 20 +- .../Interpreters/NullAwareSemiJoinHelper.h | 4 +- dbms/src/Interpreters/ProbeProcessInfo.cpp | 120 ++- dbms/src/Interpreters/ProbeProcessInfo.h | 107 ++- dbms/src/Interpreters/SemiJoinHelper.cpp | 11 +- dbms/src/Interpreters/SemiJoinHelper.h | 4 +- 58 files changed, 1611 insertions(+), 576 deletions(-) diff --git a/dbms/src/DataStreams/ScanHashMapAfterProbeBlockInputStream.cpp b/dbms/src/DataStreams/ScanHashMapAfterProbeBlockInputStream.cpp index fc73aa5282f..cccc91b7c3b 100644 --- a/dbms/src/DataStreams/ScanHashMapAfterProbeBlockInputStream.cpp +++ b/dbms/src/DataStreams/ScanHashMapAfterProbeBlockInputStream.cpp @@ -39,7 +39,7 @@ struct AdderMapEntry size_t key_num, size_t num_columns_left, MutableColumns & columns_left, - size_t num_columns_right, + const ColumnNumbers & column_indices_right, MutableColumns & columns_right, const void *&, size_t, @@ -52,8 +52,10 @@ struct AdderMapEntry /// for detailed explanation columns_left[j]->insertDefault(); - for (size_t j = 0; j < num_columns_right; ++j) - columns_right[j]->insertFrom(*mapped.block->getByPosition(key_num + j).column.get(), mapped.row_num); + for (size_t j = 0; j < column_indices_right.size(); ++j) + columns_right[j]->insertFrom( + *mapped.block->getByPosition(key_num + column_indices_right[j]).column.get(), + mapped.row_num); return 1; } }; @@ -66,7 +68,7 @@ struct AdderMapEntry size_t key_num, size_t num_columns_left, MutableColumns & columns_left, - size_t num_columns_right, + const ColumnNumbers & column_indices_right, MutableColumns & columns_right, const void *& next_element_in_row_list, size_t probe_cached_rows_threshold, @@ -78,9 +80,9 @@ struct AdderMapEntry auto add_one_row = [&]() { /// handle left columns later to utilize insertManyDefaults - for (size_t j = 0; j < num_columns_right; ++j) + for (size_t j = 0; j < column_indices_right.size(); ++j) columns_right[j]->insertFrom( - *current->block->getByPosition(key_num + j).column.get(), + *current->block->getByPosition(key_num + column_indices_right[j]).column.get(), current->row_num); ++rows_added; }; @@ -124,7 +126,7 @@ struct AdderRowFlaggedMapEntry size_t key_num, size_t num_columns_left, MutableColumns & columns_left, - size_t num_columns_right, + const ColumnNumbers & column_indices_right, MutableColumns & columns_right, const void *& next_element_in_row_list, size_t probe_cached_rows_threshold, @@ -141,9 +143,9 @@ struct AdderRowFlaggedMapEntry if (flag) { /// handle left columns later to utilize insertManyDefaults if any - for (size_t j = 0; j < num_columns_right; ++j) + for (size_t j = 0; j < column_indices_right.size(); ++j) columns_right[j]->insertFrom( - *current->block->getByPosition(key_num + j).column.get(), + *current->block->getByPosition(key_num + column_indices_right[j]).column.get(), current->row_num); ++rows_added; } @@ -199,51 +201,50 @@ ScanHashMapAfterProbeBlockInputStream::ScanHashMapAfterProbeBlockInputStream( * result_sample_block - keys, "left" columns, and "right" columns. */ - size_t num_columns_left = left_sample_block.columns(); - if (isRightSemiFamily(parent.getKind())) - num_columns_left = 0; - else - result_sample_block = materializeBlock(left_sample_block); - - size_t num_columns_right = parent.sample_block_with_columns_to_add.columns(); - /// Add columns from the right-side table to the block. - for (size_t i = 0; i < num_columns_right; ++i) + column_indices_left.reserve(left_sample_block.columns()); + if (!isRightSemiFamily(parent.getKind())) { - const ColumnWithTypeAndName & src_column = parent.sample_block_with_columns_to_add.getByPosition(i); - result_sample_block.insert(src_column.cloneEmpty()); + auto left_full_block = materializeBlock(left_sample_block); + for (size_t i = 0; i < left_full_block.columns(); ++i) + { + auto & column = left_full_block.getByPosition(i); + if (parent.output_column_names_set_after_finalize.contains(column.name)) + { + result_sample_block.insert(column.cloneEmpty()); + column_indices_left.push_back(i); + } + } } - column_indices_left.reserve(num_columns_left); - column_indices_right.reserve(num_columns_right); - - for (size_t i = 0; i < num_columns_left; ++i) + column_indices_right.reserve(parent.sample_block_without_keys.columns()); + /// Add columns from the right-side table to the block. + for (size_t i = 0; i < parent.sample_block_without_keys.columns(); ++i) { - column_indices_left.push_back(i); + const ColumnWithTypeAndName & src_column = parent.sample_block_without_keys.getByPosition(i); + if (parent.output_column_names_set_after_finalize.contains(src_column.name)) + { + result_sample_block.insert(src_column.cloneEmpty()); + column_indices_right.push_back(i); + } } - for (size_t i = 0; i < num_columns_right; ++i) - column_indices_right.push_back(num_columns_left + i); - - for (size_t i = 0; i < num_columns_left; ++i) + for (size_t i = 0; i < column_indices_left.size(); ++i) { - const auto & column_with_type_and_name = result_sample_block.getByPosition(column_indices_left[i]); + const auto & column_with_type_and_name = result_sample_block.getByPosition(i); if (parent.key_names_left.end() == std::find(parent.key_names_left.begin(), parent.key_names_left.end(), column_with_type_and_name.name)) /// if it is not the key, then convert to nullable, if it is key, then just keep the original type /// actually we don't care about the key column now refer to https://github.com/pingcap/tiflash/blob/v6.5.0/dbms/src/Flash/Coprocessor/DAGExpressionAnalyzer.cpp#L953 /// for detailed explanation - convertColumnToNullable(result_sample_block.getByPosition(column_indices_left[i])); + convertColumnToNullable(result_sample_block.getByPosition(i)); } - columns_left.resize(num_columns_left); - columns_right.resize(num_columns_right); + columns_left.resize(column_indices_left.size()); + columns_right.resize(column_indices_right.size()); current_partition_index = index; + projected_sample_block = result_sample_block; - for (const auto & name : parent.tidb_output_column_names) - { - auto & column = result_sample_block.getByName(name); - projected_sample_block.insert(column); - } + projected_sample_block = parent.removeUselessColumn(projected_sample_block); } Block ScanHashMapAfterProbeBlockInputStream::readImpl() @@ -271,7 +272,7 @@ Block ScanHashMapAfterProbeBlockInputStream::readImpl() for (size_t i = 0; i < num_columns_left; ++i) { - const auto & src_col = result_sample_block.safeGetByPosition(column_indices_left[i]); + const auto & src_col = result_sample_block.safeGetByPosition(i); columns_left[i] = src_col.type->createColumn(); if (row_counter_column == nullptr) row_counter_column = columns_left[i].get(); @@ -279,7 +280,7 @@ Block ScanHashMapAfterProbeBlockInputStream::readImpl() for (size_t i = 0; i < num_columns_right; ++i) { - const auto & src_col = result_sample_block.safeGetByPosition(column_indices_right[i]); + const auto & src_col = result_sample_block.safeGetByPosition(num_columns_left + i); columns_right[i] = src_col.type->createColumn(); if (row_counter_column == nullptr) row_counter_column = columns_right[i].get(); @@ -292,44 +293,19 @@ Block ScanHashMapAfterProbeBlockInputStream::readImpl() { case ASTTableJoin::Kind::RightSemi: if (parent.has_other_condition) - fillColumnsUsingCurrentPartition( - num_columns_left, - columns_left, - num_columns_right, - columns_right, - row_counter_column); + fillColumnsUsingCurrentPartition(columns_left, columns_right, row_counter_column); else - fillColumnsUsingCurrentPartition( - num_columns_left, - columns_left, - num_columns_right, - columns_right, - row_counter_column); + fillColumnsUsingCurrentPartition(columns_left, columns_right, row_counter_column); break; case ASTTableJoin::Kind::RightAnti: case ASTTableJoin::Kind::RightOuter: if (parent.has_other_condition) - fillColumnsUsingCurrentPartition( - num_columns_left, - columns_left, - num_columns_right, - columns_right, - row_counter_column); + fillColumnsUsingCurrentPartition(columns_left, columns_right, row_counter_column); else - fillColumnsUsingCurrentPartition( - num_columns_left, - columns_left, - num_columns_right, - columns_right, - row_counter_column); + fillColumnsUsingCurrentPartition(columns_left, columns_right, row_counter_column); break; default: - fillColumnsUsingCurrentPartition( - num_columns_left, - columns_left, - num_columns_right, - columns_right, - row_counter_column); + fillColumnsUsingCurrentPartition(columns_left, columns_right, row_counter_column); } } @@ -338,25 +314,16 @@ Block ScanHashMapAfterProbeBlockInputStream::readImpl() Block res = result_sample_block.cloneEmpty(); for (size_t i = 0; i < num_columns_left; ++i) - res.getByPosition(column_indices_left[i]).column = std::move(columns_left[i]); + res.getByPosition(i).column = std::move(columns_left[i]); for (size_t i = 0; i < num_columns_right; ++i) - res.getByPosition(column_indices_right[i]).column = std::move(columns_right[i]); + res.getByPosition(num_columns_left + i).column = std::move(columns_right[i]); - /// remove useless columns - Block projected_block; - for (const auto & name : parent.tidb_output_column_names) - { - auto & column = res.getByName(name); - projected_block.insert(std::move(column)); - } - return projected_block; + return parent.removeUselessColumn(res); } template void ScanHashMapAfterProbeBlockInputStream::fillColumnsUsingCurrentPartition( - size_t num_columns_left, MutableColumns & mutable_columns_left, - size_t num_columns_right, MutableColumns & mutable_columns_right, IColumn * row_counter_column) { @@ -384,9 +351,7 @@ void ScanHashMapAfterProbeBlockInputStream::fillColumnsUsingCurrentPartition( case JoinMapMethod::METHOD: \ fillColumns( \ *partition->maps_all_full_with_row_flag.METHOD, \ - num_columns_left, \ mutable_columns_left, \ - num_columns_right, \ mutable_columns_right, \ row_counter_column); \ break; @@ -405,9 +370,7 @@ void ScanHashMapAfterProbeBlockInputStream::fillColumnsUsingCurrentPartition( case JoinMapMethod::METHOD: \ fillColumns( \ *partition->maps_all_full.METHOD, \ - num_columns_left, \ mutable_columns_left, \ - num_columns_right, \ mutable_columns_right, \ row_counter_column); \ break; @@ -448,9 +411,7 @@ struct RowCountInfo template void ScanHashMapAfterProbeBlockInputStream::fillColumns( const Map & map, - size_t num_columns_left, MutableColumns & mutable_columns_left, - size_t num_columns_right, MutableColumns & mutable_columns_right, IColumn * row_counter_column) { @@ -461,9 +422,9 @@ void ScanHashMapAfterProbeBlockInputStream::fillColumns( { row_count_info.inc(1); /// handle left columns later to utilize insertManyDefaults - for (size_t j = 0; j < num_columns_right; ++j) + for (size_t j = 0; j < column_indices_right.size(); ++j) mutable_columns_right[j]->insertFrom( - *not_mapped_row_pos->block->getByPosition(key_num + j).column.get(), + *not_mapped_row_pos->block->getByPosition(key_num + column_indices_right[j]).column.get(), not_mapped_row_pos->row_num); not_mapped_row_pos = not_mapped_row_pos->next; @@ -471,7 +432,7 @@ void ScanHashMapAfterProbeBlockInputStream::fillColumns( break; } /// Fill left columns with defaults - for (size_t j = 0; j < num_columns_left; ++j) + for (size_t j = 0; j < column_indices_left.size(); ++j) /// should fill the key column with key columns from right block /// but we don't care about the key column now so just insert a default value is ok. /// refer to https://github.com/pingcap/tiflash/blob/v6.5.0/dbms/src/Flash/Coprocessor/DAGExpressionAnalyzer.cpp#L953 @@ -500,9 +461,9 @@ void ScanHashMapAfterProbeBlockInputStream::fillColumns( row_count_info.inc(AdderRowFlaggedMapEntry::add( (*it)->getMapped(), key_num, - num_columns_left, + column_indices_left.size(), mutable_columns_left, - num_columns_right, + column_indices_right, mutable_columns_right, next_element_in_row_list, parent.probe_cache_column_threshold, @@ -521,9 +482,9 @@ void ScanHashMapAfterProbeBlockInputStream::fillColumns( row_count_info.inc(AdderMapEntry::add( (*it)->getMapped(), key_num, - num_columns_left, + column_indices_left.size(), mutable_columns_left, - num_columns_right, + column_indices_right, mutable_columns_right, next_element_in_row_list, parent.probe_cache_column_threshold, diff --git a/dbms/src/DataStreams/ScanHashMapAfterProbeBlockInputStream.h b/dbms/src/DataStreams/ScanHashMapAfterProbeBlockInputStream.h index adc892ae04d..aaf018c11fe 100644 --- a/dbms/src/DataStreams/ScanHashMapAfterProbeBlockInputStream.h +++ b/dbms/src/DataStreams/ScanHashMapAfterProbeBlockInputStream.h @@ -60,9 +60,9 @@ class ScanHashMapAfterProbeBlockInputStream : public IProfilingBlockInputStream Block result_sample_block; Block projected_sample_block; /// same schema with join's final schema - /// Indices of columns in result_sample_block that come from the left-side table (except key columns). + /// Indices of columns in left sample block ColumnNumbers column_indices_left; - /// Indices of columns that come from the right-side table. + /// Indices of columns in right sample block /// Order is significant: it is the same as the order of columns in the blocks of the right-side table that are saved in parent.blocks. ColumnNumbers column_indices_right; /// Columns of the current output block corresponding to column_indices_left. @@ -82,17 +82,13 @@ class ScanHashMapAfterProbeBlockInputStream : public IProfilingBlockInputStream template void fillColumns( const Map & map, - size_t num_columns_left, MutableColumns & mutable_columns_left, - size_t num_columns_right, MutableColumns & mutable_columns_right, IColumn * row_counter_column); template void fillColumnsUsingCurrentPartition( - size_t num_columns_left, MutableColumns & mutable_columns_left, - size_t num_columns_right, MutableColumns & mutable_columns_right, IColumn * row_counter_column); }; diff --git a/dbms/src/Flash/Coprocessor/DAGQueryBlockInterpreter.cpp b/dbms/src/Flash/Coprocessor/DAGQueryBlockInterpreter.cpp index 67947e74ca4..1e0eb7b1790 100644 --- a/dbms/src/Flash/Coprocessor/DAGQueryBlockInterpreter.cpp +++ b/dbms/src/Flash/Coprocessor/DAGQueryBlockInterpreter.cpp @@ -248,16 +248,25 @@ void DAGQueryBlockInterpreter::handleJoin( DAGPipeline build_pipeline; probe_pipeline.streams = input_streams_vec[1 - tiflash_join.build_side_index]; build_pipeline.streams = input_streams_vec[tiflash_join.build_side_index]; + /// for DAGQueryBlockInterpreter, the schema is already aligned to TiDB's schema after appendFinalProjectForNonRootQueryBlock + const auto probe_source_columns + = JoinInterpreterHelper::genDAGExpressionAnalyzerSourceColumns(probe_pipeline.firstStream()->getHeader(), {}); + const auto build_source_columns + = JoinInterpreterHelper::genDAGExpressionAnalyzerSourceColumns(build_pipeline.firstStream()->getHeader(), {}); RUNTIME_ASSERT(!input_streams_vec[0].empty(), log, "left input streams cannot be empty"); const Block & left_input_header = input_streams_vec[0].back()->getHeader(); + const auto left_source_columns + = JoinInterpreterHelper::genDAGExpressionAnalyzerSourceColumns(left_input_header, {}); RUNTIME_ASSERT(!input_streams_vec[1].empty(), log, "right input streams cannot be empty"); const Block & right_input_header = input_streams_vec[1].back()->getHeader(); + const auto right_source_columns + = JoinInterpreterHelper::genDAGExpressionAnalyzerSourceColumns(right_input_header, {}); String match_helper_name = tiflash_join.genMatchHelperName(left_input_header, right_input_header); NamesAndTypes join_output_columns - = tiflash_join.genJoinOutputColumns(left_input_header, right_input_header, match_helper_name); + = tiflash_join.genJoinOutputColumns(left_source_columns, right_source_columns, match_helper_name); /// add necessary transformation if the join key is an expression bool is_tiflash_right_join = isRightOuterJoin(tiflash_join.kind); @@ -267,7 +276,7 @@ void DAGQueryBlockInterpreter::handleJoin( auto [probe_side_prepare_actions, probe_key_names, original_probe_key_names, probe_filter_column_name] = JoinInterpreterHelper::prepareJoin( context, - probe_pipeline.firstStream()->getHeader(), + probe_source_columns, tiflash_join.getProbeJoinKeys(), tiflash_join.join_key_types, true, @@ -280,7 +289,7 @@ void DAGQueryBlockInterpreter::handleJoin( auto [build_side_prepare_actions, build_key_names, original_build_key_names, build_filter_column_name] = JoinInterpreterHelper::prepareJoin( context, - build_pipeline.firstStream()->getHeader(), + build_source_columns, tiflash_join.getBuildJoinKeys(), tiflash_join.join_key_types, false, @@ -291,8 +300,8 @@ void DAGQueryBlockInterpreter::handleJoin( tiflash_join.fillJoinOtherConditionsAction( context, - left_input_header, - right_input_header, + left_source_columns, + right_source_columns, probe_side_prepare_actions, original_probe_key_names, original_build_key_names, @@ -324,9 +333,6 @@ void DAGQueryBlockInterpreter::handleJoin( left_input_header, right_input_header, join_non_equal_conditions.other_cond_expr != nullptr); - Names join_output_column_names; - for (const auto & col : join_output_columns) - join_output_column_names.emplace_back(col.name); JoinPtr join_ptr = std::make_shared( probe_key_names, build_key_names, @@ -337,7 +343,7 @@ void DAGQueryBlockInterpreter::handleJoin( build_spill_config, probe_spill_config, RestoreConfig{settings.join_restore_concurrency, 0, 0}, - join_output_column_names, + join_output_columns, [&](const OperatorSpillContextPtr & operator_spill_context) { if (context.getDAGContext() != nullptr) { @@ -394,6 +400,8 @@ void DAGQueryBlockInterpreter::handleJoin( right_query.source = build_pipeline.firstStream(); right_query.join = join_ptr; + + join_ptr->finalize(DB::toNames(join_output_columns)); join_ptr->initBuild(right_query.source->getHeader(), join_build_concurrency); /// probe side streams diff --git a/dbms/src/Flash/Coprocessor/JoinInterpreterHelper.cpp b/dbms/src/Flash/Coprocessor/JoinInterpreterHelper.cpp index 01b7cf24719..9acf8a582c8 100644 --- a/dbms/src/Flash/Coprocessor/JoinInterpreterHelper.cpp +++ b/dbms/src/Flash/Coprocessor/JoinInterpreterHelper.cpp @@ -250,23 +250,21 @@ String TiFlashJoin::genFlagMappedEntryHelperName(const Block & header1, const Bl } NamesAndTypes TiFlashJoin::genColumnsForOtherJoinFilter( - const Block & left_input_header, - const Block & right_input_header, + const NamesAndTypes & left_cols, + const NamesAndTypes & right_cols, const ExpressionActionsPtr & probe_prepare_join_actions) const { #ifndef NDEBUG - auto is_prepare_actions_valid = [](const Block & origin_block, const ExpressionActionsPtr & prepare_actions) { + auto is_prepare_actions_valid = [](const NamesAndTypes & cols, const ExpressionActionsPtr & prepare_actions) { const Block & prepare_sample_block = prepare_actions->getSampleBlock(); - for (const auto & p : origin_block) + for (const auto & p : cols) { if (!prepare_sample_block.has(p.name)) return false; } return true; }; - if (unlikely(!is_prepare_actions_valid( - build_side_index == 1 ? left_input_header : right_input_header, - probe_prepare_join_actions))) + if (unlikely(!is_prepare_actions_valid(build_side_index == 1 ? left_cols : right_cols, probe_prepare_join_actions))) { throw TiFlashException("probe_prepare_join_actions isn't valid", Errors::Coprocessor::Internal); } @@ -289,16 +287,16 @@ NamesAndTypes TiFlashJoin::genColumnsForOtherJoinFilter( NamesAndTypes columns_for_other_join_filter; std::unordered_set column_set_for_origin_columns; - auto append_origin_columns - = [&columns_for_other_join_filter, &column_set_for_origin_columns](const Block & header, bool make_nullable) { - for (const auto & p : header) - { - columns_for_other_join_filter.emplace_back(p.name, make_nullable ? makeNullable(p.type) : p.type); - column_set_for_origin_columns.emplace(p.name); - } - }; - append_origin_columns(left_input_header, join.join_type() == tipb::JoinType::TypeRightOuterJoin); - append_origin_columns(right_input_header, join.join_type() == tipb::JoinType::TypeLeftOuterJoin); + auto append_origin_columns = [&columns_for_other_join_filter, + &column_set_for_origin_columns](const NamesAndTypes & cols, bool make_nullable) { + for (const auto & p : cols) + { + columns_for_other_join_filter.emplace_back(p.name, make_nullable ? makeNullable(p.type) : p.type); + column_set_for_origin_columns.emplace(p.name); + } + }; + append_origin_columns(left_cols, join.join_type() == tipb::JoinType::TypeRightOuterJoin); + append_origin_columns(right_cols, join.join_type() == tipb::JoinType::TypeLeftOuterJoin); /// append the columns generated by probe side prepare join actions. /// the new columns are @@ -320,23 +318,23 @@ NamesAndTypes TiFlashJoin::genColumnsForOtherJoinFilter( } NamesAndTypes TiFlashJoin::genJoinOutputColumns( - const Block & left_input_header, - const Block & right_input_header, + const NamesAndTypes & left_cols, + const NamesAndTypes & right_cols, const String & match_helper_name) const { NamesAndTypes join_output_columns; - auto append_output_columns = [&join_output_columns](const Block & header, bool make_nullable) { - for (auto const & p : header) + auto append_output_columns = [&join_output_columns](const NamesAndTypes & cols, bool make_nullable) { + for (auto const & p : cols) { join_output_columns.emplace_back(p.name, make_nullable ? makeNullable(p.type) : p.type); } }; - append_output_columns(left_input_header, join.join_type() == tipb::JoinType::TypeRightOuterJoin); + append_output_columns(left_cols, join.join_type() == tipb::JoinType::TypeRightOuterJoin); if (!isSemiFamily() && !isLeftOuterSemiFamily()) { /// for (left outer) semi join, the columns from right table will be ignored - append_output_columns(right_input_header, join.join_type() == tipb::JoinType::TypeLeftOuterJoin); + append_output_columns(right_cols, join.join_type() == tipb::JoinType::TypeLeftOuterJoin); } if (!match_helper_name.empty()) @@ -349,15 +347,14 @@ NamesAndTypes TiFlashJoin::genJoinOutputColumns( void TiFlashJoin::fillJoinOtherConditionsAction( const Context & context, - const Block & left_input_header, - const Block & right_input_header, + const NamesAndTypes & left_cols, + const NamesAndTypes & right_cols, const ExpressionActionsPtr & probe_side_prepare_join, const Names & probe_key_names, const Names & build_key_names, JoinNonEqualConditions & join_non_equal_conditions) const { - auto columns_for_other_join_filter - = genColumnsForOtherJoinFilter(left_input_header, right_input_header, probe_side_prepare_join); + auto columns_for_other_join_filter = genColumnsForOtherJoinFilter(left_cols, right_cols, probe_side_prepare_join); if (join.other_conditions_size() == 0 && join.other_eq_conditions_from_in_size() == 0 && !join.is_null_aware_semi_join()) @@ -389,17 +386,14 @@ void TiFlashJoin::fillJoinOtherConditionsAction( std::tuple prepareJoin( const Context & context, - const Block & input_header, + const NamesAndTypes & source_columns, const google::protobuf::RepeatedPtrField & keys, const JoinKeyTypes & join_key_types, bool left, bool is_right_out_join, const google::protobuf::RepeatedPtrField & filters) { - NamesAndTypes source_columns; - for (auto const & p : input_header) - source_columns.emplace_back(p.name, p.type); - DAGExpressionAnalyzer dag_analyzer(std::move(source_columns), context); + DAGExpressionAnalyzer dag_analyzer(source_columns, context); ExpressionActionsChain chain; Names key_names; Names original_key_names; @@ -419,7 +413,8 @@ std::tuple prepareJoin( std::vector TiFlashJoin::genRuntimeFilterList( const Context & context, - const Block & input_header, + const NamesAndTypes & source_columns, + const std::unordered_map & key_names_map, const LoggerPtr & log) { std::vector result; @@ -428,11 +423,7 @@ std::vector TiFlashJoin::genRuntimeFilterList( return result; } result.reserve(join.runtime_filter_list().size()); - NamesAndTypes source_columns; - source_columns.reserve(input_header.columns()); - for (auto const & p : input_header) - source_columns.emplace_back(p.name, p.type); - DAGExpressionAnalyzer dag_analyzer(std::move(source_columns), context); + DAGExpressionAnalyzer dag_analyzer(source_columns, context); LOG_DEBUG(log, "before gen runtime filter, pb rf size:{}", join.runtime_filter_list().size()); for (auto rf_pb : join.runtime_filter_list()) { @@ -442,6 +433,12 @@ std::vector TiFlashJoin::genRuntimeFilterList( { runtime_filter->build(); dag_analyzer.appendRuntimeFilterProperties(runtime_filter); + /// update the source column name to use the join key as source column + const auto & updated_key_name_it = key_names_map.find(runtime_filter->getSourceColumnName()); + RUNTIME_CHECK_MSG( + updated_key_name_it != key_names_map.end(), + "rf source column is not join key, which is not expected"); + runtime_filter->setSourceColumnName(updated_key_name_it->second); } catch (TiFlashException & e) { @@ -453,5 +450,25 @@ std::vector TiFlashJoin::genRuntimeFilterList( } return result; } + +NamesAndTypes genDAGExpressionAnalyzerSourceColumns(Block block, const NamesAndTypes & tidb_schema) +{ + /// generate source_columns that is used to compile tipb::Expr, the rule is columns in `tidb_schema` + /// must be the first part of the source_columns, and the column order must be exactly the same as + /// in `tidb_schema` + NamesAndTypes source_columns = tidb_schema; + /// remove columns that already in tidb_schema + for (const auto & name_and_type : tidb_schema) + { + if (block.has(name_and_type.name)) + block.erase(name_and_type.name); + } + /// insert the remaining columns, this is to avoid duplicate column error + for (const auto & col : block) + { + source_columns.emplace_back(col.name, col.type); + } + return source_columns; +} } // namespace JoinInterpreterHelper } // namespace DB diff --git a/dbms/src/Flash/Coprocessor/JoinInterpreterHelper.h b/dbms/src/Flash/Coprocessor/JoinInterpreterHelper.h index 6fe8423e510..6271e3238df 100644 --- a/dbms/src/Flash/Coprocessor/JoinInterpreterHelper.h +++ b/dbms/src/Flash/Coprocessor/JoinInterpreterHelper.h @@ -175,14 +175,14 @@ struct TiFlashJoin /// The columns output by join will be: /// {columns of left_input, columns of right_input, match_helper_name} NamesAndTypes genJoinOutputColumns( - const Block & left_input_header, - const Block & right_input_header, + const NamesAndTypes & left_cols, + const NamesAndTypes & right_cols, const String & match_helper_name) const; void fillJoinOtherConditionsAction( const Context & context, - const Block & left_input_header, - const Block & right_input_header, + const NamesAndTypes & left_cols, + const NamesAndTypes & right_cols, const ExpressionActionsPtr & probe_side_prepare_join, const Names & probe_key_names, const Names & build_key_names, @@ -198,13 +198,14 @@ struct TiFlashJoin /// -`other_columns_added_by_probe_join_actions` is added to avoid duplicated columns error /// -`match_helper_col` is not handled explicitly because `other_condition` never use that column. NamesAndTypes genColumnsForOtherJoinFilter( - const Block & left_input_header, - const Block & right_input_header, + const NamesAndTypes & left_cols, + const NamesAndTypes & right_cols, const ExpressionActionsPtr & probe_prepare_join_actions) const; std::vector genRuntimeFilterList( const Context & context, - const Block & input_header, + const NamesAndTypes & source_columns, + const std::unordered_map & key_names_map, const LoggerPtr & log); }; @@ -214,11 +215,15 @@ struct TiFlashJoin /// @filter_column_name: column name of `and(filters)` std::tuple prepareJoin( const Context & context, - const Block & input_header, + const NamesAndTypes & source_columns, const google::protobuf::RepeatedPtrField & keys, const JoinKeyTypes & join_key_types, bool left, bool is_right_out_join, const google::protobuf::RepeatedPtrField & filters); + +/// generate source_columns that is used to compile tipb::Expr, the rule is columns in `tidb_schema` +/// must be the first part of the source_columns +NamesAndTypes genDAGExpressionAnalyzerSourceColumns(Block block, const NamesAndTypes & tidb_schema); } // namespace JoinInterpreterHelper } // namespace DB diff --git a/dbms/src/Flash/Planner/PhysicalPlanNode.cpp b/dbms/src/Flash/Planner/PhysicalPlanNode.cpp index 8b8849530bc..517baf024e4 100644 --- a/dbms/src/Flash/Planner/PhysicalPlanNode.cpp +++ b/dbms/src/Flash/Planner/PhysicalPlanNode.cpp @@ -63,9 +63,34 @@ String PhysicalPlanNode::toSimpleString() return fmt::format("{}|{}", type.toString(), isTiDBOperator() ? executor_id : "NonTiDBOperator"); } -void PhysicalPlanNode::finalize() +void PhysicalPlanNode::finalize(const Names & parent_require) { - finalize(DB::toNames(schema)); + if unlikely (finalized) + { + LOG_WARNING(log, "Should not reach here, {}-{} already finalized", type.toString(), executor_id); + return; + } + auto block_to_schema_string = [&](const Block & block) { + FmtBuffer buffer; + buffer.joinStr( + block.cbegin(), + block.cend(), + [](const auto & item, FmtBuffer & buf) { buf.fmtAppend("<{}, {}>", item.name, item.type->getName()); }, + ", "); + return buffer.toString(); + }; + auto block_before_finalize = getSampleBlock(); + finalizeImpl(parent_require); + finalized = true; + auto block_after_finalize = getSampleBlock(); + if (block_before_finalize.columns() != block_after_finalize.columns()) + { + LOG_DEBUG( + log, + "Finalize pruned some columns: before finalize: {}, after finalize: {}", + block_to_schema_string(block_before_finalize), + block_to_schema_string(block_after_finalize)); + } } void PhysicalPlanNode::recordProfileStreams(DAGPipeline & pipeline, const Context & context) diff --git a/dbms/src/Flash/Planner/PhysicalPlanNode.h b/dbms/src/Flash/Planner/PhysicalPlanNode.h index cfc96724e85..ea33264eef0 100644 --- a/dbms/src/Flash/Planner/PhysicalPlanNode.h +++ b/dbms/src/Flash/Planner/PhysicalPlanNode.h @@ -80,8 +80,8 @@ class PhysicalPlanNode : public std::enable_shared_from_this EventPtr sinkComplete(PipelineExecutorContext & exec_context); - virtual void finalize(const Names & parent_require) = 0; - void finalize(); + virtual void finalizeImpl(const Names & parent_require) = 0; + void finalize(const Names & parent_require); /// Obtain a sample block that contains the names and types of result columns. virtual const Block & getSampleBlock() const = 0; @@ -127,6 +127,7 @@ class PhysicalPlanNode : public std::enable_shared_from_this bool is_tidb_operator = true; bool is_restore_concurrency = true; + bool finalized = false; LoggerPtr log; }; diff --git a/dbms/src/Flash/Planner/Plans/PhysicalAggregation.cpp b/dbms/src/Flash/Planner/Plans/PhysicalAggregation.cpp index 3cf5b5fc3d7..f68f463d443 100644 --- a/dbms/src/Flash/Planner/Plans/PhysicalAggregation.cpp +++ b/dbms/src/Flash/Planner/Plans/PhysicalAggregation.cpp @@ -288,7 +288,7 @@ void PhysicalAggregation::buildPipeline( } } -void PhysicalAggregation::finalize(const Names & parent_require) +void PhysicalAggregation::finalizeImpl(const Names & parent_require) { // schema.size() >= parent_require.size() FinalizeHelper::checkSchemaContainsParentRequire(schema, parent_require); diff --git a/dbms/src/Flash/Planner/Plans/PhysicalAggregation.h b/dbms/src/Flash/Planner/Plans/PhysicalAggregation.h index 935a4a71214..f65b965b6ac 100644 --- a/dbms/src/Flash/Planner/Plans/PhysicalAggregation.h +++ b/dbms/src/Flash/Planner/Plans/PhysicalAggregation.h @@ -55,7 +55,7 @@ class PhysicalAggregation : public PhysicalUnary void buildPipeline(PipelineBuilder & builder, Context & context, PipelineExecutorContext & exec_context) override; - void finalize(const Names & parent_require) override; + void finalizeImpl(const Names & parent_require) override; const Block & getSampleBlock() const override; diff --git a/dbms/src/Flash/Planner/Plans/PhysicalExchangeReceiver.cpp b/dbms/src/Flash/Planner/Plans/PhysicalExchangeReceiver.cpp index b636a759218..25c6437b468 100644 --- a/dbms/src/Flash/Planner/Plans/PhysicalExchangeReceiver.cpp +++ b/dbms/src/Flash/Planner/Plans/PhysicalExchangeReceiver.cpp @@ -113,7 +113,7 @@ void PhysicalExchangeReceiver::buildPipelineExecGroupImpl( context.getDAGContext()->addInboundIOProfileInfos(executor_id, group_builder.getCurIOProfileInfos()); } -void PhysicalExchangeReceiver::finalize(const Names & parent_require) +void PhysicalExchangeReceiver::finalizeImpl(const Names & parent_require) { FinalizeHelper::checkSchemaContainsParentRequire(schema, parent_require); } diff --git a/dbms/src/Flash/Planner/Plans/PhysicalExchangeReceiver.h b/dbms/src/Flash/Planner/Plans/PhysicalExchangeReceiver.h index 24fd812bfb0..944af880743 100644 --- a/dbms/src/Flash/Planner/Plans/PhysicalExchangeReceiver.h +++ b/dbms/src/Flash/Planner/Plans/PhysicalExchangeReceiver.h @@ -38,7 +38,7 @@ class PhysicalExchangeReceiver : public PhysicalLeaf const Block & sample_block_, const std::shared_ptr & mpp_exchange_receiver_); - void finalize(const Names & parent_require) override; + void finalizeImpl(const Names & parent_require) override; const Block & getSampleBlock() const override; diff --git a/dbms/src/Flash/Planner/Plans/PhysicalExchangeSender.cpp b/dbms/src/Flash/Planner/Plans/PhysicalExchangeSender.cpp index b50ff61bf4c..caaba97d211 100644 --- a/dbms/src/Flash/Planner/Plans/PhysicalExchangeSender.cpp +++ b/dbms/src/Flash/Planner/Plans/PhysicalExchangeSender.cpp @@ -133,7 +133,7 @@ void PhysicalExchangeSender::buildPipelineExecGroupImpl( }); } -void PhysicalExchangeSender::finalize(const Names & parent_require) +void PhysicalExchangeSender::finalizeImpl(const Names & parent_require) { child->finalize(parent_require); } diff --git a/dbms/src/Flash/Planner/Plans/PhysicalExchangeSender.h b/dbms/src/Flash/Planner/Plans/PhysicalExchangeSender.h index bd2e9ee8f90..6147312127e 100644 --- a/dbms/src/Flash/Planner/Plans/PhysicalExchangeSender.h +++ b/dbms/src/Flash/Planner/Plans/PhysicalExchangeSender.h @@ -48,7 +48,7 @@ class PhysicalExchangeSender : public PhysicalUnary , compression_mode(compression_mode_) {} - void finalize(const Names & parent_require) override; + void finalizeImpl(const Names & parent_require) override; const Block & getSampleBlock() const override; diff --git a/dbms/src/Flash/Planner/Plans/PhysicalExpand.cpp b/dbms/src/Flash/Planner/Plans/PhysicalExpand.cpp index 7b8ef3e0cc9..3464e497d48 100644 --- a/dbms/src/Flash/Planner/Plans/PhysicalExpand.cpp +++ b/dbms/src/Flash/Planner/Plans/PhysicalExpand.cpp @@ -98,7 +98,7 @@ void PhysicalExpand::buildBlockInputStreamImpl(DAGPipeline & pipeline, Context & expandTransform(pipeline); } -void PhysicalExpand::finalize(const Names & parent_require) +void PhysicalExpand::finalizeImpl(const Names & parent_require) { FinalizeHelper::checkSchemaContainsParentRequire(schema, parent_require); Names required_output = parent_require; diff --git a/dbms/src/Flash/Planner/Plans/PhysicalExpand.h b/dbms/src/Flash/Planner/Plans/PhysicalExpand.h index 9d1923c0d01..fbd5e8fbf5a 100644 --- a/dbms/src/Flash/Planner/Plans/PhysicalExpand.h +++ b/dbms/src/Flash/Planner/Plans/PhysicalExpand.h @@ -44,7 +44,7 @@ class PhysicalExpand : public PhysicalUnary , expand_actions(expand_actions) {} - void finalize(const Names & parent_require) override; + void finalizeImpl(const Names & parent_require) override; void expandTransform(DAGPipeline & child_pipeline); diff --git a/dbms/src/Flash/Planner/Plans/PhysicalExpand2.cpp b/dbms/src/Flash/Planner/Plans/PhysicalExpand2.cpp index 367212ade70..0ad1c067b66 100644 --- a/dbms/src/Flash/Planner/Plans/PhysicalExpand2.cpp +++ b/dbms/src/Flash/Planner/Plans/PhysicalExpand2.cpp @@ -213,7 +213,7 @@ void PhysicalExpand2::buildBlockInputStreamImpl(DAGPipeline & pipeline, Context expandTransform(pipeline); } -void PhysicalExpand2::finalize(const Names & parent_require) +void PhysicalExpand2::finalizeImpl(const Names & parent_require) { FinalizeHelper::checkSchemaContainsParentRequire(schema, parent_require); child->finalize(shared_expand->getBeforeExpandActions()->getRequiredColumns()); diff --git a/dbms/src/Flash/Planner/Plans/PhysicalExpand2.h b/dbms/src/Flash/Planner/Plans/PhysicalExpand2.h index e38ef79fb5b..1d16c95a401 100644 --- a/dbms/src/Flash/Planner/Plans/PhysicalExpand2.h +++ b/dbms/src/Flash/Planner/Plans/PhysicalExpand2.h @@ -44,7 +44,7 @@ class PhysicalExpand2 : public PhysicalUnary sample_block = Block(schema_); } - void finalize(const Names & parent_require) override; + void finalizeImpl(const Names & parent_require) override; void expandTransform(DAGPipeline & child_pipeline); diff --git a/dbms/src/Flash/Planner/Plans/PhysicalFilter.cpp b/dbms/src/Flash/Planner/Plans/PhysicalFilter.cpp index f9a60500442..7cb84b0ba89 100644 --- a/dbms/src/Flash/Planner/Plans/PhysicalFilter.cpp +++ b/dbms/src/Flash/Planner/Plans/PhysicalFilter.cpp @@ -78,7 +78,7 @@ void PhysicalFilter::buildPipelineExecGroupImpl( }); } -void PhysicalFilter::finalize(const Names & parent_require) +void PhysicalFilter::finalizeImpl(const Names & parent_require) { Names required_output = parent_require; required_output.emplace_back(filter_column); diff --git a/dbms/src/Flash/Planner/Plans/PhysicalFilter.h b/dbms/src/Flash/Planner/Plans/PhysicalFilter.h index e29434923d5..dcd34b32586 100644 --- a/dbms/src/Flash/Planner/Plans/PhysicalFilter.h +++ b/dbms/src/Flash/Planner/Plans/PhysicalFilter.h @@ -43,7 +43,7 @@ class PhysicalFilter : public PhysicalUnary , before_filter_actions(before_filter_actions_) {} - void finalize(const Names & parent_require) override; + void finalizeImpl(const Names & parent_require) override; const Block & getSampleBlock() const override; diff --git a/dbms/src/Flash/Planner/Plans/PhysicalGetResultSink.h b/dbms/src/Flash/Planner/Plans/PhysicalGetResultSink.h index 730cf937e52..8ff3b598b4d 100644 --- a/dbms/src/Flash/Planner/Plans/PhysicalGetResultSink.h +++ b/dbms/src/Flash/Planner/Plans/PhysicalGetResultSink.h @@ -40,7 +40,7 @@ class PhysicalGetResultSink : public PhysicalUnary assert(result_queue); } - void finalize(const Names &) override { throw Exception("Unsupport"); } + void finalizeImpl(const Names &) override { throw Exception("Unsupport"); } const Block & getSampleBlock() const override { throw Exception("Unsupport"); } diff --git a/dbms/src/Flash/Planner/Plans/PhysicalJoin.cpp b/dbms/src/Flash/Planner/Plans/PhysicalJoin.cpp index be2c1bc3e2c..e912ecfc8c7 100644 --- a/dbms/src/Flash/Planner/Plans/PhysicalJoin.cpp +++ b/dbms/src/Flash/Planner/Plans/PhysicalJoin.cpp @@ -68,9 +68,6 @@ PhysicalPlanNodePtr PhysicalJoin::build( RUNTIME_CHECK(left); RUNTIME_CHECK(right); - left->finalize(); - right->finalize(); - const Block & left_input_header = left->getSampleBlock(); const Block & right_input_header = right->getSampleBlock(); @@ -78,13 +75,16 @@ PhysicalPlanNodePtr PhysicalJoin::build( const auto & probe_plan = tiflash_join.build_side_index == 0 ? right : left; const auto & build_plan = tiflash_join.build_side_index == 0 ? left : right; - - const Block & probe_side_header = probe_plan->getSampleBlock(); - const Block & build_side_header = build_plan->getSampleBlock(); + const auto probe_source_columns = tiflash_join.build_side_index == 0 + ? JoinInterpreterHelper::genDAGExpressionAnalyzerSourceColumns(right_input_header, right->getSchema()) + : JoinInterpreterHelper::genDAGExpressionAnalyzerSourceColumns(left_input_header, left->getSchema()); + const auto & build_source_columns = tiflash_join.build_side_index == 0 + ? JoinInterpreterHelper::genDAGExpressionAnalyzerSourceColumns(left_input_header, left->getSchema()) + : JoinInterpreterHelper::genDAGExpressionAnalyzerSourceColumns(right_input_header, right->getSchema()); String match_helper_name = tiflash_join.genMatchHelperName(left_input_header, right_input_header); NamesAndTypes join_output_schema - = tiflash_join.genJoinOutputColumns(left_input_header, right_input_header, match_helper_name); + = tiflash_join.genJoinOutputColumns(left->getSchema(), right->getSchema(), match_helper_name); auto & dag_context = *context.getDAGContext(); /// add necessary transformation if the join key is an expression @@ -96,7 +96,7 @@ PhysicalPlanNodePtr PhysicalJoin::build( auto [probe_side_prepare_actions, probe_key_names, original_probe_key_names, probe_filter_column_name] = JoinInterpreterHelper::prepareJoin( context, - probe_side_header, + probe_source_columns, tiflash_join.getProbeJoinKeys(), tiflash_join.join_key_types, /*left=*/true, @@ -110,7 +110,7 @@ PhysicalPlanNodePtr PhysicalJoin::build( auto [build_side_prepare_actions, build_key_names, original_build_key_names, build_filter_column_name] = JoinInterpreterHelper::prepareJoin( context, - build_side_header, + build_source_columns, tiflash_join.getBuildJoinKeys(), tiflash_join.join_key_types, /*left=*/false, @@ -122,8 +122,8 @@ PhysicalPlanNodePtr PhysicalJoin::build( tiflash_join.fillJoinOtherConditionsAction( context, - left_input_header, - right_input_header, + left->getSchema(), + right->getSchema(), probe_side_prepare_actions, original_probe_key_names, original_build_key_names, @@ -157,11 +157,15 @@ PhysicalPlanNodePtr PhysicalJoin::build( left_input_header, right_input_header, join_non_equal_conditions.other_cond_expr != nullptr); - Names join_output_column_names; - for (const auto & col : join_output_schema) - join_output_column_names.emplace_back(col.name); - auto runtime_filter_list = tiflash_join.genRuntimeFilterList(context, build_side_header, log); + assert(build_key_names.size() == original_build_key_names.size()); + std::unordered_map build_key_names_map; + for (size_t i = 0; i < original_build_key_names.size(); ++i) + { + build_key_names_map[original_build_key_names[i]] = build_key_names[i]; + } + auto runtime_filter_list + = tiflash_join.genRuntimeFilterList(context, build_source_columns, build_key_names_map, log); LOG_DEBUG(log, "before register runtime filter list, list size:{}", runtime_filter_list.size()); context.getDAGContext()->runtime_filter_mgr.registerRuntimeFilterList(runtime_filter_list); @@ -175,7 +179,7 @@ PhysicalPlanNodePtr PhysicalJoin::build( build_spill_config, probe_spill_config, RestoreConfig{settings.join_restore_concurrency, 0, 0}, - join_output_column_names, + join_output_schema, [&](const OperatorSpillContextPtr & operator_spill_context) { if (context.getDAGContext() != nullptr) { @@ -204,8 +208,7 @@ PhysicalPlanNodePtr PhysicalJoin::build( build_plan, join_ptr, probe_side_prepare_actions, - build_side_prepare_actions, - Block(join_output_schema)); + build_side_prepare_actions); return physical_join; } @@ -324,14 +327,35 @@ void PhysicalJoin::buildPipeline(PipelineBuilder & builder, Context & context, P builder.addPlanNode(join_probe); } -void PhysicalJoin::finalize(const Names & parent_require) +void PhysicalJoin::finalizeImpl(const Names & parent_require) { - // schema.size() >= parent_require.size() FinalizeHelper::checkSchemaContainsParentRequire(schema, parent_require); + join_ptr->finalize(parent_require); + auto required_input_columns = join_ptr->getRequiredColumns(); + + Names build_required; + Names probe_required; + const auto & build_sample_block = build_side_prepare_actions->getSampleBlock(); + for (const auto & name : required_input_columns) + { + if (build_sample_block.has(name)) + build_required.push_back(name); + else + /// if name not exists in probe side, it will throw error when call `probe_size_prepare_actions->finalize(probe_required)` + probe_required.push_back(name); + } + + build_side_prepare_actions->finalize(build_required); + build()->finalize(build_side_prepare_actions->getRequiredColumns()); + FinalizeHelper::prependProjectInputIfNeed(build_side_prepare_actions, build()->getSampleBlock().columns()); + + probe_side_prepare_actions->finalize(probe_required); + probe()->finalize(probe_side_prepare_actions->getRequiredColumns()); + FinalizeHelper::prependProjectInputIfNeed(probe_side_prepare_actions, probe()->getSampleBlock().columns()); } const Block & PhysicalJoin::getSampleBlock() const { - return sample_block; + return join_ptr->getOutputBlock(); } } // namespace DB diff --git a/dbms/src/Flash/Planner/Plans/PhysicalJoin.h b/dbms/src/Flash/Planner/Plans/PhysicalJoin.h index 2c8ef9a655a..7a829faf258 100644 --- a/dbms/src/Flash/Planner/Plans/PhysicalJoin.h +++ b/dbms/src/Flash/Planner/Plans/PhysicalJoin.h @@ -42,18 +42,16 @@ class PhysicalJoin : public PhysicalBinary const PhysicalPlanNodePtr & build_, const JoinPtr & join_ptr_, const ExpressionActionsPtr & probe_side_prepare_actions_, - const ExpressionActionsPtr & build_side_prepare_actions_, - const Block & sample_block_) + const ExpressionActionsPtr & build_side_prepare_actions_) : PhysicalBinary(executor_id_, PlanType::Join, schema_, fine_grained_shuffle_, req_id, probe_, build_) , join_ptr(join_ptr_) , probe_side_prepare_actions(probe_side_prepare_actions_) , build_side_prepare_actions(build_side_prepare_actions_) - , sample_block(sample_block_) {} void buildPipeline(PipelineBuilder & builder, Context & context, PipelineExecutorContext & exec_context) override; - void finalize(const Names & parent_require) override; + void finalizeImpl(const Names & parent_require) override; const Block & getSampleBlock() const override; @@ -73,7 +71,5 @@ class PhysicalJoin : public PhysicalBinary ExpressionActionsPtr probe_side_prepare_actions; ExpressionActionsPtr build_side_prepare_actions; - - Block sample_block; }; } // namespace DB diff --git a/dbms/src/Flash/Planner/Plans/PhysicalLimit.cpp b/dbms/src/Flash/Planner/Plans/PhysicalLimit.cpp index 56406e52329..a80e1dcc873 100644 --- a/dbms/src/Flash/Planner/Plans/PhysicalLimit.cpp +++ b/dbms/src/Flash/Planner/Plans/PhysicalLimit.cpp @@ -78,7 +78,7 @@ void PhysicalLimit::buildPipelineExecGroupImpl( }); } -void PhysicalLimit::finalize(const Names & parent_require) +void PhysicalLimit::finalizeImpl(const Names & parent_require) { child->finalize(parent_require); } diff --git a/dbms/src/Flash/Planner/Plans/PhysicalLimit.h b/dbms/src/Flash/Planner/Plans/PhysicalLimit.h index 878dfe9923d..1f4910dcf48 100644 --- a/dbms/src/Flash/Planner/Plans/PhysicalLimit.h +++ b/dbms/src/Flash/Planner/Plans/PhysicalLimit.h @@ -39,7 +39,7 @@ class PhysicalLimit : public PhysicalUnary , limit(limit_) {} - void finalize(const Names & parent_require) override; + void finalizeImpl(const Names & parent_require) override; const Block & getSampleBlock() const override; diff --git a/dbms/src/Flash/Planner/Plans/PhysicalMockExchangeReceiver.cpp b/dbms/src/Flash/Planner/Plans/PhysicalMockExchangeReceiver.cpp index 359a23fcdff..089d10a5d92 100644 --- a/dbms/src/Flash/Planner/Plans/PhysicalMockExchangeReceiver.cpp +++ b/dbms/src/Flash/Planner/Plans/PhysicalMockExchangeReceiver.cpp @@ -84,7 +84,7 @@ void PhysicalMockExchangeReceiver::buildPipelineExecGroupImpl( std::make_unique(exec_context, log->identifier(), mock_stream)); } -void PhysicalMockExchangeReceiver::finalize(const Names & parent_require) +void PhysicalMockExchangeReceiver::finalizeImpl(const Names & parent_require) { FinalizeHelper::checkSchemaContainsParentRequire(schema, parent_require); } diff --git a/dbms/src/Flash/Planner/Plans/PhysicalMockExchangeReceiver.h b/dbms/src/Flash/Planner/Plans/PhysicalMockExchangeReceiver.h index 34e2d8d59e7..af776f1c0d5 100644 --- a/dbms/src/Flash/Planner/Plans/PhysicalMockExchangeReceiver.h +++ b/dbms/src/Flash/Planner/Plans/PhysicalMockExchangeReceiver.h @@ -44,7 +44,7 @@ class PhysicalMockExchangeReceiver : public PhysicalLeaf const BlockInputStreams & mock_streams, size_t source_num_); - void finalize(const Names & parent_require) override; + void finalizeImpl(const Names & parent_require) override; const Block & getSampleBlock() const override; diff --git a/dbms/src/Flash/Planner/Plans/PhysicalMockExchangeSender.cpp b/dbms/src/Flash/Planner/Plans/PhysicalMockExchangeSender.cpp index 6a960e8d4bb..a60f54b0de1 100644 --- a/dbms/src/Flash/Planner/Plans/PhysicalMockExchangeSender.cpp +++ b/dbms/src/Flash/Planner/Plans/PhysicalMockExchangeSender.cpp @@ -49,7 +49,7 @@ void PhysicalMockExchangeSender::buildBlockInputStreamImpl( [&](auto & stream) { stream = std::make_shared(stream, log->identifier()); }); } -void PhysicalMockExchangeSender::finalize(const Names & parent_require) +void PhysicalMockExchangeSender::finalizeImpl(const Names & parent_require) { child->finalize(parent_require); } diff --git a/dbms/src/Flash/Planner/Plans/PhysicalMockExchangeSender.h b/dbms/src/Flash/Planner/Plans/PhysicalMockExchangeSender.h index dc1466b0bf0..6c491b205f8 100644 --- a/dbms/src/Flash/Planner/Plans/PhysicalMockExchangeSender.h +++ b/dbms/src/Flash/Planner/Plans/PhysicalMockExchangeSender.h @@ -37,7 +37,7 @@ class PhysicalMockExchangeSender : public PhysicalUnary : PhysicalUnary(executor_id_, PlanType::MockExchangeSender, schema_, fine_grained_shuffle_, req_id, child_) {} - void finalize(const Names & parent_require) override; + void finalizeImpl(const Names & parent_require) override; const Block & getSampleBlock() const override; diff --git a/dbms/src/Flash/Planner/Plans/PhysicalMockTableScan.cpp b/dbms/src/Flash/Planner/Plans/PhysicalMockTableScan.cpp index f888c2a7624..8554e434818 100644 --- a/dbms/src/Flash/Planner/Plans/PhysicalMockTableScan.cpp +++ b/dbms/src/Flash/Planner/Plans/PhysicalMockTableScan.cpp @@ -176,7 +176,7 @@ void PhysicalMockTableScan::buildPipelineExecGroupImpl( } } -void PhysicalMockTableScan::finalize(const Names & parent_require) +void PhysicalMockTableScan::finalizeImpl(const Names & parent_require) { FinalizeHelper::checkSchemaContainsParentRequire(schema, parent_require); } diff --git a/dbms/src/Flash/Planner/Plans/PhysicalMockTableScan.h b/dbms/src/Flash/Planner/Plans/PhysicalMockTableScan.h index 24dea7612f7..778b77e5932 100644 --- a/dbms/src/Flash/Planner/Plans/PhysicalMockTableScan.h +++ b/dbms/src/Flash/Planner/Plans/PhysicalMockTableScan.h @@ -46,7 +46,7 @@ class PhysicalMockTableScan : public PhysicalLeaf bool keep_order_, const std::vector & runtime_filter_ids_); - void finalize(const Names & parent_require) override; + void finalizeImpl(const Names & parent_require) override; const Block & getSampleBlock() const override; diff --git a/dbms/src/Flash/Planner/Plans/PhysicalProjection.cpp b/dbms/src/Flash/Planner/Plans/PhysicalProjection.cpp index 5af4bb1d817..08ca9ef7df8 100644 --- a/dbms/src/Flash/Planner/Plans/PhysicalProjection.cpp +++ b/dbms/src/Flash/Planner/Plans/PhysicalProjection.cpp @@ -154,7 +154,7 @@ void PhysicalProjection::buildPipelineExecGroupImpl( executeExpression(exec_context, group_builder, project_actions, log); } -void PhysicalProjection::finalize(const Names & parent_require) +void PhysicalProjection::finalizeImpl(const Names & parent_require) { FinalizeHelper::checkSampleBlockContainsParentRequire(getSampleBlock(), parent_require); diff --git a/dbms/src/Flash/Planner/Plans/PhysicalProjection.h b/dbms/src/Flash/Planner/Plans/PhysicalProjection.h index d32d447e0c8..3e97eb12899 100644 --- a/dbms/src/Flash/Planner/Plans/PhysicalProjection.h +++ b/dbms/src/Flash/Planner/Plans/PhysicalProjection.h @@ -63,7 +63,7 @@ class PhysicalProjection : public PhysicalUnary , project_actions(project_actions_) {} - void finalize(const Names & parent_require) override; + void finalizeImpl(const Names & parent_require) override; const Block & getSampleBlock() const override; diff --git a/dbms/src/Flash/Planner/Plans/PhysicalTableScan.cpp b/dbms/src/Flash/Planner/Plans/PhysicalTableScan.cpp index 2f7dfe5602a..364cd2e94a9 100644 --- a/dbms/src/Flash/Planner/Plans/PhysicalTableScan.cpp +++ b/dbms/src/Flash/Planner/Plans/PhysicalTableScan.cpp @@ -178,7 +178,7 @@ void PhysicalTableScan::buildProjection( executeExpression(exec_context, group_builder, schema_actions, log); } -void PhysicalTableScan::finalize(const Names & parent_require) +void PhysicalTableScan::finalizeImpl(const Names & parent_require) { FinalizeHelper::checkSchemaContainsParentRequire(schema, parent_require); } diff --git a/dbms/src/Flash/Planner/Plans/PhysicalTableScan.h b/dbms/src/Flash/Planner/Plans/PhysicalTableScan.h index 509953e1f44..284b75d0c73 100644 --- a/dbms/src/Flash/Planner/Plans/PhysicalTableScan.h +++ b/dbms/src/Flash/Planner/Plans/PhysicalTableScan.h @@ -38,7 +38,7 @@ class PhysicalTableScan : public PhysicalLeaf const TiDBTableScan & tidb_table_scan_, const Block & sample_block_); - void finalize(const Names & parent_require) override; + void finalizeImpl(const Names & parent_require) override; const Block & getSampleBlock() const override; diff --git a/dbms/src/Flash/Planner/Plans/PhysicalTopN.cpp b/dbms/src/Flash/Planner/Plans/PhysicalTopN.cpp index 6289fc6e815..e6b3e7022f8 100644 --- a/dbms/src/Flash/Planner/Plans/PhysicalTopN.cpp +++ b/dbms/src/Flash/Planner/Plans/PhysicalTopN.cpp @@ -94,7 +94,7 @@ void PhysicalTopN::buildPipelineExecGroupImpl( } } -void PhysicalTopN::finalize(const Names & parent_require) +void PhysicalTopN::finalizeImpl(const Names & parent_require) { Names required_output = parent_require; required_output.reserve(required_output.size() + order_descr.size()); diff --git a/dbms/src/Flash/Planner/Plans/PhysicalTopN.h b/dbms/src/Flash/Planner/Plans/PhysicalTopN.h index 7a9d1ae8531..7395ad1a508 100644 --- a/dbms/src/Flash/Planner/Plans/PhysicalTopN.h +++ b/dbms/src/Flash/Planner/Plans/PhysicalTopN.h @@ -46,7 +46,7 @@ class PhysicalTopN : public PhysicalUnary , limit(limit_) {} - void finalize(const Names & parent_require) override; + void finalizeImpl(const Names & parent_require) override; const Block & getSampleBlock() const override; diff --git a/dbms/src/Flash/Planner/Plans/PhysicalWindow.cpp b/dbms/src/Flash/Planner/Plans/PhysicalWindow.cpp index 35ed7669887..40b91628c8b 100644 --- a/dbms/src/Flash/Planner/Plans/PhysicalWindow.cpp +++ b/dbms/src/Flash/Planner/Plans/PhysicalWindow.cpp @@ -125,7 +125,7 @@ void PhysicalWindow::buildPipelineExecGroupImpl( executeExpression(exec_context, group_builder, window_description.after_window, log); } -void PhysicalWindow::finalize(const Names & parent_require) +void PhysicalWindow::finalizeImpl(const Names & parent_require) { FinalizeHelper::checkSchemaContainsParentRequire(schema, parent_require); diff --git a/dbms/src/Flash/Planner/Plans/PhysicalWindow.h b/dbms/src/Flash/Planner/Plans/PhysicalWindow.h index 7adfb721944..32c47b56d1d 100644 --- a/dbms/src/Flash/Planner/Plans/PhysicalWindow.h +++ b/dbms/src/Flash/Planner/Plans/PhysicalWindow.h @@ -43,7 +43,7 @@ class PhysicalWindow : public PhysicalUnary , window_description(window_description_) {} - void finalize(const Names & parent_require) override; + void finalizeImpl(const Names & parent_require) override; const Block & getSampleBlock() const override; diff --git a/dbms/src/Flash/Planner/Plans/PhysicalWindowSort.cpp b/dbms/src/Flash/Planner/Plans/PhysicalWindowSort.cpp index 0578fac99ac..2bf264b2441 100644 --- a/dbms/src/Flash/Planner/Plans/PhysicalWindowSort.cpp +++ b/dbms/src/Flash/Planner/Plans/PhysicalWindowSort.cpp @@ -70,7 +70,7 @@ void PhysicalWindowSort::buildPipelineExecGroupImpl( executeFinalSort(exec_context, group_builder, order_descr, {}, context, log); } -void PhysicalWindowSort::finalize(const Names & parent_require) +void PhysicalWindowSort::finalizeImpl(const Names & parent_require) { Names required_output = parent_require; required_output.reserve(required_output.size() + order_descr.size()); diff --git a/dbms/src/Flash/Planner/Plans/PhysicalWindowSort.h b/dbms/src/Flash/Planner/Plans/PhysicalWindowSort.h index e6dfc29a1a8..4e73c1026d7 100644 --- a/dbms/src/Flash/Planner/Plans/PhysicalWindowSort.h +++ b/dbms/src/Flash/Planner/Plans/PhysicalWindowSort.h @@ -43,7 +43,7 @@ class PhysicalWindowSort : public PhysicalUnary , order_descr(order_descr_) {} - void finalize(const Names & parent_require) override; + void finalizeImpl(const Names & parent_require) override; const Block & getSampleBlock() const override; diff --git a/dbms/src/Flash/Planner/Plans/PipelineBreakerHelper.h b/dbms/src/Flash/Planner/Plans/PipelineBreakerHelper.h index 7f7473f3dde..21fc087ac2b 100644 --- a/dbms/src/Flash/Planner/Plans/PipelineBreakerHelper.h +++ b/dbms/src/Flash/Planner/Plans/PipelineBreakerHelper.h @@ -21,7 +21,16 @@ namespace DB { \ throw Exception("Unsupport"); \ } \ - void finalize(const Names &) override { throw Exception("Unsupport"); } \ - const Block & getSampleBlock() const override { throw Exception("Unsupport"); } \ - void buildBlockInputStreamImpl(DAGPipeline &, Context &, size_t) override { throw Exception("Unsupport"); } + void finalizeImpl(const Names &) override \ + { \ + throw Exception("Unsupport"); \ + } \ + const Block & getSampleBlock() const override \ + { \ + throw Exception("Unsupport"); \ + } \ + void buildBlockInputStreamImpl(DAGPipeline &, Context &, size_t) override \ + { \ + throw Exception("Unsupport"); \ + } } // namespace DB diff --git a/dbms/src/Flash/Planner/optimize.cpp b/dbms/src/Flash/Planner/optimize.cpp index c1ba9132dc1..216dab04296 100644 --- a/dbms/src/Flash/Planner/optimize.cpp +++ b/dbms/src/Flash/Planner/optimize.cpp @@ -31,7 +31,7 @@ class FinalizeRule : public Rule public: PhysicalPlanNodePtr apply(const Context &, PhysicalPlanNodePtr plan, const LoggerPtr &) override { - plan->finalize(); + plan->finalize(toNames(plan->getSchema())); return plan; } diff --git a/dbms/src/Flash/tests/gtest_join.h b/dbms/src/Flash/tests/gtest_join.h index 6b3e621acc2..ea0d1b4fae1 100644 --- a/dbms/src/Flash/tests/gtest_join.h +++ b/dbms/src/Flash/tests/gtest_join.h @@ -168,6 +168,19 @@ class JoinTestRunner : public DB::tests::ExecutorTest right_partition_column_infos); } + ColumnsWithTypeAndName genScalarCountResults(const ColumnsWithTypeAndName & ref) + { + ColumnsWithTypeAndName ret; + ret.push_back(toVec({ref.empty() ? 0 : ref[0].column == nullptr ? 0 : ref[0].column->size()})); + return ret; + } + ColumnsWithTypeAndName genScalarCountResults(UInt64 result) + { + ColumnsWithTypeAndName ret; + ret.push_back(toVec({result})); + return ret; + } + static constexpr size_t join_type_num = 7; static constexpr tipb::JoinType join_types[join_type_num] = { diff --git a/dbms/src/Flash/tests/gtest_join_executor.cpp b/dbms/src/Flash/tests/gtest_join_executor.cpp index 4db03b70ad9..16454d625ab 100644 --- a/dbms/src/Flash/tests/gtest_join_executor.cpp +++ b/dbms/src/Flash/tests/gtest_join_executor.cpp @@ -163,6 +163,10 @@ try auto request = context.scan("simple_test", l) .join(context.scan("simple_test", r), join_types[i], {col(k)}) .build(context); + auto request_column_prune = context.scan("simple_test", l) + .join(context.scan("simple_test", r), join_types[i], {col(k)}) + .aggregation({Count(lit(static_cast(1)))}, {}) + .build(context); for (auto threshold : probe_cache_column_threshold) { @@ -170,6 +174,9 @@ try "join_probe_cache_columns_threshold", Field(static_cast(threshold))); executeAndAssertColumnsEqual(request, expected_cols[i * simple_test_num + j]); + ASSERT_COLUMNS_EQ_UR( + genScalarCountResults(expected_cols[i * simple_test_num + j]), + executeStreams(request_column_prune, 2)); } } } @@ -550,14 +557,29 @@ try { for (auto [j, jt2] : ext::enumerate(join_types)) { - auto t1 = context.scan("multi_test", "t1"); - auto t2 = context.scan("multi_test", "t2"); - auto t3 = context.scan("multi_test", "t3"); - auto t4 = context.scan("multi_test", "t4"); - auto request - = t1.join(t2, jt1, {col("a")}).join(t3.join(t4, jt1, {col("a")}), jt2, {col("b")}).build(context); + { + auto t1 = context.scan("multi_test", "t1"); + auto t2 = context.scan("multi_test", "t2"); + auto t3 = context.scan("multi_test", "t3"); + auto t4 = context.scan("multi_test", "t4"); + auto request + = t1.join(t2, jt1, {col("a")}).join(t3.join(t4, jt1, {col("a")}), jt2, {col("b")}).build(context); - executeAndAssertColumnsEqual(request, expected_cols[i * join_type_num + j]); + executeAndAssertColumnsEqual(request, expected_cols[i * join_type_num + j]); + } + { + auto t1 = context.scan("multi_test", "t1"); + auto t2 = context.scan("multi_test", "t2"); + auto t3 = context.scan("multi_test", "t3"); + auto t4 = context.scan("multi_test", "t4"); + auto request_column_prune = t1.join(t2, jt1, {col("a")}) + .join(t3.join(t4, jt1, {col("a")}), jt2, {col("b")}) + .aggregation({Count(lit(static_cast(1)))}, {}) + .build(context); + ASSERT_COLUMNS_EQ_UR( + genScalarCountResults(expected_cols[i * join_type_num + j]), + executeStreams(request_column_prune, 2)); + } } } } @@ -571,6 +593,15 @@ try .join(context.scan("cast", "t2"), tipb::JoinType::TypeInnerJoin, {col("a")}) .build(context); }; + auto cast_column_prune_request = [&]() { + return context.scan("cast", "t1") + .join(context.scan("cast", "t2"), tipb::JoinType::TypeInnerJoin, {col("a")}) + .aggregation({Count(lit(static_cast(1)))}, {}) + .build(context); + }; + + ColumnsWithTypeAndName column_prune_ref_columns; + column_prune_ref_columns.push_back(toVec({1})); /// int(1) == float(1.0) context.addMockTable("cast", "t1", {{"a", TiDB::TP::TypeLong}}, {toVec("a", {1})}); @@ -578,6 +609,7 @@ try context.addMockTable("cast", "t2", {{"a", TiDB::TP::TypeFloat}}, {toVec("a", {1.0})}); executeAndAssertColumnsEqual(cast_request(), {toNullableVec({1}), toNullableVec({1.0})}); + ASSERT_COLUMNS_EQ_UR(column_prune_ref_columns, executeStreams(cast_column_prune_request(), 2)); /// int(1) == double(1.0) context.addMockTable("cast", "t1", {{"a", TiDB::TP::TypeLong}}, {toVec("a", {1})}); @@ -585,6 +617,7 @@ try context.addMockTable("cast", "t2", {{"a", TiDB::TP::TypeDouble}}, {toVec("a", {1.0})}); executeAndAssertColumnsEqual(cast_request(), {toNullableVec({1}), toNullableVec({1.0})}); + ASSERT_COLUMNS_EQ_UR(column_prune_ref_columns, executeStreams(cast_column_prune_request(), 2)); /// float(1) == double(1.0) context.addMockTable("cast", "t1", {{"a", TiDB::TP::TypeFloat}}, {toVec("a", {1})}); @@ -592,6 +625,7 @@ try context.addMockTable("cast", "t2", {{"a", TiDB::TP::TypeDouble}}, {toVec("a", {1})}); executeAndAssertColumnsEqual(cast_request(), {toNullableVec({1}), toNullableVec({1})}); + ASSERT_COLUMNS_EQ_UR(column_prune_ref_columns, executeStreams(cast_column_prune_request(), 2)); /// varchar('x') == char('x') context.addMockTable("cast", "t1", {{"a", TiDB::TP::TypeString}}, {toVec("a", {"x"})}); @@ -599,6 +633,7 @@ try context.addMockTable("cast", "t2", {{"a", TiDB::TP::TypeVarchar}}, {toVec("a", {"x"})}); executeAndAssertColumnsEqual(cast_request(), {toNullableVec({"x"}), toNullableVec({"x"})}); + ASSERT_COLUMNS_EQ_UR(column_prune_ref_columns, executeStreams(cast_column_prune_request(), 2)); /// tinyblob('x') == varchar('x') context.addMockTable("cast", "t1", {{"a", TiDB::TP::TypeTinyBlob}}, {toVec("a", {"x"})}); @@ -606,6 +641,7 @@ try context.addMockTable("cast", "t2", {{"a", TiDB::TP::TypeVarchar}}, {toVec("a", {"x"})}); executeAndAssertColumnsEqual(cast_request(), {toNullableVec({"x"}), toNullableVec({"x"})}); + ASSERT_COLUMNS_EQ_UR(column_prune_ref_columns, executeStreams(cast_column_prune_request(), 2)); /// mediumBlob('x') == varchar('x') context.addMockTable("cast", "t1", {{"a", TiDB::TP::TypeMediumBlob}}, {toVec("a", {"x"})}); @@ -613,6 +649,7 @@ try context.addMockTable("cast", "t2", {{"a", TiDB::TP::TypeVarchar}}, {toVec("a", {"x"})}); executeAndAssertColumnsEqual(cast_request(), {toNullableVec({"x"}), toNullableVec({"x"})}); + ASSERT_COLUMNS_EQ_UR(column_prune_ref_columns, executeStreams(cast_column_prune_request(), 2)); /// blob('x') == varchar('x') context.addMockTable("cast", "t1", {{"a", TiDB::TP::TypeBlob}}, {toVec("a", {"x"})}); @@ -620,6 +657,7 @@ try context.addMockTable("cast", "t2", {{"a", TiDB::TP::TypeVarchar}}, {toVec("a", {"x"})}); executeAndAssertColumnsEqual(cast_request(), {toNullableVec({"x"}), toNullableVec({"x"})}); + ASSERT_COLUMNS_EQ_UR(column_prune_ref_columns, executeStreams(cast_column_prune_request(), 2)); /// longBlob('x') == varchar('x') context.addMockTable("cast", "t1", {{"a", TiDB::TP::TypeLongBlob}}, {toVec("a", {"x"})}); @@ -627,6 +665,7 @@ try context.addMockTable("cast", "t2", {{"a", TiDB::TP::TypeVarchar}}, {toVec("a", {"x"})}); executeAndAssertColumnsEqual(cast_request(), {toNullableVec({"x"}), toNullableVec({"x"})}); + ASSERT_COLUMNS_EQ_UR(column_prune_ref_columns, executeStreams(cast_column_prune_request(), 2)); /// decimal with different scale context.addMockTable( @@ -645,6 +684,7 @@ try cast_request(), {createNullableColumn(std::make_tuple(65, 0), {"0.12"}, {0}), createNullableColumn(std::make_tuple(65, 0), {"0.12"}, {0})}); + ASSERT_COLUMNS_EQ_UR(column_prune_ref_columns, executeStreams(cast_column_prune_request(), 2)); /// datetime(1970-01-01 00:00:01) == timestamp(1970-01-01 00:00:01) context.addMockTable( @@ -664,9 +704,17 @@ try .join(context.scan("cast", "t2"), tipb::JoinType::TypeInnerJoin, {col("datetime")}) .build(context); }; + auto cast_column_prune_request_1 = [&]() { + return context.scan("cast", "t1") + .join(context.scan("cast", "t2"), tipb::JoinType::TypeInnerJoin, {col("datetime")}) + .aggregation({Count(lit(static_cast(1)))}, {}) + .build(context); + }; executeAndAssertColumnsEqual( cast_request_1(), {createDateTimeColumn({{{1970, 1, 1, 0, 0, 1, 0}}}, 0), createDateTimeColumn({{{1970, 1, 1, 0, 0, 1, 0}}}, 0)}); + + ASSERT_COLUMNS_EQ_UR(column_prune_ref_columns, executeStreams(cast_column_prune_request_1(), 2)); } CATCH @@ -1187,34 +1235,72 @@ try for (const auto & join_type : join_types) { auto join_inputs = gen_join_inputs(); - for (auto & join_input : join_inputs) + auto join_inputs_column_prune = gen_join_inputs(); + for (size_t input_index = 0; input_index < join_inputs.size(); ++input_index) { - auto request - = join_input.first.join(join_input.second, join_type, {}, {}, {}, {cond_other}, {}).build(context); - executeAndAssertColumnsEqual(request, expected_cols[i++]); + auto request = join_inputs[input_index] + .first.join(join_inputs[input_index].second, join_type, {}, {}, {}, {cond_other}, {}) + .build(context); + const auto & expected_results = expected_cols[i]; + executeAndAssertColumnsEqual(request, expected_results); + auto request_column_prune + = join_inputs_column_prune[input_index] + .first + .join(join_inputs_column_prune[input_index].second, join_type, {}, {}, {}, {cond_other}, {}) + .aggregation({Count(lit(static_cast(1)))}, {}) + .build(context); + ASSERT_COLUMNS_EQ_UR(genScalarCountResults(expected_results), executeStreams(request_column_prune, 2)); + ++i; } /// extra tests for outer join if (join_type == tipb::TypeLeftOuterJoin) { /// left out join with left condition join_inputs = gen_join_inputs(); + join_inputs_column_prune = gen_join_inputs(); size_t left_join_index = 0; - for (auto & join_input : join_inputs) + for (size_t input_index = 0; input_index < join_inputs.size(); ++input_index) { - auto request - = join_input.first - .join(join_input.second, tipb::JoinType::TypeLeftOuterJoin, {}, {cond_left}, {}, {}, {}) - .build(context); - executeAndAssertColumnsEqual(request, left_join_expected_cols[left_join_index++]); + auto request = join_inputs[input_index] + .first + .join( + join_inputs[input_index].second, + tipb::JoinType::TypeLeftOuterJoin, + {}, + {cond_left}, + {}, + {}, + {}) + .build(context); + const auto & expected_results = left_join_expected_cols[left_join_index]; + executeAndAssertColumnsEqual(request, expected_results); + auto request_column_prune = join_inputs_column_prune[input_index] + .first + .join( + join_inputs_column_prune[input_index].second, + tipb::JoinType::TypeLeftOuterJoin, + {}, + {cond_left}, + {}, + {}, + {}) + .aggregation({Count(lit(static_cast(1)))}, {}) + .build(context); + ASSERT_COLUMNS_EQ_UR( + genScalarCountResults(expected_results), + executeStreams(request_column_prune, 2)); + ++left_join_index; } /// left out join with left condition and other condition join_inputs = gen_join_inputs(); + join_inputs_column_prune = gen_join_inputs(); i -= join_inputs.size(); - for (auto & join_input : join_inputs) + for (size_t input_index = 0; input_index < join_inputs.size(); ++input_index) { - auto request = join_input.first + auto request = join_inputs[input_index] + .first .join( - join_input.second, + join_inputs[input_index].second, tipb::JoinType::TypeLeftOuterJoin, {}, {cond_left}, @@ -1222,19 +1308,38 @@ try {cond_other}, {}) .build(context); - executeAndAssertColumnsEqual(request, expected_cols[i++]); + const auto & expected_results = expected_cols[i]; + executeAndAssertColumnsEqual(request, expected_results); + auto request_column_prune = join_inputs_column_prune[input_index] + .first + .join( + join_inputs_column_prune[input_index].second, + tipb::JoinType::TypeLeftOuterJoin, + {}, + {cond_left}, + {}, + {cond_other}, + {}) + .aggregation({Count(lit(static_cast(1)))}, {}) + .build(context); + ASSERT_COLUMNS_EQ_UR( + genScalarCountResults(expected_results), + executeStreams(request_column_prune, 2)); + ++i; } } else if (join_type == tipb::TypeRightOuterJoin) { /// right out join with right condition join_inputs = gen_join_inputs(); + join_inputs_column_prune = gen_join_inputs(); size_t right_join_index = 0; - for (auto & join_input : join_inputs) + for (size_t input_index = 0; input_index < join_inputs.size(); ++input_index) { - auto request = join_input.first + auto request = join_inputs[input_index] + .first .join( - join_input.second, + join_inputs[input_index].second, tipb::JoinType::TypeRightOuterJoin, {}, {}, @@ -1242,16 +1347,35 @@ try {}, {}) .build(context); - executeAndAssertColumnsEqual(request, right_join_expected_cols[right_join_index++]); + const auto & expected_results = right_join_expected_cols[right_join_index]; + executeAndAssertColumnsEqual(request, expected_results); + auto request_column_prune = join_inputs_column_prune[input_index] + .first + .join( + join_inputs_column_prune[input_index].second, + tipb::JoinType::TypeRightOuterJoin, + {}, + {}, + {gt(col("c"), lit(Field("2", 1)))}, + {}, + {}) + .aggregation({Count(lit(static_cast(1)))}, {}) + .build(context); + ASSERT_COLUMNS_EQ_UR( + genScalarCountResults(expected_results), + executeStreams(request_column_prune, 2)); + ++right_join_index; } /// right out join with right condition and other condition join_inputs = gen_join_inputs(); + join_inputs_column_prune = gen_join_inputs(); i -= join_inputs.size(); - for (auto & join_input : join_inputs) + for (size_t input_index = 0; input_index < join_inputs.size(); ++input_index) { - auto request = join_input.first + auto request = join_inputs[input_index] + .first .join( - join_input.second, + join_inputs[input_index].second, tipb::JoinType::TypeRightOuterJoin, {}, {}, @@ -1259,7 +1383,24 @@ try {cond_other}, {}) .build(context); - executeAndAssertColumnsEqual(request, expected_cols[i++]); + const auto & expected_results = expected_cols[i]; + executeAndAssertColumnsEqual(request, expected_results); + auto request_column_prune = join_inputs_column_prune[input_index] + .first + .join( + join_inputs_column_prune[input_index].second, + tipb::JoinType::TypeRightOuterJoin, + {}, + {}, + {cond_right}, + {cond_other}, + {}) + .aggregation({Count(lit(static_cast(1)))}, {}) + .build(context); + ASSERT_COLUMNS_EQ_UR( + genScalarCountResults(expected_results), + executeStreams(request_column_prune, 2)); + ++i; } } } @@ -1595,10 +1736,21 @@ try for (const auto & join_type : join_types) { auto join_inputs = gen_join_inputs(); - for (auto & join_input : join_inputs) + auto join_inputs_column_prune = gen_join_inputs(); + for (size_t input_index = 0; input_index < join_inputs.size(); ++input_index) { - auto request = join_input.first.join(join_input.second, join_type, {}, {}, {}, {}, {}).build(context); - executeAndAssertColumnsEqual(request, expected_cols[i++]); + auto request = join_inputs[input_index] + .first.join(join_inputs[input_index].second, join_type, {}, {}, {}, {}, {}) + .build(context); + const auto & expected_results = expected_cols[i]; + executeAndAssertColumnsEqual(request, expected_results); + auto request_column_prune + = join_inputs_column_prune[input_index] + .first.join(join_inputs_column_prune[input_index].second, join_type, {}, {}, {}, {}, {}) + .aggregation({Count(lit(static_cast(1)))}, {}) + .build(context); + ASSERT_COLUMNS_EQ_UR(genScalarCountResults(expected_results), executeStreams(request_column_prune, 2)); + ++i; } } } @@ -1611,6 +1763,11 @@ try auto request = context.scan("test_db", "l_table") .join(context.scan("test_db", "r_table"), tipb::JoinType::TypeLeftOuterJoin, {col("join_c")}) .build(context); + auto request_column_prune + = context.scan("test_db", "l_table") + .join(context.scan("test_db", "r_table"), tipb::JoinType::TypeLeftOuterJoin, {col("join_c")}) + .aggregation({Count(lit(static_cast(1)))}, {}) + .build(context); { executeAndAssertColumnsEqual( request, @@ -1618,6 +1775,7 @@ try toNullableVec({"apple", "banana"}), toNullableVec({"banana", "banana"}), toNullableVec({"apple", "banana"})}); + ASSERT_COLUMNS_EQ_UR(genScalarCountResults(2), executeStreams(request_column_prune, 2)); } request = context.scan("test_db", "l_table") @@ -1695,6 +1853,7 @@ try {toVec("a", {}), toVec("b", {}), toVec("c", {})}); std::shared_ptr request; + std::shared_ptr request_column_prune; // inner join { @@ -1703,6 +1862,11 @@ try .join(context.scan("null_test", "t"), tipb::JoinType::TypeInnerJoin, {col("a")}) .build(context); executeAndAssertColumnsEqual(request, {}); + request_column_prune = context.scan("null_test", "null_table") + .join(context.scan("null_test", "t"), tipb::JoinType::TypeInnerJoin, {col("a")}) + .aggregation({Count(lit(static_cast(1)))}, {}) + .build(context); + ASSERT_COLUMNS_EQ_UR(genScalarCountResults(0), executeStreams(request_column_prune, 2)); // non-null table join null table request = context.scan("null_test", "t") @@ -1716,12 +1880,24 @@ try toNullableVec({}), toNullableVec({}), toNullableVec({})}); + request_column_prune + = context.scan("null_test", "t") + .join(context.scan("null_test", "null_table"), tipb::JoinType::TypeInnerJoin, {col("a")}) + .aggregation({Count(lit(static_cast(1)))}, {}) + .build(context); + ASSERT_COLUMNS_EQ_UR(genScalarCountResults(0), executeStreams(request_column_prune, 2)); // null table join null table request = context.scan("null_test", "null_table") .join(context.scan("null_test", "null_table"), tipb::JoinType::TypeInnerJoin, {col("a")}) .build(context); executeAndAssertColumnsEqual(request, {}); + request_column_prune + = context.scan("null_test", "null_table") + .join(context.scan("null_test", "null_table"), tipb::JoinType::TypeInnerJoin, {col("a")}) + .aggregation({Count(lit(static_cast(1)))}, {}) + .build(context); + ASSERT_COLUMNS_EQ_UR(genScalarCountResults(0), executeStreams(request_column_prune, 2)); } // cross join @@ -1740,6 +1916,12 @@ try toNullableVec({}), toNullableVec({}), toNullableVec({})}); + request_column_prune + = context.scan("null_test", "t") + .join(context.scan("null_test", "null_table"), tipb::JoinType::TypeInnerJoin, {}, {}, {}, {cond}, {}) + .aggregation({Count(lit(static_cast(1)))}, {}) + .build(context); + ASSERT_COLUMNS_EQ_UR(genScalarCountResults(0), executeStreams(request_column_prune, 2)); request = context.scan("null_test", "t") .join( @@ -1759,6 +1941,18 @@ try toNullableVec({{}, {}, {}, {}, {}, {}, {}, {}, {}, {}}), toNullableVec({{}, {}, {}, {}, {}, {}, {}, {}, {}, {}}), toNullableVec({{}, {}, {}, {}, {}, {}, {}, {}, {}, {}})}); + request_column_prune = context.scan("null_test", "t") + .join( + context.scan("null_test", "null_table"), + tipb::JoinType::TypeLeftOuterJoin, + {}, + {cond}, + {}, + {}, + {}) + .aggregation({Count(lit(static_cast(1)))}, {}) + .build(context); + ASSERT_COLUMNS_EQ_UR(genScalarCountResults(10), executeStreams(request_column_prune, 2)); request = context.scan("null_test", "t") .join( @@ -1771,6 +1965,18 @@ try {}) .build(context); executeAndAssertColumnsEqual(request, {}); + request_column_prune = context.scan("null_test", "t") + .join( + context.scan("null_test", "null_table"), + tipb::JoinType::TypeRightOuterJoin, + {}, + {}, + {cond}, + {}, + {}) + .aggregation({Count(lit(static_cast(1)))}, {}) + .build(context); + ASSERT_COLUMNS_EQ_UR(genScalarCountResults(0), executeStreams(request_column_prune, 2)); request = context.scan("null_test", "t") @@ -1779,6 +1985,12 @@ try executeAndAssertColumnsEqual( request, {toNullableVec({}), toNullableVec({}), toNullableVec({})}); + request_column_prune + = context.scan("null_test", "t") + .join(context.scan("null_test", "null_table"), tipb::JoinType::TypeSemiJoin, {}, {}, {}, {cond}, {}) + .aggregation({Count(lit(static_cast(1)))}, {}) + .build(context); + ASSERT_COLUMNS_EQ_UR(genScalarCountResults(0), executeStreams(request_column_prune, 2)); request = context.scan("null_test", "t") .join( @@ -1795,6 +2007,18 @@ try {toNullableVec({1, 2, 3, 4, 5, 6, 7, 8, 9, 0}), toNullableVec({1, 1, 1, 1, 1, 1, 1, 2, 2, 2}), toNullableVec({1, 1, 1, 1, 1, 2, 2, 2, 2, 2})}); + request_column_prune = context.scan("null_test", "t") + .join( + context.scan("null_test", "null_table"), + tipb::JoinType::TypeAntiSemiJoin, + {}, + {}, + {}, + {cond}, + {}) + .aggregation({Count(lit(static_cast(1)))}, {}) + .build(context); + ASSERT_COLUMNS_EQ_UR(genScalarCountResults(10), executeStreams(request_column_prune, 2)); request = context.scan("null_test", "t") .join( @@ -1814,6 +2038,19 @@ try toNullableVec({1, 1, 1, 1, 1, 2, 2, 2, 2, 2}), toNullableVec({0, 0, 0, 0, 0, 0, 0, 0, 0, 0})}); + request_column_prune = context.scan("null_test", "t") + .join( + context.scan("null_test", "null_table"), + tipb::JoinType::TypeLeftOuterSemiJoin, + {}, + {}, + {}, + {cond}, + {}) + .aggregation({Count(lit(static_cast(1)))}, {}) + .build(context); + ASSERT_COLUMNS_EQ_UR(genScalarCountResults(10), executeStreams(request_column_prune, 2)); + request = context.scan("null_test", "t") .join( context.scan("null_test", "null_table"), @@ -1831,6 +2068,18 @@ try toNullableVec({1, 1, 1, 1, 1, 1, 1, 2, 2, 2}), toNullableVec({1, 1, 1, 1, 1, 2, 2, 2, 2, 2}), toNullableVec({1, 1, 1, 1, 1, 1, 1, 1, 1, 1})}); + request_column_prune = context.scan("null_test", "t") + .join( + context.scan("null_test", "null_table"), + tipb::JoinType::TypeAntiLeftOuterSemiJoin, + {}, + {}, + {}, + {cond}, + {}) + .aggregation({Count(lit(static_cast(1)))}, {}) + .build(context); + ASSERT_COLUMNS_EQ_UR(genScalarCountResults(10), executeStreams(request_column_prune, 2)); } // null table join non-null table @@ -1839,11 +2088,23 @@ try .join(context.scan("null_test", "t"), tipb::JoinType::TypeInnerJoin, {}, {}, {}, {cond}, {}) .build(context); executeAndAssertColumnsEqual(request, {}); + request_column_prune + = context.scan("null_test", "null_table") + .join(context.scan("null_test", "t"), tipb::JoinType::TypeInnerJoin, {}, {}, {}, {cond}, {}) + .aggregation({Count(lit(static_cast(1)))}, {}) + .build(context); + ASSERT_COLUMNS_EQ_UR(genScalarCountResults(0), executeStreams(request_column_prune, 2)); request = context.scan("null_test", "null_table") .join(context.scan("null_test", "t"), tipb::JoinType::TypeLeftOuterJoin, {}, {cond}, {}, {}, {}) .build(context); executeAndAssertColumnsEqual(request, {}); + request_column_prune + = context.scan("null_test", "null_table") + .join(context.scan("null_test", "t"), tipb::JoinType::TypeLeftOuterJoin, {}, {cond}, {}, {}, {}) + .aggregation({Count(lit(static_cast(1)))}, {}) + .build(context); + ASSERT_COLUMNS_EQ_UR(genScalarCountResults(0), executeStreams(request_column_prune, 2)); request = context.scan("null_test", "null_table") .join(context.scan("null_test", "t"), tipb::JoinType::TypeRightOuterJoin, {}, {}, {cond}, {}, {}) @@ -1856,32 +2117,68 @@ try toNullableVec({1, 2, 3, 4, 5, 6, 7, 8, 9, 0}), toNullableVec({1, 1, 1, 1, 1, 1, 1, 2, 2, 2}), toNullableVec({1, 1, 1, 1, 1, 2, 2, 2, 2, 2})}); + request_column_prune + = context.scan("null_test", "null_table") + .join(context.scan("null_test", "t"), tipb::JoinType::TypeRightOuterJoin, {}, {}, {cond}, {}, {}) + .aggregation({Count(lit(static_cast(1)))}, {}) + .build(context); + ASSERT_COLUMNS_EQ_UR(genScalarCountResults(10), executeStreams(request_column_prune, 2)); request = context.scan("null_test", "null_table") .join(context.scan("null_test", "t"), tipb::JoinType::TypeSemiJoin, {}, {}, {}, {cond}, {}) .build(context); executeAndAssertColumnsEqual(request, {}); + request_column_prune + = context.scan("null_test", "null_table") + .join(context.scan("null_test", "t"), tipb::JoinType::TypeSemiJoin, {}, {}, {}, {cond}, {}) + .aggregation({Count(lit(static_cast(1)))}, {}) + .build(context); + ASSERT_COLUMNS_EQ_UR(genScalarCountResults(0), executeStreams(request_column_prune, 2)); request = context.scan("null_test", "null_table") .join(context.scan("null_test", "t"), tipb::JoinType::TypeSemiJoin, {}, {}, {}, {cond}, {}, 0) .build(context); executeAndAssertColumnsEqual(request, {}); + request_column_prune + = context.scan("null_test", "null_table") + .join(context.scan("null_test", "t"), tipb::JoinType::TypeSemiJoin, {}, {}, {}, {cond}, {}, 0) + .aggregation({Count(lit(static_cast(1)))}, {}) + .build(context); + ASSERT_COLUMNS_EQ_UR(genScalarCountResults(0), executeStreams(request_column_prune, 2)); request = context.scan("null_test", "null_table") .join(context.scan("null_test", "t"), tipb::JoinType::TypeAntiSemiJoin, {}, {}, {}, {cond}, {}) .build(context); executeAndAssertColumnsEqual(request, {}); + request_column_prune + = context.scan("null_test", "null_table") + .join(context.scan("null_test", "t"), tipb::JoinType::TypeAntiSemiJoin, {}, {}, {}, {cond}, {}) + .aggregation({Count(lit(static_cast(1)))}, {}) + .build(context); + ASSERT_COLUMNS_EQ_UR(genScalarCountResults(0), executeStreams(request_column_prune, 2)); request = context.scan("null_test", "null_table") .join(context.scan("null_test", "t"), tipb::JoinType::TypeAntiSemiJoin, {}, {}, {}, {cond}, {}, 0) .build(context); executeAndAssertColumnsEqual(request, {}); + request_column_prune + = context.scan("null_test", "null_table") + .join(context.scan("null_test", "t"), tipb::JoinType::TypeAntiSemiJoin, {}, {}, {}, {cond}, {}, 0) + .aggregation({Count(lit(static_cast(1)))}, {}) + .build(context); + ASSERT_COLUMNS_EQ_UR(genScalarCountResults(0), executeStreams(request_column_prune, 2)); request = context.scan("null_test", "null_table") .join(context.scan("null_test", "t"), tipb::JoinType::TypeLeftOuterSemiJoin, {}, {}, {}, {cond}, {}) .build(context); executeAndAssertColumnsEqual(request, {}); + request_column_prune + = context.scan("null_test", "null_table") + .join(context.scan("null_test", "t"), tipb::JoinType::TypeLeftOuterSemiJoin, {}, {}, {}, {cond}, {}) + .aggregation({Count(lit(static_cast(1)))}, {}) + .build(context); + ASSERT_COLUMNS_EQ_UR(genScalarCountResults(0), executeStreams(request_column_prune, 2)); request = context.scan("null_test", "null_table") .join( @@ -1894,6 +2191,18 @@ try {}) .build(context); executeAndAssertColumnsEqual(request, {}); + request_column_prune = context.scan("null_test", "null_table") + .join( + context.scan("null_test", "t"), + tipb::JoinType::TypeAntiLeftOuterSemiJoin, + {}, + {}, + {}, + {cond}, + {}) + .aggregation({Count(lit(static_cast(1)))}, {}) + .build(context); + ASSERT_COLUMNS_EQ_UR(genScalarCountResults(0), executeStreams(request_column_prune, 2)); } // null table join null table @@ -1903,6 +2212,12 @@ try .join(context.scan("null_test", "null_table"), tipb::JoinType::TypeInnerJoin, {}, {}, {}, {cond}, {}) .build(context); executeAndAssertColumnsEqual(request, {}); + request_column_prune + = context.scan("null_test", "null_table") + .join(context.scan("null_test", "null_table"), tipb::JoinType::TypeInnerJoin, {}, {}, {}, {cond}, {}) + .aggregation({Count(lit(static_cast(1)))}, {}) + .build(context); + ASSERT_COLUMNS_EQ_UR(genScalarCountResults(0), executeStreams(request_column_prune, 2)); request = context.scan("null_test", "null_table") .join( @@ -1915,6 +2230,18 @@ try {}) .build(context); executeAndAssertColumnsEqual(request, {}); + request_column_prune = context.scan("null_test", "null_table") + .join( + context.scan("null_test", "null_table"), + tipb::JoinType::TypeLeftOuterJoin, + {}, + {cond}, + {}, + {}, + {}) + .aggregation({Count(lit(static_cast(1)))}, {}) + .build(context); + ASSERT_COLUMNS_EQ_UR(genScalarCountResults(0), executeStreams(request_column_prune, 2)); request = context.scan("null_test", "null_table") .join( @@ -1927,12 +2254,30 @@ try {}) .build(context); executeAndAssertColumnsEqual(request, {}); + request_column_prune = context.scan("null_test", "null_table") + .join( + context.scan("null_test", "null_table"), + tipb::JoinType::TypeRightOuterJoin, + {}, + {}, + {cond}, + {}, + {}) + .aggregation({Count(lit(static_cast(1)))}, {}) + .build(context); + ASSERT_COLUMNS_EQ_UR(genScalarCountResults(0), executeStreams(request_column_prune, 2)); request = context.scan("null_test", "null_table") .join(context.scan("null_test", "null_table"), tipb::JoinType::TypeSemiJoin, {}, {}, {}, {cond}, {}) .build(context); executeAndAssertColumnsEqual(request, {}); + request_column_prune + = context.scan("null_test", "null_table") + .join(context.scan("null_test", "null_table"), tipb::JoinType::TypeSemiJoin, {}, {}, {}, {cond}, {}) + .aggregation({Count(lit(static_cast(1)))}, {}) + .build(context); + ASSERT_COLUMNS_EQ_UR(genScalarCountResults(0), executeStreams(request_column_prune, 2)); request = context.scan("null_test", "null_table") .join( @@ -1945,6 +2290,18 @@ try {}) .build(context); executeAndAssertColumnsEqual(request, {}); + request_column_prune = context.scan("null_test", "null_table") + .join( + context.scan("null_test", "null_table"), + tipb::JoinType::TypeAntiSemiJoin, + {}, + {}, + {}, + {cond}, + {}) + .aggregation({Count(lit(static_cast(1)))}, {}) + .build(context); + ASSERT_COLUMNS_EQ_UR(genScalarCountResults(0), executeStreams(request_column_prune, 2)); request = context.scan("null_test", "null_table") .join( @@ -1957,6 +2314,18 @@ try {}) .build(context); executeAndAssertColumnsEqual(request, {}); + request_column_prune = context.scan("null_test", "null_table") + .join( + context.scan("null_test", "null_table"), + tipb::JoinType::TypeLeftOuterSemiJoin, + {}, + {}, + {}, + {cond}, + {}) + .aggregation({Count(lit(static_cast(1)))}, {}) + .build(context); + ASSERT_COLUMNS_EQ_UR(genScalarCountResults(0), executeStreams(request_column_prune, 2)); request = context.scan("null_test", "null_table") .join( @@ -1969,6 +2338,18 @@ try {}) .build(context); executeAndAssertColumnsEqual(request, {}); + request_column_prune = context.scan("null_test", "null_table") + .join( + context.scan("null_test", "null_table"), + tipb::JoinType::TypeAntiLeftOuterSemiJoin, + {}, + {}, + {}, + {cond}, + {}) + .aggregation({Count(lit(static_cast(1)))}, {}) + .build(context); + ASSERT_COLUMNS_EQ_UR(genScalarCountResults(0), executeStreams(request_column_prune, 2)); } } CATCH @@ -2003,6 +2384,10 @@ try auto request = context.scan("split_test", "t1") .join(context.scan("split_test", "t2"), tipb::JoinType::TypeInnerJoin, {col("a")}) .build(context); + auto request_column_prune = context.scan("split_test", "t1") + .join(context.scan("split_test", "t2"), tipb::JoinType::TypeInnerJoin, {col("a")}) + .aggregation({Count(lit(static_cast(1)))}, {}) + .build(context); std::vector block_sizes{1, 2, 7, 25, 49, 50, 51, DEFAULT_BLOCK_SIZE}; std::vector> expect{ @@ -2024,6 +2409,7 @@ try { ASSERT_EQ(expect[i][j], blocks[j].rows()); } + ASSERT_COLUMNS_EQ_UR(genScalarCountResults(50), executeStreams(request_column_prune, 2)); WRAP_FOR_JOIN_TEST_END } } @@ -2139,15 +2525,20 @@ try context.scan("outer_join_test", right_table_name), tipb::JoinType::TypeRightOuterJoin, {col("a")}) + .project({fmt::format("{}.a", left_table_name), fmt::format("{}.b", right_table_name)}) .build(context); + ColumnsWithTypeAndName ref; + ref.push_back(ref_columns[0]); + ref.push_back(ref_columns[3]); WRAP_FOR_JOIN_TEST_BEGIN for (auto threshold : probe_cache_column_threshold) { context.context->setSetting( "join_probe_cache_columns_threshold", Field(static_cast(threshold))); - ASSERT_COLUMNS_EQ_UR(ref_columns, executeStreams(request, original_max_streams)) - << "left_table_name = " << left_table_name << ", right_table_name = " << right_table_name; + ASSERT_COLUMNS_EQ_UR(ref, executeStreams(request, original_max_streams)) + << "left_table_name = " << left_table_name << ", right_table_name = " << right_table_name + << "probe cache threshold = " << threshold; } WRAP_FOR_JOIN_TEST_END } @@ -2157,11 +2548,10 @@ try { for (size_t exchange_concurrency : right_exchange_receiver_concurrency) { + auto right_name = fmt::format("right_exchange_receiver_{}_concurrency", exchange_concurrency); request = context.scan("outer_join_test", left_table_name) .join( - context.receive( - fmt::format("right_exchange_receiver_{}_concurrency", exchange_concurrency), - exchange_concurrency), + context.receive(right_name, exchange_concurrency), tipb::JoinType::TypeRightOuterJoin, {col("a")}, {}, @@ -2169,18 +2559,22 @@ try {}, {}, exchange_concurrency) + .project({fmt::format("{}.b", left_table_name), fmt::format("{}.a", right_name)}) .build(context); + ColumnsWithTypeAndName ref; + ref.push_back(ref_columns[1]); + ref.push_back(ref_columns[2]); WRAP_FOR_JOIN_TEST_BEGIN for (auto threshold : probe_cache_column_threshold) { context.context->setSetting( "join_probe_cache_columns_threshold", Field(static_cast(threshold))); - ASSERT_COLUMNS_EQ_UR(ref_columns, executeStreams(request, original_max_streams)) + ASSERT_COLUMNS_EQ_UR(ref, executeStreams(request, original_max_streams)) << "left_table_name = " << left_table_name << ", right_exchange_receiver_concurrency = " << exchange_concurrency; if (original_max_streams_small < exchange_concurrency) - ASSERT_COLUMNS_EQ_UR(ref_columns, executeStreams(request, original_max_streams_small)) + ASSERT_COLUMNS_EQ_UR(ref, executeStreams(request, original_max_streams_small)) << "left_table_name = " << left_table_name << ", right_exchange_receiver_concurrency = " << exchange_concurrency; } @@ -2222,14 +2616,18 @@ try {}, {}, 0) + .project({fmt::format("{}.a", left_table_name), fmt::format("{}.b", right_table_name)}) .build(context); + ColumnsWithTypeAndName ref; + ref.push_back(ref_columns[0]); + ref.push_back(ref_columns[3]); WRAP_FOR_JOIN_TEST_BEGIN for (auto threshold : probe_cache_column_threshold) { context.context->setSetting( "join_probe_cache_columns_threshold", Field(static_cast(threshold))); - ASSERT_COLUMNS_EQ_UR(ref_columns, executeStreams(request, original_max_streams)) + ASSERT_COLUMNS_EQ_UR(ref, executeStreams(request, original_max_streams)) << "left_table_name = " << left_table_name << ", right_table_name = " << right_table_name; } WRAP_FOR_JOIN_TEST_END @@ -2243,9 +2641,7 @@ try String exchange_name = fmt::format("right_exchange_receiver_{}_concurrency", exchange_concurrency); request = context.scan("outer_join_test", left_table_name) .join( - context.receive( - fmt::format("right_exchange_receiver_{}_concurrency", exchange_concurrency), - exchange_concurrency), + context.receive(exchange_name, exchange_concurrency), tipb::JoinType::TypeRightOuterJoin, {col("a")}, {}, @@ -2253,19 +2649,23 @@ try {}, {}, exchange_concurrency) + .project({fmt::format("{}.b", left_table_name), fmt::format("{}.a", exchange_name)}) .build(context); + ColumnsWithTypeAndName ref; + ref.push_back(ref_columns[1]); + ref.push_back(ref_columns[2]); WRAP_FOR_JOIN_TEST_BEGIN for (auto threshold : probe_cache_column_threshold) { context.context->setSetting( "join_probe_cache_columns_threshold", Field(static_cast(threshold))); - ASSERT_COLUMNS_EQ_UR(ref_columns, executeStreams(request, original_max_streams)) + ASSERT_COLUMNS_EQ_UR(ref, executeStreams(request, original_max_streams)) << "left_table_name = " << left_table_name << ", right_exchange_receiver_concurrency = " << exchange_concurrency << ", join_probe_cache_columns_threshold = " << threshold; if (original_max_streams_small < exchange_concurrency) - ASSERT_COLUMNS_EQ_UR(ref_columns, executeStreams(request, original_max_streams_small)) + ASSERT_COLUMNS_EQ_UR(ref, executeStreams(request, original_max_streams_small)) << "left_table_name = " << left_table_name << ", right_exchange_receiver_concurrency = " << exchange_concurrency << ", join_probe_cache_columns_threshold = " << threshold; @@ -2403,6 +2803,13 @@ try .join(context.scan("null_aware_semi", "s"), type, {col("a")}, {}, {}, {}, {}, 0, is_null_aware) .build(context); executeAndAssertColumnsEqual(request, reference); + auto request_column_prune + = context.scan("null_aware_semi", "t") + .join(context.scan("null_aware_semi", "s"), type, {col("a")}, {}, {}, {}, {}, 0, is_null_aware) + .aggregation({Count(lit(static_cast(1)))}, {}) + .build(context); + ASSERT_COLUMNS_EQ_UR(genScalarCountResults(reference), executeStreams(request_column_prune, 2)); + /// nullaware cross join for (const auto shallow_copy_threshold : cross_join_shallow_copy_thresholds) { context.context->setSetting( @@ -2421,6 +2828,20 @@ try false) .build(context); executeAndAssertColumnsEqual(request, reference); + request_column_prune = context.scan("null_aware_semi", "t") + .join( + context.scan("null_aware_semi", "s"), + type, + {}, + {}, + {}, + {}, + {eq(col("t.a"), col("s.a"))}, + 0, + false) + .aggregation({Count(lit(static_cast(1)))}, {}) + .build(context); + ASSERT_COLUMNS_EQ_UR(genScalarCountResults(reference), executeStreams(request_column_prune, 2)); } } } @@ -2485,6 +2906,20 @@ try is_null_aware) .build(context); executeAndAssertColumnsEqual(request, reference); + auto request_column_prune = context.scan("null_aware_semi", "t") + .join( + context.scan("null_aware_semi", "s"), + type, + {col("a")}, + {}, + {}, + {lt(col("t.c"), col("s.c"))}, + {}, + 0, + is_null_aware) + .aggregation({Count(lit(static_cast(1)))}, {}) + .build(context); + ASSERT_COLUMNS_EQ_UR(genScalarCountResults(reference), executeStreams(request_column_prune, 2)); for (const auto shallow_copy_threshold : cross_join_shallow_copy_thresholds) { context.context->setSetting( @@ -2503,6 +2938,20 @@ try false) .build(context); executeAndAssertColumnsEqual(request, reference); + request_column_prune = context.scan("null_aware_semi", "t") + .join( + context.scan("null_aware_semi", "s"), + type, + {}, + {}, + {}, + {lt(col("t.c"), col("s.c"))}, + {eq(col("t.a"), col("s.a"))}, + 0, + false) + .aggregation({Count(lit(static_cast(1)))}, {}) + .build(context); + ASSERT_COLUMNS_EQ_UR(genScalarCountResults(reference), executeStreams(request_column_prune, 2)); } } } @@ -2569,6 +3018,20 @@ try is_null_aware) .build(context); executeAndAssertColumnsEqual(request, reference); + auto request_column_prune = context.scan("null_aware_semi", "t") + .join( + context.scan("null_aware_semi", "s"), + type, + {col("a"), col("b")}, + {}, + {}, + {}, + {}, + 0, + is_null_aware) + .aggregation({Count(lit(static_cast(1)))}, {}) + .build(context); + ASSERT_COLUMNS_EQ_UR(genScalarCountResults(reference), executeStreams(request_column_prune, 2)); for (const auto shallow_copy_threshold : cross_join_shallow_copy_thresholds) { context.context->setSetting( @@ -2587,6 +3050,20 @@ try false) .build(context); executeAndAssertColumnsEqual(request, reference); + request = context.scan("null_aware_semi", "t") + .join( + context.scan("null_aware_semi", "s"), + type, + {}, + {}, + {}, + {}, + {And(eq(col("t.a"), col("s.a")), eq(col("t.b"), col("s.b")))}, + 0, + false) + .aggregation({Count(lit(static_cast(1)))}, {}) + .build(context); + ASSERT_COLUMNS_EQ_UR(genScalarCountResults(reference), executeStreams(request_column_prune, 2)); } } } @@ -2709,6 +3186,20 @@ try is_null_aware) .build(context); executeAndAssertColumnsEqual(request, reference); + auto request_column_prune = context.scan("null_aware_semi", "t") + .join( + context.scan("null_aware_semi", "s"), + type, + {col("a"), col("b")}, + {}, + {}, + {lt(col("t.c"), col("s.c"))}, + {}, + 0, + is_null_aware) + .aggregation({Count(lit(static_cast(1)))}, {}) + .build(context); + ASSERT_COLUMNS_EQ_UR(genScalarCountResults(reference), executeStreams(request_column_prune, 2)); for (const auto shallow_copy_threshold : cross_join_shallow_copy_thresholds) { context.context->setSetting( @@ -2727,6 +3218,20 @@ try false) .build(context); executeAndAssertColumnsEqual(request, reference); + request = context.scan("null_aware_semi", "t") + .join( + context.scan("null_aware_semi", "s"), + type, + {}, + {}, + {}, + {lt(col("t.c"), col("s.c"))}, + {eq(col("t.a"), col("s.a")), eq(col("t.b"), col("s.b"))}, + 0, + false) + .aggregation({Count(lit(static_cast(1)))}, {}) + .build(context); + ASSERT_COLUMNS_EQ_UR(genScalarCountResults(reference), executeStreams(request_column_prune, 2)); } } } @@ -2791,6 +3296,20 @@ try is_null_aware) .build(context); executeAndAssertColumnsEqual(request, reference); + auto request_column_prune = context.scan("null_aware_semi", "t") + .join( + context.scan("null_aware_semi", "s"), + type, + {col("a"), col("b")}, + {}, + {}, + {Or(lt(col("c"), col("d")), eq(col("t.a"), col("s.a")))}, + {}, + 0, + is_null_aware) + .aggregation({Count(lit(static_cast(1)))}, {}) + .build(context); + ASSERT_COLUMNS_EQ_UR(genScalarCountResults(reference), executeStreams(request_column_prune, 2)); for (const auto shallow_copy_threshold : cross_join_shallow_copy_thresholds) { context.context->setSetting( @@ -2809,6 +3328,20 @@ try false) .build(context); executeAndAssertColumnsEqual(request, reference); + request_column_prune = context.scan("null_aware_semi", "t") + .join( + context.scan("null_aware_semi", "s"), + type, + {}, + {}, + {}, + {Or(lt(col("c"), col("d")), eq(col("t.a"), col("s.a")))}, + {And(eq(col("t.a"), col("s.a")), eq(col("t.b"), col("s.b")))}, + 0, + false) + .aggregation({Count(lit(static_cast(1)))}, {}) + .build(context); + ASSERT_COLUMNS_EQ_UR(genScalarCountResults(reference), executeStreams(request_column_prune, 2)); } } } @@ -2864,6 +3397,20 @@ try is_null_aware) .build(context); executeAndAssertColumnsEqual(request, reference); + auto request_column_prune = context.scan("null_aware_semi", "t") + .join( + context.scan("null_aware_semi", "s"), + type, + {col("a"), col("b")}, + {}, + {}, + {}, + {}, + 0, + is_null_aware) + .aggregation({Count(lit(static_cast(1)))}, {}) + .build(context); + ASSERT_COLUMNS_EQ_UR(genScalarCountResults(reference), executeStreams(request_column_prune, 2)); for (const auto shallow_copy_threshold : cross_join_shallow_copy_thresholds) { context.context->setSetting( @@ -2882,6 +3429,20 @@ try false) .build(context); executeAndAssertColumnsEqual(request, reference); + request_column_prune = context.scan("null_aware_semi", "t") + .join( + context.scan("null_aware_semi", "s"), + type, + {}, + {}, + {}, + {}, + {And(eq(col("t.a"), col("s.a")), eq(col("t.b"), col("s.b")))}, + 0, + false) + .aggregation({Count(lit(static_cast(1)))}, {}) + .build(context); + ASSERT_COLUMNS_EQ_UR(genScalarCountResults(reference), executeStreams(request_column_prune, 2)); } } } @@ -3253,6 +3814,12 @@ try .join(context.scan("right_semi_family", "s"), type, {col("a")}, {}, {}, {}, {}, 0, false, 0) .build(context); executeAndAssertColumnsEqual(request, res); + auto request_column_prune + = context.scan("right_semi_family", "t") + .join(context.scan("right_semi_family", "s"), type, {col("a")}, {}, {}, {}, {}, 0, false, 0) + .aggregation({Count(lit(static_cast(1)))}, {}) + .build(context); + ASSERT_COLUMNS_EQ_UR(genScalarCountResults(res), executeStreams(request_column_prune, 2)); } /// One join key(t.a = s.a) + other condition(t.c < s.c). @@ -3315,6 +3882,21 @@ try 0) .build(context); executeAndAssertColumnsEqual(request, res); + auto request_column_prune = context.scan("right_semi_family", "t") + .join( + context.scan("right_semi_family", "s"), + type, + {col("a")}, + {}, + {}, + {lt(col("t.c"), col("s.c"))}, + {}, + 0, + false, + 0) + .aggregation({Count(lit(static_cast(1)))}, {}) + .build(context); + ASSERT_COLUMNS_EQ_UR(genScalarCountResults(res), executeStreams(request_column_prune, 2)); } } CATCH @@ -3383,6 +3965,12 @@ try .join(context.scan("right_outer", "s"), type, {col("a")}, {}, {}, {}, {}, 0, false, 1) .build(context); executeAndAssertColumnsEqual(request2, swap_expect); + auto request_column_prune + = context.scan("right_outer", "t") + .join(context.scan("right_outer", "s"), type, {col("a")}, {}, {}, {}, {}, 0, false, 1) + .aggregation({Count(lit(static_cast(1)))}, {}) + .build(context); + ASSERT_COLUMNS_EQ_UR(genScalarCountResults(swap_expect), executeStreams(request_column_prune, 2)); } /// One join key(t.a = s.a) + no left/right condition + other condition(t.c < s.c). @@ -3454,6 +4042,21 @@ try 1) .build(context); executeAndAssertColumnsEqual(request2, swap_expect); + auto request_column_prune = context.scan("right_outer", "t") + .join( + context.scan("right_outer", "s"), + type, + {col("a")}, + {}, + {}, + {lt(col("t.c"), col("s.c"))}, + {}, + 0, + false, + 1) + .aggregation({Count(lit(static_cast(1)))}, {}) + .build(context); + ASSERT_COLUMNS_EQ_UR(genScalarCountResults(swap_expect), executeStreams(request_column_prune, 2)); } /// One join key(t.a = s.a) + left/right condition + other condition(t.c < s.c). @@ -3526,6 +4129,21 @@ try 1) .build(context); executeAndAssertColumnsEqual(request2, swap_expect); + auto request_column_prune = context.scan("right_outer", "t") + .join( + context.scan("right_outer", "s"), + type, + {col("a")}, + {}, + {lt(col("s.a"), literal_integer)}, + {lt(col("t.c"), col("s.c"))}, + {}, + 0, + false, + 1) + .aggregation({Count(lit(static_cast(1)))}, {}) + .build(context); + ASSERT_COLUMNS_EQ_UR(genScalarCountResults(swap_expect), executeStreams(request_column_prune, 2)); } } CATCH diff --git a/dbms/src/Flash/tests/gtest_planner_interpreter.out b/dbms/src/Flash/tests/gtest_planner_interpreter.out index 93e65545c66..2aa695a93b9 100644 --- a/dbms/src/Flash/tests/gtest_planner_interpreter.out +++ b/dbms/src/Flash/tests/gtest_planner_interpreter.out @@ -617,8 +617,8 @@ CreatingSets Expression: SharedQuery: ParallelAggregating, max_threads: 10, final: true - Expression x 10: - HashJoinProbe: + HashJoinProbe x 10: + Expression: Expression: MockTableScan @ @@ -636,11 +636,10 @@ CreatingSets Expression: SharedQuery: ParallelAggregating, max_threads: 10, final: true - Expression x 10: - HashJoinProbe: - Expression: - Expression: - MockTableScan + HashJoinProbe x 10: + Expression: + Expression: + MockTableScan @ ~test_suite_name: JoinThenAgg ~result_index: 2 @@ -661,11 +660,10 @@ CreatingSets Expression: SharedQuery: ParallelAggregating, max_threads: 20, final: true - Expression x 20: - HashJoinProbe: - Expression: - Expression: - MockExchangeReceiver + HashJoinProbe x 20: + Expression: + Expression: + MockExchangeReceiver @ ~test_suite_name: ListBase ~result_index: 0 diff --git a/dbms/src/Flash/tests/gtest_spill_join.cpp b/dbms/src/Flash/tests/gtest_spill_join.cpp index c1b7c608ce9..75097500d91 100644 --- a/dbms/src/Flash/tests/gtest_spill_join.cpp +++ b/dbms/src/Flash/tests/gtest_spill_join.cpp @@ -162,6 +162,10 @@ try auto request = context.scan("simple_test", l) .join(context.scan("simple_test", r), join_type, {col(k)}) .build(context); + auto request_column_prune = context.scan("simple_test", l) + .join(context.scan("simple_test", r), join_type, {col(k)}) + .aggregation({Count(lit(static_cast(1)))}, {}) + .build(context); { context.context->setSetting("max_bytes_before_external_join", Field(static_cast(10000))); @@ -174,6 +178,9 @@ try << "join_type = " << magic_enum::enum_name(join_type) << ", simple_test_index = " << j << ", concurrency = " << concurrency; } + ASSERT_COLUMNS_EQ_UR( + genScalarCountResults(expected_cols[i * simple_test_num + j]), + executeStreams(request_column_prune, 2)); } } } @@ -197,6 +204,10 @@ try auto request = context.scan("split_test", "t1") .join(context.scan("split_test", "t2"), tipb::JoinType::TypeInnerJoin, {col("a")}) .build(context); + auto request_column_prune = context.scan("split_test", "t1") + .join(context.scan("split_test", "t2"), tipb::JoinType::TypeInnerJoin, {col("a")}) + .aggregation({Count(lit(static_cast(1)))}, {}) + .build(context); auto join_restore_concurrences = {-1, 0, 1, 5}; auto concurrences = {2, 5, 10}; @@ -218,6 +229,7 @@ try { ASSERT_COLUMNS_EQ_UR(expect, executeStreams(request, concurrency)); } + ASSERT_COLUMNS_EQ_UR(genScalarCountResults(expect), executeStreams(request_column_prune, 2)); } WRAP_FOR_SPILL_TEST_END } @@ -272,6 +284,7 @@ try context.scan("outer_join_test", right_table_name), tipb::JoinType::TypeRightOuterJoin, {col("a")}) + .project({fmt::format("{}.a", left_table_name), fmt::format("{}.b", right_table_name)}) .build(context); if (right_table_name == "right_table_1_concurrency") { @@ -280,7 +293,10 @@ try } else { - ASSERT_COLUMNS_EQ_UR(ref_columns, executeStreams(request, original_max_streams)) + ColumnsWithTypeAndName ref; + ref.push_back(ref_columns[0]); + ref.push_back(ref_columns[3]); + ASSERT_COLUMNS_EQ_UR(ref, executeStreams(request, original_max_streams)) << "left_table_name = " << left_table_name << ", right_table_name = " << right_table_name; } } @@ -290,11 +306,10 @@ try { for (size_t exchange_concurrency : right_exchange_receiver_concurrency) { + auto right_name = fmt::format("right_exchange_receiver_{}_concurrency", exchange_concurrency); request = context.scan("outer_join_test", left_table_name) .join( - context.receive( - fmt::format("right_exchange_receiver_{}_concurrency", exchange_concurrency), - exchange_concurrency), + context.receive(right_name, exchange_concurrency), tipb::JoinType::TypeRightOuterJoin, {col("a")}, {}, @@ -302,6 +317,7 @@ try {}, {}, exchange_concurrency) + .project({fmt::format("{}.b", left_table_name), fmt::format("{}.a", right_name)}) .build(context); if (exchange_concurrency == 1) { @@ -309,9 +325,12 @@ try } else { - ASSERT_COLUMNS_EQ_UR(ref_columns, executeStreams(request, original_max_streams)); + ColumnsWithTypeAndName ref; + ref.push_back(ref_columns[1]); + ref.push_back(ref_columns[2]); + ASSERT_COLUMNS_EQ_UR(ref, executeStreams(request, original_max_streams)); if (original_max_streams_small < exchange_concurrency) - ASSERT_COLUMNS_EQ_UR(ref_columns, executeStreams(request, original_max_streams_small)); + ASSERT_COLUMNS_EQ_UR(ref, executeStreams(request, original_max_streams_small)); } } } @@ -355,6 +374,7 @@ try {}, {}, 0) + .project({fmt::format("{}.a", left_table_name), fmt::format("{}.b", right_table_name)}) .build(context); if (right_table_name == "right_table_1_concurrency") { @@ -363,7 +383,10 @@ try } else { - ASSERT_COLUMNS_EQ_UR(ref_columns, executeStreams(request, original_max_streams)) + ColumnsWithTypeAndName ref; + ref.push_back(ref_columns[0]); + ref.push_back(ref_columns[3]); + ASSERT_COLUMNS_EQ_UR(ref, executeStreams(request, original_max_streams)) << "left_table_name = " << left_table_name << ", right_table_name = " << right_table_name; } } @@ -376,9 +399,7 @@ try String exchange_name = fmt::format("right_exchange_receiver_{}_concurrency", exchange_concurrency); request = context.scan("outer_join_test", left_table_name) .join( - context.receive( - fmt::format("right_exchange_receiver_{}_concurrency", exchange_concurrency), - exchange_concurrency), + context.receive(exchange_name, exchange_concurrency), tipb::JoinType::TypeRightOuterJoin, {col("a")}, {}, @@ -386,6 +407,7 @@ try {}, {}, exchange_concurrency) + .project({fmt::format("{}.b", left_table_name), fmt::format("{}.a", exchange_name)}) .build(context); if (exchange_concurrency == 1) { @@ -393,9 +415,12 @@ try } else { - ASSERT_COLUMNS_EQ_UR(ref_columns, executeStreams(request, original_max_streams)); + ColumnsWithTypeAndName ref; + ref.push_back(ref_columns[1]); + ref.push_back(ref_columns[2]); + ASSERT_COLUMNS_EQ_UR(ref, executeStreams(request, original_max_streams)); if (original_max_streams_small < exchange_concurrency) - ASSERT_COLUMNS_EQ_UR(ref_columns, executeStreams(request, original_max_streams_small)); + ASSERT_COLUMNS_EQ_UR(ref, executeStreams(request, original_max_streams_small)); } } } diff --git a/dbms/src/Interpreters/CrossJoinProbeHelper.cpp b/dbms/src/Interpreters/CrossJoinProbeHelper.cpp index c5166ee0b9b..73f5881a659 100644 --- a/dbms/src/Interpreters/CrossJoinProbeHelper.cpp +++ b/dbms/src/Interpreters/CrossJoinProbeHelper.cpp @@ -31,6 +31,7 @@ struct CrossJoinAdder size_t num_existing_columns, ColumnRawPtrs & src_left_columns, size_t num_columns_to_add, + const std::vector & right_column_index_in_right_block, size_t i, const Blocks & blocks, IColumn::Filter *, @@ -57,7 +58,8 @@ struct CrossJoinAdder for (size_t col_num = 0; col_num < num_columns_to_add; ++col_num) { - const IColumn * column_right = block_right.getByPosition(col_num).column.get(); + const IColumn * column_right + = block_right.getByPosition(right_column_index_in_right_block[col_num]).column.get(); dst_columns[num_existing_columns + col_num]->insertRangeFrom(*column_right, 0, rows_right); } expanded_row_size += rows_right; @@ -97,6 +99,7 @@ struct CrossJoinAdder size_t num_existing_columns, ColumnRawPtrs & src_left_columns, size_t num_columns_to_add, + const std::vector & right_column_index_in_right_block, size_t i, const Blocks & blocks, IColumn::Filter * is_row_matched, @@ -110,6 +113,7 @@ struct CrossJoinAdder num_existing_columns, src_left_columns, num_columns_to_add, + right_column_index_in_right_block, i, blocks, is_row_matched, @@ -151,6 +155,7 @@ struct CrossJoinAdder & /* right_column_index_in_right_block */, size_t /* i */, const Blocks & /* blocks */, IColumn::Filter * /* is_row_matched */, @@ -192,6 +197,7 @@ struct CrossJoinAdder & right_column_index_in_right_block, size_t i, const Blocks & blocks, IColumn::Filter * is_row_matched, @@ -205,6 +211,7 @@ struct CrossJoinAdder & right_column_index_in_right_block, size_t i, const Blocks & blocks, IColumn::Filter * is_row_matched, @@ -263,6 +271,7 @@ struct CrossJoinAdder & right_column_index_in_right_block, size_t i, const Blocks & blocks, IColumn::Filter * is_row_matched, @@ -307,6 +317,7 @@ struct CrossJoinAdder size_t num_existing_columns, ColumnRawPtrs & src_left_columns, size_t num_columns_to_add, + const std::vector & right_column_index_in_right_block, size_t i, const Blocks & blocks, IColumn::Filter * is_row_matched, @@ -359,6 +371,7 @@ struct CrossJoinAdder num_existing_columns, src_left_columns, num_columns_to_add - 1, + right_column_index_in_right_block, i, blocks, is_row_matched, @@ -400,28 +413,26 @@ struct CrossJoinAdder template Block crossProbeBlockDeepCopyRightBlockImpl(ProbeProcessInfo & probe_process_info, const Blocks & right_blocks) { - size_t num_existing_columns = probe_process_info.block.columns(); - size_t num_columns_to_add = probe_process_info.result_block_schema.columns() - num_existing_columns; + size_t num_existing_columns = probe_process_info.cross_join_data->left_column_index_in_left_block.size(); + size_t num_columns_to_add = probe_process_info.cross_join_data->right_column_index_in_right_block.size(); ColumnRawPtrs src_left_columns(num_existing_columns); for (size_t i = 0; i < num_existing_columns; ++i) { - src_left_columns[i] = probe_process_info.block.getByPosition(i).column.get(); + src_left_columns[i] = probe_process_info.block + .getByPosition(probe_process_info.cross_join_data->left_column_index_in_left_block[i]) + .column.get(); } - std::vector right_column_index; - for (size_t i = 0; i < num_columns_to_add; ++i) - right_column_index.push_back(num_existing_columns + i); - size_t current_row = probe_process_info.start_row; size_t block_rows = probe_process_info.block.rows(); - MutableColumns dst_columns(probe_process_info.result_block_schema.columns()); + MutableColumns dst_columns(probe_process_info.cross_join_data->result_block_schema.columns()); size_t reserved_rows = std::min( - (block_rows - current_row) * probe_process_info.right_rows_to_be_added_when_matched, + (block_rows - current_row) * probe_process_info.cross_join_data->right_rows_to_be_added_when_matched, probe_process_info.max_block_size); - for (size_t i = 0; i < probe_process_info.result_block_schema.columns(); ++i) + for (size_t i = 0; i < probe_process_info.cross_join_data->result_block_schema.columns(); ++i) { - dst_columns[i] = probe_process_info.result_block_schema.getByPosition(i).column->cloneEmpty(); + dst_columns[i] = probe_process_info.cross_join_data->result_block_schema.getByPosition(i).column->cloneEmpty(); if likely (reserved_rows > 0) dst_columns[i]->reserve(reserved_rows); } @@ -453,19 +464,20 @@ Block crossProbeBlockDeepCopyRightBlockImpl(ProbeProcessInfo & probe_process_inf continue; } } - if (probe_process_info.right_rows_to_be_added_when_matched > 0) + if (probe_process_info.cross_join_data->right_rows_to_be_added_when_matched > 0) { block_full = CrossJoinAdder::addFound( dst_columns, num_existing_columns, src_left_columns, num_columns_to_add, + probe_process_info.cross_join_data->right_column_index_in_right_block, current_row, right_blocks, filter_ptr, current_offset, offset_ptr, - probe_process_info.right_rows_to_be_added_when_matched, + probe_process_info.cross_join_data->right_rows_to_be_added_when_matched, probe_process_info.max_block_size); } else @@ -486,27 +498,29 @@ Block crossProbeBlockDeepCopyRightBlockImpl(ProbeProcessInfo & probe_process_inf break; } probe_process_info.updateEndRow(current_row); - return probe_process_info.result_block_schema.cloneWithColumns(std::move(dst_columns)); + return probe_process_info.cross_join_data->result_block_schema.cloneWithColumns(std::move(dst_columns)); } template std::pair crossProbeBlockShallowCopyRightBlockAddNotMatchedRows(ProbeProcessInfo & probe_process_info) { - size_t num_existing_columns = probe_process_info.block.columns(); - MutableColumns dst_columns = probe_process_info.result_block_schema.cloneEmptyColumns(); - if (probe_process_info.row_num_filtered_by_left_condition > 0) + size_t num_existing_columns = probe_process_info.cross_join_data->left_column_index_in_left_block.size(); + MutableColumns dst_columns = probe_process_info.cross_join_data->result_block_schema.cloneEmptyColumns(); + if (probe_process_info.cross_join_data->row_num_filtered_by_left_condition > 0) { for (auto & dst_column : dst_columns) - dst_column->reserve(probe_process_info.row_num_filtered_by_left_condition); + dst_column->reserve(probe_process_info.cross_join_data->row_num_filtered_by_left_condition); } auto * filter_ptr = probe_process_info.filter.get(); auto * offset_ptr = probe_process_info.offsets_to_replicate.get(); IColumn::Offset current_offset = 0; - size_t num_columns_to_add = probe_process_info.result_block_schema.columns() - probe_process_info.block.columns(); + size_t num_columns_to_add = probe_process_info.cross_join_data->right_column_index_in_right_block.size(); ColumnRawPtrs src_left_columns(num_existing_columns); for (size_t i = 0; i < num_existing_columns; ++i) { - src_left_columns[i] = probe_process_info.block.getByPosition(i).column.get(); + src_left_columns[i] = probe_process_info.block + .getByPosition(probe_process_info.cross_join_data->left_column_index_in_left_block[i]) + .column.get(); } IColumn::Filter::value_type filter_column_value{}; if constexpr (has_null_map) @@ -534,21 +548,21 @@ std::pair crossProbeBlockShallowCopyRightBlockAddNotMatchedRows(Pro /// construct fill filter and offset column if (!dst_columns[0]->empty()) { - assert(dst_columns[0]->size() == probe_process_info.row_num_filtered_by_left_condition); + assert(dst_columns[0]->size() == probe_process_info.cross_join_data->row_num_filtered_by_left_condition); if (filter_ptr != nullptr) { - for (size_t i = 0; i < probe_process_info.row_num_filtered_by_left_condition; ++i) + for (size_t i = 0; i < probe_process_info.cross_join_data->row_num_filtered_by_left_condition; ++i) { (*filter_ptr)[i] = filter_column_value; } } - for (size_t i = 0; i < probe_process_info.row_num_filtered_by_left_condition; ++i) + for (size_t i = 0; i < probe_process_info.cross_join_data->row_num_filtered_by_left_condition; ++i) { (*offset_ptr)[i] = i + 1; } } probe_process_info.all_rows_joined_finish = true; - return {probe_process_info.result_block_schema.cloneWithColumns(std::move(dst_columns)), false}; + return {probe_process_info.cross_join_data->result_block_schema.cloneWithColumns(std::move(dst_columns)), false}; } template @@ -559,7 +573,7 @@ std::pair crossProbeBlockShallowCopyRightBlockImpl( static_assert(KIND != ASTTableJoin::Kind::Cross_LeftOuterAnti); assert(probe_process_info.offsets_to_replicate != nullptr); - size_t num_existing_columns = probe_process_info.block.columns(); + size_t num_existing_columns = probe_process_info.cross_join_data->left_column_index_in_left_block.size(); if constexpr (has_null_map) { /// skip filtered rows, the filtered rows will be handled at the end of this block @@ -576,25 +590,33 @@ std::pair crossProbeBlockShallowCopyRightBlockImpl( return crossProbeBlockShallowCopyRightBlockAddNotMatchedRows( probe_process_info); } - assert(probe_process_info.next_right_block_index < right_blocks.size()); + assert(probe_process_info.cross_join_data->next_right_block_index < right_blocks.size()); - Block right_block = right_blocks[probe_process_info.next_right_block_index]; + Block right_block = right_blocks[probe_process_info.cross_join_data->next_right_block_index]; size_t right_row = right_block.rows(); assert(right_row > 0); - Block block = probe_process_info.result_block_schema.cloneEmpty(); + Block block = probe_process_info.cross_join_data->result_block_schema.cloneEmpty(); for (size_t i = 0; i < num_existing_columns; ++i) { /// left columns - assert(block.getByPosition(i).column != nullptr); + auto left_column_index = probe_process_info.cross_join_data->left_column_index_in_left_block[i]; + assert(probe_process_info.block.getByPosition(left_column_index).column != nullptr); Field value; - probe_process_info.block.getByPosition(i).column->get(probe_process_info.start_row, value); + probe_process_info.block.getByPosition(left_column_index).column->get(probe_process_info.start_row, value); block.getByPosition(i).column = block.getByPosition(i).type->createColumnConst(right_row, value); } - for (size_t i = 0; i < right_block.columns(); i++) + auto right_column_num = probe_process_info.cross_join_data->right_column_index_in_right_block.size(); + if constexpr (KIND == ASTTableJoin::Kind::Cross_LeftOuterSemi) + { + --right_column_num; + } + for (size_t i = 0; i < right_column_num; ++i) { /// right columns - block.getByPosition(i + num_existing_columns).column = right_block.getByPosition(i).column; + block.getByPosition(i + num_existing_columns).column + = right_block.getByPosition(probe_process_info.cross_join_data->right_column_index_in_right_block[i]) + .column; } if constexpr (KIND == ASTTableJoin::Kind::Cross_LeftOuterSemi) { @@ -609,7 +631,7 @@ std::pair crossProbeBlockShallowCopyRightBlockImpl( (*probe_process_info.filter)[0] = 1; } (*probe_process_info.offsets_to_replicate)[0] = right_row; - probe_process_info.next_right_block_index++; + probe_process_info.cross_join_data->next_right_block_index++; probe_process_info.updateEndRow(probe_process_info.start_row + 1); return {block, true}; } diff --git a/dbms/src/Interpreters/ExpressionActions.cpp b/dbms/src/Interpreters/ExpressionActions.cpp index 840900afe11..5831f99646c 100644 --- a/dbms/src/Interpreters/ExpressionActions.cpp +++ b/dbms/src/Interpreters/ExpressionActions.cpp @@ -539,7 +539,8 @@ void ExpressionActions::execute(Block & block) const action.execute(block); } -std::string ExpressionActions::getSmallestColumn(const NamesAndTypesList & columns) +template +std::string ExpressionActions::getSmallestColumn(const NameAndTypeContainer & columns) { std::optional min_size; String res; @@ -562,7 +563,7 @@ std::string ExpressionActions::getSmallestColumn(const NamesAndTypesList & colum return res; } -void ExpressionActions::finalize(const Names & output_columns) +void ExpressionActions::finalize(const Names & output_columns, bool keep_used_input_columns) { NameSet final_columns; for (const auto & name : output_columns) @@ -674,6 +675,15 @@ void ExpressionActions::finalize(const Names & output_columns) /// If the column after performing the function `refcount = 0`, it can be deleted. std::map columns_refcount; + NameSet columns_should_not_be_removed; + if (keep_used_input_columns) + { + /// if keep_used_input_columns is true, then don't remove the input_columns + /// this is used in nullaware/semi join which intends to reuse the input column + for (const auto & column : input_columns) + columns_should_not_be_removed.insert(column.name); + } + for (const auto & name : final_columns) ++columns_refcount[name]; @@ -696,21 +706,22 @@ void ExpressionActions::finalize(const Names & output_columns) { new_actions.push_back(action); - auto process = [&](const String & name) { + auto process = [&](const String & name, const ExpressionAction::Type & type) { auto refcount = --columns_refcount[name]; - if (refcount <= 0) + if (refcount <= 0 && columns_should_not_be_removed.count(name) == 0) { - new_actions.push_back(ExpressionAction::removeColumn(name)); + if (type != ExpressionAction::REMOVE_COLUMN) + new_actions.push_back(ExpressionAction::removeColumn(name)); if (sample_block.has(name)) sample_block.erase(name); } }; if (!action.source_name.empty()) - process(action.source_name); + process(action.source_name, action.type); for (const auto & name : action.argument_names) - process(name); + process(name, action.type); /// For `projection`, there is no reduction in `refcount`, because the `project` action replaces the names of the columns, in effect, already deleting them under the old names. } @@ -791,4 +802,7 @@ std::string ExpressionActionsChain::dumpChain() return ss.str(); } +template std::string ExpressionActions::getSmallestColumn(const NamesAndTypesList & columns); +template std::string ExpressionActions::getSmallestColumn(const NamesAndTypes & columns); + } // namespace DB diff --git a/dbms/src/Interpreters/ExpressionActions.h b/dbms/src/Interpreters/ExpressionActions.h index 585c239bdaa..552aec13041 100644 --- a/dbms/src/Interpreters/ExpressionActions.h +++ b/dbms/src/Interpreters/ExpressionActions.h @@ -185,7 +185,7 @@ class ExpressionActions /// - Does not reorder the columns. /// - Does not remove "unexpected" columns (for example, added by functions). /// - If output_columns is empty, leaves one arbitrary column (so that the number of rows in the block is not lost). - void finalize(const Names & output_columns); + void finalize(const Names & output_columns, bool keep_used_input_columns = false); const Actions & getActions() const { return actions; } @@ -208,7 +208,8 @@ class ExpressionActions std::string dumpActions() const; - static std::string getSmallestColumn(const NamesAndTypesList & columns); + template + static std::string getSmallestColumn(const NameAndTypeContainer & columns); private: NamesAndTypesList input_columns; diff --git a/dbms/src/Interpreters/Join.cpp b/dbms/src/Interpreters/Join.cpp index a97c68c9f57..d4b84908f8e 100644 --- a/dbms/src/Interpreters/Join.cpp +++ b/dbms/src/Interpreters/Join.cpp @@ -117,7 +117,7 @@ Join::Join( const SpillConfig & build_spill_config, const SpillConfig & probe_spill_config, const RestoreConfig & restore_config_, - const Names & tidb_output_column_names_, + const NamesAndTypes & output_columns_, const RegisterOperatorSpillContext & register_operator_spill_context_, AutoSpillTrigger * auto_spill_trigger_, const TiDB::TiDBCollators & collators_, @@ -151,7 +151,7 @@ Join::Join( shallow_copy_cross_probe_threshold_ > 0 ? shallow_copy_cross_probe_threshold_ : std::max(1, max_block_size / 10)) , probe_cache_column_threshold(probe_cache_column_threshold_) - , tidb_output_column_names(tidb_output_column_names_) + , output_columns(output_columns_) , is_test(is_test_) , log(Logger::get( restore_config.restore_round == 0 ? join_req_id @@ -195,6 +195,7 @@ Join::Join( LOG_WARNING(log, fmt::format("restore round reach to {}, spilling will be disabled.", max_restore_round)); hash_join_spill_context->disableSpill(); } + output_block = Block(output_columns); LOG_DEBUG( log, @@ -318,27 +319,27 @@ void Join::setBuildConcurrencyAndInitJoinPartition(size_t build_concurrency_) void Join::setSampleBlock(const Block & block) { - sample_block_with_columns_to_add = materializeBlock(block); + sample_block_without_keys = materializeBlock(block); - /// Move from `sample_block_with_columns_to_add` key columns to `sample_block_with_keys`, keeping the order. + /// Move from `sample_block_without_keys` key columns to `sample_block_only_keys`, keeping the order. size_t pos = 0; - while (pos < sample_block_with_columns_to_add.columns()) + while (pos < sample_block_without_keys.columns()) { - const auto & name = sample_block_with_columns_to_add.getByPosition(pos).name; + const auto & name = sample_block_without_keys.getByPosition(pos).name; if (key_names_right.end() != std::find(key_names_right.begin(), key_names_right.end(), name)) { - sample_block_with_keys.insert(sample_block_with_columns_to_add.getByPosition(pos)); - sample_block_with_columns_to_add.erase(pos); + sample_block_only_keys.insert(sample_block_without_keys.getByPosition(pos)); + sample_block_without_keys.erase(pos); } else ++pos; } - size_t num_columns_to_add = sample_block_with_columns_to_add.columns(); + size_t num_columns_to_add = sample_block_without_keys.columns(); for (size_t i = 0; i < num_columns_to_add; ++i) { - auto & column = sample_block_with_columns_to_add.getByPosition(i); + auto & column = sample_block_without_keys.getByPosition(i); if (!column.column) column.column = column.type->createColumn(); } @@ -346,15 +347,15 @@ void Join::setSampleBlock(const Block & block) /// In case of LEFT and FULL joins, if use_nulls, convert joined columns to Nullable. if (isLeftOuterJoin(kind) || kind == ASTTableJoin::Kind::Full) for (size_t i = 0; i < num_columns_to_add; ++i) - convertColumnToNullable(sample_block_with_columns_to_add.getByPosition(i)); + convertColumnToNullable(sample_block_without_keys.getByPosition(i)); if (isLeftOuterSemiFamily(kind)) - sample_block_with_columns_to_add.insert(ColumnWithTypeAndName(Join::match_helper_type, match_helper_name)); + sample_block_without_keys.insert(ColumnWithTypeAndName(Join::match_helper_type, match_helper_name)); } std::shared_ptr Join::createRestoreJoin(size_t max_bytes_before_external_join_, size_t restore_partition_id) { - return std::make_shared( + auto ret = std::make_shared( key_names_left, key_names_right, kind, @@ -367,7 +368,7 @@ std::shared_ptr Join::createRestoreJoin(size_t max_bytes_before_external_j hash_join_spill_context->createProbeSpillConfig( fmt::format("{}_{}_probe", join_req_id, restore_config.restore_round + 1)), RestoreConfig{restore_config.join_restore_concurrency, restore_config.restore_round + 1, restore_partition_id}, - tidb_output_column_names, + output_columns, register_operator_spill_context, auto_spill_trigger, collators, @@ -378,6 +379,15 @@ std::shared_ptr Join::createRestoreJoin(size_t max_bytes_before_external_j flag_mapped_entry_helper_name, probe_cache_column_threshold, is_test); + /// init output names after finalize, the restored join don't need to finalize + ret->output_columns_after_finalize = output_columns_after_finalize; + ret->output_column_names_set_after_finalize = output_column_names_set_after_finalize; + ret->output_columns_names_set_for_other_condition_after_finalize + = output_columns_names_set_for_other_condition_after_finalize; + ret->required_columns = required_columns; + ret->output_block_after_finalize = output_block_after_finalize; + ret->finalized = true; + return ret; } void Join::initBuild(const Block & sample_block, size_t build_concurrency_) @@ -620,6 +630,13 @@ bool Join::isRestoreJoin() const void Join::insertFromBlockInternal(Block * stored_block, size_t stream_index) { size_t keys_size = key_names_right.size(); + Block key_block; + if (!runtime_filter_list.empty()) + { + /// save the key column for runtime filter + for (const auto & name : key_names_right) + key_block.insert(stored_block->getByName(name)); + } const Block & block = *stored_block; @@ -732,7 +749,8 @@ void Join::insertFromBlockInternal(Block * stored_block, size_t stream_index) } // generator in runtime filter - generateRuntimeFilterValues(block); + if (!runtime_filter_list.empty()) + generateRuntimeFilterValues(key_block); } void Join::generateRuntimeFilterValues(const Block & block) @@ -835,22 +853,39 @@ void Join::handleOtherConditions( IColumn::Offsets * offsets_to_replicate, const std::vector & right_table_columns) const { + /// save block_rows because block.rows() returns the first column's size, after other_cond_expr->execute(block), + /// some column maybe removed, and the first column maybe the match_helper_column which does not have the same size + /// as the other columns + auto block_rows = block.rows(); non_equal_conditions.other_cond_expr->execute(block); auto filter_column = ColumnUInt8::create(); auto & filter = filter_column->getData(); mergeNullAndFilterResult(block, filter, non_equal_conditions.other_cond_name, false); - ColumnUInt8::Container row_filter(block.rows(), 0); + ColumnUInt8::Container row_filter(block_rows, 0); + + auto erase_useless_column = [&](Block & input_block) { + for (size_t i = 0; i < input_block.columns();) + { + auto & col_name = input_block.getByPosition(i).name; + if ((!flag_mapped_entry_helper_name.empty() && col_name == flag_mapped_entry_helper_name) + || output_column_names_set_after_finalize.find(col_name) + != output_column_names_set_after_finalize.end()) + ++i; + else + input_block.erase(i); + } + }; if (isLeftOuterSemiFamily(kind)) { - if (filter.size() != block.rows()) + if (filter.size() != block_rows) { assert(filter.empty()); - filter.assign(block.rows(), static_cast(1)); + filter.assign(block_rows, static_cast(1)); } - const auto helper_pos = block.getPositionByName(match_helper_name); + auto helper_pos = block.getPositionByName(match_helper_name); const auto * old_match_nullable = checkAndGetColumn(block.safeGetByPosition(helper_pos).column.get()); @@ -923,6 +958,8 @@ void Join::handleOtherConditions( match_nullmap_vec[i] = 1; } + erase_useless_column(block); + helper_pos = block.getPositionByName(match_helper_name); for (size_t i = 0; i < block.columns(); ++i) if (i != helper_pos) block.getByPosition(i).column = block.getByPosition(i).column->filter(row_filter, -1); @@ -936,10 +973,11 @@ void Join::handleOtherConditions( /// otherwise, it will check other_eq_filter_from_in_column, if other_eq_filter_from_in_column return false, this row should /// be returned, if other_eq_filter_from_in_column return true or null this row should not be returned. mergeNullAndFilterResult(block, filter, non_equal_conditions.other_eq_cond_from_in_name, isAntiJoin(kind)); - assert(block.rows() == filter.size()); + assert(block_rows == filter.size()); if (isInnerJoin(kind) || isNecessaryKindToUseRowFlaggedHashMap(kind)) { + erase_useless_column(block); /// inner | rightSemi | rightAnti | rightOuter join, just use other_filter_column to filter result for (size_t i = 0; i < block.columns(); ++i) block.safeGetByPosition(i).column = block.safeGetByPosition(i).column->filter(filter, -1); @@ -1002,12 +1040,14 @@ void Join::handleOtherConditions( static_cast(*result_column).applyNegatedNullMap(*filter_column); column.column = std::move(result_column); } + erase_useless_column(block); for (size_t i = 0; i < block.columns(); ++i) block.getByPosition(i).column = block.getByPosition(i).column->filter(row_filter, -1); return; } if (is_semi_family) { + erase_useless_column(block); /// for semi/anti join, filter out not matched rows for (size_t i = 0; i < block.columns(); ++i) block.getByPosition(i).column = block.getByPosition(i).column->filter(row_filter, -1); @@ -1020,13 +1060,30 @@ void Join::handleOtherConditions( void Join::handleOtherConditionsForOneProbeRow(Block & block, ProbeProcessInfo & probe_process_info) const { assert(kind != ASTTableJoin::Kind::Cross_RightOuter); + /// save block_rows because block.rows() returns the first column's size, after other_cond_expr->execute(block), + /// some column maybe removed, and the first column maybe the match_helper_column which does not have the same size + /// as the other columns + auto block_rows = block.rows(); /// inside this function, we can ensure that /// 1. probe_process_info.offsets_to_replicate.size() == 1 - /// 2. probe_process_info.offsets_to_replicate[0] == block.rows() + /// 2. probe_process_info.offsets_to_replicate[0] == block_rows /// 3. for anti semi join: probe_process_info.filter[0] == 1 /// 4. for left outer semi join: match_helper_column[0] == 1 assert(probe_process_info.offsets_to_replicate->size() == 1); - assert((*probe_process_info.offsets_to_replicate)[0] == block.rows()); + assert((*probe_process_info.offsets_to_replicate)[0] == block_rows); + + auto erase_useless_column = [&](Block & input_block) { + for (size_t i = 0; i < input_block.columns();) + { + auto & col_name = input_block.getByPosition(i).name; + if ((!flag_mapped_entry_helper_name.empty() && col_name == flag_mapped_entry_helper_name) + || output_column_names_set_after_finalize.find(col_name) + != output_column_names_set_after_finalize.end()) + ++i; + else + input_block.erase(i); + } + }; non_equal_conditions.other_cond_expr->execute(block); auto filter_column = ColumnUInt8::create(); @@ -1035,23 +1092,23 @@ void Join::handleOtherConditionsForOneProbeRow(Block & block, ProbeProcessInfo & UInt64 matched_row_count_in_current_block = 0; if (isLeftOuterSemiFamily(kind) && !non_equal_conditions.other_eq_cond_from_in_name.empty()) { - if (filter.size() != block.rows()) + if (filter.size() != block_rows) { assert(filter.empty()); - filter.assign(block.rows(), static_cast(1)); + filter.assign(block_rows, static_cast(1)); } - assert(probe_process_info.has_row_matched == false); + assert(probe_process_info.cross_join_data->has_row_matched == false); ColumnPtr eq_in_column = block.getByName(non_equal_conditions.other_eq_cond_from_in_name).column; auto [eq_in_vec, eq_in_nullmap] = getDataAndNullMapVectorFromFilterColumn(eq_in_column); - for (size_t i = 0; i < block.rows(); ++i) + for (size_t i = 0; i < block_rows; ++i) { if (!filter[i]) continue; if (eq_in_nullmap && (*eq_in_nullmap)[i]) - probe_process_info.has_row_null = true; + probe_process_info.cross_join_data->has_row_null = true; else if ((*eq_in_vec)[i]) { - probe_process_info.has_row_matched = true; + probe_process_info.cross_join_data->has_row_matched = true; break; } } @@ -1059,13 +1116,14 @@ void Join::handleOtherConditionsForOneProbeRow(Block & block, ProbeProcessInfo & else { mergeNullAndFilterResult(block, filter, non_equal_conditions.other_eq_cond_from_in_name, isAntiJoin(kind)); - assert(filter.size() == block.rows()); + assert(filter.size() == block_rows); matched_row_count_in_current_block = countBytesInFilter(filter); - probe_process_info.has_row_matched |= matched_row_count_in_current_block != 0; + probe_process_info.cross_join_data->has_row_matched |= matched_row_count_in_current_block != 0; } /// case 1, inner join if (kind == ASTTableJoin::Kind::Cross) { + erase_useless_column(block); if (matched_row_count_in_current_block > 0) { for (size_t i = 0; i < block.columns(); ++i) @@ -1083,17 +1141,18 @@ void Join::handleOtherConditionsForOneProbeRow(Block & block, ProbeProcessInfo & { if (matched_row_count_in_current_block > 0) { + erase_useless_column(block); for (size_t i = 0; i < block.columns(); ++i) block.safeGetByPosition(i).column = block.safeGetByPosition(i).column->filter(filter, matched_row_count_in_current_block); } - else if (probe_process_info.isCurrentProbeRowFinished() && !probe_process_info.has_row_matched) + else if (probe_process_info.isCurrentProbeRowFinished() && !probe_process_info.cross_join_data->has_row_matched) { /// no matched rows for current row, return the un-matched result for (size_t i = 0; i < block.columns(); ++i) block.getByPosition(i).column = block.getByPosition(i).column->cut(0, 1); filter.resize(1); - for (size_t right_table_column : probe_process_info.right_column_index) + for (size_t right_table_column : probe_process_info.cross_join_data->right_column_index_in_result_block) { auto & column = block.getByPosition(right_table_column); auto full_column @@ -1107,15 +1166,20 @@ void Join::handleOtherConditionsForOneProbeRow(Block & block, ProbeProcessInfo & static_cast(*result_column).applyNegatedNullMap(*filter_column); column.column = std::move(result_column); } + erase_useless_column(block); } else + { + erase_useless_column(block); block = block.cloneEmpty(); + } return; } /// case 3, semi join if (kind == ASTTableJoin::Kind::Cross_Semi) { - if (probe_process_info.has_row_matched) + erase_useless_column(block); + if (probe_process_info.cross_join_data->has_row_matched) { /// has matched rows, return the first row, and set the current row probe done for (size_t i = 0; i < block.columns(); ++i) @@ -1132,7 +1196,8 @@ void Join::handleOtherConditionsForOneProbeRow(Block & block, ProbeProcessInfo & /// case 4, anti join if (kind == ASTTableJoin::Kind::Cross_Anti) { - if (probe_process_info.has_row_matched) + erase_useless_column(block); + if (probe_process_info.cross_join_data->has_row_matched) { block = block.cloneEmpty(); probe_process_info.finishCurrentProbeRow(); @@ -1151,7 +1216,8 @@ void Join::handleOtherConditionsForOneProbeRow(Block & block, ProbeProcessInfo & /// case 5, left outer semi join if (isLeftOuterSemiFamily(kind)) { - if (probe_process_info.has_row_matched || probe_process_info.isCurrentProbeRowFinished()) + erase_useless_column(block); + if (probe_process_info.cross_join_data->has_row_matched || probe_process_info.isCurrentProbeRowFinished()) { for (size_t i = 0; i < block.columns(); ++i) block.getByPosition(i).column = block.getByPosition(i).column->cut(0, 1); @@ -1159,9 +1225,9 @@ void Join::handleOtherConditionsForOneProbeRow(Block & block, ProbeProcessInfo & auto & match_vec = match_col->getData(); auto match_nullmap = ColumnUInt8::create(1, 0); auto & match_nullmap_vec = match_nullmap->getData(); - if (probe_process_info.has_row_matched) + if (probe_process_info.cross_join_data->has_row_matched) match_vec[0] = 1; - else if (probe_process_info.has_row_null) + else if (probe_process_info.cross_join_data->has_row_null) match_nullmap_vec[0] = 1; block.getByName(match_helper_name).column = ColumnNullable::create(std::move(match_col), std::move(match_nullmap)); @@ -1182,6 +1248,17 @@ Block Join::doJoinBlockHash(ProbeProcessInfo & probe_process_info, const JoinBui probe_process_info.updateStartRow(); /// this makes a copy of `probe_process_info.block` Block block = probe_process_info.block; + const NameSet & probe_output_name_set = has_other_condition + ? output_columns_names_set_for_other_condition_after_finalize + : output_column_names_set_after_finalize; + for (size_t pos = 0; pos < block.columns();) + { + if (probe_output_name_set.find(block.getByPosition(pos).name) == probe_output_name_set.end()) + block.erase(pos); + else + ++pos; + } + size_t keys_size = key_names_left.size(); size_t existing_columns = block.columns(); @@ -1194,29 +1271,34 @@ Block Join::doJoinBlockHash(ProbeProcessInfo & probe_process_info, const JoinBui num_columns_to_skip = keys_size; /// Add new columns to the block. - size_t num_columns_to_add = sample_block_with_columns_to_add.columns(); + std::vector num_columns_to_add; + for (size_t i = 0; i < sample_block_without_keys.columns(); ++i) + { + if (probe_output_name_set.find(sample_block_without_keys.getByPosition(i).name) != probe_output_name_set.end()) + num_columns_to_add.push_back(i); + } std::vector right_table_column_indexes; - right_table_column_indexes.reserve(num_columns_to_add); + right_table_column_indexes.reserve(num_columns_to_add.size()); - for (size_t i = 0; i < num_columns_to_add; ++i) + for (size_t i = 0; i < num_columns_to_add.size(); ++i) { right_table_column_indexes.push_back(i + existing_columns); } MutableColumns added_columns; - added_columns.reserve(num_columns_to_add); + added_columns.reserve(num_columns_to_add.size()); std::vector right_indexes; - right_indexes.reserve(num_columns_to_add); + right_indexes.reserve(num_columns_to_add.size()); - size_t rows = block.rows(); - for (size_t i = 0; i < num_columns_to_add; ++i) + size_t rows = probe_process_info.block.rows(); + for (const auto & index : num_columns_to_add) { - const ColumnWithTypeAndName & src_column = sample_block_with_columns_to_add.getByPosition(i); + const ColumnWithTypeAndName & src_column = sample_block_without_keys.getByPosition(index); RUNTIME_CHECK_MSG( !block.has(src_column.name), - "block from probe side has a column with the same name: {} as a column in sample_block_with_columns_to_add", + "block from probe side has a column with the same name: {} as a column in sample_block_without_keys", src_column.name); added_columns.push_back(src_column.column->cloneEmpty()); @@ -1225,7 +1307,7 @@ Block Join::doJoinBlockHash(ProbeProcessInfo & probe_process_info, const JoinBui // todo figure out more accurate `rows` added_columns.back()->reserve(rows); } - right_indexes.push_back(num_columns_to_skip + i); + right_indexes.push_back(num_columns_to_skip + index); } bool use_row_flagged_hash_map = useRowFlaggedHashMap(kind, has_other_condition); @@ -1244,7 +1326,7 @@ Block Join::doJoinBlockHash(ProbeProcessInfo & probe_process_info, const JoinBui JoinPartition::probeBlock( partitions, rows, - probe_process_info.key_columns, + probe_process_info.hash_join_data->key_columns, key_sizes, added_columns, probe_process_info.null_map, @@ -1258,17 +1340,17 @@ Block Join::doJoinBlockHash(ProbeProcessInfo & probe_process_info, const JoinBui /// For RIGHT_SEMI/RIGHT_ANTI join without other conditions, hash table has been marked already, just return empty build table header if (isRightSemiFamily(kind) && !use_row_flagged_hash_map) { - return sample_block_with_columns_to_add; + return sample_block_without_keys; } - for (size_t i = 0; i < num_columns_to_add; ++i) + for (size_t index = 0; index < num_columns_to_add.size(); ++index) { - const ColumnWithTypeAndName & sample_col = sample_block_with_columns_to_add.getByPosition(i); - block.insert(ColumnWithTypeAndName(std::move(added_columns[i]), sample_col.type, sample_col.name)); + const ColumnWithTypeAndName & sample_col = sample_block_without_keys.getByPosition(num_columns_to_add[index]); + block.insert(ColumnWithTypeAndName(std::move(added_columns[index]), sample_col.type, sample_col.name)); } if (use_row_flagged_hash_map) block.insert(ColumnWithTypeAndName( - std::move(added_columns[num_columns_to_add]), + std::move(added_columns[num_columns_to_add.size()]), flag_mapped_entry_helper_type, flag_mapped_entry_helper_name)); @@ -1319,19 +1401,11 @@ Block Join::doJoinBlockHash(ProbeProcessInfo & probe_process_info, const JoinBui if (isRightSemiFamily(kind)) { // Return build table header for right semi/anti join - block = sample_block_with_columns_to_add; + block = sample_block_without_keys; } else if (kind == ASTTableJoin::Kind::RightOuter) { block.erase(flag_mapped_entry_helper_name); - if (!non_equal_conditions.other_cond_name.empty()) - { - block.erase(non_equal_conditions.other_cond_name); - } - if (!non_equal_conditions.other_eq_cond_from_in_name.empty()) - { - block.erase(non_equal_conditions.other_eq_cond_from_in_name); - } } } } @@ -1342,9 +1416,9 @@ Block Join::doJoinBlockHash(ProbeProcessInfo & probe_process_info, const JoinBui Block Join::removeUselessColumn(Block & block) const { Block projected_block; - for (const auto & name : tidb_output_column_names) + for (const auto & name_and_type : output_columns_after_finalize) { - auto & column = block.getByName(name); + auto & column = block.getByName(name_and_type.name); projected_block.insert(std::move(column)); } return projected_block; @@ -1405,7 +1479,7 @@ Block Join::doJoinBlockCross(ProbeProcessInfo & probe_process_info) const block, probe_process_info.filter.get(), probe_process_info.offsets_to_replicate.get(), - probe_process_info.right_column_index); + probe_process_info.cross_join_data->right_column_index_in_result_block); } return block; } @@ -1423,18 +1497,23 @@ Block Join::doJoinBlockCross(ProbeProcessInfo & probe_process_info) const /// state is saved in `probe_process_info` handleOtherConditionsForOneProbeRow(block, probe_process_info); } - for (size_t i = 0; i < probe_process_info.block.columns(); ++i) + for (size_t i = 0; i < probe_process_info.cross_join_data->left_column_index_in_left_block.size(); ++i) { - if (block.getByPosition(i).column->isColumnConst()) - block.getByPosition(i).column = block.getByPosition(i).column->convertToFullColumnIfConst(); + auto & name = probe_process_info.block + .getByPosition(probe_process_info.cross_join_data->left_column_index_in_left_block[i]) + .name; + if (block.has(name)) + { + auto & column_and_name = block.getByName(name); + if (column_and_name.column->isColumnConst()) + column_and_name.column = column_and_name.column->convertToFullColumnIfConst(); + } } if (isLeftOuterSemiFamily(kind)) { - auto helper_index - = probe_process_info.block.columns() + probe_process_info.right_column_index.size() - 1; - if (block.getByPosition(helper_index).column->isColumnConst()) - block.getByPosition(helper_index).column - = block.getByPosition(helper_index).column->convertToFullColumnIfConst(); + auto & help_column = block.getByName(match_helper_name); + if (help_column.column->isColumnConst()) + help_column.column = help_column.column->convertToFullColumnIfConst(); } } else if (non_equal_conditions.other_cond_expr != nullptr) @@ -1444,7 +1523,7 @@ Block Join::doJoinBlockCross(ProbeProcessInfo & probe_process_info) const block, probe_process_info.filter.get(), probe_process_info.offsets_to_replicate.get(), - probe_process_info.right_column_index); + probe_process_info.cross_join_data->right_column_index_in_result_block); } return block; } @@ -1460,7 +1539,9 @@ Block Join::joinBlockCross(ProbeProcessInfo & probe_process_info) const non_equal_conditions.left_filter_column, kind, strictness, - sample_block_with_columns_to_add, + sample_block_without_keys, + has_other_condition ? output_columns_names_set_for_other_condition_after_finalize + : output_column_names_set_after_finalize, right_rows_to_be_added_when_matched_for_cross_join, cross_probe_mode, blocks.size()); @@ -1486,34 +1567,16 @@ Block Join::joinBlockCross(ProbeProcessInfo & probe_process_info) const void Join::checkTypes(const Block & block) const { - checkTypesOfKeys(block, sample_block_with_keys); + checkTypesOfKeys(block, sample_block_only_keys); } Block Join::joinBlockNullAwareSemi(ProbeProcessInfo & probe_process_info) const { - Block block = probe_process_info.block; - - /// Rare case, when keys are constant. To avoid code bloat, simply materialize them. - /// Note: this variable can't be removed because it will take smart pointers' lifecycle to the end of this function. - Columns materialized_columns; - ColumnRawPtrs key_columns = extractAndMaterializeKeyColumns(block, materialized_columns, key_names_left); - - /// Note that `extractAllKeyNullMap` must be done before `extractNestedColumnsAndNullMap` - /// because `extractNestedColumnsAndNullMap` will change the nullable column to its nested column. - ColumnPtr all_key_null_map_holder; - ConstNullMapPtr all_key_null_map{}; - extractAllKeyNullMap(key_columns, all_key_null_map_holder, all_key_null_map); + probe_process_info.prepareForNullAware(key_names_left, non_equal_conditions.left_filter_column); - ColumnPtr null_map_holder; - ConstNullMapPtr null_map{}; - extractNestedColumnsAndNullMap(key_columns, null_map_holder, null_map); - - ColumnPtr filter_map_holder; - ConstNullMapPtr filter_map{}; - recordFilteredRows(block, non_equal_conditions.left_filter_column, filter_map_holder, filter_map); + Block block{}; +#define CALL(KIND, STRICTNESS, MAP) block = joinBlockNullAwareSemiImpl(probe_process_info); -#define CALL(KIND, STRICTNESS, MAP) \ - joinBlockNullAwareSemiImpl(block, key_columns, null_map, filter_map, all_key_null_map); using enum ASTTableJoin::Strictness; using enum ASTTableJoin::Kind; @@ -1542,19 +1605,17 @@ Block Join::joinBlockNullAwareSemi(ProbeProcessInfo & probe_process_info) const } template -void Join::joinBlockNullAwareSemiImpl( - Block & block, - const ColumnRawPtrs & key_columns, - const ConstNullMapPtr & null_map, - const ConstNullMapPtr & filter_map, - const ConstNullMapPtr & all_key_null_map) const +Block Join::joinBlockNullAwareSemiImpl(const ProbeProcessInfo & probe_process_info) const { - size_t rows = block.rows(); + size_t rows = probe_process_info.block.rows(); std::vector null_rows(partitions.size(), nullptr); for (size_t i = 0; i < partitions.size(); ++i) null_rows[i] = partitions[i]->getRowsNotInsertedToMap(); - NALeftSideInfo left_side_info(null_map, filter_map, all_key_null_map); + NALeftSideInfo left_side_info( + probe_process_info.null_map, + probe_process_info.null_aware_join_data->filter_map, + probe_process_info.null_aware_join_data->all_key_null_map); NARightSideInfo right_side_info( right_has_all_key_null_row.load(std::memory_order_relaxed), right_table_is_empty.load(std::memory_order_relaxed), @@ -1563,7 +1624,7 @@ void Join::joinBlockNullAwareSemiImpl( auto [res, res_list] = JoinPartition::probeBlockNullAwareSemi( partitions, rows, - key_columns, + probe_process_info.null_aware_join_data->key_columns, key_sizes, collators, left_side_info, @@ -1571,23 +1632,43 @@ void Join::joinBlockNullAwareSemiImpl( RUNTIME_ASSERT(res.size() == rows, "SemiJoinResult size {} must be equal to block size {}", res.size(), rows); + Block block{}; + for (size_t i = 0; i < probe_process_info.block.columns(); ++i) + { + const auto & column = probe_process_info.block.getByPosition(i); + if (output_columns_names_set_for_other_condition_after_finalize.contains(column.name)) + block.insert(column); + } + size_t left_columns = block.columns(); - size_t right_columns = sample_block_with_columns_to_add.columns(); /// Add new columns to the block. - for (const auto & src_column : sample_block_with_columns_to_add.getColumnsWithTypeAndName()) + std::vector right_column_indices_to_add; + + for (size_t i = 0; i < sample_block_without_keys.columns(); ++i) { - RUNTIME_CHECK_MSG( - !block.has(src_column.name), - "block from probe side has a column with the same name: {} as a column in sample_block_with_columns_to_add", - src_column.name); - block.insert(src_column); + const auto & column = sample_block_without_keys.getByPosition(i); + if (output_columns_names_set_for_other_condition_after_finalize.contains(column.name)) + { + RUNTIME_CHECK_MSG( + !block.has(column.name), + "block from probe side has a column with the same name: {} as a column in sample_block_without_keys", + column.name); + block.insert(column); + right_column_indices_to_add.push_back(i); + } } if (!res_list.empty()) { - NASemiJoinHelper - helper(block, left_columns, right_columns, blocks, null_rows, max_block_size, non_equal_conditions); + NASemiJoinHelper helper( + block, + left_columns, + right_column_indices_to_add, + blocks, + null_rows, + max_block_size, + non_equal_conditions); helper.joinResult(res_list); @@ -1600,17 +1681,15 @@ void Join::joinBlockNullAwareSemiImpl( if constexpr (KIND == ASTTableJoin::Kind::NullAware_Anti) filter = std::make_unique(rows); - MutableColumns added_columns(right_columns); - for (size_t i = 0; i < right_columns; ++i) - added_columns[i] = block.getByPosition(i + left_columns).column->cloneEmpty(); - + MutableColumnPtr left_semi_column_ptr = nullptr; ColumnInt8::Container * left_semi_column_data = nullptr; ColumnUInt8::Container * left_semi_null_map = nullptr; if constexpr ( KIND == ASTTableJoin::Kind::NullAware_LeftOuterSemi || KIND == ASTTableJoin::Kind::NullAware_LeftOuterAnti) { - auto * left_semi_column = typeid_cast(added_columns[right_columns - 1].get()); + left_semi_column_ptr = block.getByPosition(block.columns() - 1).column->cloneEmpty(); + auto * left_semi_column = typeid_cast(left_semi_column_ptr.get()); left_semi_column_data = &typeid_cast &>(left_semi_column->getNestedColumn()).getData(); left_semi_null_map = &left_semi_column->getNullMapColumn().getData(); left_semi_column_data->reserve(rows); @@ -1658,20 +1737,23 @@ void Join::joinBlockNullAwareSemiImpl( if constexpr ( KIND == ASTTableJoin::Kind::NullAware_LeftOuterSemi || KIND == ASTTableJoin::Kind::NullAware_LeftOuterAnti) { - block.getByPosition(block.columns() - 1).column = std::move(added_columns[right_columns - 1]); + block.getByPosition(block.columns() - 1).column = std::move(left_semi_column_ptr); } if constexpr (KIND == ASTTableJoin::Kind::NullAware_Anti) { for (size_t i = 0; i < left_columns; ++i) - block.getByPosition(i).column = block.getByPosition(i).column->filter(*filter, rows_for_anti); + { + auto & column = block.getByPosition(i); + if (output_column_names_set_after_finalize.contains(column.name)) + column.column = column.column->filter(*filter, rows_for_anti); + } } + return block; } Block Join::joinBlockSemi(ProbeProcessInfo & probe_process_info) const { - Block block = probe_process_info.block; - JoinBuildInfo join_build_info{ enable_fine_grained_shuffle, fine_grained_shuffle_count, @@ -1689,8 +1771,9 @@ Block Join::joinBlockSemi(ProbeProcessInfo & probe_process_info) const collators, restore_config.restore_round); + Block block{}; #define CALL(KIND, STRICTNESS, MAP) \ - joinBlockSemiImpl(block, join_build_info, probe_process_info); + block = joinBlockSemiImpl(join_build_info, probe_process_info); using enum ASTTableJoin::Strictness; using enum ASTTableJoin::Kind; @@ -1723,12 +1806,9 @@ Block Join::joinBlockSemi(ProbeProcessInfo & probe_process_info) const } template -void Join::joinBlockSemiImpl( - Block & block, - const JoinBuildInfo & join_build_info, - const ProbeProcessInfo & probe_process_info) const +Block Join::joinBlockSemiImpl(const JoinBuildInfo & join_build_info, const ProbeProcessInfo & probe_process_info) const { - size_t rows = block.rows(); + size_t rows = probe_process_info.block.rows(); auto [res, res_list] = JoinPartition::probeBlockSemi( partitions, @@ -1740,17 +1820,33 @@ void Join::joinBlockSemiImpl( RUNTIME_ASSERT(res.size() == rows, "SemiJoinResult size {} must be equal to block size {}", res.size(), rows); - size_t left_columns = block.columns(); - size_t right_columns = sample_block_with_columns_to_add.columns(); + const NameSet & probe_output_name_set = has_other_condition + ? output_columns_names_set_for_other_condition_after_finalize + : output_column_names_set_after_finalize; + Block block{}; + for (size_t i = 0; i < probe_process_info.block.columns(); ++i) + { + const auto & column = probe_process_info.block.getByPosition(i); + if (probe_output_name_set.contains(column.name)) + block.insert(column); + } + size_t left_columns = block.columns(); /// Add new columns to the block. - for (const auto & src_column : sample_block_with_columns_to_add.getColumnsWithTypeAndName()) + std::vector right_column_indices_to_add; + + for (size_t i = 0; i < sample_block_without_keys.columns(); ++i) { - RUNTIME_CHECK_MSG( - !block.has(src_column.name), - "block from probe side has a column with the same name: {} as a column in sample_block_with_columns_to_add", - src_column.name); - block.insert(src_column); + const auto & column = sample_block_without_keys.getByPosition(i); + if (probe_output_name_set.contains(column.name)) + { + RUNTIME_CHECK_MSG( + !block.has(column.name), + "block from probe side has a column with the same name: {} as a column in sample_block_without_keys", + column.name); + block.insert(column); + right_column_indices_to_add.push_back(i); + } } if constexpr (STRICTNESS == ASTTableJoin::Strictness::All) @@ -1758,7 +1854,7 @@ void Join::joinBlockSemiImpl( if (!res_list.empty()) { SemiJoinHelper - helper(block, left_columns, right_columns, max_block_size, non_equal_conditions); + helper(block, left_columns, right_column_indices_to_add, max_block_size, non_equal_conditions); helper.joinResult(res_list); @@ -1772,16 +1868,14 @@ void Join::joinBlockSemiImpl( if constexpr (KIND == ASTTableJoin::Kind::Semi || KIND == ASTTableJoin::Kind::Anti) filter = std::make_unique(rows); - MutableColumns added_columns(right_columns); - for (size_t i = 0; i < right_columns; ++i) - added_columns[i] = block.getByPosition(i + left_columns).column->cloneEmpty(); - + MutableColumnPtr left_semi_column_ptr = nullptr; ColumnInt8::Container * left_semi_column_data = nullptr; ColumnUInt8::Container * left_semi_null_map = nullptr; if constexpr (KIND == ASTTableJoin::Kind::LeftOuterSemi || KIND == ASTTableJoin::Kind::LeftOuterAnti) { - auto * left_semi_column = typeid_cast(added_columns[right_columns - 1].get()); + left_semi_column_ptr = block.getByPosition(block.columns() - 1).column->cloneEmpty(); + auto * left_semi_column = typeid_cast(left_semi_column_ptr.get()); left_semi_column_data = &typeid_cast &>(left_semi_column->getNestedColumn()).getData(); left_semi_column_data->reserve(rows); left_semi_null_map = &left_semi_column->getNullMapColumn().getData(); @@ -1842,14 +1936,19 @@ void Join::joinBlockSemiImpl( if constexpr (KIND == ASTTableJoin::Kind::LeftOuterSemi || KIND == ASTTableJoin::Kind::LeftOuterAnti) { - block.getByPosition(block.columns() - 1).column = std::move(added_columns[right_columns - 1]); + block.getByPosition(block.columns() - 1).column = std::move(left_semi_column_ptr); } if constexpr (KIND == ASTTableJoin::Kind::Semi || KIND == ASTTableJoin::Kind::Anti) { for (size_t i = 0; i < left_columns; ++i) - block.getByPosition(i).column = block.getByPosition(i).column->filter(*filter, rows_for_semi_anti); + { + auto & column = block.getByPosition(i); + if (output_column_names_set_after_finalize.contains(column.name)) + column.column = column.column->filter(*filter, rows_for_semi_anti); + } } + return block; } void Join::checkTypesOfKeys(const Block & block_left, const Block & block_right) const @@ -2121,6 +2220,7 @@ void Join::finishOneNonJoin(size_t partition_index) Block Join::joinBlock(ProbeProcessInfo & probe_process_info, bool dry_run) const { assert(!probe_process_info.all_rows_joined_finish); + assert(finalized); if unlikely (dry_run) { assert(probe_process_info.block.rows() == 0); @@ -2175,6 +2275,7 @@ BlockInputStreamPtr Join::createScanHashMapAfterProbeStream( size_t step, size_t max_block_size_) const { + RUNTIME_CHECK_MSG(finalized, "Create ScanHashMapAfterProbeStream before join be finalized"); return std::make_shared( *this, left_sample_block, @@ -2497,4 +2598,88 @@ void Join::wakeUpAllWaitingThreads() cancelRuntimeFilter("Join has been cancelled."); } +void Join::finalize(const Names & parent_require) +{ + if unlikely (finalized) + return; + /// finalize will do 3 things + /// 1. update expected_output_schema + /// 2. set expected_output_schema_for_other_condition + /// 3. generated needed input columns + NameSet required_names_set; + for (const auto & name : parent_require) + required_names_set.insert(name); + if unlikely (!match_helper_name.empty() && !required_names_set.contains(match_helper_name)) + { + /// should only happens in some tests + required_names_set.insert(match_helper_name); + } + for (const auto & name_and_type : output_columns) + { + if (required_names_set.find(name_and_type.name) != required_names_set.end()) + { + output_columns_after_finalize.push_back(name_and_type); + output_column_names_set_after_finalize.insert(name_and_type.name); + } + } + output_block_after_finalize = Block(output_columns_after_finalize); + Names updated_require; + if (match_helper_name.empty()) + updated_require = parent_require; + else + { + for (const auto & name : required_names_set) + if (name != match_helper_name) + updated_require.push_back(name); + required_names_set.erase(match_helper_name); + } + if (!non_equal_conditions.null_aware_eq_cond_name.empty()) + { + updated_require.push_back(non_equal_conditions.null_aware_eq_cond_name); + } + if (!non_equal_conditions.other_eq_cond_from_in_name.empty()) + updated_require.push_back(non_equal_conditions.other_eq_cond_from_in_name); + if (!non_equal_conditions.other_cond_name.empty()) + updated_require.push_back(non_equal_conditions.other_cond_name); + auto keep_used_input_columns + = !isCrossJoin(kind) && (isNullAwareSemiFamily(kind) || isSemiFamily(kind) || isLeftOuterSemiFamily(kind)); + /// nullaware/semi join will reuse the input columns so need to let finalize keep the input columns + if (non_equal_conditions.null_aware_eq_cond_expr != nullptr) + { + non_equal_conditions.null_aware_eq_cond_expr->finalize(updated_require, keep_used_input_columns); + updated_require = non_equal_conditions.null_aware_eq_cond_expr->getRequiredColumns(); + } + if (non_equal_conditions.other_cond_expr != nullptr) + { + non_equal_conditions.other_cond_expr->finalize(updated_require, keep_used_input_columns); + updated_require = non_equal_conditions.other_cond_expr->getRequiredColumns(); + } + /// remove duplicated column + required_names_set.clear(); + for (const auto & name : updated_require) + required_names_set.insert(name); + /// add some internal used columns + if (!non_equal_conditions.left_filter_column.empty()) + required_names_set.insert(non_equal_conditions.left_filter_column); + if (!non_equal_conditions.right_filter_column.empty()) + required_names_set.insert(non_equal_conditions.right_filter_column); + /// add join key to required_columns + for (const auto & name : key_names_right) + required_names_set.insert(name); + for (const auto & name : key_names_left) + required_names_set.insert(name); + + + if (non_equal_conditions.other_cond_expr != nullptr || non_equal_conditions.null_aware_eq_cond_expr != nullptr) + { + for (const auto & name : required_names_set) + output_columns_names_set_for_other_condition_after_finalize.insert(name); + if (!match_helper_name.empty()) + output_columns_names_set_for_other_condition_after_finalize.insert(match_helper_name); + } + for (const auto & name : required_names_set) + required_columns.push_back(name); + finalized = true; +} + } // namespace DB diff --git a/dbms/src/Interpreters/Join.h b/dbms/src/Interpreters/Join.h index 72f9b425fc3..8716533cf32 100644 --- a/dbms/src/Interpreters/Join.h +++ b/dbms/src/Interpreters/Join.h @@ -171,7 +171,7 @@ class Join const SpillConfig & build_spill_config_, const SpillConfig & probe_spill_config_, const RestoreConfig & restore_config_, - const Names & tidb_output_column_names_, + const NamesAndTypes & output_columns_, const RegisterOperatorSpillContext & register_operator_spill_context_, AutoSpillTrigger * auto_spill_trigger_, const TiDB::TiDBCollators & collators_, @@ -324,6 +324,9 @@ class Join const JoinProfileInfoPtr profile_info = std::make_shared(); HashJoinSpillContextPtr hash_join_spill_context; + const Block & getOutputBlock() const { return finalized ? output_block_after_finalize : output_block; } + const Names & getRequiredColumns() const { return required_columns; } + void finalize(const Names & parent_require); private: friend class ScanHashMapAfterProbeBlockInputStream; @@ -407,11 +410,18 @@ class Join Sizes key_sizes; /// Block with columns from the right-side table except key columns. - Block sample_block_with_columns_to_add; + Block sample_block_without_keys; /// Block with key columns in the same order they appear in the right-side table. - Block sample_block_with_keys; + Block sample_block_only_keys; - Names tidb_output_column_names; + NamesAndTypes output_columns; + Block output_block; + NamesAndTypes output_columns_after_finalize; + Block output_block_after_finalize; + NameSet output_column_names_set_after_finalize; + NameSet output_columns_names_set_for_other_condition_after_finalize; + Names required_columns; + bool finalized = false; bool is_test; @@ -482,18 +492,10 @@ class Join Block doJoinBlockCross(ProbeProcessInfo & probe_process_info) const; template - void joinBlockNullAwareSemiImpl( - Block & block, - const ColumnRawPtrs & key_columns, - const ConstNullMapPtr & null_map, - const ConstNullMapPtr & filter_map, - const ConstNullMapPtr & all_key_null_map) const; + Block joinBlockNullAwareSemiImpl(const ProbeProcessInfo & probe_process_info) const; template - void joinBlockSemiImpl( - Block & block, - const JoinBuildInfo & join_build_info, - const ProbeProcessInfo & probe_process_info) const; + Block joinBlockSemiImpl(const JoinBuildInfo & join_build_info, const ProbeProcessInfo & probe_process_info) const; IColumn::Selector hashToSelector(const WeakHash32 & hash) const; IColumn::Selector selectDispatchBlock(const Strings & key_columns_names, const Block & from_block); diff --git a/dbms/src/Interpreters/JoinPartition.cpp b/dbms/src/Interpreters/JoinPartition.cpp index 88174ae81b6..a060878c4f7 100644 --- a/dbms/src/Interpreters/JoinPartition.cpp +++ b/dbms/src/Interpreters/JoinPartition.cpp @@ -1502,10 +1502,10 @@ void NO_INLINE probeBlockImplTypeCase( if (need_virtual_dispatch_for_probe_block) { RUNTIME_ASSERT(!(join_build_info.restore_round > 0 && join_build_info.enable_fine_grained_shuffle)); - RUNTIME_ASSERT(probe_process_info.hash_data->getData().size() == rows); + RUNTIME_ASSERT(probe_process_info.hash_join_data->hash_data->getData().size() == rows); } - const auto & build_hash_data = probe_process_info.hash_data->getData(); + const auto & build_hash_data = probe_process_info.hash_join_data->hash_data->getData(); size_t i; bool block_full = false; for (i = probe_process_info.start_row; i < rows; ++i) @@ -1746,6 +1746,7 @@ probeBlockNullAwareSemiInternal( { if ((*left_side_info.null_map)[i]) { + /// some key is null if constexpr (STRICTNESS == ASTTableJoin::Strictness::Any) { if (key_columns.size() == 1 || right_side_info.has_all_key_null_row @@ -1922,22 +1923,22 @@ probeBlockSemiInternal( } } - KeyGetter key_getter(probe_process_info.key_columns, key_sizes, collators); + KeyGetter key_getter(probe_process_info.hash_join_data->key_columns, key_sizes, collators); std::vector sort_key_containers; - sort_key_containers.resize(probe_process_info.key_columns.size()); + sort_key_containers.resize(probe_process_info.hash_join_data->key_columns.size()); Arena pool; bool need_virtual_dispatch_for_probe_block = join_build_info.needVirtualDispatchForProbeBlock(); if (need_virtual_dispatch_for_probe_block) { RUNTIME_ASSERT(!(join_build_info.restore_round > 0 && join_build_info.enable_fine_grained_shuffle)); - RUNTIME_ASSERT(probe_process_info.hash_data->getData().size() == rows); + RUNTIME_ASSERT(probe_process_info.hash_join_data->hash_data->getData().size() == rows); } PaddedPODArray> res; res.reserve(rows); std::list *> res_list; - const auto & build_hash_data = probe_process_info.hash_data->getData(); + const auto & build_hash_data = probe_process_info.hash_join_data->hash_data->getData(); for (size_t i = 0; i < rows; ++i) { if constexpr (has_null_map) diff --git a/dbms/src/Interpreters/NullAwareSemiJoinHelper.cpp b/dbms/src/Interpreters/NullAwareSemiJoinHelper.cpp index 859fc4e3b23..80e66b205b2 100644 --- a/dbms/src/Interpreters/NullAwareSemiJoinHelper.cpp +++ b/dbms/src/Interpreters/NullAwareSemiJoinHelper.cpp @@ -43,6 +43,7 @@ void NASemiJoinResult::fillRightColumns( MutableColumns & added_columns, size_t left_columns, size_t right_columns, + const std::vector & right_column_indices_to_add, const std::vector & null_rows, size_t & current_offset, size_t max_pace) @@ -77,7 +78,9 @@ void NASemiJoinResult::fillRightColumns( for (size_t i = 0; i < current_pace && iter != nullptr; ++i) { for (size_t j = 0; j < right_columns; ++j) - added_columns[j + left_columns]->insertFrom(*iter->block->getByPosition(j).column.get(), iter->row_num); + added_columns[j + left_columns]->insertFrom( + *iter->block->getByPosition(right_column_indices_to_add[j]).column.get(), + iter->row_num); ++current_offset; iter = iter->next; } @@ -96,12 +99,16 @@ void NASemiJoinResult::fillRightColumns( while (pos_in_columns_vector < rows.materialized_columns_vec.size() && count > 0) { + /// todo only materialize used columns const auto & columns = rows.materialized_columns_vec[pos_in_columns_vector]; const size_t columns_size = columns[0]->size(); size_t insert_cnt = std::min(count, columns_size - pos_in_columns); for (size_t j = 0; j < right_columns; ++j) - added_columns[j + left_columns]->insertRangeFrom(*columns[j].get(), pos_in_columns, insert_cnt); + added_columns[j + left_columns]->insertRangeFrom( + *columns[right_column_indices_to_add[j]].get(), + pos_in_columns, + insert_cnt); pos_in_columns += insert_cnt; count -= insert_cnt; @@ -233,14 +240,14 @@ template ::NASemiJoinHelper( Block & block_, size_t left_columns_, - size_t right_columns_, + const std::vector & right_column_indices_to_add_, const BlocksList & right_blocks_, const std::vector & null_rows_, size_t max_block_size_, const JoinNonEqualConditions & non_equal_conditions_) : block(block_) , left_columns(left_columns_) - , right_columns(right_columns_) + , right_column_indices_to_add(right_column_indices_to_add_) , right_blocks(right_blocks_) , null_rows(null_rows_) , max_block_size(max_block_size_) @@ -249,6 +256,7 @@ NASemiJoinHelper::NASemiJoinHelper( static_assert(KIND == NullAware_Anti || KIND == NullAware_LeftOuterAnti || KIND == NullAware_LeftOuterSemi); static_assert(STRICTNESS == Any || STRICTNESS == All); + right_columns = right_column_indices_to_add.size(); RUNTIME_CHECK(block.columns() == left_columns + right_columns); if constexpr (KIND == NullAware_LeftOuterAnti || KIND == NullAware_LeftOuterSemi) @@ -337,6 +345,7 @@ void NASemiJoinHelper::runStep( columns, left_columns, right_columns, + right_column_indices_to_add, null_rows, current_offset, max_block_size - current_offset); @@ -392,7 +401,8 @@ void NASemiJoinHelper::runStepAllBlocks(std::list( exec_block, diff --git a/dbms/src/Interpreters/NullAwareSemiJoinHelper.h b/dbms/src/Interpreters/NullAwareSemiJoinHelper.h index fad4af4e9b9..83ab56d332b 100644 --- a/dbms/src/Interpreters/NullAwareSemiJoinHelper.h +++ b/dbms/src/Interpreters/NullAwareSemiJoinHelper.h @@ -116,6 +116,7 @@ class NASemiJoinResult MutableColumns & added_columns, size_t left_columns, size_t right_columns, + const std::vector & right_column_indices_to_add, const std::vector & null_rows, size_t & current_offset, size_t max_pace); @@ -162,7 +163,7 @@ class NASemiJoinHelper NASemiJoinHelper( Block & block, size_t left_columns, - size_t right_columns, + const std::vector & right_column_indices_to_add, const BlocksList & right_blocks, const std::vector & null_rows, size_t max_block_size, @@ -187,6 +188,7 @@ class NASemiJoinHelper Block & block; size_t left_columns; size_t right_columns; + const std::vector & right_column_indices_to_add; const BlocksList & right_blocks; const std::vector & null_rows; size_t max_block_size; diff --git a/dbms/src/Interpreters/ProbeProcessInfo.cpp b/dbms/src/Interpreters/ProbeProcessInfo.cpp index 9ed792e77d4..e0e4416d443 100644 --- a/dbms/src/Interpreters/ProbeProcessInfo.cpp +++ b/dbms/src/Interpreters/ProbeProcessInfo.cpp @@ -36,18 +36,12 @@ void ProbeProcessInfo::resetBlock(Block && block_, size_t partition_index_) null_map_holder = nullptr; filter.reset(); offsets_to_replicate.reset(); - key_columns.clear(); - materialized_columns.clear(); - hash_data = nullptr; - result_block_schema.clear(); - right_column_index.clear(); - right_rows_to_be_added_when_matched = 0; - cross_probe_mode = CrossProbeMode::DEEP_COPY_RIGHT_BLOCK; - right_block_size = 0; - next_right_block_index = 0; - row_num_filtered_by_left_condition = 0; - has_row_matched = false; - has_row_null = false; + if (hash_join_data) + hash_join_data->reset(); + if (cross_join_data) + cross_join_data->reset(); + if (null_aware_join_data) + null_aware_join_data->reset(); } void ProbeProcessInfo::prepareForHashProbe( @@ -61,11 +55,14 @@ void ProbeProcessInfo::prepareForHashProbe( { if (prepare_for_probe_done) return; + if (unlikely(hash_join_data == nullptr)) + hash_join_data = std::make_unique(); /// Rare case, when keys are constant. To avoid code bloat, simply materialize them. /// Note: this variable can't be removed because it will take smart pointers' lifecycle to the end of this function. - key_columns = extractAndMaterializeKeyColumns(block, materialized_columns, key_names); + hash_join_data->key_columns + = extractAndMaterializeKeyColumns(block, hash_join_data->materialized_columns, key_names); /// Keys with NULL value in any column won't join to anything. - extractNestedColumnsAndNullMap(key_columns, null_map_holder, null_map); + extractNestedColumnsAndNullMap(hash_join_data->key_columns, null_map_holder, null_map); /// reuse null_map to record the filtered rows, the rows contains NULL or does not /// match the join filter won't join to anything recordFilteredRows(block, filter_column, null_map_holder, null_map); @@ -93,12 +90,18 @@ void ProbeProcessInfo::prepareForHashProbe( if (!isSemiFamily(kind) && !isLeftOuterSemiFamily(kind) && strictness == ASTTableJoin::Strictness::All) offsets_to_replicate = std::make_unique(block.rows()); - hash_data = std::make_unique(0); + hash_join_data->hash_data = std::make_unique(0); if (need_compute_hash) { std::vector sort_key_containers; - sort_key_containers.resize(key_columns.size()); - computeDispatchHash(block.rows(), key_columns, collators, sort_key_containers, restore_round, *hash_data); + sort_key_containers.resize(hash_join_data->key_columns.size()); + computeDispatchHash( + block.rows(), + hash_join_data->key_columns, + collators, + sort_key_containers, + restore_round, + *hash_join_data->hash_data); } prepare_for_probe_done = true; } @@ -107,17 +110,20 @@ void ProbeProcessInfo::prepareForCrossProbe( const String & filter_column, ASTTableJoin::Kind kind, ASTTableJoin::Strictness strictness, - const Block & sample_block_with_columns_to_add, + const Block & sample_block_without_keys, + const NameSet & output_column_names_set, size_t right_rows_to_be_added_when_matched_, CrossProbeMode cross_probe_mode_, size_t right_block_size_) { if (prepare_for_probe_done) return; + if (unlikely(cross_join_data == nullptr)) + cross_join_data = std::make_unique(); - right_rows_to_be_added_when_matched = right_rows_to_be_added_when_matched_; - cross_probe_mode = cross_probe_mode_; - right_block_size = right_block_size_; + cross_join_data->right_rows_to_be_added_when_matched = right_rows_to_be_added_when_matched_; + cross_join_data->cross_probe_mode = cross_probe_mode_; + cross_join_data->right_block_size = right_block_size_; recordFilteredRows(block, filter_column, null_map_holder, null_map); if (kind == ASTTableJoin::Kind::Cross_Anti && strictness == ASTTableJoin::Strictness::All) @@ -128,23 +134,61 @@ void ProbeProcessInfo::prepareForCrossProbe( /// Should convert all the columns in block to nullable if it is cross right join, here we don't need /// to do so because cross_right join is converted to cross left join during compile - result_block_schema = block.cloneEmpty(); - for (size_t i = 0; i < sample_block_with_columns_to_add.columns(); ++i) + if unlikely (cross_join_data->result_block_schema.columns() == 0) { - const ColumnWithTypeAndName & src_column = sample_block_with_columns_to_add.getByPosition(i); - RUNTIME_CHECK_MSG( - !result_block_schema.has(src_column.name), - "block from probe side has a column with the same name: {} as a column in sample_block_with_columns_to_add", - src_column.name); - result_block_schema.insert(src_column); + /// these information only need to be init once + for (size_t i = 0; i < block.columns(); ++i) + { + auto & column = block.getByPosition(i); + if (output_column_names_set.contains(column.name)) + { + cross_join_data->result_block_schema.insert(column.cloneEmpty()); + cross_join_data->left_column_index_in_left_block.push_back(i); + } + } + for (size_t i = 0; i < sample_block_without_keys.columns(); ++i) + { + const ColumnWithTypeAndName & src_column = sample_block_without_keys.getByPosition(i); + if (output_column_names_set.contains(src_column.name)) + { + RUNTIME_CHECK_MSG( + !cross_join_data->result_block_schema.has(src_column.name), + "block from probe side has a column with the same name: {} as a column in " + "sample_block_without_keys", + src_column.name); + cross_join_data->result_block_schema.insert(src_column); + cross_join_data->right_column_index_in_right_block.push_back(i); + } + } + auto offset = cross_join_data->left_column_index_in_left_block.size(); + for (size_t i = 0; i < cross_join_data->right_column_index_in_right_block.size(); ++i) + cross_join_data->right_column_index_in_result_block.push_back(offset + i); } - size_t num_existing_columns = block.columns(); - size_t num_columns_to_add = sample_block_with_columns_to_add.columns(); - for (size_t i = 0; i < num_columns_to_add; ++i) - right_column_index.push_back(num_existing_columns + i); + if (cross_join_data->cross_probe_mode == CrossProbeMode::SHALLOW_COPY_RIGHT_BLOCK && null_map != nullptr) + cross_join_data->row_num_filtered_by_left_condition = countBytesInFilter(*null_map); + prepare_for_probe_done = true; +} + +void ProbeProcessInfo::prepareForNullAware(const Names & key_names, const String & filter_column) +{ + assert(prepare_for_probe_done == false); + if unlikely (null_aware_join_data == nullptr) + null_aware_join_data = std::make_unique(); + /// Rare case, when keys are constant. To avoid code bloat, simply materialize them. + /// Note: this variable can't be removed because it will take smart pointers' lifecycle to the end of this function. + null_aware_join_data->key_columns + = extractAndMaterializeKeyColumns(block, null_aware_join_data->materialized_columns, key_names); + + /// Note that `extractAllKeyNullMap` must be done before `extractNestedColumnsAndNullMap` + /// because `extractNestedColumnsAndNullMap` will change the nullable column to its nested column. + extractAllKeyNullMap( + null_aware_join_data->key_columns, + null_aware_join_data->all_key_null_map_holder, + null_aware_join_data->all_key_null_map); + + extractNestedColumnsAndNullMap(null_aware_join_data->key_columns, null_map_holder, null_map); - if (cross_probe_mode == CrossProbeMode::SHALLOW_COPY_RIGHT_BLOCK && null_map != nullptr) - row_num_filtered_by_left_condition = countBytesInFilter(*null_map); + recordFilteredRows(block, filter_column, null_aware_join_data->filter_map_holder, null_aware_join_data->filter_map); prepare_for_probe_done = true; } @@ -159,13 +203,13 @@ void ProbeProcessInfo::cutFilterAndOffsetVector(size_t start, size_t end) const bool ProbeProcessInfo::isCurrentProbeRowFinished() const { /// only used in cross join of shallow copy cross probe mode - return next_right_block_index == right_block_size; + return cross_join_data->next_right_block_index == cross_join_data->right_block_size; } -void ProbeProcessInfo::finishCurrentProbeRow() +void ProbeProcessInfo::finishCurrentProbeRow() const { /// only used in cross join of shallow copy cross probe mode - next_right_block_index = right_block_size; + cross_join_data->next_right_block_index = cross_join_data->right_block_size; } } // namespace DB diff --git a/dbms/src/Interpreters/ProbeProcessInfo.h b/dbms/src/Interpreters/ProbeProcessInfo.h index 9f145fb577e..b3e1582b7b7 100644 --- a/dbms/src/Interpreters/ProbeProcessInfo.h +++ b/dbms/src/Interpreters/ProbeProcessInfo.h @@ -29,6 +29,69 @@ enum class CrossProbeMode SHALLOW_COPY_RIGHT_BLOCK, }; +struct HashJoinProbeProcessData +{ + Columns materialized_columns; + ColumnRawPtrs key_columns; + /// TODO: consider adding a virtual column in Sender side to avoid computing cost and potential inconsistency by heterogeneous envs(AMD64, ARM64) + /// Note: 1. Not sure, if inconsistency will do happen in heterogeneous envs + /// 2. Virtual column would take up a little more network bandwidth, might lead to poor performance if network was bottleneck + /// Currently, the computation cost is tolerable, since it's a very simple crc32 hash algorithm, and heterogeneous envs support is not considered + std::unique_ptr hash_data; /// to reproduce hash values according to build stage + void reset() + { + key_columns.clear(); + materialized_columns.clear(); + hash_data = nullptr; + } +}; +struct CrossJoinProbeProcessData +{ + Block result_block_schema; + std::vector right_column_index_in_result_block; + std::vector right_column_index_in_right_block; + std::vector left_column_index_in_left_block; + size_t right_rows_to_be_added_when_matched = 0; + CrossProbeMode cross_probe_mode = CrossProbeMode::DEEP_COPY_RIGHT_BLOCK; + /// the following fields are used for NO_COPY_RIGHT_BLOCK probe + size_t right_block_size = 0; + /// the rows that is filtered by left condition + size_t row_num_filtered_by_left_condition = 0; + size_t next_right_block_index = 0; + /// used for outer/semi/anti/left outer semi/left outer anti join + bool has_row_matched = false; + /// used for left outer semi/left outer anti join + bool has_row_null = false; + void reset() + { + right_rows_to_be_added_when_matched = 0; + cross_probe_mode = CrossProbeMode::DEEP_COPY_RIGHT_BLOCK; + right_block_size = 0; + next_right_block_index = 0; + row_num_filtered_by_left_condition = 0; + has_row_matched = false; + has_row_null = false; + } +}; +struct NullAwareJoinProbeProcessData +{ + Columns materialized_columns; + ColumnRawPtrs key_columns; + ColumnPtr filter_map_holder = nullptr; + ConstNullMapPtr filter_map = nullptr; + ColumnPtr all_key_null_map_holder = nullptr; + ConstNullMapPtr all_key_null_map = nullptr; + void reset() + { + key_columns.clear(); + materialized_columns.clear(); + filter_map_holder = nullptr; + filter_map = nullptr; + all_key_null_map_holder = nullptr; + all_key_null_map = nullptr; + } +}; + struct ProbeProcessInfo { Block block; @@ -50,28 +113,13 @@ struct ProbeProcessInfo std::unique_ptr offsets_to_replicate = nullptr; /// for hash probe - Columns materialized_columns; - ColumnRawPtrs key_columns; - /// TODO: consider adding a virtual column in Sender side to avoid computing cost and potential inconsistency by heterogeneous envs(AMD64, ARM64) - /// Note: 1. Not sure, if inconsistency will do happen in heterogeneous envs - /// 2. Virtual column would take up a little more network bandwidth, might lead to poor performance if network was bottleneck - /// Currently, the computation cost is tolerable, since it's a very simple crc32 hash algorithm, and heterogeneous envs support is not considered - std::unique_ptr hash_data; /// to reproduce hash values according to build stage + std::unique_ptr hash_join_data; /// for cross probe - Block result_block_schema; - std::vector right_column_index; - size_t right_rows_to_be_added_when_matched = 0; - CrossProbeMode cross_probe_mode = CrossProbeMode::DEEP_COPY_RIGHT_BLOCK; - /// the following fields are used for NO_COPY_RIGHT_BLOCK probe - size_t right_block_size = 0; - /// the rows that is filtered by left condition - size_t row_num_filtered_by_left_condition = 0; - size_t next_right_block_index = 0; - /// used for outer/semi/anti/left outer semi/left outer anti join - bool has_row_matched = false; - /// used for left outer semi/left outer anti join - bool has_row_null = false; + std::unique_ptr cross_join_data; + + /// for null-aware join + std::unique_ptr null_aware_join_data; ProbeProcessInfo(UInt64 max_block_size_, UInt64 cache_columns_threshold_) : partition_index(0) @@ -88,11 +136,11 @@ struct ProbeProcessInfo { if constexpr (is_shallow_cross_probe_mode) { - if (next_right_block_index < right_block_size) + if (cross_join_data->next_right_block_index < cross_join_data->right_block_size) return; - next_right_block_index = 0; - has_row_matched = false; - has_row_null = false; + cross_join_data->next_right_block_index = 0; + cross_join_data->has_row_matched = false; + cross_join_data->has_row_null = false; } assert(start_row <= end_row); start_row = end_row; @@ -108,12 +156,13 @@ struct ProbeProcessInfo end_row = next_row_to_probe; if constexpr (is_shallow_cross_probe_mode) { - if (next_right_block_index < right_block_size) + if (cross_join_data->next_right_block_index < cross_join_data->right_block_size) { /// current probe is not finished, just return return; } - all_rows_joined_finish = row_num_filtered_by_left_condition == 0 && end_row == block.rows(); + all_rows_joined_finish + = cross_join_data->row_num_filtered_by_left_condition == 0 && end_row == block.rows(); } else { @@ -133,13 +182,15 @@ struct ProbeProcessInfo const String & filter_column, ASTTableJoin::Kind kind, ASTTableJoin::Strictness strictness, - const Block & sample_block_with_columns_to_add, + const Block & sample_block_without_keys, + const NameSet & output_column_names_set, size_t right_rows_to_be_added_when_matched, CrossProbeMode cross_probe_mode, size_t right_block_size); + void prepareForNullAware(const Names & key_names, const String & filter_column); void cutFilterAndOffsetVector(size_t start, size_t end) const; bool isCurrentProbeRowFinished() const; - void finishCurrentProbeRow(); + void finishCurrentProbeRow() const; }; } // namespace DB diff --git a/dbms/src/Interpreters/SemiJoinHelper.cpp b/dbms/src/Interpreters/SemiJoinHelper.cpp index 9606eea78c9..38e60e60170 100644 --- a/dbms/src/Interpreters/SemiJoinHelper.cpp +++ b/dbms/src/Interpreters/SemiJoinHelper.cpp @@ -40,6 +40,7 @@ void SemiJoinResult::fillRightColumns( MutableColumns & added_columns, size_t left_columns, size_t right_columns, + const std::vector & right_column_indices_to_add, size_t & current_offset, size_t max_pace) { @@ -60,7 +61,9 @@ void SemiJoinResult::fillRightColumns( for (size_t i = 0; i < current_pace && iter != nullptr; ++i) { for (size_t j = 0; j < right_columns; ++j) - added_columns[j + left_columns]->insertFrom(*iter->block->getByPosition(j).column.get(), iter->row_num); + added_columns[j + left_columns]->insertFrom( + *iter->block->getByPosition(right_column_indices_to_add[j]).column.get(), + iter->row_num); ++current_offset; iter = iter->next; } @@ -142,17 +145,18 @@ template SemiJoinHelper::SemiJoinHelper( Block & block_, size_t left_columns_, - size_t right_columns_, + const std::vector & right_column_indices_to_added_, size_t max_block_size_, const JoinNonEqualConditions & non_equal_conditions_) : block(block_) , left_columns(left_columns_) - , right_columns(right_columns_) + , right_column_indices_to_add(right_column_indices_to_added_) , max_block_size(max_block_size_) , non_equal_conditions(non_equal_conditions_) { static_assert(KIND == Semi || KIND == Anti || KIND == LeftOuterAnti || KIND == LeftOuterSemi); + right_columns = right_column_indices_to_add.size(); RUNTIME_CHECK(block.columns() == left_columns + right_columns); if constexpr (KIND == LeftOuterAnti || KIND == LeftOuterSemi) @@ -195,6 +199,7 @@ void SemiJoinHelper::joinResult(std::list & res_list) columns, left_columns, right_columns, + right_column_indices_to_add, current_offset, max_block_size - current_offset); diff --git a/dbms/src/Interpreters/SemiJoinHelper.h b/dbms/src/Interpreters/SemiJoinHelper.h index 00816146b1c..3759da9c391 100644 --- a/dbms/src/Interpreters/SemiJoinHelper.h +++ b/dbms/src/Interpreters/SemiJoinHelper.h @@ -102,6 +102,7 @@ class SemiJoinResult MutableColumns & added_columns, size_t left_columns, size_t right_columns, + const std::vector & right_column_indices_to_add, size_t & current_offset, size_t max_pace); @@ -134,7 +135,7 @@ class SemiJoinHelper SemiJoinHelper( Block & block, size_t left_columns, - size_t right_columns, + const std::vector & right_column_indices_to_add, size_t max_block_size, const JoinNonEqualConditions & non_equal_conditions); @@ -153,6 +154,7 @@ class SemiJoinHelper Block & block; size_t left_columns; size_t right_columns; + std::vector right_column_indices_to_add; size_t max_block_size; const JoinNonEqualConditions & non_equal_conditions;