diff --git a/dbms/src/Common/PtrHolder.h b/dbms/src/Common/PtrHolder.h index 9be26f86b69..bb96dc7fd21 100644 --- a/dbms/src/Common/PtrHolder.h +++ b/dbms/src/Common/PtrHolder.h @@ -15,6 +15,7 @@ #pragma once #include +#include namespace DB { @@ -26,7 +27,6 @@ class PtrHolder { assert(obj_); std::lock_guard lock(mu); - assert(!obj); obj = std::move(obj_); } @@ -39,13 +39,20 @@ class PtrHolder return res; } - auto operator->() + auto * operator->() { std::lock_guard lock(mu); assert(obj != nullptr); return obj.get(); } + auto & operator*() + { + std::lock_guard lock(mu); + assert(obj != nullptr); + return *obj.get(); + } + private: std::mutex mu; Ptr obj; diff --git a/dbms/src/DataStreams/HashJoinProbeBlockInputStream.cpp b/dbms/src/DataStreams/HashJoinProbeBlockInputStream.cpp index 0c82b6dd671..df59cefe618 100644 --- a/dbms/src/DataStreams/HashJoinProbeBlockInputStream.cpp +++ b/dbms/src/DataStreams/HashJoinProbeBlockInputStream.cpp @@ -16,6 +16,8 @@ #include #include +#include + namespace DB { HashJoinProbeBlockInputStream::HashJoinProbeBlockInputStream( @@ -26,26 +28,25 @@ HashJoinProbeBlockInputStream::HashJoinProbeBlockInputStream( UInt64 max_block_size_) : log(Logger::get(req_id)) , original_join(join_) - , join(original_join) - , need_output_non_joined_data(join->needReturnNonJoinedData()) - , current_non_joined_stream_index(non_joined_stream_index) - , max_block_size(max_block_size_) - , probe_process_info(max_block_size_) { children.push_back(input); - current_probe_stream = children.back(); - RUNTIME_CHECK_MSG(join != nullptr, "join ptr should not be null."); - RUNTIME_CHECK_MSG(join->getProbeConcurrency() > 0, "Join probe concurrency must be greater than 0"); - auto input_header = input->getHeader(); - assert(input_header.rows() == 0); - if (need_output_non_joined_data) - non_joined_stream = join->createStreamWithNonJoinedRows(input_header, current_non_joined_stream_index, join->getProbeConcurrency(), max_block_size); + RUNTIME_CHECK_MSG(original_join != nullptr, "join ptr should not be null."); + RUNTIME_CHECK_MSG(original_join->getProbeConcurrency() > 0, "Join probe concurrency must be greater than 0"); + + probe_exec.set(HashJoinProbeExec::build(original_join, input, non_joined_stream_index, max_block_size_)); + probe_exec->setCancellationHook([&]() { return isCancelledOrThrowIfKilled(); }); + ProbeProcessInfo header_probe_process_info(0); - header_probe_process_info.resetBlock(std::move(input_header)); + header_probe_process_info.resetBlock(input->getHeader()); header = original_join->joinBlock(header_probe_process_info, true); } +void HashJoinProbeBlockInputStream::readSuffixImpl() +{ + LOG_DEBUG(log, "Finish join probe, total output rows {}, joined rows {}, non joined rows {}", joined_rows + non_joined_rows, joined_rows, non_joined_rows); +} + Block HashJoinProbeBlockInputStream::getHeader() const { return header; @@ -54,177 +55,55 @@ Block HashJoinProbeBlockInputStream::getHeader() const void HashJoinProbeBlockInputStream::cancel(bool kill) { IProfilingBlockInputStream::cancel(kill); - JoinPtr current_join; - RestoreInfo restore_info; - { - std::lock_guard lock(mutex); - current_join = join; - restore_info.non_joined_stream = non_joined_stream; - restore_info.build_stream = restore_build_stream; - restore_info.probe_stream = restore_probe_stream; - } - /// Join::wakeUpAllWaitingThreads wakes up all the threads waiting in Join::waitUntilAllBuildFinished/waitUntilAllProbeFinished, - /// and once this function is called, all the subsequent call of Join::waitUntilAllBuildFinished/waitUntilAllProbeFinished will - /// skip waiting directly. - /// HashJoinProbeBlockInputStream::cancel will be called in two cases: - /// 1. the query is cancelled by the caller or meet error: in this case, wake up all waiting threads is safe, because no data - /// will be used data anymore - /// 2. the query is executed normally, and one of the data stream has read an empty block, the the data stream and all its - /// children will call `cancel(false)`, in this case, there is two sub-cases - /// a. the data stream read an empty block because of EOF, then it means there must be no threads waiting in Join, so wake - /// up all waiting threads is safe because actually there is no threads to be waken up - /// b. the data stream read an empty block because of early exit of some executor(like limit), in this case, waking up the - /// waiting threads is not 100% safe because if the probe thread is waken up when build is not finished yet, it may get - /// wrong result. Currently, the execution framework ensures that when any of the data stream read empty block because - /// of early exit, no further data will be used, and in order to make sure no wrong result is generated - /// - for threads reading joined data: will return empty block if build is not finished yet - /// - for threads reading non joined data: will return empty block if build or probe is not finished yet - current_join->wakeUpAllWaitingThreads(); - if (restore_info.non_joined_stream != nullptr) - { - auto * p_stream = dynamic_cast(restore_info.non_joined_stream.get()); - if (p_stream != nullptr) - p_stream->cancel(kill); - } - if (restore_info.probe_stream != nullptr) - { - auto * p_stream = dynamic_cast(restore_info.probe_stream.get()); - if (p_stream != nullptr) - p_stream->cancel(kill); - } - if (restore_info.build_stream != nullptr) - { - auto * p_stream = dynamic_cast(restore_info.build_stream.get()); - if (p_stream != nullptr) - p_stream->cancel(kill); - } + + probe_exec->cancel(); } Block HashJoinProbeBlockInputStream::readImpl() { - try - { - Block ret = getOutputBlock(); - return ret; - } - catch (...) - { - auto error_message = getCurrentExceptionMessage(false, true); - join->meetError(error_message); - throw Exception(error_message); - } + return getOutputBlock(); } -void HashJoinProbeBlockInputStream::readSuffixImpl() +void HashJoinProbeBlockInputStream::switchStatus(ProbeStatus to) { - LOG_DEBUG(log, "Finish join probe, total output rows {}, joined rows {}, non joined rows {}", joined_rows + non_joined_rows, joined_rows, non_joined_rows); + LOG_TRACE(log, fmt::format("{} -> {}", magic_enum::enum_name(status), magic_enum::enum_name(to))); + status = to; } void HashJoinProbeBlockInputStream::onCurrentProbeDone() { - if (join->isRestoreJoin()) - current_probe_stream->readSuffix(); - join->finishOneProbe(); - if (need_output_non_joined_data || join->isEnableSpill()) - { - status = ProbeStatus::WAIT_PROBE_FINISH; - } - else - { - status = ProbeStatus::FINISHED; - } + switchStatus(probe_exec->onProbeFinish() ? ProbeStatus::FINISHED : ProbeStatus::WAIT_PROBE_FINISH); } void HashJoinProbeBlockInputStream::onCurrentReadNonJoinedDataDone() { - non_joined_stream->readSuffix(); - if (!join->isEnableSpill()) - { - status = ProbeStatus::FINISHED; - } - else - { - join->finishOneNonJoin(current_non_joined_stream_index); - status = ProbeStatus::GET_RESTORE_JOIN; - } + switchStatus(probe_exec->onNonJoinedFinish() ? ProbeStatus::FINISHED : ProbeStatus::GET_RESTORE_JOIN); } void HashJoinProbeBlockInputStream::tryGetRestoreJoin() { - /// find restore join in DFS way - while (true) + if (auto restore_probe_exec = probe_exec->tryGetRestoreExec(); restore_probe_exec && unlikely(!isCancelledOrThrowIfKilled())) { - assert(join->isEnableSpill()); - /// first check if current join has a partition to restore - if (join->hasPartitionSpilledWithLock()) - { - auto restore_info = join->getOneRestoreStream(max_block_size); - /// get a restore join - if (restore_info.join) - { - /// restored join should always enable spill - assert(restore_info.join->isEnableSpill()); - parents.push_back(join); - { - std::lock_guard lock(mutex); - if (isCancelledOrThrowIfKilled()) - { - status = ProbeStatus::FINISHED; - return; - } - join = restore_info.join; - restore_build_stream = restore_info.build_stream; - restore_probe_stream = restore_info.probe_stream; - non_joined_stream = restore_info.non_joined_stream; - current_probe_stream = restore_probe_stream; - if (non_joined_stream != nullptr) - current_non_joined_stream_index = dynamic_cast(non_joined_stream.get())->getNonJoinedIndex(); - } - status = ProbeStatus::RESTORE_BUILD; - return; - } - assert(join->hasPartitionSpilledWithLock() == false); - } - /// current join has no more partition to restore, so check if previous join still has partition to restore - if (!parents.empty()) - { - /// replace current join with previous join - std::lock_guard lock(mutex); - if (isCancelledOrThrowIfKilled()) - { - status = ProbeStatus::FINISHED; - return; - } - else - { - join = parents.back(); - parents.pop_back(); - restore_probe_stream = nullptr; - restore_build_stream = nullptr; - current_probe_stream = nullptr; - non_joined_stream = nullptr; - } - } - else - { - /// no previous join, set status to FINISHED - status = ProbeStatus::FINISHED; - return; - } + probe_exec.set(std::move(restore_probe_exec)); + switchStatus(ProbeStatus::RESTORE_BUILD); + } + else + { + switchStatus(ProbeStatus::FINISHED); } } void HashJoinProbeBlockInputStream::onAllProbeDone() { - if (need_output_non_joined_data) + auto & cur_probe_exec = *probe_exec; + if (cur_probe_exec.needOutputNonJoinedData()) { - assert(non_joined_stream != nullptr); - status = ProbeStatus::READ_NON_JOINED_DATA; - non_joined_stream->readPrefix(); + cur_probe_exec.onNonJoinedStart(); + switchStatus(ProbeStatus::READ_NON_JOINED_DATA); } else { - status = ProbeStatus::GET_RESTORE_JOIN; + switchStatus(ProbeStatus::GET_RESTORE_JOIN); } } @@ -234,43 +113,43 @@ Block HashJoinProbeBlockInputStream::getOutputBlock() { while (true) { + if unlikely (isCancelledOrThrowIfKilled()) + return {}; + switch (status) { case ProbeStatus::WAIT_BUILD_FINISH: - join->waitUntilAllBuildFinished(); + { + auto & cur_probe_exec = *probe_exec; + cur_probe_exec.waitUntilAllBuildFinished(); /// after Build finish, always go to Probe stage - if (join->isRestoreJoin()) - current_probe_stream->readSuffix(); - status = ProbeStatus::PROBE; + cur_probe_exec.onProbeStart(); + switchStatus(ProbeStatus::PROBE); break; + } case ProbeStatus::PROBE: { - assert(current_probe_stream != nullptr); - if (probe_process_info.all_rows_joined_finish) + auto ret = probe_exec->probe(); + if (!ret) + { + onCurrentProbeDone(); + break; + } + else { - auto [partition_index, block] = getOneProbeBlock(); - if (!block) - { - onCurrentProbeDone(); - break; - } - else - { - join->checkTypes(block); - probe_process_info.resetBlock(std::move(block), partition_index); - } + joined_rows += ret.rows(); + return ret; } - auto ret = join->joinBlock(probe_process_info); - joined_rows += ret.rows(); - return ret; } case ProbeStatus::WAIT_PROBE_FINISH: - join->waitUntilAllProbeFinished(); + { + probe_exec->waitUntilAllProbeFinished(); onAllProbeDone(); break; + } case ProbeStatus::READ_NON_JOINED_DATA: { - auto block = non_joined_stream->read(); + auto block = probe_exec->fetchNonJoined(); non_joined_rows += block.rows(); if (!block) { @@ -286,11 +165,8 @@ Block HashJoinProbeBlockInputStream::getOutputBlock() } case ProbeStatus::RESTORE_BUILD: { - probe_process_info.all_rows_joined_finish = true; - restore_build_stream->readPrefix(); - while (restore_build_stream->read()) {}; - restore_build_stream->readSuffix(); - status = ProbeStatus::WAIT_BUILD_FINISH; + probe_exec->restoreBuild(); + switchStatus(ProbeStatus::WAIT_BUILD_FINISH); break; } case ProbeStatus::FINISHED: @@ -300,46 +176,11 @@ Block HashJoinProbeBlockInputStream::getOutputBlock() } catch (...) { - /// set status to finish if any exception happens - status = ProbeStatus::FINISHED; - throw; - } -} - -std::tuple HashJoinProbeBlockInputStream::getOneProbeBlock() -{ - size_t partition_index = 0; - Block block; - - /// Even if spill is enabled, if spill is not triggered during build, - /// there is no need to dispatch probe block - if (!join->isSpilled()) - { - block = current_probe_stream->read(); - } - else - { - while (true) - { - if (!probe_partition_blocks.empty()) - { - auto partition_block = probe_partition_blocks.front(); - probe_partition_blocks.pop_front(); - partition_index = std::get<0>(partition_block); - block = std::get<1>(partition_block); - break; - } - else - { - auto new_block = current_probe_stream->read(); - if (new_block) - join->dispatchProbeBlock(new_block, probe_partition_blocks); - else - break; - } - } + auto error_message = getCurrentExceptionMessage(true, true); + probe_exec->meetError(error_message); + switchStatus(ProbeStatus::FINISHED); + throw Exception(error_message); } - return {partition_index, block}; } } // namespace DB diff --git a/dbms/src/DataStreams/HashJoinProbeBlockInputStream.h b/dbms/src/DataStreams/HashJoinProbeBlockInputStream.h index d0d21858879..344c8d05e7e 100644 --- a/dbms/src/DataStreams/HashJoinProbeBlockInputStream.h +++ b/dbms/src/DataStreams/HashJoinProbeBlockInputStream.h @@ -14,8 +14,8 @@ #pragma once +#include #include -#include #include namespace DB @@ -48,6 +48,8 @@ class HashJoinProbeBlockInputStream : public IProfilingBlockInputStream protected: Block readImpl() override; + void readSuffixImpl() override; + private: /* * spill not enabled: @@ -112,33 +114,26 @@ class HashJoinProbeBlockInputStream : public IProfilingBlockInputStream FINISHED, /// the final state }; + void switchStatus(ProbeStatus to); Block getOutputBlock(); std::tuple getOneProbeBlock(); void onCurrentProbeDone(); void onAllProbeDone(); void onCurrentReadNonJoinedDataDone(); void tryGetRestoreJoin(); - void readSuffixImpl() override; + +private: const LoggerPtr log; - /// join/non_joined_stream/restore_build_stream/restore_probe_stream can be modified during the runtime - /// although read/write to those are almost only in 1 thread, but an exception is cancel thread will - /// read them, so need to protect the multi-threads access - std::mutex mutex; JoinPtr original_join; - JoinPtr join; - const bool need_output_non_joined_data; - size_t current_non_joined_stream_index; - BlockInputStreamPtr current_probe_stream; - UInt64 max_block_size; - ProbeProcessInfo probe_process_info; - BlockInputStreamPtr non_joined_stream; - BlockInputStreamPtr restore_build_stream; - BlockInputStreamPtr restore_probe_stream; + /// probe_exec can be modified during the runtime, + /// although read/write to those are almost only in 1 thread, + /// but an exception is cancel thread will read them, + /// so need to use HashJoinProbeExecHolder protect the multi-threads access. + HashJoinProbeExecHolder probe_exec; ProbeStatus status{ProbeStatus::WAIT_BUILD_FINISH}; size_t joined_rows = 0; size_t non_joined_rows = 0; - std::list parents; - std::list> probe_partition_blocks; + Block header; }; diff --git a/dbms/src/DataStreams/HashJoinProbeExec.cpp b/dbms/src/DataStreams/HashJoinProbeExec.cpp new file mode 100644 index 00000000000..6954524ede9 --- /dev/null +++ b/dbms/src/DataStreams/HashJoinProbeExec.cpp @@ -0,0 +1,261 @@ +// Copyright 2023 PingCAP, Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include + +namespace DB +{ +HashJoinProbeExecPtr HashJoinProbeExec::build( + const JoinPtr & join, + const BlockInputStreamPtr & probe_stream, + size_t non_joined_stream_index, + size_t max_block_size) +{ + bool need_output_non_joined_data = join->needReturnNonJoinedData(); + BlockInputStreamPtr non_joined_stream = nullptr; + if (need_output_non_joined_data) + non_joined_stream = join->createStreamWithNonJoinedRows(probe_stream->getHeader(), non_joined_stream_index, join->getProbeConcurrency(), max_block_size); + + return std::make_shared( + join, + nullptr, + probe_stream, + need_output_non_joined_data, + non_joined_stream_index, + non_joined_stream, + max_block_size); +} + +HashJoinProbeExec::HashJoinProbeExec( + const JoinPtr & join_, + const BlockInputStreamPtr & restore_build_stream_, + const BlockInputStreamPtr & probe_stream_, + bool need_output_non_joined_data_, + size_t non_joined_stream_index_, + const BlockInputStreamPtr & non_joined_stream_, + size_t max_block_size_) + : join(join_) + , restore_build_stream(restore_build_stream_) + , probe_stream(probe_stream_) + , need_output_non_joined_data(need_output_non_joined_data_) + , non_joined_stream_index(non_joined_stream_index_) + , non_joined_stream(non_joined_stream_) + , max_block_size(max_block_size_) + , probe_process_info(max_block_size_) +{} + +void HashJoinProbeExec::waitUntilAllBuildFinished() +{ + join->waitUntilAllBuildFinished(); +} + +void HashJoinProbeExec::waitUntilAllProbeFinished() +{ + join->waitUntilAllProbeFinished(); +} + +void HashJoinProbeExec::restoreBuild() +{ + restore_build_stream->readPrefix(); + if unlikely (is_cancelled()) + return; + while (restore_build_stream->read()) + { + if unlikely (is_cancelled()) + return; + } + restore_build_stream->readSuffix(); +} + +PartitionBlock HashJoinProbeExec::getProbeBlock() +{ + /// Even if spill is enabled, if spill is not triggered during build, + /// there is no need to dispatch probe block + if (!join->isSpilled()) + { + return PartitionBlock{probe_stream->read()}; + } + else + { + while (true) + { + if unlikely (is_cancelled()) + return {}; + + if (!probe_partition_blocks.empty()) + { + auto partition_block = std::move(probe_partition_blocks.front()); + probe_partition_blocks.pop_front(); + return partition_block; + } + else + { + auto new_block = probe_stream->read(); + if (new_block) + join->dispatchProbeBlock(new_block, probe_partition_blocks); + else + return {}; + } + } + } +} + +Block HashJoinProbeExec::probe() +{ + if (probe_process_info.all_rows_joined_finish) + { + auto partition_block = getProbeBlock(); + if (partition_block) + { + join->checkTypes(partition_block.block); + probe_process_info.resetBlock(std::move(partition_block.block), partition_block.partition_index); + } + else + { + return {}; + } + } + return join->joinBlock(probe_process_info); +} + +HashJoinProbeExecPtr HashJoinProbeExec::tryGetRestoreExec() +{ + if unlikely (is_cancelled()) + return {}; + + /// find restore exec in DFS way + if (auto ret = doTryGetRestoreExec(); ret) + return ret; + + /// current join has no more partition to restore, so check if previous join still has partition to restore + return parent ? parent->tryGetRestoreExec() : HashJoinProbeExecPtr{}; +} + +HashJoinProbeExecPtr HashJoinProbeExec::doTryGetRestoreExec() +{ + assert(join->isEnableSpill()); + /// first check if current join has a partition to restore + if (join->hasPartitionSpilledWithLock()) + { + /// get a restore join + if (auto restore_info = join->getOneRestoreStream(max_block_size); restore_info) + { + /// restored join should always enable spill + assert(restore_info->join && restore_info->join->isEnableSpill()); + size_t non_joined_stream_index = 0; + if (need_output_non_joined_data) + { + assert(restore_info->non_joined_stream); + non_joined_stream_index = dynamic_cast(restore_info->non_joined_stream.get())->getNonJoinedIndex(); + } + auto restore_probe_exec = std::make_shared( + restore_info->join, + restore_info->build_stream, + restore_info->probe_stream, + need_output_non_joined_data, + non_joined_stream_index, + restore_info->non_joined_stream, + max_block_size); + restore_probe_exec->parent = shared_from_this(); + restore_probe_exec->setCancellationHook(is_cancelled); + return restore_probe_exec; + } + assert(join->hasPartitionSpilledWithLock() == false); + } + return {}; +} + +void HashJoinProbeExec::cancel() +{ + /// Join::wakeUpAllWaitingThreads wakes up all the threads waiting in Join::waitUntilAllBuildFinished/waitUntilAllProbeFinished, + /// and once this function is called, all the subsequent call of Join::waitUntilAllBuildFinished/waitUntilAllProbeFinished will + /// skip waiting directly. + /// HashJoinProbeBlockInputStream::cancel will be called in two cases: + /// 1. the query is cancelled by the caller or meet error: in this case, wake up all waiting threads is safe, because no data + /// will be used data anymore + /// 2. the query is executed normally, and one of the data stream has read an empty block, the the data stream and all its + /// children will call `cancel(false)`, in this case, there is two sub-cases + /// a. the data stream read an empty block because of EOF, then it means there must be no threads waiting in Join, so wake + /// up all waiting threads is safe because actually there is no threads to be waken up + /// b. the data stream read an empty block because of early exit of some executor(like limit), in this case, waking up the + /// waiting threads is not 100% safe because if the probe thread is waken up when build is not finished yet, it may get + /// wrong result. Currently, the execution framework ensures that when any of the data stream read empty block because + /// of early exit, no further data will be used, and in order to make sure no wrong result is generated + /// - for threads reading joined data: will return empty block if build is not finished yet + /// - for threads reading non joined data: will return empty block if build or probe is not finished yet + join->wakeUpAllWaitingThreads(); + if (non_joined_stream != nullptr) + { + if (auto * p_stream = dynamic_cast(non_joined_stream.get()); p_stream != nullptr) + p_stream->cancel(false); + } + if (probe_stream != nullptr) + { + if (auto * p_stream = dynamic_cast(probe_stream.get()); p_stream != nullptr) + p_stream->cancel(false); + } + if (restore_build_stream != nullptr) + { + if (auto * p_stream = dynamic_cast(restore_build_stream.get()); p_stream != nullptr) + p_stream->cancel(false); + } +} + +void HashJoinProbeExec::meetError(const String & error_message) +{ + join->meetError(error_message); +} + +void HashJoinProbeExec::onProbeStart() +{ + if (join->isRestoreJoin()) + probe_stream->readPrefix(); +} + +bool HashJoinProbeExec::onProbeFinish() +{ + if (join->isRestoreJoin()) + probe_stream->readSuffix(); + join->finishOneProbe(); + return !need_output_non_joined_data && !join->isEnableSpill(); +} + +void HashJoinProbeExec::onNonJoinedStart() +{ + assert(non_joined_stream != nullptr); + non_joined_stream->readPrefix(); +} + +Block HashJoinProbeExec::fetchNonJoined() +{ + assert(non_joined_stream != nullptr); + return non_joined_stream->read(); +} + +bool HashJoinProbeExec::onNonJoinedFinish() +{ + non_joined_stream->readSuffix(); + if (!join->isEnableSpill()) + { + return true; + } + else + { + join->finishOneNonJoin(non_joined_stream_index); + return false; + } +} +} // namespace DB diff --git a/dbms/src/DataStreams/HashJoinProbeExec.h b/dbms/src/DataStreams/HashJoinProbeExec.h new file mode 100644 index 00000000000..2fa1f215382 --- /dev/null +++ b/dbms/src/DataStreams/HashJoinProbeExec.h @@ -0,0 +1,109 @@ +// Copyright 2023 PingCAP, Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include + +#include + +#pragma once + +namespace DB +{ +class HashJoinProbeExec; +using HashJoinProbeExecPtr = std::shared_ptr; + +class HashJoinProbeExec : public std::enable_shared_from_this +{ +public: + static HashJoinProbeExecPtr build( + const JoinPtr & join, + const BlockInputStreamPtr & probe_stream, + size_t non_joined_stream_index, + size_t max_block_size); + + using CancellationHook = std::function; + + HashJoinProbeExec( + const JoinPtr & join_, + const BlockInputStreamPtr & restore_build_stream_, + const BlockInputStreamPtr & probe_stream_, + bool need_output_non_joined_data_, + size_t non_joined_stream_index_, + const BlockInputStreamPtr & non_joined_stream_, + size_t max_block_size_); + + void waitUntilAllBuildFinished(); + + void waitUntilAllProbeFinished(); + + HashJoinProbeExecPtr tryGetRestoreExec(); + + void cancel(); + + void meetError(const String & error_message); + + void restoreBuild(); + + void onProbeStart(); + // Returns empty block if probe finish. + Block probe(); + // Returns true if the probe_exec ends. + // Returns false if the probe_exec continues to execute. + bool onProbeFinish(); + + bool needOutputNonJoinedData() { return need_output_non_joined_data; } + void onNonJoinedStart(); + Block fetchNonJoined(); + // Returns true if the probe_exec ends. + // Returns false if the probe_exec continues to execute. + bool onNonJoinedFinish(); + + void setCancellationHook(CancellationHook cancellation_hook) + { + is_cancelled = std::move(cancellation_hook); + } + +private: + PartitionBlock getProbeBlock(); + + HashJoinProbeExecPtr doTryGetRestoreExec(); + +private: + const JoinPtr join; + + const BlockInputStreamPtr restore_build_stream; + + const BlockInputStreamPtr probe_stream; + + const bool need_output_non_joined_data; + const size_t non_joined_stream_index; + const BlockInputStreamPtr non_joined_stream; + + const size_t max_block_size; + + CancellationHook is_cancelled{[]() { + return false; + }}; + + ProbeProcessInfo probe_process_info; + PartitionBlocks probe_partition_blocks; + + HashJoinProbeExecPtr parent; +}; + +using HashJoinProbeExecHolder = PtrHolder; +} // namespace DB diff --git a/dbms/src/Interpreters/Join.cpp b/dbms/src/Interpreters/Join.cpp index e1f08ca161a..fd0c5908943 100644 --- a/dbms/src/Interpreters/Join.cpp +++ b/dbms/src/Interpreters/Join.cpp @@ -1437,19 +1437,20 @@ void Join::joinBlockNullAwareImpl( void Join::checkTypesOfKeys(const Block & block_left, const Block & block_right) const { size_t keys_size = key_names_left.size(); - for (size_t i = 0; i < keys_size; ++i) { /// Compare up to Nullability. - DataTypePtr left_type = removeNullable(block_left.getByName(key_names_left[i]).type); DataTypePtr right_type = removeNullable(block_right.getByName(key_names_right[i]).type); - - if (!left_type->equals(*right_type)) - throw Exception("Type mismatch of columns to JOIN by: " - + key_names_left[i] + " " + left_type->getName() + " at left, " - + key_names_right[i] + " " + right_type->getName() + " at right", - ErrorCodes::TYPE_MISMATCH); + if unlikely (!left_type->equals(*right_type)) + throw Exception( + fmt::format( + "Type mismatch of columns to JOIN by: {} {} at left, {} {} at right", + key_names_left[i], + left_type->getName(), + key_names_right[i], + right_type->getName()), + ErrorCodes::TYPE_MISMATCH); } } @@ -1795,7 +1796,7 @@ bool Join::hasPartitionSpilled() return !spilled_partition_indexes.empty(); } -RestoreInfo Join::getOneRestoreStream(size_t max_block_size_) +std::optional Join::getOneRestoreStream(size_t max_block_size_) { std::unique_lock lock(build_probe_mutex); if (meet_error) @@ -1818,7 +1819,7 @@ RestoreInfo Join::getOneRestoreStream(size_t max_block_size_) { spilled_partition_indexes.pop_front(); } - return {restore_join, non_joined_data_stream, build_stream, probe_stream}; + return RestoreInfo{restore_join, std::move(non_joined_data_stream), std::move(build_stream), std::move(probe_stream)}; } if (spilled_partition_indexes.empty()) { @@ -1856,7 +1857,7 @@ RestoreInfo Join::getOneRestoreStream(size_t max_block_size_) restore_non_joined_data_streams[i] = restore_join->createStreamWithNonJoinedRows(probe_stream->getHeader(), i, restore_join_build_concurrency, max_block_size_); } auto non_joined_data_stream = get_back_stream(restore_non_joined_data_streams); - return {restore_join, non_joined_data_stream, build_stream, probe_stream}; + return RestoreInfo{restore_join, std::move(non_joined_data_stream), std::move(build_stream), std::move(probe_stream)}; } catch (...) { @@ -1869,7 +1870,7 @@ RestoreInfo Join::getOneRestoreStream(size_t max_block_size_) } } -void Join::dispatchProbeBlock(Block & block, std::list> & partition_blocks_list) +void Join::dispatchProbeBlock(Block & block, PartitionBlocks & partition_blocks_list) { Blocks partition_blocks = dispatchBlock(key_names_left, block); for (size_t i = 0; i < partition_blocks.size(); ++i) @@ -1892,7 +1893,9 @@ void Join::dispatchProbeBlock(Block & block, std::list probe_spiller->spillBlocks(std::move(blocks_to_spill), i); } else - partition_blocks_list.push_back({i, partition_blocks[i]}); + { + partition_blocks_list.emplace_back(i, std::move(partition_blocks[i])); + } } } diff --git a/dbms/src/Interpreters/Join.h b/dbms/src/Interpreters/Join.h index e23ec428508..e4b167e4168 100644 --- a/dbms/src/Interpreters/Join.h +++ b/dbms/src/Interpreters/Join.h @@ -31,8 +31,49 @@ namespace DB { -struct ProbeProcessInfo; -struct RestoreInfo; +class Join; +using JoinPtr = std::shared_ptr; +using Joins = std::vector; + +struct RestoreInfo +{ + JoinPtr join; + BlockInputStreamPtr non_joined_stream; + BlockInputStreamPtr build_stream; + BlockInputStreamPtr probe_stream; + + RestoreInfo(JoinPtr & join_, BlockInputStreamPtr && non_joined_data_stream_, BlockInputStreamPtr && build_stream_, BlockInputStreamPtr && probe_stream_) + : join(join_) + , non_joined_stream(std::move(non_joined_data_stream_)) + , build_stream(std::move(build_stream_)) + , probe_stream(std::move(probe_stream_)) + {} +}; + +struct PartitionBlock +{ + size_t partition_index; + Block block; + + PartitionBlock() + : partition_index(0) + , block({}) + {} + + explicit PartitionBlock(Block && block_) + : partition_index(0) + , block(std::move(block_)) + {} + + PartitionBlock(size_t partition_index_, Block && block_) + : partition_index(partition_index_) + , block(std::move(block_)) + {} + + explicit operator bool() const { return static_cast(block); } + bool operator!() const { return !block; } +}; +using PartitionBlocks = std::list; /** Data structure for implementation of JOIN. * It is just a hash table: keys -> rows of joined ("right") table. @@ -87,8 +128,6 @@ struct RestoreInfo; * Always generate Nullable column and substitute NULLs for non-joined rows, * as in standard SQL. */ -using JoinPtr = std::shared_ptr; -using Joins = std::vector; class Join { @@ -149,9 +188,9 @@ class Join bool isSpilled() const { return is_spilled; } - RestoreInfo getOneRestoreStream(size_t max_block_size); + std::optional getOneRestoreStream(size_t max_block_size); - void dispatchProbeBlock(Block & block, std::list> & partition_blocks_list); + void dispatchProbeBlock(Block & block, PartitionBlocks & partition_blocks_list); Blocks dispatchBlock(const Strings & key_columns_names, const Block & from_block); @@ -378,19 +417,4 @@ class Join void workAfterProbeFinish(); }; -struct RestoreInfo -{ - JoinPtr join; - BlockInputStreamPtr non_joined_stream; - BlockInputStreamPtr build_stream; - BlockInputStreamPtr probe_stream; - - RestoreInfo() = default; - RestoreInfo(JoinPtr & join_, BlockInputStreamPtr non_joined_data_stream_, BlockInputStreamPtr build_stream_, BlockInputStreamPtr probe_stream_) - : join(join_) - , non_joined_stream(non_joined_data_stream_) - , build_stream(build_stream_) - , probe_stream(probe_stream_){}; -}; - } // namespace DB