diff --git a/dbms/src/Interpreters/Aggregator.cpp b/dbms/src/Interpreters/Aggregator.cpp index b7193833031..4a9ffc1c993 100644 --- a/dbms/src/Interpreters/Aggregator.cpp +++ b/dbms/src/Interpreters/Aggregator.cpp @@ -32,6 +32,7 @@ #include #include #include +#include #include #include @@ -56,7 +57,6 @@ extern const int LOGICAL_ERROR; namespace FailPoints { extern const char random_aggregate_create_state_failpoint[]; -extern const char random_aggregate_merge_failpoint[]; } // namespace FailPoints #define AggregationMethodName(NAME) AggregatedDataVariants::AggregationMethod_##NAME @@ -1773,223 +1773,6 @@ void NO_INLINE Aggregator::mergeBucketImpl( } -/** Combines aggregation states together, turns them into blocks, and outputs streams. - * If the aggregation states are two-level, then it produces blocks strictly in order of 'bucket_num'. - * (This is important for distributed processing.) - * In doing so, it can handle different buckets in parallel, using up to `threads` threads. - */ -class MergingAndConvertingBlockInputStream : public IProfilingBlockInputStream -{ -public: - /** The input is a set of non-empty sets of partially aggregated data, - * which are all either single-level, or are two-level. - */ - MergingAndConvertingBlockInputStream(const Aggregator & aggregator_, ManyAggregatedDataVariants & data_, bool final_, size_t threads_) - : log(Logger::get(aggregator_.log ? aggregator_.log->identifier() : "")) - , aggregator(aggregator_) - , data(data_) - , final(final_) - , threads(threads_) - { - /// At least we need one arena in first data item per thread - if (!data.empty() && threads > data[0]->aggregates_pools.size()) - { - Arenas & first_pool = data[0]->aggregates_pools; - for (size_t j = first_pool.size(); j < threads; ++j) - first_pool.emplace_back(std::make_shared()); - } - } - - String getName() const override { return "MergingAndConverting"; } - - Block getHeader() const override { return aggregator.getHeader(final); } - - ~MergingAndConvertingBlockInputStream() override - { - LOG_TRACE(&Poco::Logger::get(__PRETTY_FUNCTION__), "Waiting for threads to finish"); - - /// We need to wait for threads to finish before destructor of 'parallel_merge_data', - /// because the threads access 'parallel_merge_data'. - if (parallel_merge_data && parallel_merge_data->thread_pool) - parallel_merge_data->thread_pool->wait(); - } - -protected: - Block readImpl() override - { - if (data.empty()) - return {}; - - if (current_bucket_num >= NUM_BUCKETS) - return {}; - - FAIL_POINT_TRIGGER_EXCEPTION(FailPoints::random_aggregate_merge_failpoint); - - AggregatedDataVariantsPtr & first = data[0]; - - if (current_bucket_num == -1) - { - ++current_bucket_num; - - if (first->type == AggregatedDataVariants::Type::without_key || aggregator.params.overflow_row) - { - aggregator.mergeWithoutKeyDataImpl(data); - return aggregator.prepareBlockAndFillWithoutKey( - *first, - final, - first->type != AggregatedDataVariants::Type::without_key); - } - } - - if (!first->isTwoLevel()) - { - if (current_bucket_num > 0) - return {}; - - if (first->type == AggregatedDataVariants::Type::without_key) - return {}; - - ++current_bucket_num; - -#define M(NAME) \ - case AggregationMethodType(NAME): \ - { \ - aggregator.mergeSingleLevelDataImpl(data); \ - break; \ - } - switch (first->type) - { - APPLY_FOR_VARIANTS_SINGLE_LEVEL(M) - default: - throw Exception("Unknown aggregated data variant.", ErrorCodes::UNKNOWN_AGGREGATED_DATA_VARIANT); - } -#undef M - return aggregator.prepareBlockAndFillSingleLevel(*first, final); - } - else - { - if (!parallel_merge_data) - { - parallel_merge_data = std::make_unique(threads); - for (size_t i = 0; i < threads; ++i) - scheduleThreadForNextBucket(); - } - - Block res; - - while (true) - { - std::unique_lock lock(parallel_merge_data->mutex); - - if (parallel_merge_data->exception) - std::rethrow_exception(parallel_merge_data->exception); - - auto it = parallel_merge_data->ready_blocks.find(current_bucket_num); - if (it != parallel_merge_data->ready_blocks.end()) - { - ++current_bucket_num; - scheduleThreadForNextBucket(); - - if (it->second) - { - res.swap(it->second); - break; - } - else if (current_bucket_num >= NUM_BUCKETS) - break; - } - - parallel_merge_data->condvar.wait(lock); - } - - return res; - } - } - -private: - const LoggerPtr log; - const Aggregator & aggregator; - ManyAggregatedDataVariants data; - bool final; - size_t threads; - - std::atomic current_bucket_num = -1; - std::atomic max_scheduled_bucket_num = -1; - static constexpr Int32 NUM_BUCKETS = 256; - - struct ParallelMergeData - { - std::map ready_blocks; - std::exception_ptr exception; - std::mutex mutex; - std::condition_variable condvar; - std::shared_ptr thread_pool; - - explicit ParallelMergeData(size_t threads) - : thread_pool(newThreadPoolManager(threads)) - {} - }; - - std::unique_ptr parallel_merge_data; - - void scheduleThreadForNextBucket() - { - int num = max_scheduled_bucket_num.fetch_add(1) + 1; - if (num >= NUM_BUCKETS) - return; - - parallel_merge_data->thread_pool->schedule(true, [this, num] { thread(num); }); - } - - void thread(Int32 bucket_num) - { - try - { - /// TODO: add no_more_keys support maybe - - auto & merged_data = *data[0]; - auto method = merged_data.type; - Block block; - - /// Select Arena to avoid race conditions - size_t thread_number = static_cast(bucket_num) % threads; - Arena * arena = merged_data.aggregates_pools.at(thread_number).get(); - -#define M(NAME) \ - case AggregationMethodType(NAME): \ - { \ - aggregator.mergeBucketImpl(data, bucket_num, arena); \ - block = aggregator.convertOneBucketToBlock( \ - merged_data, \ - *ToAggregationMethodPtr(NAME, merged_data.aggregation_method_impl), \ - arena, \ - final, \ - bucket_num); \ - break; \ - } - switch (method) - { - APPLY_FOR_VARIANTS_TWO_LEVEL(M) - default: - break; - } -#undef M - - std::lock_guard lock(parallel_merge_data->mutex); - parallel_merge_data->ready_blocks[bucket_num] = std::move(block); - } - catch (...) - { - std::lock_guard lock(parallel_merge_data->mutex); - if (!parallel_merge_data->exception) - parallel_merge_data->exception = std::current_exception(); - } - - parallel_merge_data->condvar.notify_all(); - } -}; - - std::unique_ptr Aggregator::mergeAndConvertToBlocks( ManyAggregatedDataVariants & data_variants, bool final, @@ -2751,5 +2534,11 @@ void Aggregator::setCancellationHook(CancellationHook cancellation_hook) is_cancelled = cancellation_hook; } +#undef AggregationMethodName +#undef AggregationMethodNameTwoLevel +#undef AggregationMethodType +#undef AggregationMethodTypeTwoLevel +#undef ToAggregationMethodPtr +#undef ToAggregationMethodPtrTwoLevel } // namespace DB diff --git a/dbms/src/Interpreters/MergingAndConvertingBlockInputStream.h b/dbms/src/Interpreters/MergingAndConvertingBlockInputStream.h new file mode 100644 index 00000000000..7e58eb8da81 --- /dev/null +++ b/dbms/src/Interpreters/MergingAndConvertingBlockInputStream.h @@ -0,0 +1,252 @@ +// Copyright 2022 PingCAP, Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include + +namespace DB +{ +namespace FailPoints +{ +extern const char random_aggregate_merge_failpoint[]; +} // namespace FailPoints + +#define AggregationMethodName(NAME) AggregatedDataVariants::AggregationMethod_##NAME +#define AggregationMethodType(NAME) AggregatedDataVariants::Type::NAME +#define ToAggregationMethodPtr(NAME, ptr) (reinterpret_cast(ptr)) + +/** Combines aggregation states together, turns them into blocks, and outputs streams. + * If the aggregation states are two-level, then it produces blocks strictly in order of 'bucket_num'. + * (This is important for distributed processing.) + * In doing so, it can handle different buckets in parallel, using up to `threads` threads. + */ +class MergingAndConvertingBlockInputStream : public IProfilingBlockInputStream +{ +public: + /** The input is a set of non-empty sets of partially aggregated data, + * which are all either single-level, or are two-level. + */ + MergingAndConvertingBlockInputStream(const Aggregator & aggregator_, ManyAggregatedDataVariants & data_, bool final_, size_t threads_) + : log(Logger::get(aggregator_.log ? aggregator_.log->identifier() : "")) + , aggregator(aggregator_) + , data(data_) + , final(final_) + , threads(threads_) + { + /// At least we need one arena in first data item per thread + if (!data.empty() && threads > data[0]->aggregates_pools.size()) + { + Arenas & first_pool = data[0]->aggregates_pools; + for (size_t j = first_pool.size(); j < threads; ++j) + first_pool.emplace_back(std::make_shared()); + } + } + + String getName() const override { return "MergingAndConverting"; } + + Block getHeader() const override { return aggregator.getHeader(final); } + + ~MergingAndConvertingBlockInputStream() override + { + LOG_TRACE(&Poco::Logger::get(__PRETTY_FUNCTION__), "Waiting for threads to finish"); + + /// We need to wait for threads to finish before destructor of 'parallel_merge_data', + /// because the threads access 'parallel_merge_data'. + if (parallel_merge_data && parallel_merge_data->thread_pool) + parallel_merge_data->thread_pool->wait(); + } + +protected: + Block readImpl() override + { + if (data.empty()) + return {}; + + if (current_bucket_num >= NUM_BUCKETS) + return {}; + + FAIL_POINT_TRIGGER_EXCEPTION(FailPoints::random_aggregate_merge_failpoint); + + AggregatedDataVariantsPtr & first = data[0]; + + if (current_bucket_num == -1) + { + ++current_bucket_num; + + if (first->type == AggregatedDataVariants::Type::without_key || aggregator.params.overflow_row) + { + aggregator.mergeWithoutKeyDataImpl(data); + return aggregator.prepareBlockAndFillWithoutKey( + *first, + final, + first->type != AggregatedDataVariants::Type::without_key); + } + } + + if (!first->isTwoLevel()) + { + if (current_bucket_num > 0) + return {}; + + if (first->type == AggregatedDataVariants::Type::without_key) + return {}; + + ++current_bucket_num; + +#define M(NAME) \ + case AggregationMethodType(NAME): \ + { \ + aggregator.mergeSingleLevelDataImpl(data); \ + break; \ + } + switch (first->type) + { + APPLY_FOR_VARIANTS_SINGLE_LEVEL(M) + default: + throw Exception("Unknown aggregated data variant.", ErrorCodes::UNKNOWN_AGGREGATED_DATA_VARIANT); + } +#undef M + return aggregator.prepareBlockAndFillSingleLevel(*first, final); + } + else + { + if (!parallel_merge_data) + { + parallel_merge_data = std::make_unique(threads); + for (size_t i = 0; i < threads; ++i) + scheduleThreadForNextBucket(); + } + + Block res; + + while (true) + { + std::unique_lock lock(parallel_merge_data->mutex); + + if (parallel_merge_data->exception) + std::rethrow_exception(parallel_merge_data->exception); + + auto it = parallel_merge_data->ready_blocks.find(current_bucket_num); + if (it != parallel_merge_data->ready_blocks.end()) + { + ++current_bucket_num; + scheduleThreadForNextBucket(); + + if (it->second) + { + res.swap(it->second); + break; + } + else if (current_bucket_num >= NUM_BUCKETS) + break; + } + + parallel_merge_data->condvar.wait(lock); + } + + return res; + } + } + +private: + const LoggerPtr log; + const Aggregator & aggregator; + ManyAggregatedDataVariants data; + bool final; + size_t threads; + + std::atomic current_bucket_num = -1; + std::atomic max_scheduled_bucket_num = -1; + static constexpr Int32 NUM_BUCKETS = 256; + + struct ParallelMergeData + { + std::map ready_blocks; + std::exception_ptr exception; + std::mutex mutex; + std::condition_variable condvar; + std::shared_ptr thread_pool; + + explicit ParallelMergeData(size_t threads) + : thread_pool(newThreadPoolManager(threads)) + {} + }; + + std::unique_ptr parallel_merge_data; + + void scheduleThreadForNextBucket() + { + int num = max_scheduled_bucket_num.fetch_add(1) + 1; + if (num >= NUM_BUCKETS) + return; + + parallel_merge_data->thread_pool->schedule(true, [this, num] { thread(num); }); + } + + void thread(Int32 bucket_num) + { + try + { + /// TODO: add no_more_keys support maybe + + auto & merged_data = *data[0]; + auto method = merged_data.type; + Block block; + + /// Select Arena to avoid race conditions + size_t thread_number = static_cast(bucket_num) % threads; + Arena * arena = merged_data.aggregates_pools.at(thread_number).get(); + +#define M(NAME) \ + case AggregationMethodType(NAME): \ + { \ + aggregator.mergeBucketImpl(data, bucket_num, arena); \ + block = aggregator.convertOneBucketToBlock( \ + merged_data, \ + *ToAggregationMethodPtr(NAME, merged_data.aggregation_method_impl), \ + arena, \ + final, \ + bucket_num); \ + break; \ + } + switch (method) + { + APPLY_FOR_VARIANTS_TWO_LEVEL(M) + default: + break; + } +#undef M + + std::lock_guard lock(parallel_merge_data->mutex); + parallel_merge_data->ready_blocks[bucket_num] = std::move(block); + } + catch (...) + { + std::lock_guard lock(parallel_merge_data->mutex); + if (!parallel_merge_data->exception) + parallel_merge_data->exception = std::current_exception(); + } + + parallel_merge_data->condvar.notify_all(); + } +}; + +#undef AggregationMethodName +#undef AggregationMethodType +#undef ToAggregationMethodPtr +} // namespace DB