Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
SeaRise committed Mar 27, 2023
1 parent 1e4f678 commit 4619f6f
Show file tree
Hide file tree
Showing 6 changed files with 60 additions and 38 deletions.
10 changes: 7 additions & 3 deletions dbms/src/DataStreams/HashJoinProbeBlockInputStream.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ HashJoinProbeBlockInputStream::HashJoinProbeBlockInputStream(
RUNTIME_CHECK_MSG(original_join->getProbeConcurrency() > 0, "Join probe concurrency must be greater than 0");

probe_exec.set(HashJoinProbeExec::build(original_join, input, non_joined_stream_index, max_block_size_));
probe_exec->setCancellationHook([&]() { return isCancelledOrThrowIfKilled(); });
}

void HashJoinProbeBlockInputStream::readSuffixImpl()
Expand Down Expand Up @@ -85,7 +86,7 @@ void HashJoinProbeBlockInputStream::onCurrentReadNonJoinedDataDone()
void HashJoinProbeBlockInputStream::tryGetRestoreJoin()
{
auto cur_probe_exec = *probe_exec;
auto restore_probe_exec = cur_probe_exec->tryGetRestoreExec([&]() { return isCancelledOrThrowIfKilled(); });
auto restore_probe_exec = cur_probe_exec->tryGetRestoreExec();
if (restore_probe_exec.has_value() && !isCancelledOrThrowIfKilled())
{
probe_exec.set(std::move(*restore_probe_exec));
Expand All @@ -99,7 +100,7 @@ void HashJoinProbeBlockInputStream::tryGetRestoreJoin()

void HashJoinProbeBlockInputStream::onAllProbeDone()
{
const auto & cur_probe_exec = *probe_exec;
auto cur_probe_exec = *probe_exec;
if (cur_probe_exec->needOutputNonJoinedData())
{
cur_probe_exec->onNonJoinedStart();
Expand All @@ -117,11 +118,14 @@ Block HashJoinProbeBlockInputStream::getOutputBlock()
{
while (true)
{
if unlikely (isCancelledOrThrowIfKilled())
return {};

switch (status)
{
case ProbeStatus::WAIT_BUILD_FINISH:
{
const auto & cur_probe_exec = *probe_exec;
auto cur_probe_exec = *probe_exec;
cur_probe_exec->waitUntilAllBuildFinished();
/// after Build finish, always go to Probe stage
cur_probe_exec->onProbeStart();
Expand Down
2 changes: 0 additions & 2 deletions dbms/src/DataStreams/HashJoinProbeBlockInputStream.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

#include <DataStreams/HashJoinProbeExec.h>
#include <DataStreams/IProfilingBlockInputStream.h>
#include <DataStreams/SquashingHashJoinBlockTransform.h>
#include <Interpreters/Join.h>

namespace DB
Expand Down Expand Up @@ -134,7 +133,6 @@ class HashJoinProbeBlockInputStream : public IProfilingBlockInputStream
ProbeStatus status{ProbeStatus::WAIT_BUILD_FINISH};
size_t joined_rows = 0;
size_t non_joined_rows = 0;
std::list<HashJoinProbeExecPtr> parents;
};

} // namespace DB
47 changes: 28 additions & 19 deletions dbms/src/DataStreams/HashJoinProbeExec.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,13 @@ void HashJoinProbeExec::waitUntilAllProbeFinished()
void HashJoinProbeExec::restoreBuild()
{
restore_build_stream->readPrefix();
while (restore_build_stream->read()) {};
if unlikely (is_cancelled())
return;
while (restore_build_stream->read())
{
if unlikely (is_cancelled())
return;
}
restore_build_stream->readSuffix();
}

Expand All @@ -89,6 +95,9 @@ std::tuple<size_t, Block> HashJoinProbeExec::getProbeBlock()
{
while (true)
{
if unlikely (is_cancelled())
return {0, {}};

if (!probe_partition_blocks.empty())
{
auto partition_block = probe_partition_blocks.front();
Expand Down Expand Up @@ -128,20 +137,20 @@ Block HashJoinProbeExec::probe()
return join->joinBlock(probe_process_info);
}

std::optional<HashJoinProbeExecPtr> HashJoinProbeExec::tryGetRestoreExec(std::function<bool()> && is_cancelled)
std::optional<HashJoinProbeExecPtr> HashJoinProbeExec::tryGetRestoreExec()
{
/// find restore exec in DFS way
if (is_cancelled())
if unlikely (is_cancelled())
return {};

/// find restore exec in DFS way
auto ret = doTryGetRestoreExec();
if (ret.has_value())
return ret;

/// current join has no more partition to restore, so check if previous join still has partition to restore
if (parent.has_value())
{
return (*parent)->tryGetRestoreExec(std::move(is_cancelled));
return (*parent)->tryGetRestoreExec();
}
else
{
Expand All @@ -155,24 +164,27 @@ std::optional<HashJoinProbeExecPtr> HashJoinProbeExec::doTryGetRestoreExec()
/// first check if current join has a partition to restore
if (join->hasPartitionSpilledWithLock())
{
auto restore_info = join->getOneRestoreStream(max_block_size);
/// get a restore join
if (restore_info.join)
if (auto restore_info = join->getOneRestoreStream(max_block_size); restore_info)
{
/// restored join should always enable spill
assert(restore_info.join->isEnableSpill());
assert(restore_info->join && restore_info->join->isEnableSpill());
size_t non_joined_stream_index = 0;
if (need_output_non_joined_data)
non_joined_stream_index = dynamic_cast<NonJoinedBlockInputStream *>(restore_info.non_joined_stream.get())->getNonJoinedIndex();
{
assert(restore_info->non_joined_stream);
non_joined_stream_index = dynamic_cast<NonJoinedBlockInputStream *>(restore_info->non_joined_stream.get())->getNonJoinedIndex();
}
auto restore_probe_exec = std::make_shared<HashJoinProbeExec>(
restore_info.join,
restore_info.build_stream,
restore_info.probe_stream,
restore_info->join,
restore_info->build_stream,
restore_info->probe_stream,
need_output_non_joined_data,
non_joined_stream_index,
restore_info.non_joined_stream,
restore_info->non_joined_stream,
max_block_size);
restore_probe_exec->parent = shared_from_this();
restore_probe_exec->setCancellationHook(is_cancelled);
return {std::move(restore_probe_exec)};
}
assert(join->hasPartitionSpilledWithLock() == false);
Expand All @@ -197,20 +209,17 @@ void HashJoinProbeExec::cancel()
join->cancel();
if (non_joined_stream != nullptr)
{
auto * p_stream = dynamic_cast<IProfilingBlockInputStream *>(non_joined_stream.get());
if (p_stream != nullptr)
if (auto * p_stream = dynamic_cast<IProfilingBlockInputStream *>(non_joined_stream.get()); p_stream != nullptr)
p_stream->cancel(false);
}
if (probe_stream != nullptr)
{
auto * p_stream = dynamic_cast<IProfilingBlockInputStream *>(probe_stream.get());
if (p_stream != nullptr)
if (auto * p_stream = dynamic_cast<IProfilingBlockInputStream *>(probe_stream.get()); p_stream != nullptr)
p_stream->cancel(false);
}
if (restore_build_stream != nullptr)
{
auto * p_stream = dynamic_cast<IProfilingBlockInputStream *>(restore_build_stream.get());
if (p_stream != nullptr)
if (auto * p_stream = dynamic_cast<IProfilingBlockInputStream *>(restore_build_stream.get()); p_stream != nullptr)
p_stream->cancel(false);
}
}
Expand Down
31 changes: 21 additions & 10 deletions dbms/src/DataStreams/HashJoinProbeExec.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ class HashJoinProbeExec : public std::enable_shared_from_this<HashJoinProbeExec>
size_t non_joined_stream_index,
size_t max_block_size);

using CancellationHook = std::function<bool()>;

HashJoinProbeExec(
const JoinPtr & join_,
const BlockInputStreamPtr & restore_build_stream_,
Expand All @@ -47,7 +49,7 @@ class HashJoinProbeExec : public std::enable_shared_from_this<HashJoinProbeExec>

void waitUntilAllProbeFinished();

std::optional<HashJoinProbeExecPtr> tryGetRestoreExec(std::function<bool()> && is_cancelled);
std::optional<HashJoinProbeExecPtr> tryGetRestoreExec();

void cancel();

Expand All @@ -69,23 +71,32 @@ class HashJoinProbeExec : public std::enable_shared_from_this<HashJoinProbeExec>
// Returns false if the probe_exec continues to execute.
bool onNonJoinedFinish();

void setCancellationHook(CancellationHook cancellation_hook)
{
is_cancelled = std::move(cancellation_hook);
}

private:
std::tuple<size_t, Block> getProbeBlock();

std::optional<HashJoinProbeExecPtr> doTryGetRestoreExec();

private:
JoinPtr join;
const JoinPtr join;

const BlockInputStreamPtr restore_build_stream;

BlockInputStreamPtr restore_build_stream;
const BlockInputStreamPtr probe_stream;

BlockInputStreamPtr probe_stream;
const bool need_output_non_joined_data;
const size_t non_joined_stream_index;
const BlockInputStreamPtr non_joined_stream;

bool need_output_non_joined_data;
size_t non_joined_stream_index;
BlockInputStreamPtr non_joined_stream;
const size_t max_block_size;

size_t max_block_size;
CancellationHook is_cancelled{[]() {
return false;
}};

ProbeProcessInfo probe_process_info;
std::list<std::tuple<size_t, Block>> probe_partition_blocks;
Expand All @@ -96,14 +107,14 @@ class HashJoinProbeExec : public std::enable_shared_from_this<HashJoinProbeExec>
class HashJoinProbeExecHolder
{
public:
const HashJoinProbeExecPtr & operator->()
HashJoinProbeExecPtr operator->()
{
std::lock_guard lock(mu);
assert(exec);
return exec;
}

const HashJoinProbeExecPtr & operator*()
HashJoinProbeExecPtr operator*()
{
std::lock_guard lock(mu);
assert(exec);
Expand Down
6 changes: 3 additions & 3 deletions dbms/src/Interpreters/Join.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3118,7 +3118,7 @@ bool Join::hasPartitionSpilled()
return !spilled_partition_indexes.empty();
}

RestoreInfo Join::getOneRestoreStream(size_t max_block_size)
std::optional<RestoreInfo> Join::getOneRestoreStream(size_t max_block_size)
{
std::unique_lock lock(build_probe_mutex);
if (meet_error)
Expand All @@ -3141,7 +3141,7 @@ RestoreInfo Join::getOneRestoreStream(size_t max_block_size)
{
spilled_partition_indexes.pop_front();
}
return {restore_join, non_joined_data_stream, build_stream, probe_stream};
return RestoreInfo{restore_join, non_joined_data_stream, build_stream, probe_stream};
}
if (spilled_partition_indexes.empty())
{
Expand Down Expand Up @@ -3179,7 +3179,7 @@ RestoreInfo Join::getOneRestoreStream(size_t max_block_size)
restore_non_joined_data_streams[i] = restore_join->createStreamWithNonJoinedRows(probe_stream->getHeader(), i, restore_join_build_concurrency, max_block_size);
}
auto non_joined_data_stream = get_back_stream(restore_non_joined_data_streams);
return {restore_join, non_joined_data_stream, build_stream, probe_stream};
return RestoreInfo{restore_join, non_joined_data_stream, build_stream, probe_stream};
}
catch (...)
{
Expand Down
2 changes: 1 addition & 1 deletion dbms/src/Interpreters/Join.h
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ class Join

bool isSpilled() const { return is_spilled; }

RestoreInfo getOneRestoreStream(size_t max_block_size);
std::optional<RestoreInfo> getOneRestoreStream(size_t max_block_size);

void dispatchProbeBlock(Block & block, std::list<std::tuple<size_t, Block>> & partition_blocks_list);

Expand Down

0 comments on commit 4619f6f

Please sign in to comment.