From 816b8d5d5dff61b7fa452311c0b666947c34deb8 Mon Sep 17 00:00:00 2001 From: Meng Xin Date: Tue, 27 Dec 2022 22:46:16 +0800 Subject: [PATCH] merge output blocks if need in hash join (#6529) close pingcap/tiflash#6533 --- dbms/src/Core/Block.cpp | 39 ++++++++ dbms/src/Core/Block.h | 2 + .../HashJoinProbeBlockInputStream.cpp | 21 ++++- .../HashJoinProbeBlockInputStream.h | 4 +- .../SquashingHashJoinBlockTransform.cpp | 85 ++++++++++++++++++ .../SquashingHashJoinBlockTransform.h | 44 +++++++++ dbms/src/Flash/tests/gtest_join_executor.cpp | 21 ++++- .../gtest_squashing_hash_join_transform.cpp | 89 +++++++++++++++++++ .../tests/gtest_segment_test_basic.cpp | 4 +- dbms/src/TestUtils/ExecutorTestUtils.cpp | 6 +- dbms/src/TestUtils/ExecutorTestUtils.h | 2 - dbms/src/TestUtils/MPPTaskTestUtils.cpp | 2 +- 12 files changed, 307 insertions(+), 12 deletions(-) create mode 100644 dbms/src/DataStreams/SquashingHashJoinBlockTransform.cpp create mode 100644 dbms/src/DataStreams/SquashingHashJoinBlockTransform.h create mode 100644 dbms/src/Flash/tests/gtest_squashing_hash_join_transform.cpp diff --git a/dbms/src/Core/Block.cpp b/dbms/src/Core/Block.cpp index 69fc45ec3c1..b8adade5a84 100644 --- a/dbms/src/Core/Block.cpp +++ b/dbms/src/Core/Block.cpp @@ -514,6 +514,45 @@ static ReturnType checkBlockStructure(const Block & lhs, const Block & rhs, cons return ReturnType(true); } +Block mergeBlocks(Blocks && blocks) +{ + if (blocks.empty()) + { + return {}; + } + + if (blocks.size() == 1) + { + return std::move(blocks[0]); + } + + auto & first_block = blocks[0]; + size_t result_rows = 0; + for (const auto & block : blocks) + { + result_rows += block.rows(); + } + + MutableColumns dst_columns(first_block.columns()); + + for (size_t i = 0; i < first_block.columns(); ++i) + { + dst_columns[i] = (*std::move(first_block.getByPosition(i).column)).mutate(); + dst_columns[i]->reserve(result_rows); + } + + for (size_t i = 1; i < blocks.size(); ++i) + { + if (likely(blocks[i].rows()) > 0) + { + for (size_t column = 0; column < blocks[i].columns(); ++column) + { + dst_columns[column]->insertRangeFrom(*blocks[i].getByPosition(column).column, 0, blocks[i].rows()); + } + } + } + return first_block.cloneWithColumns(std::move(dst_columns)); +} bool blocksHaveEqualStructure(const Block & lhs, const Block & rhs) { diff --git a/dbms/src/Core/Block.h b/dbms/src/Core/Block.h index 206d6d959cc..5f2cabe7859 100644 --- a/dbms/src/Core/Block.h +++ b/dbms/src/Core/Block.h @@ -149,6 +149,7 @@ class Block */ void updateHash(SipHash & hash) const; + private: void eraseImpl(size_t position); void initializeIndexByName(); @@ -157,6 +158,7 @@ class Block using Blocks = std::vector; using BlocksList = std::list; +Block mergeBlocks(Blocks && blocks); /// Compare number of columns, data types, column types, column names, and values of constant columns. bool blocksHaveEqualStructure(const Block & lhs, const Block & rhs); diff --git a/dbms/src/DataStreams/HashJoinProbeBlockInputStream.cpp b/dbms/src/DataStreams/HashJoinProbeBlockInputStream.cpp index b7ae64cfafc..2fd304b162f 100644 --- a/dbms/src/DataStreams/HashJoinProbeBlockInputStream.cpp +++ b/dbms/src/DataStreams/HashJoinProbeBlockInputStream.cpp @@ -13,7 +13,6 @@ // limitations under the License. #include -#include namespace DB { @@ -25,6 +24,7 @@ HashJoinProbeBlockInputStream::HashJoinProbeBlockInputStream( : log(Logger::get(req_id)) , join(join_) , probe_process_info(max_block_size) + , squashing_transform(max_block_size) { children.push_back(input); @@ -66,12 +66,30 @@ Block HashJoinProbeBlockInputStream::getHeader() const } Block HashJoinProbeBlockInputStream::readImpl() +{ + // if join finished, return {} directly. + if (squashing_transform.isJoinFinished()) + { + return Block{}; + } + + while (squashing_transform.needAppendBlock()) + { + Block result_block = getOutputBlock(); + squashing_transform.appendBlock(result_block); + } + return squashing_transform.getFinalOutputBlock(); +} + +Block HashJoinProbeBlockInputStream::getOutputBlock() { if (probe_process_info.all_rows_joined_finish) { Block block = children.back()->read(); if (!block) + { return block; + } join->checkTypes(block); probe_process_info.resetBlock(std::move(block)); } @@ -79,5 +97,4 @@ Block HashJoinProbeBlockInputStream::readImpl() return join->joinBlock(probe_process_info); } - } // namespace DB diff --git a/dbms/src/DataStreams/HashJoinProbeBlockInputStream.h b/dbms/src/DataStreams/HashJoinProbeBlockInputStream.h index 3cc6fc4af6b..cf6e557d32c 100644 --- a/dbms/src/DataStreams/HashJoinProbeBlockInputStream.h +++ b/dbms/src/DataStreams/HashJoinProbeBlockInputStream.h @@ -15,12 +15,12 @@ #pragma once #include +#include #include namespace DB { - /** Executes a certain expression over the block. * Basically the same as ExpressionBlockInputStream, * but requires that there must be a join probe action in the Expression. @@ -47,11 +47,13 @@ class HashJoinProbeBlockInputStream : public IProfilingBlockInputStream protected: Block readImpl() override; + Block getOutputBlock(); private: const LoggerPtr log; JoinPtr join; ProbeProcessInfo probe_process_info; + SquashingHashJoinBlockTransform squashing_transform; }; } // namespace DB diff --git a/dbms/src/DataStreams/SquashingHashJoinBlockTransform.cpp b/dbms/src/DataStreams/SquashingHashJoinBlockTransform.cpp new file mode 100644 index 00000000000..9c876d7883d --- /dev/null +++ b/dbms/src/DataStreams/SquashingHashJoinBlockTransform.cpp @@ -0,0 +1,85 @@ +// Copyright 2022 PingCAP, Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +namespace DB +{ + +SquashingHashJoinBlockTransform::SquashingHashJoinBlockTransform(UInt64 max_block_size_) + : output_rows(0) + , max_block_size(max_block_size_) + , join_finished(false) +{} + +void SquashingHashJoinBlockTransform::handleOverLimitBlock() +{ + // if over_limit_block is not null, we need to push it into blocks. + if (over_limit_block) + { + assert(!(output_rows && blocks.empty())); + output_rows += over_limit_block->rows(); + blocks.push_back(std::move(over_limit_block.value())); + over_limit_block.reset(); + } +} + +void SquashingHashJoinBlockTransform::appendBlock(Block & block) +{ + if (!block) + { + // if append block is {}, mark join finished. + join_finished = true; + return; + } + size_t current_rows = block.rows(); + + if (!output_rows || output_rows + current_rows <= max_block_size) + { + blocks.push_back(std::move(block)); + output_rows += current_rows; + } + else + { + // if output_rows + current_rows > max block size, put the current result block into over_limit_block and handle it in next read. + assert(!over_limit_block); + over_limit_block.emplace(std::move(block)); + } +} + +Block SquashingHashJoinBlockTransform::getFinalOutputBlock() +{ + Block final_block = mergeBlocks(std::move(blocks)); + reset(); + handleOverLimitBlock(); + return final_block; +} + +void SquashingHashJoinBlockTransform::reset() +{ + blocks.clear(); + output_rows = 0; +} + +bool SquashingHashJoinBlockTransform::isJoinFinished() const +{ + return join_finished; +} + +bool SquashingHashJoinBlockTransform::needAppendBlock() const +{ + return !over_limit_block && !join_finished; +} + +} // namespace DB diff --git a/dbms/src/DataStreams/SquashingHashJoinBlockTransform.h b/dbms/src/DataStreams/SquashingHashJoinBlockTransform.h new file mode 100644 index 00000000000..956dac0903f --- /dev/null +++ b/dbms/src/DataStreams/SquashingHashJoinBlockTransform.h @@ -0,0 +1,44 @@ +// Copyright 2022 PingCAP, Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +namespace DB +{ + +class SquashingHashJoinBlockTransform +{ +public: + SquashingHashJoinBlockTransform(UInt64 max_block_size_); + + void appendBlock(Block & block); + Block getFinalOutputBlock(); + bool isJoinFinished() const; + bool needAppendBlock() const; + + +private: + void handleOverLimitBlock(); + void reset(); + + Blocks blocks; + std::optional over_limit_block; + size_t output_rows; + UInt64 max_block_size; + bool join_finished; +}; + +} // namespace DB \ No newline at end of file diff --git a/dbms/src/Flash/tests/gtest_join_executor.cpp b/dbms/src/Flash/tests/gtest_join_executor.cpp index b7e3ff58683..e20a19b3174 100644 --- a/dbms/src/Flash/tests/gtest_join_executor.cpp +++ b/dbms/src/Flash/tests/gtest_join_executor.cpp @@ -700,7 +700,7 @@ CATCH TEST_F(JoinExecutorTestRunner, SplitJoinResult) try { - context.addMockTable("split_test", "t1", {{"a", TiDB::TP::TypeLong}}, {toVec("a", {1, 1, 1, 1, 1, 1, 1, 1, 1, 1})}); + context.addMockTable("split_test", "t1", {{"a", TiDB::TP::TypeLong}, {"b", TiDB::TP::TypeLong}}, {toVec("a", {1, 1, 1, 1, 1, 1, 1, 1, 1, 1}), toVec("b", {1, 1, 3, 3, 1, 1, 3, 3, 1, 3})}); context.addMockTable("split_test", "t2", {{"a", TiDB::TP::TypeLong}}, {toVec("a", {1, 1, 1, 1, 1})}); auto request = context @@ -720,6 +720,25 @@ try ASSERT_EQ(expect[i][j], blocks[j].rows()); } } + + // with other condition + const auto cond = gt(col("b"), lit(Field(static_cast(2)))); + request = context + .scan("split_test", "t1") + .join(context.scan("split_test", "t2"), tipb::JoinType::TypeInnerJoin, {col("a")}, {}, {}, {cond}, {}) + + .build(context); + expect = {{5, 5, 5, 5, 5}, {5, 5, 5, 5, 5}, {5, 5, 5, 5, 5}, {25}, {25}, {25}, {25}, {25}}; + for (size_t i = 0; i < block_sizes.size(); ++i) + { + context.context.setSetting("max_block_size", Field(static_cast(block_sizes[i]))); + auto blocks = getExecuteStreamsReturnBlocks(request); + ASSERT_EQ(expect[i].size(), blocks.size()); + for (size_t j = 0; j < blocks.size(); ++j) + { + ASSERT_EQ(expect[i][j], blocks[j].rows()); + } + } } CATCH diff --git a/dbms/src/Flash/tests/gtest_squashing_hash_join_transform.cpp b/dbms/src/Flash/tests/gtest_squashing_hash_join_transform.cpp new file mode 100644 index 00000000000..1f61878da48 --- /dev/null +++ b/dbms/src/Flash/tests/gtest_squashing_hash_join_transform.cpp @@ -0,0 +1,89 @@ +// Copyright 2022 PingCAP, Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include + + +namespace DB +{ +namespace tests +{ +class SquashingHashJoinBlockTransformTest : public ::testing::Test +{ +public: + void SetUp() override {} + static ColumnWithTypeAndName toVec(const std::vector & v) + { + return createColumn(v); + } + + static void check(Blocks blocks, UInt64 max_block_size) + { + for (size_t i = 0; i < blocks.size(); ++i) + { + ASSERT(blocks[i].rows() <= max_block_size); + } + } +}; + +TEST_F(SquashingHashJoinBlockTransformTest, testALL) +try +{ + std::vector block_size{1, 5, 10, 99, 999, 9999, 39999, DEFAULT_BLOCK_SIZE}; + size_t merge_block_count = 10000; + + for (auto size : block_size) + { + Int64 expect_rows = 0; + Blocks test_blocks; + + for (size_t i = 0; i < merge_block_count; ++i) + { + size_t rand_block_size = std::rand() % size + 1; + expect_rows += rand_block_size; + std::vector values; + for (size_t j = 0; j < rand_block_size; ++j) + { + values.push_back(1); + } + Block block{toVec(values)}; + test_blocks.push_back(block); + } + test_blocks.push_back(Block{}); + + Blocks final_blocks; + size_t index = 0; + Int64 actual_rows = 0; + SquashingHashJoinBlockTransform squashing_transform(size); + while (!squashing_transform.isJoinFinished()) + { + while (squashing_transform.needAppendBlock()) + { + Block result_block = test_blocks[index++]; + squashing_transform.appendBlock(result_block); + } + final_blocks.push_back(squashing_transform.getFinalOutputBlock()); + actual_rows += final_blocks.back().rows(); + } + check(final_blocks, std::min(size, expect_rows)); + ASSERT(actual_rows == expect_rows); + } +} +CATCH + +} // namespace tests +} // namespace DB \ No newline at end of file diff --git a/dbms/src/Storages/DeltaMerge/tests/gtest_segment_test_basic.cpp b/dbms/src/Storages/DeltaMerge/tests/gtest_segment_test_basic.cpp index 617a1e7f13c..a890c868d0b 100644 --- a/dbms/src/Storages/DeltaMerge/tests/gtest_segment_test_basic.cpp +++ b/dbms/src/Storages/DeltaMerge/tests/gtest_segment_test_basic.cpp @@ -301,7 +301,7 @@ Block SegmentTestBasic::prepareWriteBlock(Int64 start_key, Int64 end_key, bool i is_deleted); } -Block mergeBlocks(std::vector && blocks) +Block sortMergeBlocks(std::vector && blocks) { auto accumulated_block = std::move(blocks[0]); @@ -391,7 +391,7 @@ Block SegmentTestBasic::prepareWriteBlockInSegmentRange(PageId segment_id, UInt6 remaining_rows); } - return mergeBlocks(std::move(blocks)); + return sortMergeBlocks(std::move(blocks)); } void SegmentTestBasic::writeSegment(PageId segment_id, UInt64 write_rows, std::optional start_at) diff --git a/dbms/src/TestUtils/ExecutorTestUtils.cpp b/dbms/src/TestUtils/ExecutorTestUtils.cpp index dc279f3ea7f..7719d021b37 100644 --- a/dbms/src/TestUtils/ExecutorTestUtils.cpp +++ b/dbms/src/TestUtils/ExecutorTestUtils.cpp @@ -166,7 +166,7 @@ void ExecutorTest::executeAndAssertRowsEqual(const std::shared_ptr streams) Blocks actual_blocks; for (const auto & stream : streams) readStream(actual_blocks, stream); - return mergeBlocks(actual_blocks).getColumnsWithTypeAndName(); + return mergeBlocksForTest(std::move(actual_blocks)).getColumnsWithTypeAndName(); } void ExecutorTest::enablePlanner(bool is_enable) @@ -238,7 +238,7 @@ ColumnsWithTypeAndName ExecutorTest::executeStreams(DAGContext * dag_context) // Currently, don't care about regions information in tests. Blocks blocks; queryExecute(context.context, /*internal=*/true)->execute([&blocks](const Block & block) { blocks.push_back(block); }).verify(); - return mergeBlocks(blocks).getColumnsWithTypeAndName(); + return mergeBlocksForTest(std::move(blocks)).getColumnsWithTypeAndName(); } Blocks ExecutorTest::getExecuteStreamsReturnBlocks(const std::shared_ptr & request, size_t concurrency) diff --git a/dbms/src/TestUtils/ExecutorTestUtils.h b/dbms/src/TestUtils/ExecutorTestUtils.h index ee014e4b069..79c279f2822 100644 --- a/dbms/src/TestUtils/ExecutorTestUtils.h +++ b/dbms/src/TestUtils/ExecutorTestUtils.h @@ -30,8 +30,6 @@ TiDB::TP dataTypeToTP(const DataTypePtr & type); ColumnsWithTypeAndName readBlock(BlockInputStreamPtr stream); ColumnsWithTypeAndName readBlocks(std::vector streams); -Block mergeBlocks(Blocks blocks); - #define WRAP_FOR_DIS_ENABLE_PLANNER_BEGIN \ std::vector bools{false, true}; \ diff --git a/dbms/src/TestUtils/MPPTaskTestUtils.cpp b/dbms/src/TestUtils/MPPTaskTestUtils.cpp index d33ae8e5910..b187f3e6f5a 100644 --- a/dbms/src/TestUtils/MPPTaskTestUtils.cpp +++ b/dbms/src/TestUtils/MPPTaskTestUtils.cpp @@ -98,7 +98,7 @@ ColumnsWithTypeAndName extractColumns(Context & context, const std::shared_ptrchunks()) blocks.emplace_back(codec->decode(chunk.rows_data(), schema)); - return mergeBlocks(blocks).getColumnsWithTypeAndName(); + return mergeBlocks(std::move(blocks)).getColumnsWithTypeAndName(); } ColumnsWithTypeAndName MPPTaskTestUtils::executeCoprocessorTask(std::shared_ptr & dag_request)