Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

move MergingAndConvertingBlockInputStream out of Aggregate.cpp #6531

Merged
merged 5 commits into from
Dec 23, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
225 changes: 7 additions & 218 deletions dbms/src/Interpreters/Aggregator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
#include <Encryption/WriteBufferFromFileProvider.h>
#include <IO/CompressedWriteBuffer.h>
#include <Interpreters/Aggregator.h>
#include <Interpreters/MergingAndConvertingBlockInputStream.h>
#include <Storages/Transaction/CollatorUtils.h>
#include <common/demangle.h>

Expand All @@ -56,7 +57,6 @@ extern const int LOGICAL_ERROR;
namespace FailPoints
{
extern const char random_aggregate_create_state_failpoint[];
extern const char random_aggregate_merge_failpoint[];
} // namespace FailPoints

#define AggregationMethodName(NAME) AggregatedDataVariants::AggregationMethod_##NAME
Expand Down Expand Up @@ -1773,223 +1773,6 @@ void NO_INLINE Aggregator::mergeBucketImpl(
}


/** Combines aggregation states together, turns them into blocks, and outputs streams.
* If the aggregation states are two-level, then it produces blocks strictly in order of 'bucket_num'.
* (This is important for distributed processing.)
* In doing so, it can handle different buckets in parallel, using up to `threads` threads.
*/
class MergingAndConvertingBlockInputStream : public IProfilingBlockInputStream
{
public:
/** The input is a set of non-empty sets of partially aggregated data,
* which are all either single-level, or are two-level.
*/
MergingAndConvertingBlockInputStream(const Aggregator & aggregator_, ManyAggregatedDataVariants & data_, bool final_, size_t threads_)
: log(Logger::get(aggregator_.log ? aggregator_.log->identifier() : ""))
, aggregator(aggregator_)
, data(data_)
, final(final_)
, threads(threads_)
{
/// At least we need one arena in first data item per thread
if (!data.empty() && threads > data[0]->aggregates_pools.size())
{
Arenas & first_pool = data[0]->aggregates_pools;
for (size_t j = first_pool.size(); j < threads; ++j)
first_pool.emplace_back(std::make_shared<Arena>());
}
}

String getName() const override { return "MergingAndConverting"; }

Block getHeader() const override { return aggregator.getHeader(final); }

~MergingAndConvertingBlockInputStream() override
{
LOG_TRACE(&Poco::Logger::get(__PRETTY_FUNCTION__), "Waiting for threads to finish");

/// We need to wait for threads to finish before destructor of 'parallel_merge_data',
/// because the threads access 'parallel_merge_data'.
if (parallel_merge_data && parallel_merge_data->thread_pool)
parallel_merge_data->thread_pool->wait();
}

protected:
Block readImpl() override
{
if (data.empty())
return {};

if (current_bucket_num >= NUM_BUCKETS)
return {};

FAIL_POINT_TRIGGER_EXCEPTION(FailPoints::random_aggregate_merge_failpoint);

AggregatedDataVariantsPtr & first = data[0];

if (current_bucket_num == -1)
{
++current_bucket_num;

if (first->type == AggregatedDataVariants::Type::without_key || aggregator.params.overflow_row)
{
aggregator.mergeWithoutKeyDataImpl(data);
return aggregator.prepareBlockAndFillWithoutKey(
*first,
final,
first->type != AggregatedDataVariants::Type::without_key);
}
}

if (!first->isTwoLevel())
{
if (current_bucket_num > 0)
return {};

if (first->type == AggregatedDataVariants::Type::without_key)
return {};

++current_bucket_num;

#define M(NAME) \
case AggregationMethodType(NAME): \
{ \
aggregator.mergeSingleLevelDataImpl<AggregationMethodName(NAME)>(data); \
break; \
}
switch (first->type)
{
APPLY_FOR_VARIANTS_SINGLE_LEVEL(M)
default:
throw Exception("Unknown aggregated data variant.", ErrorCodes::UNKNOWN_AGGREGATED_DATA_VARIANT);
}
#undef M
return aggregator.prepareBlockAndFillSingleLevel(*first, final);
}
else
{
if (!parallel_merge_data)
{
parallel_merge_data = std::make_unique<ParallelMergeData>(threads);
for (size_t i = 0; i < threads; ++i)
scheduleThreadForNextBucket();
}

Block res;

while (true)
{
std::unique_lock lock(parallel_merge_data->mutex);

if (parallel_merge_data->exception)
std::rethrow_exception(parallel_merge_data->exception);

auto it = parallel_merge_data->ready_blocks.find(current_bucket_num);
if (it != parallel_merge_data->ready_blocks.end())
{
++current_bucket_num;
scheduleThreadForNextBucket();

if (it->second)
{
res.swap(it->second);
break;
}
else if (current_bucket_num >= NUM_BUCKETS)
break;
}

parallel_merge_data->condvar.wait(lock);
}

return res;
}
}

private:
const LoggerPtr log;
const Aggregator & aggregator;
ManyAggregatedDataVariants data;
bool final;
size_t threads;

std::atomic<Int32> current_bucket_num = -1;
std::atomic<Int32> max_scheduled_bucket_num = -1;
static constexpr Int32 NUM_BUCKETS = 256;

struct ParallelMergeData
{
std::map<Int32, Block> ready_blocks;
std::exception_ptr exception;
std::mutex mutex;
std::condition_variable condvar;
std::shared_ptr<ThreadPoolManager> thread_pool;

explicit ParallelMergeData(size_t threads)
: thread_pool(newThreadPoolManager(threads))
{}
};

std::unique_ptr<ParallelMergeData> parallel_merge_data;

void scheduleThreadForNextBucket()
{
int num = max_scheduled_bucket_num.fetch_add(1) + 1;
if (num >= NUM_BUCKETS)
return;

parallel_merge_data->thread_pool->schedule(true, [this, num] { thread(num); });
}

void thread(Int32 bucket_num)
{
try
{
/// TODO: add no_more_keys support maybe

auto & merged_data = *data[0];
auto method = merged_data.type;
Block block;

/// Select Arena to avoid race conditions
size_t thread_number = static_cast<size_t>(bucket_num) % threads;
Arena * arena = merged_data.aggregates_pools.at(thread_number).get();

#define M(NAME) \
case AggregationMethodType(NAME): \
{ \
aggregator.mergeBucketImpl<AggregationMethodName(NAME)>(data, bucket_num, arena); \
block = aggregator.convertOneBucketToBlock( \
merged_data, \
*ToAggregationMethodPtr(NAME, merged_data.aggregation_method_impl), \
arena, \
final, \
bucket_num); \
break; \
}
switch (method)
{
APPLY_FOR_VARIANTS_TWO_LEVEL(M)
default:
break;
}
#undef M

std::lock_guard lock(parallel_merge_data->mutex);
parallel_merge_data->ready_blocks[bucket_num] = std::move(block);
}
catch (...)
{
std::lock_guard lock(parallel_merge_data->mutex);
if (!parallel_merge_data->exception)
parallel_merge_data->exception = std::current_exception();
}

parallel_merge_data->condvar.notify_all();
}
};


std::unique_ptr<IBlockInputStream> Aggregator::mergeAndConvertToBlocks(
ManyAggregatedDataVariants & data_variants,
bool final,
Expand Down Expand Up @@ -2751,5 +2534,11 @@ void Aggregator::setCancellationHook(CancellationHook cancellation_hook)
is_cancelled = cancellation_hook;
}

#undef AggregationMethodName
#undef AggregationMethodNameTwoLevel
#undef AggregationMethodType
#undef AggregationMethodTypeTwoLevel
#undef ToAggregationMethodPtr
#undef ToAggregationMethodPtrTwoLevel

} // namespace DB
Loading