Skip to content

Commit

Permalink
move MergingAndConvertingBlockInputStream out of Aggregate.cpp (#6531)
Browse files Browse the repository at this point in the history
ref #5900
  • Loading branch information
SeaRise authored Dec 23, 2022
1 parent 87d7c9b commit 1e7edb8
Show file tree
Hide file tree
Showing 2 changed files with 259 additions and 218 deletions.
225 changes: 7 additions & 218 deletions dbms/src/Interpreters/Aggregator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
#include <Encryption/WriteBufferFromFileProvider.h>
#include <IO/CompressedWriteBuffer.h>
#include <Interpreters/Aggregator.h>
#include <Interpreters/MergingAndConvertingBlockInputStream.h>
#include <Storages/Transaction/CollatorUtils.h>
#include <common/demangle.h>

Expand All @@ -56,7 +57,6 @@ extern const int LOGICAL_ERROR;
namespace FailPoints
{
extern const char random_aggregate_create_state_failpoint[];
extern const char random_aggregate_merge_failpoint[];
} // namespace FailPoints

#define AggregationMethodName(NAME) AggregatedDataVariants::AggregationMethod_##NAME
Expand Down Expand Up @@ -1773,223 +1773,6 @@ void NO_INLINE Aggregator::mergeBucketImpl(
}


/** Combines aggregation states together, turns them into blocks, and outputs streams.
* If the aggregation states are two-level, then it produces blocks strictly in order of 'bucket_num'.
* (This is important for distributed processing.)
* In doing so, it can handle different buckets in parallel, using up to `threads` threads.
*/
class MergingAndConvertingBlockInputStream : public IProfilingBlockInputStream
{
public:
/** The input is a set of non-empty sets of partially aggregated data,
* which are all either single-level, or are two-level.
*/
MergingAndConvertingBlockInputStream(const Aggregator & aggregator_, ManyAggregatedDataVariants & data_, bool final_, size_t threads_)
: log(Logger::get(aggregator_.log ? aggregator_.log->identifier() : ""))
, aggregator(aggregator_)
, data(data_)
, final(final_)
, threads(threads_)
{
/// At least we need one arena in first data item per thread
if (!data.empty() && threads > data[0]->aggregates_pools.size())
{
Arenas & first_pool = data[0]->aggregates_pools;
for (size_t j = first_pool.size(); j < threads; ++j)
first_pool.emplace_back(std::make_shared<Arena>());
}
}

String getName() const override { return "MergingAndConverting"; }

Block getHeader() const override { return aggregator.getHeader(final); }

~MergingAndConvertingBlockInputStream() override
{
LOG_TRACE(&Poco::Logger::get(__PRETTY_FUNCTION__), "Waiting for threads to finish");

/// We need to wait for threads to finish before destructor of 'parallel_merge_data',
/// because the threads access 'parallel_merge_data'.
if (parallel_merge_data && parallel_merge_data->thread_pool)
parallel_merge_data->thread_pool->wait();
}

protected:
Block readImpl() override
{
if (data.empty())
return {};

if (current_bucket_num >= NUM_BUCKETS)
return {};

FAIL_POINT_TRIGGER_EXCEPTION(FailPoints::random_aggregate_merge_failpoint);

AggregatedDataVariantsPtr & first = data[0];

if (current_bucket_num == -1)
{
++current_bucket_num;

if (first->type == AggregatedDataVariants::Type::without_key || aggregator.params.overflow_row)
{
aggregator.mergeWithoutKeyDataImpl(data);
return aggregator.prepareBlockAndFillWithoutKey(
*first,
final,
first->type != AggregatedDataVariants::Type::without_key);
}
}

if (!first->isTwoLevel())
{
if (current_bucket_num > 0)
return {};

if (first->type == AggregatedDataVariants::Type::without_key)
return {};

++current_bucket_num;

#define M(NAME) \
case AggregationMethodType(NAME): \
{ \
aggregator.mergeSingleLevelDataImpl<AggregationMethodName(NAME)>(data); \
break; \
}
switch (first->type)
{
APPLY_FOR_VARIANTS_SINGLE_LEVEL(M)
default:
throw Exception("Unknown aggregated data variant.", ErrorCodes::UNKNOWN_AGGREGATED_DATA_VARIANT);
}
#undef M
return aggregator.prepareBlockAndFillSingleLevel(*first, final);
}
else
{
if (!parallel_merge_data)
{
parallel_merge_data = std::make_unique<ParallelMergeData>(threads);
for (size_t i = 0; i < threads; ++i)
scheduleThreadForNextBucket();
}

Block res;

while (true)
{
std::unique_lock lock(parallel_merge_data->mutex);

if (parallel_merge_data->exception)
std::rethrow_exception(parallel_merge_data->exception);

auto it = parallel_merge_data->ready_blocks.find(current_bucket_num);
if (it != parallel_merge_data->ready_blocks.end())
{
++current_bucket_num;
scheduleThreadForNextBucket();

if (it->second)
{
res.swap(it->second);
break;
}
else if (current_bucket_num >= NUM_BUCKETS)
break;
}

parallel_merge_data->condvar.wait(lock);
}

return res;
}
}

private:
const LoggerPtr log;
const Aggregator & aggregator;
ManyAggregatedDataVariants data;
bool final;
size_t threads;

std::atomic<Int32> current_bucket_num = -1;
std::atomic<Int32> max_scheduled_bucket_num = -1;
static constexpr Int32 NUM_BUCKETS = 256;

struct ParallelMergeData
{
std::map<Int32, Block> ready_blocks;
std::exception_ptr exception;
std::mutex mutex;
std::condition_variable condvar;
std::shared_ptr<ThreadPoolManager> thread_pool;

explicit ParallelMergeData(size_t threads)
: thread_pool(newThreadPoolManager(threads))
{}
};

std::unique_ptr<ParallelMergeData> parallel_merge_data;

void scheduleThreadForNextBucket()
{
int num = max_scheduled_bucket_num.fetch_add(1) + 1;
if (num >= NUM_BUCKETS)
return;

parallel_merge_data->thread_pool->schedule(true, [this, num] { thread(num); });
}

void thread(Int32 bucket_num)
{
try
{
/// TODO: add no_more_keys support maybe

auto & merged_data = *data[0];
auto method = merged_data.type;
Block block;

/// Select Arena to avoid race conditions
size_t thread_number = static_cast<size_t>(bucket_num) % threads;
Arena * arena = merged_data.aggregates_pools.at(thread_number).get();

#define M(NAME) \
case AggregationMethodType(NAME): \
{ \
aggregator.mergeBucketImpl<AggregationMethodName(NAME)>(data, bucket_num, arena); \
block = aggregator.convertOneBucketToBlock( \
merged_data, \
*ToAggregationMethodPtr(NAME, merged_data.aggregation_method_impl), \
arena, \
final, \
bucket_num); \
break; \
}
switch (method)
{
APPLY_FOR_VARIANTS_TWO_LEVEL(M)
default:
break;
}
#undef M

std::lock_guard lock(parallel_merge_data->mutex);
parallel_merge_data->ready_blocks[bucket_num] = std::move(block);
}
catch (...)
{
std::lock_guard lock(parallel_merge_data->mutex);
if (!parallel_merge_data->exception)
parallel_merge_data->exception = std::current_exception();
}

parallel_merge_data->condvar.notify_all();
}
};


std::unique_ptr<IBlockInputStream> Aggregator::mergeAndConvertToBlocks(
ManyAggregatedDataVariants & data_variants,
bool final,
Expand Down Expand Up @@ -2751,5 +2534,11 @@ void Aggregator::setCancellationHook(CancellationHook cancellation_hook)
is_cancelled = cancellation_hook;
}

#undef AggregationMethodName
#undef AggregationMethodNameTwoLevel
#undef AggregationMethodType
#undef AggregationMethodTypeTwoLevel
#undef ToAggregationMethodPtr
#undef ToAggregationMethodPtrTwoLevel

} // namespace DB
Loading

0 comments on commit 1e7edb8

Please sign in to comment.