From e1ea438b8779836070962b4361364273dd449b47 Mon Sep 17 00:00:00 2001 From: xufei Date: Mon, 6 Mar 2023 11:58:09 +0800 Subject: [PATCH] simplify code (#5) Signed-off-by: xufei --- .../HashJoinProbeBlockInputStream.cpp | 39 +++++-------------- .../HashJoinProbeBlockInputStream.h | 2 + 2 files changed, 11 insertions(+), 30 deletions(-) diff --git a/dbms/src/DataStreams/HashJoinProbeBlockInputStream.cpp b/dbms/src/DataStreams/HashJoinProbeBlockInputStream.cpp index 1f3aded89c2..e83fe38cbd5 100644 --- a/dbms/src/DataStreams/HashJoinProbeBlockInputStream.cpp +++ b/dbms/src/DataStreams/HashJoinProbeBlockInputStream.cpp @@ -27,44 +27,20 @@ HashJoinProbeBlockInputStream::HashJoinProbeBlockInputStream( : log(Logger::get(req_id)) , original_join(join_) , join(original_join) + , need_output_non_joined_data(join->needReturnNonJoinedData()) , current_non_joined_stream_index(non_joined_stream_index) , max_block_size(max_block_size_) , probe_process_info(max_block_size_) { children.push_back(input); + current_probe_stream = children.back(); RUNTIME_CHECK_MSG(join != nullptr, "join ptr should not be null."); RUNTIME_CHECK_MSG(join->getProbeConcurrency() > 0, "Join probe concurrency must be greater than 0"); - if (join->needReturnNonJoinedData()) + if (need_output_non_joined_data) non_joined_stream = join->createStreamWithNonJoinedRows(input->getHeader(), current_non_joined_stream_index, join->getProbeConcurrency(), max_block_size); } -Block HashJoinProbeBlockInputStream::getTotals() -{ - /// getTotals will be deleted soon, so don't care the implementation - if (auto * child = !original_join->isRestoreJoin() ? dynamic_cast(&*children.back()) : dynamic_cast(&*restore_probe_stream)) - { - totals = child->getTotals(); - if (!totals) - { - if (original_join->hasTotals()) - { - for (const auto & name_and_type : child->getHeader().getColumnsWithTypeAndName()) - { - auto column = name_and_type.type->createColumn(); - column->insertDefault(); - totals.insert(ColumnWithTypeAndName(std::move(column), name_and_type.type, name_and_type.name)); - } - } - else - return totals; /// There's nothing to JOIN. - } - original_join->joinTotals(totals); - } - - return totals; -} - Block HashJoinProbeBlockInputStream::getHeader() const { Block res = children.back()->getHeader(); @@ -153,6 +129,7 @@ Block HashJoinProbeBlockInputStream::getOutputBlock() case ProbeStatus::PROBE: { join->waitUntilAllBuildFinished(); + assert(current_probe_stream != nullptr); if (probe_process_info.all_rows_joined_finish) { size_t partition_index = 0; @@ -160,7 +137,7 @@ Block HashJoinProbeBlockInputStream::getOutputBlock() if (!join->isEnableSpill()) { - block = children.back()->read(); + block = current_probe_stream->read(); } else { @@ -173,7 +150,7 @@ Block HashJoinProbeBlockInputStream::getOutputBlock() { if (join->isEnableSpill()) { - block = !join->isRestoreJoin() ? children.back()->read() : restore_probe_stream->read(); + block = current_probe_stream->read(); if (block) { join->dispatchProbeBlock(block, probe_partition_blocks); @@ -182,7 +159,7 @@ Block HashJoinProbeBlockInputStream::getOutputBlock() } join->finishOneProbe(); - if (join->needReturnNonJoinedData()) + if (need_output_non_joined_data) { status = ProbeStatus::WAIT_FOR_READ_NON_JOINED_DATA; } @@ -245,6 +222,7 @@ Block HashJoinProbeBlockInputStream::getOutputBlock() restore_build_stream = build_stream; restore_probe_stream = probe_stream; non_joined_stream = restore_non_joined_stream; + current_probe_stream = restore_probe_stream; if (non_joined_stream != nullptr) current_non_joined_stream_index = dynamic_cast(non_joined_stream.get())->getNonJoinedIndex(); } @@ -277,6 +255,7 @@ Block HashJoinProbeBlockInputStream::getOutputBlock() parents.pop_back(); restore_probe_stream = nullptr; restore_build_stream = nullptr; + current_probe_stream = nullptr; non_joined_stream = nullptr; } else diff --git a/dbms/src/DataStreams/HashJoinProbeBlockInputStream.h b/dbms/src/DataStreams/HashJoinProbeBlockInputStream.h index 977430ac9aa..78e7041d16e 100644 --- a/dbms/src/DataStreams/HashJoinProbeBlockInputStream.h +++ b/dbms/src/DataStreams/HashJoinProbeBlockInputStream.h @@ -70,7 +70,9 @@ class HashJoinProbeBlockInputStream : public IProfilingBlockInputStream std::mutex mutex; JoinPtr original_join; JoinPtr join; + const bool need_output_non_joined_data; size_t current_non_joined_stream_index; + BlockInputStreamPtr current_probe_stream; UInt64 max_block_size; ProbeProcessInfo probe_process_info; BlockInputStreamPtr non_joined_stream;