Skip to content

Commit

Permalink
merge output blocks if need in hash join (pingcap#6529)
Browse files Browse the repository at this point in the history
  • Loading branch information
mengxin9014 authored Dec 27, 2022
1 parent 957d2d4 commit 816b8d5
Show file tree
Hide file tree
Showing 12 changed files with 307 additions and 12 deletions.
39 changes: 39 additions & 0 deletions dbms/src/Core/Block.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -514,6 +514,45 @@ static ReturnType checkBlockStructure(const Block & lhs, const Block & rhs, cons
return ReturnType(true);
}

Block mergeBlocks(Blocks && blocks)
{
if (blocks.empty())
{
return {};
}

if (blocks.size() == 1)
{
return std::move(blocks[0]);
}

auto & first_block = blocks[0];
size_t result_rows = 0;
for (const auto & block : blocks)
{
result_rows += block.rows();
}

MutableColumns dst_columns(first_block.columns());

for (size_t i = 0; i < first_block.columns(); ++i)
{
dst_columns[i] = (*std::move(first_block.getByPosition(i).column)).mutate();
dst_columns[i]->reserve(result_rows);
}

for (size_t i = 1; i < blocks.size(); ++i)
{
if (likely(blocks[i].rows()) > 0)
{
for (size_t column = 0; column < blocks[i].columns(); ++column)
{
dst_columns[column]->insertRangeFrom(*blocks[i].getByPosition(column).column, 0, blocks[i].rows());
}
}
}
return first_block.cloneWithColumns(std::move(dst_columns));
}

bool blocksHaveEqualStructure(const Block & lhs, const Block & rhs)
{
Expand Down
2 changes: 2 additions & 0 deletions dbms/src/Core/Block.h
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ class Block
*/
void updateHash(SipHash & hash) const;


private:
void eraseImpl(size_t position);
void initializeIndexByName();
Expand All @@ -157,6 +158,7 @@ class Block
using Blocks = std::vector<Block>;
using BlocksList = std::list<Block>;

Block mergeBlocks(Blocks && blocks);

/// Compare number of columns, data types, column types, column names, and values of constant columns.
bool blocksHaveEqualStructure(const Block & lhs, const Block & rhs);
Expand Down
21 changes: 19 additions & 2 deletions dbms/src/DataStreams/HashJoinProbeBlockInputStream.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
// limitations under the License.

#include <DataStreams/HashJoinProbeBlockInputStream.h>
#include <Interpreters/ExpressionActions.h>

namespace DB
{
Expand All @@ -25,6 +24,7 @@ HashJoinProbeBlockInputStream::HashJoinProbeBlockInputStream(
: log(Logger::get(req_id))
, join(join_)
, probe_process_info(max_block_size)
, squashing_transform(max_block_size)
{
children.push_back(input);

Expand Down Expand Up @@ -66,18 +66,35 @@ Block HashJoinProbeBlockInputStream::getHeader() const
}

Block HashJoinProbeBlockInputStream::readImpl()
{
// if join finished, return {} directly.
if (squashing_transform.isJoinFinished())
{
return Block{};
}

while (squashing_transform.needAppendBlock())
{
Block result_block = getOutputBlock();
squashing_transform.appendBlock(result_block);
}
return squashing_transform.getFinalOutputBlock();
}

Block HashJoinProbeBlockInputStream::getOutputBlock()
{
if (probe_process_info.all_rows_joined_finish)
{
Block block = children.back()->read();
if (!block)
{
return block;
}
join->checkTypes(block);
probe_process_info.resetBlock(std::move(block));
}

return join->joinBlock(probe_process_info);
}


} // namespace DB
4 changes: 3 additions & 1 deletion dbms/src/DataStreams/HashJoinProbeBlockInputStream.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,12 @@
#pragma once

#include <DataStreams/IProfilingBlockInputStream.h>
#include <DataStreams/SquashingHashJoinBlockTransform.h>
#include <Interpreters/Join.h>

namespace DB
{


/** Executes a certain expression over the block.
* Basically the same as ExpressionBlockInputStream,
* but requires that there must be a join probe action in the Expression.
Expand All @@ -47,11 +47,13 @@ class HashJoinProbeBlockInputStream : public IProfilingBlockInputStream

protected:
Block readImpl() override;
Block getOutputBlock();

private:
const LoggerPtr log;
JoinPtr join;
ProbeProcessInfo probe_process_info;
SquashingHashJoinBlockTransform squashing_transform;
};

} // namespace DB
85 changes: 85 additions & 0 deletions dbms/src/DataStreams/SquashingHashJoinBlockTransform.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
// Copyright 2022 PingCAP, Ltd.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <DataStreams/SquashingHashJoinBlockTransform.h>

namespace DB
{

SquashingHashJoinBlockTransform::SquashingHashJoinBlockTransform(UInt64 max_block_size_)
: output_rows(0)
, max_block_size(max_block_size_)
, join_finished(false)
{}

void SquashingHashJoinBlockTransform::handleOverLimitBlock()
{
// if over_limit_block is not null, we need to push it into blocks.
if (over_limit_block)
{
assert(!(output_rows && blocks.empty()));
output_rows += over_limit_block->rows();
blocks.push_back(std::move(over_limit_block.value()));
over_limit_block.reset();
}
}

void SquashingHashJoinBlockTransform::appendBlock(Block & block)
{
if (!block)
{
// if append block is {}, mark join finished.
join_finished = true;
return;
}
size_t current_rows = block.rows();

if (!output_rows || output_rows + current_rows <= max_block_size)
{
blocks.push_back(std::move(block));
output_rows += current_rows;
}
else
{
// if output_rows + current_rows > max block size, put the current result block into over_limit_block and handle it in next read.
assert(!over_limit_block);
over_limit_block.emplace(std::move(block));
}
}

Block SquashingHashJoinBlockTransform::getFinalOutputBlock()
{
Block final_block = mergeBlocks(std::move(blocks));
reset();
handleOverLimitBlock();
return final_block;
}

void SquashingHashJoinBlockTransform::reset()
{
blocks.clear();
output_rows = 0;
}

bool SquashingHashJoinBlockTransform::isJoinFinished() const
{
return join_finished;
}

bool SquashingHashJoinBlockTransform::needAppendBlock() const
{
return !over_limit_block && !join_finished;
}

} // namespace DB
44 changes: 44 additions & 0 deletions dbms/src/DataStreams/SquashingHashJoinBlockTransform.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
// Copyright 2022 PingCAP, Ltd.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <Core/Block.h>

namespace DB
{

class SquashingHashJoinBlockTransform
{
public:
SquashingHashJoinBlockTransform(UInt64 max_block_size_);

void appendBlock(Block & block);
Block getFinalOutputBlock();
bool isJoinFinished() const;
bool needAppendBlock() const;


private:
void handleOverLimitBlock();
void reset();

Blocks blocks;
std::optional<Block> over_limit_block;
size_t output_rows;
UInt64 max_block_size;
bool join_finished;
};

} // namespace DB
21 changes: 20 additions & 1 deletion dbms/src/Flash/tests/gtest_join_executor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -700,7 +700,7 @@ CATCH
TEST_F(JoinExecutorTestRunner, SplitJoinResult)
try
{
context.addMockTable("split_test", "t1", {{"a", TiDB::TP::TypeLong}}, {toVec<Int32>("a", {1, 1, 1, 1, 1, 1, 1, 1, 1, 1})});
context.addMockTable("split_test", "t1", {{"a", TiDB::TP::TypeLong}, {"b", TiDB::TP::TypeLong}}, {toVec<Int32>("a", {1, 1, 1, 1, 1, 1, 1, 1, 1, 1}), toVec<Int32>("b", {1, 1, 3, 3, 1, 1, 3, 3, 1, 3})});
context.addMockTable("split_test", "t2", {{"a", TiDB::TP::TypeLong}}, {toVec<Int32>("a", {1, 1, 1, 1, 1})});

auto request = context
Expand All @@ -720,6 +720,25 @@ try
ASSERT_EQ(expect[i][j], blocks[j].rows());
}
}

// with other condition
const auto cond = gt(col("b"), lit(Field(static_cast<Int64>(2))));
request = context
.scan("split_test", "t1")
.join(context.scan("split_test", "t2"), tipb::JoinType::TypeInnerJoin, {col("a")}, {}, {}, {cond}, {})

.build(context);
expect = {{5, 5, 5, 5, 5}, {5, 5, 5, 5, 5}, {5, 5, 5, 5, 5}, {25}, {25}, {25}, {25}, {25}};
for (size_t i = 0; i < block_sizes.size(); ++i)
{
context.context.setSetting("max_block_size", Field(static_cast<UInt64>(block_sizes[i])));
auto blocks = getExecuteStreamsReturnBlocks(request);
ASSERT_EQ(expect[i].size(), blocks.size());
for (size_t j = 0; j < blocks.size(); ++j)
{
ASSERT_EQ(expect[i][j], blocks[j].rows());
}
}
}
CATCH

Expand Down
Loading

0 comments on commit 816b8d5

Please sign in to comment.