Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

merge output blocks if need in hash join #6529

Merged
merged 17 commits into from
Dec 27, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions dbms/src/Core/Block.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -514,6 +514,45 @@ static ReturnType checkBlockStructure(const Block & lhs, const Block & rhs, cons
return ReturnType(true);
}

Block mergeBlocks(Blocks && blocks)
{
if (blocks.empty())
{
return {};
}

if (blocks.size() == 1)
{
return std::move(blocks[0]);
}

auto & first_block = blocks[0];
size_t result_rows = 0;
for (const auto & block : blocks)
{
result_rows += block.rows();
}

MutableColumns dst_columns(first_block.columns());

for (size_t i = 0; i < first_block.columns(); ++i)
{
dst_columns[i] = (*std::move(first_block.getByPosition(i).column)).mutate();
dst_columns[i]->reserve(result_rows);
}

for (size_t i = 1; i < blocks.size(); ++i)
{
if (likely(blocks[i].rows()) > 0)
{
for (size_t column = 0; column < blocks[i].columns(); ++column)
{
dst_columns[column]->insertRangeFrom(*blocks[i].getByPosition(column).column, 0, blocks[i].rows());
}
}
}
return first_block.cloneWithColumns(std::move(dst_columns));
}

bool blocksHaveEqualStructure(const Block & lhs, const Block & rhs)
{
Expand Down
2 changes: 2 additions & 0 deletions dbms/src/Core/Block.h
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ class Block
*/
void updateHash(SipHash & hash) const;


private:
void eraseImpl(size_t position);
void initializeIndexByName();
Expand All @@ -157,6 +158,7 @@ class Block
using Blocks = std::vector<Block>;
using BlocksList = std::list<Block>;

Block mergeBlocks(Blocks && blocks);

/// Compare number of columns, data types, column types, column names, and values of constant columns.
bool blocksHaveEqualStructure(const Block & lhs, const Block & rhs);
Expand Down
21 changes: 19 additions & 2 deletions dbms/src/DataStreams/HashJoinProbeBlockInputStream.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
// limitations under the License.

#include <DataStreams/HashJoinProbeBlockInputStream.h>
#include <Interpreters/ExpressionActions.h>

namespace DB
{
Expand All @@ -25,6 +24,7 @@ HashJoinProbeBlockInputStream::HashJoinProbeBlockInputStream(
: log(Logger::get(req_id))
, join(join_)
, probe_process_info(max_block_size)
, squashing_transform(max_block_size)
{
children.push_back(input);

Expand Down Expand Up @@ -66,18 +66,35 @@ Block HashJoinProbeBlockInputStream::getHeader() const
}

Block HashJoinProbeBlockInputStream::readImpl()
{
// if join finished, return {} directly.
if (squashing_transform.isJoinFinished())
SeaRise marked this conversation as resolved.
Show resolved Hide resolved
{
return Block{};
}

while (squashing_transform.needAppendBlock())
{
Block result_block = getOutputBlock();
squashing_transform.appendBlock(result_block);
}
return squashing_transform.getFinalOutputBlock();
}

Block HashJoinProbeBlockInputStream::getOutputBlock()
{
if (probe_process_info.all_rows_joined_finish)
{
Block block = children.back()->read();
if (!block)
{
return block;
}
join->checkTypes(block);
probe_process_info.resetBlock(std::move(block));
}

return join->joinBlock(probe_process_info);
}


} // namespace DB
4 changes: 3 additions & 1 deletion dbms/src/DataStreams/HashJoinProbeBlockInputStream.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,12 @@
#pragma once

#include <DataStreams/IProfilingBlockInputStream.h>
#include <DataStreams/SquashingHashJoinBlockTransform.h>
#include <Interpreters/Join.h>

namespace DB
{


/** Executes a certain expression over the block.
* Basically the same as ExpressionBlockInputStream,
* but requires that there must be a join probe action in the Expression.
Expand All @@ -47,11 +47,13 @@ class HashJoinProbeBlockInputStream : public IProfilingBlockInputStream

protected:
Block readImpl() override;
Block getOutputBlock();

private:
const LoggerPtr log;
JoinPtr join;
ProbeProcessInfo probe_process_info;
SquashingHashJoinBlockTransform squashing_transform;
};

} // namespace DB
85 changes: 85 additions & 0 deletions dbms/src/DataStreams/SquashingHashJoinBlockTransform.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
// Copyright 2022 PingCAP, Ltd.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <DataStreams/SquashingHashJoinBlockTransform.h>

namespace DB
{

SquashingHashJoinBlockTransform::SquashingHashJoinBlockTransform(UInt64 max_block_size_)
: output_rows(0)
, max_block_size(max_block_size_)
, join_finished(false)
{}

void SquashingHashJoinBlockTransform::handleOverLimitBlock()
{
// if over_limit_block is not null, we need to push it into blocks.
if (over_limit_block)
{
assert(!(output_rows && blocks.empty()));
output_rows += over_limit_block->rows();
SeaRise marked this conversation as resolved.
Show resolved Hide resolved
blocks.push_back(std::move(over_limit_block.value()));
over_limit_block.reset();
}
}

void SquashingHashJoinBlockTransform::appendBlock(Block & block)
{
if (!block)
{
// if append block is {}, mark join finished.
join_finished = true;
return;
}
size_t current_rows = block.rows();

if (!output_rows || output_rows + current_rows <= max_block_size)
{
blocks.push_back(std::move(block));
output_rows += current_rows;
}
else
{
// if output_rows + current_rows > max block size, put the current result block into over_limit_block and handle it in next read.
SeaRise marked this conversation as resolved.
Show resolved Hide resolved
assert(!over_limit_block);
over_limit_block.emplace(std::move(block));
}
}

Block SquashingHashJoinBlockTransform::getFinalOutputBlock()
{
Block final_block = mergeBlocks(std::move(blocks));
reset();
handleOverLimitBlock();
return final_block;
}

void SquashingHashJoinBlockTransform::reset()
{
blocks.clear();
output_rows = 0;
}

bool SquashingHashJoinBlockTransform::isJoinFinished() const
{
return join_finished;
}

bool SquashingHashJoinBlockTransform::needAppendBlock() const
{
return !over_limit_block && !join_finished;
}

} // namespace DB
44 changes: 44 additions & 0 deletions dbms/src/DataStreams/SquashingHashJoinBlockTransform.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
// Copyright 2022 PingCAP, Ltd.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <Core/Block.h>

namespace DB
{

class SquashingHashJoinBlockTransform
{
public:
SquashingHashJoinBlockTransform(UInt64 max_block_size_);

void appendBlock(Block & block);
Block getFinalOutputBlock();
bool isJoinFinished() const;
bool needAppendBlock() const;


private:
void handleOverLimitBlock();
void reset();

Blocks blocks;
std::optional<Block> over_limit_block;
size_t output_rows;
UInt64 max_block_size;
bool join_finished;
};

} // namespace DB
21 changes: 20 additions & 1 deletion dbms/src/Flash/tests/gtest_join_executor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -700,7 +700,7 @@ CATCH
TEST_F(JoinExecutorTestRunner, SplitJoinResult)
try
{
context.addMockTable("split_test", "t1", {{"a", TiDB::TP::TypeLong}}, {toVec<Int32>("a", {1, 1, 1, 1, 1, 1, 1, 1, 1, 1})});
context.addMockTable("split_test", "t1", {{"a", TiDB::TP::TypeLong}, {"b", TiDB::TP::TypeLong}}, {toVec<Int32>("a", {1, 1, 1, 1, 1, 1, 1, 1, 1, 1}), toVec<Int32>("b", {1, 1, 3, 3, 1, 1, 3, 3, 1, 3})});
context.addMockTable("split_test", "t2", {{"a", TiDB::TP::TypeLong}}, {toVec<Int32>("a", {1, 1, 1, 1, 1})});

auto request = context
Expand All @@ -720,6 +720,25 @@ try
ASSERT_EQ(expect[i][j], blocks[j].rows());
}
}

// with other condition
const auto cond = gt(col("b"), lit(Field(static_cast<Int64>(2))));
request = context
.scan("split_test", "t1")
.join(context.scan("split_test", "t2"), tipb::JoinType::TypeInnerJoin, {col("a")}, {}, {}, {cond}, {})

.build(context);
expect = {{5, 5, 5, 5, 5}, {5, 5, 5, 5, 5}, {5, 5, 5, 5, 5}, {25}, {25}, {25}, {25}, {25}};
for (size_t i = 0; i < block_sizes.size(); ++i)
{
context.context.setSetting("max_block_size", Field(static_cast<UInt64>(block_sizes[i])));
auto blocks = getExecuteStreamsReturnBlocks(request);
ASSERT_EQ(expect[i].size(), blocks.size());
for (size_t j = 0; j < blocks.size(); ++j)
{
ASSERT_EQ(expect[i][j], blocks[j].rows());
}
}
}
CATCH

Expand Down
Loading