diff --git a/dbms/src/DataStreams/SquashingHashJoinBlockTransform.cpp b/dbms/src/DataStreams/SquashingHashJoinBlockTransform.cpp deleted file mode 100644 index 7a43984088c..00000000000 --- a/dbms/src/DataStreams/SquashingHashJoinBlockTransform.cpp +++ /dev/null @@ -1,85 +0,0 @@ -// Copyright 2022 PingCAP, Ltd. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include - -namespace DB -{ - -SquashingHashJoinBlockTransform::SquashingHashJoinBlockTransform(UInt64 max_block_size_) - : output_rows(0) - , max_block_size(max_block_size_) - , join_finished(false) -{} - -void SquashingHashJoinBlockTransform::handleOverLimitBlock() -{ - // if over_limit_block is not null, we need to push it into blocks. - if (over_limit_block) - { - assert(!(output_rows && blocks.empty())); - output_rows += over_limit_block->rows(); - blocks.push_back(std::move(over_limit_block.value())); - over_limit_block.reset(); - } -} - -void SquashingHashJoinBlockTransform::appendBlock(Block & block) -{ - if (!block) - { - // if append block is {}, mark join finished. - join_finished = true; - return; - } - size_t current_rows = block.rows(); - - if (!output_rows || output_rows + current_rows <= max_block_size) - { - blocks.push_back(std::move(block)); - output_rows += current_rows; - } - else - { - // if output_rows + current_rows > max block size, put the current result block into over_limit_block and handle it in next read. - assert(!over_limit_block); - over_limit_block.emplace(std::move(block)); - } -} - -Block SquashingHashJoinBlockTransform::getFinalOutputBlock() -{ - Block final_block = vstackBlocks(std::move(blocks)); - reset(); - handleOverLimitBlock(); - return final_block; -} - -void SquashingHashJoinBlockTransform::reset() -{ - blocks.clear(); - output_rows = 0; -} - -bool SquashingHashJoinBlockTransform::isJoinFinished() const -{ - return join_finished; -} - -bool SquashingHashJoinBlockTransform::needAppendBlock() const -{ - return !over_limit_block && !join_finished; -} - -} // namespace DB diff --git a/dbms/src/DataStreams/SquashingHashJoinBlockTransform.h b/dbms/src/DataStreams/SquashingHashJoinBlockTransform.h deleted file mode 100644 index 956dac0903f..00000000000 --- a/dbms/src/DataStreams/SquashingHashJoinBlockTransform.h +++ /dev/null @@ -1,44 +0,0 @@ -// Copyright 2022 PingCAP, Ltd. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include - -namespace DB -{ - -class SquashingHashJoinBlockTransform -{ -public: - SquashingHashJoinBlockTransform(UInt64 max_block_size_); - - void appendBlock(Block & block); - Block getFinalOutputBlock(); - bool isJoinFinished() const; - bool needAppendBlock() const; - - -private: - void handleOverLimitBlock(); - void reset(); - - Blocks blocks; - std::optional over_limit_block; - size_t output_rows; - UInt64 max_block_size; - bool join_finished; -}; - -} // namespace DB \ No newline at end of file diff --git a/dbms/src/DataStreams/tests/gtest_squashing_hash_join_transform.cpp b/dbms/src/DataStreams/tests/gtest_squashing_hash_join_transform.cpp deleted file mode 100644 index 7e38db3f60e..00000000000 --- a/dbms/src/DataStreams/tests/gtest_squashing_hash_join_transform.cpp +++ /dev/null @@ -1,89 +0,0 @@ -// Copyright 2023 PingCAP, Ltd. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include -#include -#include -#include - - -namespace DB -{ -namespace tests -{ -class SquashingHashJoinBlockTransformTest : public ::testing::Test -{ -public: - void SetUp() override {} - static ColumnWithTypeAndName toVec(const std::vector & v) - { - return createColumn(v); - } - - static void check(Blocks blocks, UInt64 max_block_size) - { - for (auto & block : blocks) - { - ASSERT(block.rows() <= max_block_size); - } - } -}; - -TEST_F(SquashingHashJoinBlockTransformTest, testALL) -try -{ - std::vector block_size{1, 5, 10, 99, 999, 9999, 39999, DEFAULT_BLOCK_SIZE}; - size_t merge_block_count = 10000; - - for (auto size : block_size) - { - Int64 expect_rows = 0; - Blocks test_blocks; - - for (size_t i = 0; i < merge_block_count; ++i) - { - size_t rand_block_size = std::rand() % size + 1; - expect_rows += rand_block_size; - std::vector values; - for (size_t j = 0; j < rand_block_size; ++j) - { - values.push_back(1); - } - Block block{toVec(values)}; - test_blocks.push_back(block); - } - test_blocks.push_back(Block{}); - - Blocks final_blocks; - size_t index = 0; - Int64 actual_rows = 0; - SquashingHashJoinBlockTransform squashing_transform(size); - while (!squashing_transform.isJoinFinished()) - { - while (squashing_transform.needAppendBlock()) - { - Block result_block = test_blocks[index++]; - squashing_transform.appendBlock(result_block); - } - final_blocks.push_back(squashing_transform.getFinalOutputBlock()); - actual_rows += final_blocks.back().rows(); - } - check(final_blocks, std::min(size, expect_rows)); - ASSERT(actual_rows == expect_rows); - } -} -CATCH - -} // namespace tests -} // namespace DB \ No newline at end of file diff --git a/dbms/src/Flash/tests/gtest_join_executor.cpp b/dbms/src/Flash/tests/gtest_join_executor.cpp index d332011c51d..3471741f5bb 100644 --- a/dbms/src/Flash/tests/gtest_join_executor.cpp +++ b/dbms/src/Flash/tests/gtest_join_executor.cpp @@ -800,6 +800,65 @@ try } CATCH +TEST_F(JoinExecutorTestRunner, MergeAfterSplit) +try +{ + context.addMockTable("split_test", "t1", {{"a", TiDB::TP::TypeLong}, {"b", TiDB::TP::TypeLong}}, {toVec("a", {1, 1, 1, 1, 1, 1, 1, 1, 1, 1}), toVec("b", {2, 2, 2, 2, 2, 2, 2, 2, 2, 2})}); + context.addMockTable("split_test", "t2", {{"a", TiDB::TP::TypeLong}, {"c", TiDB::TP::TypeLong}}, {toVec("a", {1, 1, 1, 1, 1}), toVec("c", {1, 2, 3, 4, 5})}); + + std::vector block_sizes{ + 1, + 2, + 7, + 25, + 49, + 50, + 51, + DEFAULT_BLOCK_SIZE}; + auto join_types = {tipb::JoinType::TypeInnerJoin, tipb::JoinType::TypeSemiJoin}; + std::vector>> expects{ + { + {1, 1, 1, 1, 1, 1, 1, 1, 1, 1}, + {1, 1, 1, 1, 1, 1, 1, 1, 1, 1}, + {4, 3, 2, 1}, + {5, 5}, + {9, 1}, + {10}, + {10}, + {10}, + }, + { + {1, 1, 1, 1, 1, 1, 1, 1, 1, 1}, + {2, 2, 2, 2, 2}, + {7, 3}, + {10}, + {10}, + {10}, + {10}, + {10}, + }, + }; + for (size_t index = 0; index < join_types.size(); index++) + { + auto request = context + .scan("split_test", "t1") + .join(context.scan("split_test", "t2"), *(join_types.begin() + index), {col("a")}, {}, {}, {gt(col("b"), col("c"))}, {}) + .build(context); + auto & expect = expects[index]; + + for (size_t i = 0; i < block_sizes.size(); ++i) + { + context.context->setSetting("max_block_size", Field(static_cast(block_sizes[i]))); + auto blocks = getExecuteStreamsReturnBlocks(request); + ASSERT_EQ(expect[i].size(), blocks.size()); + for (size_t j = 0; j < blocks.size(); ++j) + { + ASSERT_EQ(expect[i][j], blocks[j].rows()); + } + } + } +} +CATCH TEST_F(JoinExecutorTestRunner, SpillToDisk) try diff --git a/dbms/src/Interpreters/Join.cpp b/dbms/src/Interpreters/Join.cpp index 4d093cef282..09c8c1ca5a3 100644 --- a/dbms/src/Interpreters/Join.cpp +++ b/dbms/src/Interpreters/Join.cpp @@ -130,6 +130,8 @@ Join::Join( , match_helper_name(match_helper_name) , kind(kind_) , strictness(strictness_) + , original_strictness(strictness) + , may_probe_side_expanded_after_join(mayProbeSideExpandedAfterJoin(kind, strictness)) , key_names_left(key_names_left_) , key_names_right(key_names_right_) , build_concurrency(0) @@ -138,7 +140,6 @@ Join::Join( , active_probe_threads(0) , collators(collators_) , non_equal_conditions(non_equal_conditions_) - , original_strictness(strictness) , max_block_size(max_block_size_) , max_bytes_before_external_join(max_bytes_before_external_join_) , build_spill_config(build_spill_config_) @@ -821,8 +822,10 @@ void Join::handleOtherConditions(Block & block, std::unique_ptr throw Exception("Logical error: unknown combination of JOIN", ErrorCodes::LOGICAL_ERROR); } -Block Join::joinBlockHash(ProbeProcessInfo & probe_process_info) const +Block Join::doJoinBlockHash(ProbeProcessInfo & probe_process_info) const { + probe_process_info.updateStartRow(); + /// this makes a copy of `probe_process_info.block` Block block = probe_process_info.block; size_t keys_size = key_names_left.size(); @@ -966,6 +969,26 @@ Block Join::joinBlockHash(ProbeProcessInfo & probe_process_info) const return block; } +Block Join::joinBlockHash(ProbeProcessInfo & probe_process_info) const +{ + std::vector result_blocks; + size_t result_rows = 0; + while (true) + { + auto block = doJoinBlockHash(probe_process_info); + assert(block); + result_rows += block.rows(); + result_blocks.push_back(std::move(block)); + /// exit the while loop if + /// 1. probe_process_info.all_rows_joined_finish is true, which means all the rows in current block is processed + /// 2. the block may be expanded after join and result_rows exceeds the min_result_block_size + if (probe_process_info.all_rows_joined_finish || (may_probe_side_expanded_after_join && result_rows >= probe_process_info.min_result_block_size)) + break; + } + assert(!result_blocks.empty()); + return vstackBlocks(std::move(result_blocks)); +} + namespace { template @@ -1245,6 +1268,8 @@ Block Join::joinBlockCross(ProbeProcessInfo & probe_process_info) const DISPATCH(false) } #undef DISPATCH + /// todo control the returned block size for cross join + probe_process_info.all_rows_joined_finish = true; return block; } @@ -1307,6 +1332,9 @@ Block Join::joinBlockNullAware(ProbeProcessInfo & probe_process_info) const FAIL_POINT_TRIGGER_EXCEPTION(FailPoints::random_join_prob_failpoint); + /// Null aware join never expand the left block, just handle the whole block at one time is enough + probe_process_info.all_rows_joined_finish = true; + return block; } @@ -1585,6 +1613,7 @@ void Join::finishOneNonJoin(size_t partition_index) Block Join::joinBlock(ProbeProcessInfo & probe_process_info, bool dry_run) const { + assert(!probe_process_info.all_rows_joined_finish); if unlikely (dry_run) { assert(probe_process_info.block.rows() == 0); @@ -1600,8 +1629,6 @@ Block Join::joinBlock(ProbeProcessInfo & probe_process_info, bool dry_run) const } std::shared_lock lock(rwlock); - probe_process_info.updateStartRow(); - Block block{}; using enum ASTTableJoin::Strictness; @@ -1629,11 +1656,6 @@ Block Join::joinBlock(ProbeProcessInfo & probe_process_info, bool dry_run) const block.getByName(match_helper_name).column = ColumnNullable::create(std::move(col_non_matched), std::move(nullable_column->getNullMapColumnPtr())); } - if (isCrossJoin(kind) || isNullAwareSemiFamily(kind)) - { - probe_process_info.all_rows_joined_finish = true; - } - return block; } diff --git a/dbms/src/Interpreters/Join.h b/dbms/src/Interpreters/Join.h index e4b167e4168..960a4fc4589 100644 --- a/dbms/src/Interpreters/Join.h +++ b/dbms/src/Interpreters/Join.h @@ -265,6 +265,8 @@ class Join ASTTableJoin::Kind kind; ASTTableJoin::Strictness strictness; + ASTTableJoin::Strictness original_strictness; + const bool may_probe_side_expanded_after_join; /// Names of key columns (columns for equi-JOIN) in "left" table (in the order they appear in USING clause). const Names key_names_left; @@ -290,7 +292,6 @@ class Join const JoinNonEqualConditions non_equal_conditions; - ASTTableJoin::Strictness original_strictness; size_t max_block_size; /** Blocks of "right" table. */ @@ -378,6 +379,7 @@ class Join void insertFromBlockInternal(Block * stored_block, size_t stream_index); Block joinBlockHash(ProbeProcessInfo & probe_process_info) const; + Block doJoinBlockHash(ProbeProcessInfo & probe_process_info) const; Block joinBlockNullAware(ProbeProcessInfo & probe_process_info) const; diff --git a/dbms/src/Interpreters/JoinUtils.cpp b/dbms/src/Interpreters/JoinUtils.cpp index a0b2b73883e..d20c1d0bf65 100644 --- a/dbms/src/Interpreters/JoinUtils.cpp +++ b/dbms/src/Interpreters/JoinUtils.cpp @@ -26,6 +26,8 @@ void ProbeProcessInfo::resetBlock(Block && block_, size_t partition_index_) all_rows_joined_finish = false; // If the probe block size is greater than max_block_size, we will set max_block_size to the probe block size to avoid some unnecessary split. max_block_size = std::max(max_block_size, block.rows()); + // min_result_block_size is use to avoid generating too many small block, use 50% of the block size as the default value + min_result_block_size = std::max(1, (std::min(block.rows(), max_block_size) + 1) / 2); } void ProbeProcessInfo::updateStartRow() @@ -64,4 +66,20 @@ void computeDispatchHash(size_t rows, data[i] = updateHashValue(join_restore_round, data[i]); } } + +bool mayProbeSideExpandedAfterJoin(ASTTableJoin::Kind kind, ASTTableJoin::Strictness strictness) +{ + /// null aware semi/left semi/anti join never expand the probe side + if (isNullAwareSemiFamily(kind)) + return false; + if (isLeftSemiFamily(kind)) + return false; + if (isAntiJoin(kind)) + return false; + /// strictness == Any means semi join, it never expand the probe side + if (strictness == ASTTableJoin::Strictness::Any) + return false; + /// for all the other cases, return true by default + return true; +} } // namespace DB diff --git a/dbms/src/Interpreters/JoinUtils.h b/dbms/src/Interpreters/JoinUtils.h index 6d3f0c154fd..8f60f1e4760 100644 --- a/dbms/src/Interpreters/JoinUtils.h +++ b/dbms/src/Interpreters/JoinUtils.h @@ -58,17 +58,22 @@ inline bool isNullAwareSemiFamily(ASTTableJoin::Kind kind) return kind == ASTTableJoin::Kind::NullAware_Anti || kind == ASTTableJoin::Kind::NullAware_LeftAnti || kind == ASTTableJoin::Kind::NullAware_LeftSemi; } + +bool mayProbeSideExpandedAfterJoin(ASTTableJoin::Kind kind, ASTTableJoin::Strictness strictness); + struct ProbeProcessInfo { Block block; size_t partition_index; UInt64 max_block_size; + UInt64 min_result_block_size; size_t start_row; size_t end_row; bool all_rows_joined_finish; explicit ProbeProcessInfo(UInt64 max_block_size_) : max_block_size(max_block_size_) + , min_result_block_size((max_block_size + 1) / 2) , all_rows_joined_finish(true){}; void resetBlock(Block && block_, size_t partition_index_ = 0);