Skip to content

Commit

Permalink
refine handleOtherConditions in join (#8642)
Browse files Browse the repository at this point in the history
ref #8633
  • Loading branch information
windtalker authored Jan 10, 2024
1 parent e317f2a commit bd99043
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 81 deletions.
109 changes: 37 additions & 72 deletions dbms/src/Interpreters/Join.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -780,6 +780,8 @@ void Join::cancelRuntimeFilter(const String & reason)
}
}

namespace
{
void mergeNullAndFilterResult(
Block & block,
ColumnVector<UInt8>::Container & filter_column,
Expand Down Expand Up @@ -833,6 +835,24 @@ void mergeNullAndFilterResult(
}
}
}
void applyNullToNotMatchedRows(Block & block, const Block & right_columns, const ColumnUInt8 & filter_column)
{
for (size_t i = 0; i < block.columns(); ++i)
{
auto & column = block.getByPosition(i);
if (right_columns.has(column.name))
{
auto full_column
= column.column->isColumnConst() ? column.column->convertToFullColumnIfConst() : column.column;
RUNTIME_CHECK_MSG(full_column->isColumnNullable(), "the right table column for left join must be nullable");
auto current_column = full_column;
auto result_column = (*std::move(current_column)).mutate();
static_cast<ColumnNullable &>(*result_column).applyNegatedNullMap(filter_column);
column.column = std::move(result_column);
}
}
}
} // namespace

/**
* handle other join conditions
Expand All @@ -847,11 +867,8 @@ void mergeNullAndFilterResult(
* @param left_table_columns
* @param right_table_columns
*/
void Join::handleOtherConditions(
Block & block,
IColumn::Filter * anti_filter,
IColumn::Offsets * offsets_to_replicate,
const std::vector<size_t> & right_table_columns) const
void Join::handleOtherConditions(Block & block, IColumn::Filter * anti_filter, IColumn::Offsets * offsets_to_replicate)
const
{
/// save block_rows because block.rows() returns the first column's size, after other_cond_expr->execute(block),
/// some column maybe removed, and the first column maybe the match_helper_column which does not have the same size
Expand All @@ -870,8 +887,7 @@ void Join::handleOtherConditions(
{
auto & col_name = input_block.getByPosition(i).name;
if ((!flag_mapped_entry_helper_name.empty() && col_name == flag_mapped_entry_helper_name)
|| output_column_names_set_after_finalize.find(col_name)
!= output_column_names_set_after_finalize.end())
|| output_column_names_set_after_finalize.contains(col_name))
++i;
else
input_block.erase(i);
Expand All @@ -885,10 +901,8 @@ void Join::handleOtherConditions(
assert(filter.empty());
filter.assign(block_rows, static_cast<UInt8>(1));
}
auto helper_pos = block.getPositionByName(match_helper_name);

const auto * old_match_nullable
= checkAndGetColumn<ColumnNullable>(block.safeGetByPosition(helper_pos).column.get());
= checkAndGetColumn<ColumnNullable>(block.getByName(match_helper_name).column.get());
const auto & old_match_vec
= static_cast<const ColumnVector<Int8> *>(old_match_nullable->getNestedColumnPtr().get())->getData();

Expand Down Expand Up @@ -959,7 +973,7 @@ void Join::handleOtherConditions(
}

erase_useless_column(block);
helper_pos = block.getPositionByName(match_helper_name);
auto helper_pos = block.getPositionByName(match_helper_name);
for (size_t i = 0; i < block.columns(); ++i)
if (i != helper_pos)
block.getByPosition(i).column = block.getByPosition(i).column->filter(row_filter, -1);
Expand Down Expand Up @@ -1023,34 +1037,17 @@ void Join::handleOtherConditions(
}
prev_offset = current_offset;
}
erase_useless_column(block);
if (isLeftOuterJoin(kind))
{
/// for left join, convert right column to null if not joined
for (size_t right_table_column : right_table_columns)
{
auto & column = block.getByPosition(right_table_column);
if (output_column_names_set_after_finalize.contains(column.name))
{
auto full_column
= column.column->isColumnConst() ? column.column->convertToFullColumnIfConst() : column.column;
if (!full_column->isColumnNullable())
{
throw Exception("Should not reach here, the right table column for left join must be nullable");
}
auto current_column = full_column;
auto result_column = (*std::move(current_column)).mutate();
static_cast<ColumnNullable &>(*result_column).applyNegatedNullMap(*filter_column);
column.column = std::move(result_column);
}
}
erase_useless_column(block);
applyNullToNotMatchedRows(block, sample_block_without_keys, *filter_column);
for (size_t i = 0; i < block.columns(); ++i)
block.getByPosition(i).column = block.getByPosition(i).column->filter(row_filter, -1);
return;
}
if (is_semi_family)
{
erase_useless_column(block);
/// for semi/anti join, filter out not matched rows
for (size_t i = 0; i < block.columns(); ++i)
block.getByPosition(i).column = block.getByPosition(i).column->filter(row_filter, -1);
Expand Down Expand Up @@ -1080,8 +1077,7 @@ void Join::handleOtherConditionsForOneProbeRow(Block & block, ProbeProcessInfo &
{
auto & col_name = input_block.getByPosition(i).name;
if ((!flag_mapped_entry_helper_name.empty() && col_name == flag_mapped_entry_helper_name)
|| output_column_names_set_after_finalize.find(col_name)
!= output_column_names_set_after_finalize.end())
|| output_column_names_set_after_finalize.contains(col_name))
++i;
else
input_block.erase(i);
Expand Down Expand Up @@ -1123,10 +1119,10 @@ void Join::handleOtherConditionsForOneProbeRow(Block & block, ProbeProcessInfo &
matched_row_count_in_current_block = countBytesInFilter(filter);
probe_process_info.cross_join_data->has_row_matched |= matched_row_count_in_current_block != 0;
}
erase_useless_column(block);
/// case 1, inner join
if (kind == ASTTableJoin::Kind::Cross)
{
erase_useless_column(block);
if (matched_row_count_in_current_block > 0)
{
for (size_t i = 0; i < block.columns(); ++i)
Expand All @@ -1144,7 +1140,6 @@ void Join::handleOtherConditionsForOneProbeRow(Block & block, ProbeProcessInfo &
{
if (matched_row_count_in_current_block > 0)
{
erase_useless_column(block);
for (size_t i = 0; i < block.columns(); ++i)
block.safeGetByPosition(i).column
= block.safeGetByPosition(i).column->filter(filter, matched_row_count_in_current_block);
Expand All @@ -1155,36 +1150,17 @@ void Join::handleOtherConditionsForOneProbeRow(Block & block, ProbeProcessInfo &
for (size_t i = 0; i < block.columns(); ++i)
block.getByPosition(i).column = block.getByPosition(i).column->cut(0, 1);
filter.resize(1);
for (size_t right_table_column : probe_process_info.cross_join_data->right_column_index_in_result_block)
{
auto & column = block.getByPosition(right_table_column);
if (output_column_names_set_after_finalize.contains(column.name))
{
auto full_column
= column.column->isColumnConst() ? column.column->convertToFullColumnIfConst() : column.column;
if (!full_column->isColumnNullable())
{
throw Exception("Should not reach here, the right table column for left join must be nullable");
}
auto current_column = full_column;
auto result_column = (*std::move(current_column)).mutate();
static_cast<ColumnNullable &>(*result_column).applyNegatedNullMap(*filter_column);
column.column = std::move(result_column);
}
}
erase_useless_column(block);
applyNullToNotMatchedRows(block, sample_block_without_keys, *filter_column);
}
else
{
erase_useless_column(block);
block = block.cloneEmpty();
}
return;
}
/// case 3, semi join
if (kind == ASTTableJoin::Kind::Cross_Semi)
{
erase_useless_column(block);
if (probe_process_info.cross_join_data->has_row_matched)
{
/// has matched rows, return the first row, and set the current row probe done
Expand All @@ -1202,7 +1178,6 @@ void Join::handleOtherConditionsForOneProbeRow(Block & block, ProbeProcessInfo &
/// case 4, anti join
if (kind == ASTTableJoin::Kind::Cross_Anti)
{
erase_useless_column(block);
if (probe_process_info.cross_join_data->has_row_matched)
{
block = block.cloneEmpty();
Expand All @@ -1222,7 +1197,6 @@ void Join::handleOtherConditionsForOneProbeRow(Block & block, ProbeProcessInfo &
/// case 5, left outer semi join
if (isLeftOuterSemiFamily(kind))
{
erase_useless_column(block);
if (probe_process_info.cross_join_data->has_row_matched || probe_process_info.isCurrentProbeRowFinished())
{
for (size_t i = 0; i < block.columns(); ++i)
Expand Down Expand Up @@ -1284,14 +1258,6 @@ Block Join::doJoinBlockHash(ProbeProcessInfo & probe_process_info, const JoinBui
num_columns_to_add.push_back(i);
}

std::vector<size_t> right_table_column_indexes;
right_table_column_indexes.reserve(num_columns_to_add.size());

for (size_t i = 0; i < num_columns_to_add.size(); ++i)
{
right_table_column_indexes.push_back(i + existing_columns);
}

MutableColumns added_columns;
added_columns.reserve(num_columns_to_add.size());

Expand Down Expand Up @@ -1389,7 +1355,7 @@ Block Join::doJoinBlockHash(ProbeProcessInfo & probe_process_info, const JoinBui
if (has_other_condition)
{
assert(offsets_to_replicate != nullptr);
handleOtherConditions(block, nullptr, offsets_to_replicate.get(), right_table_column_indexes);
handleOtherConditions(block, nullptr, offsets_to_replicate.get());

if (useRowFlaggedHashMap(kind, has_other_condition))
{
Expand Down Expand Up @@ -1484,8 +1450,7 @@ Block Join::doJoinBlockCross(ProbeProcessInfo & probe_process_info) const
handleOtherConditions(
block,
probe_process_info.filter.get(),
probe_process_info.offsets_to_replicate.get(),
probe_process_info.cross_join_data->right_column_index_in_result_block);
probe_process_info.offsets_to_replicate.get());
}
return block;
}
Expand Down Expand Up @@ -1528,8 +1493,7 @@ Block Join::doJoinBlockCross(ProbeProcessInfo & probe_process_info) const
handleOtherConditions(
block,
probe_process_info.filter.get(),
probe_process_info.offsets_to_replicate.get(),
probe_process_info.cross_join_data->right_column_index_in_result_block);
probe_process_info.offsets_to_replicate.get());
}
return block;
}
Expand Down Expand Up @@ -2647,16 +2611,17 @@ void Join::finalize(const Names & parent_require)
updated_require.push_back(non_equal_conditions.other_eq_cond_from_in_name);
if (!non_equal_conditions.other_cond_name.empty())
updated_require.push_back(non_equal_conditions.other_cond_name);
auto keep_used_input_columns
= !isCrossJoin(kind) && (isNullAwareSemiFamily(kind) || isSemiFamily(kind) || isLeftOuterSemiFamily(kind));
/// nullaware/semi join will reuse the input columns so need to let finalize keep the input columns
if (non_equal_conditions.null_aware_eq_cond_expr != nullptr)
{
non_equal_conditions.null_aware_eq_cond_expr->finalize(updated_require, true);
non_equal_conditions.null_aware_eq_cond_expr->finalize(updated_require, keep_used_input_columns);
updated_require = non_equal_conditions.null_aware_eq_cond_expr->getRequiredColumns();
}
if (non_equal_conditions.other_cond_expr != nullptr)
{
/// todo don't keep input columns for non-semi/non-nullaware join
non_equal_conditions.other_cond_expr->finalize(updated_require, true);
non_equal_conditions.other_cond_expr->finalize(updated_require, keep_used_input_columns);
updated_require = non_equal_conditions.other_cond_expr->getRequiredColumns();
}
/// remove duplicated column
Expand Down
6 changes: 1 addition & 5 deletions dbms/src/Interpreters/Join.h
Original file line number Diff line number Diff line change
Expand Up @@ -481,11 +481,7 @@ class Join
*
* @param block
*/
void handleOtherConditions(
Block & block,
IColumn::Filter * filter,
IColumn::Offsets * offsets_to_replicate,
const std::vector<size_t> & right_table_column) const;
void handleOtherConditions(Block & block, IColumn::Filter * filter, IColumn::Offsets * offsets_to_replicate) const;

void handleOtherConditionsForOneProbeRow(Block & block, ProbeProcessInfo & probe_process_info) const;

Expand Down
3 changes: 0 additions & 3 deletions dbms/src/Interpreters/ProbeProcessInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -160,9 +160,6 @@ void ProbeProcessInfo::prepareForCrossProbe(
cross_join_data->right_column_index_in_right_block.push_back(i);
}
}
auto offset = cross_join_data->left_column_index_in_left_block.size();
for (size_t i = 0; i < cross_join_data->right_column_index_in_right_block.size(); ++i)
cross_join_data->right_column_index_in_result_block.push_back(offset + i);
}
if (cross_join_data->cross_probe_mode == CrossProbeMode::SHALLOW_COPY_RIGHT_BLOCK && null_map != nullptr)
cross_join_data->row_num_filtered_by_left_condition = countBytesInFilter(*null_map);
Expand Down
1 change: 0 additions & 1 deletion dbms/src/Interpreters/ProbeProcessInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ struct HashJoinProbeProcessData
struct CrossJoinProbeProcessData
{
Block result_block_schema;
std::vector<size_t> right_column_index_in_result_block;
std::vector<size_t> right_column_index_in_right_block;
std::vector<size_t> left_column_index_in_left_block;
size_t right_rows_to_be_added_when_matched = 0;
Expand Down

0 comments on commit bd99043

Please sign in to comment.