Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

merge output blocks if need in hash join #6529

Merged
merged 17 commits into from
Dec 27, 2022
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions dbms/src/Core/Block.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -514,6 +514,36 @@ static ReturnType checkBlockStructure(const Block & lhs, const Block & rhs, cons
return ReturnType(true);
}

Block mergeBlocks(Blocks && blocks)
{
assert(!blocks.empty());
auto & sample_block = blocks[0];
size_t result_rows = 0;
for (const auto & block : blocks)
{
result_rows += block.rows();
}

MutableColumns dst_columns(sample_block.columns());
mengxin9014 marked this conversation as resolved.
Show resolved Hide resolved

for (size_t i = 0; i < sample_block.columns(); i++)
{
dst_columns[i] = (*std::move(sample_block.getByPosition(i).column)).mutate();
dst_columns[i]->reserve(result_rows);
}

for (size_t i = 1; i < blocks.size(); ++i)
{
if (blocks[i].rows() > 0)
SeaRise marked this conversation as resolved.
Show resolved Hide resolved
{
for (size_t column = 0; column < blocks[i].columns(); column++)
SeaRise marked this conversation as resolved.
Show resolved Hide resolved
{
dst_columns[column]->insertRangeFrom(*blocks[i].getByPosition(column).column, 0, blocks[i].rows());
}
}
}
return sample_block.cloneWithColumns(std::move(dst_columns));
}

bool blocksHaveEqualStructure(const Block & lhs, const Block & rhs)
{
Expand Down
2 changes: 2 additions & 0 deletions dbms/src/Core/Block.h
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ class Block
*/
void updateHash(SipHash & hash) const;


private:
void eraseImpl(size_t position);
void initializeIndexByName();
Expand All @@ -157,6 +158,7 @@ class Block
using Blocks = std::vector<Block>;
using BlocksList = std::list<Block>;

Block mergeBlocks(Blocks && blocks);

/// Compare number of columns, data types, column types, column names, and values of constant columns.
bool blocksHaveEqualStructure(const Block & lhs, const Block & rhs);
Expand Down
71 changes: 68 additions & 3 deletions dbms/src/DataStreams/HashJoinProbeBlockInputStream.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ HashJoinProbeBlockInputStream::HashJoinProbeBlockInputStream(
: log(Logger::get(req_id))
, join(join_)
, probe_process_info(max_block_size)
, join_finished(false)
{
children.push_back(input);

Expand Down Expand Up @@ -67,17 +68,81 @@ Block HashJoinProbeBlockInputStream::getHeader() const

Block HashJoinProbeBlockInputStream::readImpl()
{
if (probe_process_info.all_rows_joined_finish)
// if join finished, return {}
if (join_finished)
{
return Block{};
}

result_blocks.clear();
size_t output_rows = 0;

// if over_limit_block is not null, we need to push it into result_blocks first.
if (over_limit_block)
SeaRise marked this conversation as resolved.
Show resolved Hide resolved
{
output_rows += over_limit_block.rows();
result_blocks.push_back(std::move(over_limit_block));
over_limit_block = Block{};
yibin87 marked this conversation as resolved.
Show resolved Hide resolved
}

while (output_rows <= probe_process_info.max_block_size)
{
Block result_block = getOutputBlock(probe_process_info);

if (!result_block)
{
// if result blocks is not empty, merge and return them, then mark join finished.
if (!result_blocks.empty())
{
join_finished = true;
return mergeResultBlocks(std::move(result_blocks));
}
// if result blocks is empty, return result block directly.
return result_block;
}
size_t current_rows = result_block.rows();

if (!output_rows || output_rows + current_rows <= probe_process_info.max_block_size)
{
result_blocks.push_back(result_block);
}
else
{
// if output_rows + current_rows > max block size, put the current result block into over_limit_block and handle it in next read.
over_limit_block = result_block;
}
output_rows += current_rows;
}

return mergeResultBlocks(std::move(result_blocks));
}

Block HashJoinProbeBlockInputStream::getOutputBlock(ProbeProcessInfo & probe_process_info_) const
SeaRise marked this conversation as resolved.
Show resolved Hide resolved
{
if (probe_process_info_.all_rows_joined_finish)
{
Block block = children.back()->read();
if (!block)
{
return block;
}
join->checkTypes(block);
probe_process_info.resetBlock(std::move(block));
probe_process_info_.resetBlock(std::move(block));
}

return join->joinBlock(probe_process_info);
return join->joinBlock(probe_process_info_);
}

Block HashJoinProbeBlockInputStream::mergeResultBlocks(Blocks && result_blocks)
{
if (result_blocks.size() == 1)
{
return result_blocks[0];
}
else
{
return mergeBlocks(std::move(result_blocks));
}
}

} // namespace DB
5 changes: 5 additions & 0 deletions dbms/src/DataStreams/HashJoinProbeBlockInputStream.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ class HashJoinProbeBlockInputStream : public IProfilingBlockInputStream
String getName() const override { return name; }
Block getTotals() override;
Block getHeader() const override;
Block getOutputBlock(ProbeProcessInfo & probe_process_info_) const;
static Block mergeResultBlocks(Blocks && result_blocks);

protected:
Block readImpl() override;
Expand All @@ -52,6 +54,9 @@ class HashJoinProbeBlockInputStream : public IProfilingBlockInputStream
const LoggerPtr log;
JoinPtr join;
ProbeProcessInfo probe_process_info;
Blocks result_blocks;
Block over_limit_block;
SeaRise marked this conversation as resolved.
Show resolved Hide resolved
bool join_finished;
};

} // namespace DB
21 changes: 20 additions & 1 deletion dbms/src/Flash/tests/gtest_join_executor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -700,7 +700,7 @@ CATCH
TEST_F(JoinExecutorTestRunner, SplitJoinResult)
try
{
context.addMockTable("split_test", "t1", {{"a", TiDB::TP::TypeLong}}, {toVec<Int32>("a", {1, 1, 1, 1, 1, 1, 1, 1, 1, 1})});
context.addMockTable("split_test", "t1", {{"a", TiDB::TP::TypeLong}, {"b", TiDB::TP::TypeLong}}, {toVec<Int32>("a", {1, 1, 1, 1, 1, 1, 1, 1, 1, 1}), toVec<Int32>("b", {1, 1, 3, 3, 1, 1, 3, 3, 1, 3})});
context.addMockTable("split_test", "t2", {{"a", TiDB::TP::TypeLong}}, {toVec<Int32>("a", {1, 1, 1, 1, 1})});

auto request = context
Expand All @@ -720,6 +720,25 @@ try
ASSERT_EQ(expect[i][j], blocks[j].rows());
}
}

// with other condition
const auto cond = gt(col("b"), lit(Field(static_cast<Int64>(2))));
request = context
.scan("split_test", "t1")
.join(context.scan("split_test", "t2"), tipb::JoinType::TypeInnerJoin, {col("a")}, {}, {}, {cond}, {})

.build(context);
expect = {{5, 5, 5, 5, 5}, {5, 5, 5, 5, 5}, {5, 5, 5, 5, 5}, {25}, {25}, {25}, {25}, {25}};
for (size_t i = 0; i < block_sizes.size(); ++i)
{
context.context.setSetting("max_block_size", Field(static_cast<UInt64>(block_sizes[i])));
auto blocks = getExecuteStreamsReturnBlocks(request);
ASSERT_EQ(expect[i].size(), blocks.size());
for (size_t j = 0; j < blocks.size(); ++j)
{
ASSERT_EQ(expect[i][j], blocks[j].rows());
}
}
}
CATCH

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@ Block SegmentTestBasic::prepareWriteBlock(Int64 start_key, Int64 end_key, bool i
is_deleted);
}

Block mergeBlocks(std::vector<Block> && blocks)
Block sortMergeBlocks(std::vector<Block> && blocks)
{
auto accumulated_block = std::move(blocks[0]);
SeaRise marked this conversation as resolved.
Show resolved Hide resolved

Expand Down Expand Up @@ -391,7 +391,7 @@ Block SegmentTestBasic::prepareWriteBlockInSegmentRange(PageId segment_id, UInt6
remaining_rows);
}

return mergeBlocks(std::move(blocks));
return sortMergeBlocks(std::move(blocks));
}

void SegmentTestBasic::writeSegment(PageId segment_id, UInt64 write_rows, std::optional<Int64> start_at)
Expand Down
6 changes: 3 additions & 3 deletions dbms/src/TestUtils/ExecutorTestUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ void ExecutorTest::executeAndAssertRowsEqual(const std::shared_ptr<tipb::DAGRequ
});
}

Block mergeBlocks(Blocks blocks)
Block mergeBlocksForTest(Blocks blocks)
SeaRise marked this conversation as resolved.
Show resolved Hide resolved
{
if (blocks.empty())
return {};
Expand Down Expand Up @@ -214,7 +214,7 @@ DB::ColumnsWithTypeAndName readBlocks(std::vector<BlockInputStreamPtr> streams)
Blocks actual_blocks;
for (const auto & stream : streams)
readStream(actual_blocks, stream);
return mergeBlocks(actual_blocks).getColumnsWithTypeAndName();
return mergeBlocksForTest(actual_blocks).getColumnsWithTypeAndName();
}

void ExecutorTest::enablePlanner(bool is_enable)
Expand All @@ -231,7 +231,7 @@ DB::ColumnsWithTypeAndName ExecutorTest::executeStreams(const std::shared_ptr<ti
// Currently, don't care about regions information in tests.
Blocks blocks;
queryExecute(context.context, /*internal=*/true)->execute([&blocks](const Block & block) { blocks.push_back(block); }).verify();
return mergeBlocks(blocks).getColumnsWithTypeAndName();
return mergeBlocksForTest(blocks).getColumnsWithTypeAndName();
}

Blocks ExecutorTest::getExecuteStreamsReturnBlocks(const std::shared_ptr<tipb::DAGRequest> & request, size_t concurrency)
Expand Down
2 changes: 1 addition & 1 deletion dbms/src/TestUtils/ExecutorTestUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ TiDB::TP dataTypeToTP(const DataTypePtr & type);

ColumnsWithTypeAndName readBlock(BlockInputStreamPtr stream);
ColumnsWithTypeAndName readBlocks(std::vector<BlockInputStreamPtr> streams);
Block mergeBlocks(Blocks blocks);
Block mergeBlocksForTest(Blocks blocks);


#define WRAP_FOR_DIS_ENABLE_PLANNER_BEGIN \
Expand Down
2 changes: 1 addition & 1 deletion dbms/src/TestUtils/MPPTaskTestUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ ColumnsWithTypeAndName extractColumns(Context & context, const std::shared_ptr<t
auto schema = getSelectSchema(context);
for (const auto & chunk : dag_response->chunks())
blocks.emplace_back(codec->decode(chunk.rows_data(), schema));
return mergeBlocks(blocks).getColumnsWithTypeAndName();
return mergeBlocksForTest(blocks).getColumnsWithTypeAndName();
}

ColumnsWithTypeAndName MPPTaskTestUtils::executeCoprocessorTask(std::shared_ptr<tipb::DAGRequest> & dag_request)
Expand Down