From bbce32a7fb316c241971b50eade9f3286e42c0db Mon Sep 17 00:00:00 2001 From: Fu Zhe Date: Mon, 2 Aug 2021 15:24:59 +0800 Subject: [PATCH] Use ThreadFactory to auto set thread attributes (#2496) --- dbms/src/Common/ThreadFactory.h | 71 +++++++++++++++++++ .../AsynchronousBlockInputStream.h | 9 ++- .../CreatingSetsBlockInputStream.cpp | 7 +- .../CreatingSetsBlockInputStream.h | 2 +- .../DedupSortedBlockInputStream.cpp | 5 +- .../DataStreams/DedupSortedBlockInputStream.h | 2 +- ...regatedMemoryEfficientBlockInputStream.cpp | 30 +++----- ...ggregatedMemoryEfficientBlockInputStream.h | 4 +- .../src/DataStreams/ParallelInputsProcessor.h | 6 +- .../DataStreams/SharedQueryBlockInputStream.h | 3 +- dbms/src/Interpreters/Aggregator.cpp | 19 +++-- .../DistributedBlockOutputStream.cpp | 12 +--- 12 files changed, 111 insertions(+), 59 deletions(-) create mode 100644 dbms/src/Common/ThreadFactory.h diff --git a/dbms/src/Common/ThreadFactory.h b/dbms/src/Common/ThreadFactory.h new file mode 100644 index 00000000000..59fc159ddcd --- /dev/null +++ b/dbms/src/Common/ThreadFactory.h @@ -0,0 +1,71 @@ +#pragma once + +#include +#include +#include +#include + +namespace DB +{ + +/// ThreadFactory helps to set attributes on new threads or threadpool's jobs. +/// Current supported attributes: +/// 1. MemoryTracker +/// 2. ThreadName +/// +/// ThreadFactory should only be constructed on stack. +class ThreadFactory +{ +public: + /// force_overwrite_thread_attribute is only used for ThreadPool's jobs. + /// For new threads it is treated as always true. + explicit ThreadFactory(bool force_overwrite_thread_attribute = false, std::string thread_name_ = "") + : force_overwrite(force_overwrite_thread_attribute), thread_name(thread_name_) {} + + ThreadFactory(const ThreadFactory &) = delete; + ThreadFactory & operator=(const ThreadFactory &) = delete; + + ThreadFactory(ThreadFactory &&) = default; + ThreadFactory & operator=(ThreadFactory &&) = default; + + template + std::thread newThread(F && f, Args &&... args) + { + auto memory_tracker = current_memory_tracker; + auto wrapped_func = [memory_tracker, thread_name = thread_name, f = std::move(f)](Args &&... args) + { + setAttributes(memory_tracker, thread_name, true); + return std::invoke(f, std::forward(args)...); + }; + return std::thread(wrapped_func, std::forward(args)...); + } + + template + ThreadPool::Job newJob(F && f, Args &&... args) + { + auto memory_tracker = current_memory_tracker; + /// Use std::tuple to workaround the limit on the lambda's init-capture of C++17. + /// See https://stackoverflow.com/questions/47496358/c-lambdas-how-to-capture-variadic-parameter-pack-from-the-upper-scope + return [force_overwrite = force_overwrite, memory_tracker, thread_name = thread_name, f = std::move(f), args = std::make_tuple(std::move(args)...)] + { + setAttributes(memory_tracker, thread_name, force_overwrite); + return std::apply(f, std::move(args)); + }; + } +private: + static void setAttributes(MemoryTracker * memory_tracker, const std::string & thread_name, bool force_overwrite) + { + if (force_overwrite || !current_memory_tracker) + { + current_memory_tracker = memory_tracker; + if (!thread_name.empty()) + setThreadName(thread_name.c_str()); + } + } + + bool force_overwrite = false; + std::string thread_name; +}; + +} // namespace DB + diff --git a/dbms/src/DataStreams/AsynchronousBlockInputStream.h b/dbms/src/DataStreams/AsynchronousBlockInputStream.h index 0a80628cf2a..935d330f999 100644 --- a/dbms/src/DataStreams/AsynchronousBlockInputStream.h +++ b/dbms/src/DataStreams/AsynchronousBlockInputStream.h @@ -5,6 +5,7 @@ #include #include #include +#include #include #include @@ -97,7 +98,7 @@ class AsynchronousBlockInputStream : public IProfilingBlockInputStream /// If there were no calculations yet, calculate the first block synchronously if (!started) { - calculate(current_memory_tracker); + calculate(); started = true; } else /// If the calculations are already in progress - wait for the result @@ -121,12 +122,12 @@ class AsynchronousBlockInputStream : public IProfilingBlockInputStream void next() { ready.reset(); - pool.schedule(std::bind(&AsynchronousBlockInputStream::calculate, this, current_memory_tracker)); + pool.schedule(ThreadFactory(false, "AsyncBlockInput").newJob([this] { calculate(); })); } /// Calculations that can be performed in a separate thread - void calculate(MemoryTracker * memory_tracker) + void calculate() { CurrentMetrics::Increment metric_increment{CurrentMetrics::QueryThread}; @@ -135,8 +136,6 @@ class AsynchronousBlockInputStream : public IProfilingBlockInputStream if (first) { first = false; - setThreadName("AsyncBlockInput"); - current_memory_tracker = memory_tracker; children.back()->readPrefix(); } diff --git a/dbms/src/DataStreams/CreatingSetsBlockInputStream.cpp b/dbms/src/DataStreams/CreatingSetsBlockInputStream.cpp index 3fcef57544d..68f7405ecb2 100644 --- a/dbms/src/DataStreams/CreatingSetsBlockInputStream.cpp +++ b/dbms/src/DataStreams/CreatingSetsBlockInputStream.cpp @@ -1,4 +1,5 @@ #include +#include #include #include #include @@ -104,7 +105,7 @@ void CreatingSetsBlockInputStream::createAll() { if (isCancelledOrThrowIfKilled()) return; - workers.push_back(std::thread(&CreatingSetsBlockInputStream::createOne, this, std::ref(elem.second), current_memory_tracker)); + workers.emplace_back(ThreadFactory().newThread([this, &subquery = elem.second]{ createOne(subquery); })); FAIL_POINT_TRIGGER_EXCEPTION(FailPoints::exception_in_creating_set_input_stream); } } @@ -121,12 +122,10 @@ void CreatingSetsBlockInputStream::createAll() } } -void CreatingSetsBlockInputStream::createOne(SubqueryForSet & subquery, MemoryTracker * memory_tracker) +void CreatingSetsBlockInputStream::createOne(SubqueryForSet & subquery) { try { - - current_memory_tracker = memory_tracker; LOG_DEBUG(log, (subquery.set ? "Creating set. " : "") << (subquery.join ? "Creating join. " : "") << (subquery.table ? "Filling temporary table. " : "") << " for task " diff --git a/dbms/src/DataStreams/CreatingSetsBlockInputStream.h b/dbms/src/DataStreams/CreatingSetsBlockInputStream.h index 4aa20af188a..e83658abfb4 100644 --- a/dbms/src/DataStreams/CreatingSetsBlockInputStream.h +++ b/dbms/src/DataStreams/CreatingSetsBlockInputStream.h @@ -67,7 +67,7 @@ class CreatingSetsBlockInputStream : public IProfilingBlockInputStream Logger * log = &Logger::get("CreatingSetsBlockInputStream"); void createAll(); - void createOne(SubqueryForSet & subquery, MemoryTracker * memory_tracker); + void createOne(SubqueryForSet & subquery); }; } // namespace DB diff --git a/dbms/src/DataStreams/DedupSortedBlockInputStream.cpp b/dbms/src/DataStreams/DedupSortedBlockInputStream.cpp index 7af6652ae58..36c8990b262 100644 --- a/dbms/src/DataStreams/DedupSortedBlockInputStream.cpp +++ b/dbms/src/DataStreams/DedupSortedBlockInputStream.cpp @@ -2,6 +2,7 @@ #include #include +#include // #define DEDUP_TRACER #ifndef DEDUP_TRACER @@ -36,7 +37,7 @@ DedupSortedBlockInputStream::DedupSortedBlockInputStream(BlockInputStreams & inp readers.schedule(std::bind(&DedupSortedBlockInputStream::asynFetch, this, i)); LOG_DEBUG(log, "Start deduping in single thread, using priority-queue"); - dedup_thread = std::make_unique([this] { asynDedupByQueue(); }); + dedup_thread = std::make_unique(ThreadFactory().newThread([this] { asyncDedupByQueue(); })); } @@ -105,7 +106,7 @@ void DedupSortedBlockInputStream::readFromSource(DedupCursors & output, BoundQue } -void DedupSortedBlockInputStream::asynDedupByQueue() +void DedupSortedBlockInputStream::asyncDedupByQueue() { BoundQueue bounds; DedupCursors cursors(source_blocks.size()); diff --git a/dbms/src/DataStreams/DedupSortedBlockInputStream.h b/dbms/src/DataStreams/DedupSortedBlockInputStream.h index 34caac503a7..56d9c70336f 100644 --- a/dbms/src/DataStreams/DedupSortedBlockInputStream.h +++ b/dbms/src/DataStreams/DedupSortedBlockInputStream.h @@ -49,7 +49,7 @@ class DedupSortedBlockInputStream : public IProfilingBlockInputStream } private: - void asynDedupByQueue(); + void asyncDedupByQueue(); void asynFetch(size_t pisition); void fetchBlock(size_t pisition); diff --git a/dbms/src/DataStreams/MergingAggregatedMemoryEfficientBlockInputStream.cpp b/dbms/src/DataStreams/MergingAggregatedMemoryEfficientBlockInputStream.cpp index bf436bb8547..9bdcd0898ab 100644 --- a/dbms/src/DataStreams/MergingAggregatedMemoryEfficientBlockInputStream.cpp +++ b/dbms/src/DataStreams/MergingAggregatedMemoryEfficientBlockInputStream.cpp @@ -1,7 +1,7 @@ #include #include #include -#include +#include #include @@ -175,14 +175,12 @@ void MergingAggregatedMemoryEfficientBlockInputStream::start() { auto & child = children[i]; - auto memory_tracker = current_memory_tracker; - reading_pool->schedule([&child, memory_tracker] - { - current_memory_tracker = memory_tracker; - setThreadName("MergeAggReadThr"); - CurrentMetrics::Increment metric_increment{CurrentMetrics::QueryThread}; - child->readPrefix(); - }); + reading_pool->schedule( + ThreadFactory(true, "MergeAggReadThr").newJob([&child] + { + CurrentMetrics::Increment metric_increment{CurrentMetrics::QueryThread}; + child->readPrefix(); + })); } reading_pool->wait(); @@ -196,8 +194,7 @@ void MergingAggregatedMemoryEfficientBlockInputStream::start() */ for (size_t i = 0; i < merging_threads; ++i) - pool.schedule(std::bind(&MergingAggregatedMemoryEfficientBlockInputStream::mergeThread, - this, current_memory_tracker)); + pool.schedule(ThreadFactory(true, "MergeAggMergThr").newJob([this] { mergeThread(); })); } } @@ -293,10 +290,8 @@ void MergingAggregatedMemoryEfficientBlockInputStream::finalize() } -void MergingAggregatedMemoryEfficientBlockInputStream::mergeThread(MemoryTracker * memory_tracker) +void MergingAggregatedMemoryEfficientBlockInputStream::mergeThread() { - setThreadName("MergeAggMergThr"); - current_memory_tracker = memory_tracker; CurrentMetrics::Increment metric_increment{CurrentMetrics::QueryThread}; try @@ -480,14 +475,11 @@ MergingAggregatedMemoryEfficientBlockInputStream::BlocksToMerge MergingAggregate { if (need_that_input(input)) { - auto memory_tracker = current_memory_tracker; - reading_pool->schedule([&input, &read_from_input, memory_tracker] + reading_pool->schedule(ThreadFactory(true, "MergeAggReadThr").newJob([&input, &read_from_input] { - current_memory_tracker = memory_tracker; - setThreadName("MergeAggReadThr"); CurrentMetrics::Increment metric_increment{CurrentMetrics::QueryThread}; read_from_input(input); - }); + })); } } diff --git a/dbms/src/DataStreams/MergingAggregatedMemoryEfficientBlockInputStream.h b/dbms/src/DataStreams/MergingAggregatedMemoryEfficientBlockInputStream.h index 837c10869cf..cd689408f42 100644 --- a/dbms/src/DataStreams/MergingAggregatedMemoryEfficientBlockInputStream.h +++ b/dbms/src/DataStreams/MergingAggregatedMemoryEfficientBlockInputStream.h @@ -7,8 +7,6 @@ #include -class MemoryTracker; - namespace DB { @@ -151,7 +149,7 @@ class MergingAggregatedMemoryEfficientBlockInputStream final : public IProfiling std::unique_ptr parallel_merge_data; - void mergeThread(MemoryTracker * memory_tracker); + void mergeThread(); void finalize(); }; diff --git a/dbms/src/DataStreams/ParallelInputsProcessor.h b/dbms/src/DataStreams/ParallelInputsProcessor.h index 115dc6e1a3a..1c372891804 100644 --- a/dbms/src/DataStreams/ParallelInputsProcessor.h +++ b/dbms/src/DataStreams/ParallelInputsProcessor.h @@ -12,6 +12,7 @@ #include #include #include +#include /** Allows to process multiple block input streams (sources) in parallel, using specified number of threads. @@ -106,7 +107,7 @@ class ParallelInputsProcessor active_threads = max_threads; threads.reserve(max_threads); for (size_t i = 0; i < max_threads; ++i) - threads.emplace_back(std::bind(&ParallelInputsProcessor::thread, this, current_memory_tracker, i)); + threads.emplace_back(ThreadFactory(true, "ParalInputsProc").newThread([this, i]{ thread(i); })); } /// Ask all sources to stop earlier than they run out. @@ -174,9 +175,8 @@ class ParallelInputsProcessor } } - void thread(MemoryTracker * memory_tracker, size_t thread_num) + void thread(size_t thread_num) { - current_memory_tracker = memory_tracker; std::exception_ptr exception; setThreadName("ParalInputsProc"); diff --git a/dbms/src/DataStreams/SharedQueryBlockInputStream.h b/dbms/src/DataStreams/SharedQueryBlockInputStream.h index edbe0e773d1..4307e41cc73 100644 --- a/dbms/src/DataStreams/SharedQueryBlockInputStream.h +++ b/dbms/src/DataStreams/SharedQueryBlockInputStream.h @@ -3,6 +3,7 @@ #include #include +#include #include #include @@ -56,7 +57,7 @@ class SharedQueryBlockInputStream : public IProfilingBlockInputStream read_prefixed = true; /// Start reading thread. - thread = std::thread(&SharedQueryBlockInputStream::fetchBlocks, this); + thread = ThreadFactory().newThread([this] { fetchBlocks(); }); } void readSuffix() override diff --git a/dbms/src/Interpreters/Aggregator.cpp b/dbms/src/Interpreters/Aggregator.cpp index c60853b5e62..90b1388b52c 100644 --- a/dbms/src/Interpreters/Aggregator.cpp +++ b/dbms/src/Interpreters/Aggregator.cpp @@ -23,6 +23,7 @@ #include #include #include +#include #include #include @@ -1175,7 +1176,7 @@ BlocksList Aggregator::prepareBlocksAndFillTwoLevelImpl( [thread_id, &converter] { return converter(thread_id); }); if (thread_pool) - thread_pool->schedule([thread_id, &tasks] { tasks[thread_id](); }); + thread_pool->schedule(ThreadFactory().newJob([thread_id, &tasks] { tasks[thread_id](); })); else tasks[thread_id](); } @@ -1594,14 +1595,12 @@ class MergingAndConvertingBlockInputStream : public IProfilingBlockInputStream if (max_scheduled_bucket_num >= NUM_BUCKETS) return; - parallel_merge_data->pool.schedule(std::bind(&MergingAndConvertingBlockInputStream::thread, this, - max_scheduled_bucket_num, current_memory_tracker)); + parallel_merge_data->pool.schedule( + ThreadFactory(true, "MergingAggregtd").newJob([this]{ thread(max_scheduled_bucket_num); })); } - void thread(Int32 bucket_num, MemoryTracker * memory_tracker) + void thread(Int32 bucket_num) { - current_memory_tracker = memory_tracker; - setThreadName("MergingAggregtd"); CurrentMetrics::Increment metric_increment{CurrentMetrics::QueryThread}; try @@ -1964,10 +1963,8 @@ void Aggregator::mergeStream(const BlockInputStreamPtr & stream, AggregatedDataV LOG_TRACE(log, "Merging partially aggregated two-level data."); - auto merge_bucket = [&bucket_to_blocks, &result, this](Int32 bucket, Arena * aggregates_pool, MemoryTracker * memory_tracker) + auto merge_bucket = [&bucket_to_blocks, &result, this](Int32 bucket, Arena * aggregates_pool) { - current_memory_tracker = memory_tracker; - for (Block & block : bucket_to_blocks[bucket]) { if (isCancelled()) @@ -2000,10 +1997,10 @@ void Aggregator::mergeStream(const BlockInputStreamPtr & stream, AggregatedDataV result.aggregates_pools.push_back(std::make_shared()); Arena * aggregates_pool = result.aggregates_pools.back().get(); - auto task = std::bind(merge_bucket, bucket, aggregates_pool, current_memory_tracker); + auto task = std::bind(merge_bucket, bucket, aggregates_pool); if (thread_pool) - thread_pool->schedule(task); + thread_pool->schedule(ThreadFactory().newJob(task)); else task(); } diff --git a/dbms/src/Storages/Distributed/DistributedBlockOutputStream.cpp b/dbms/src/Storages/Distributed/DistributedBlockOutputStream.cpp index d739badba88..da98128a5df 100644 --- a/dbms/src/Storages/Distributed/DistributedBlockOutputStream.cpp +++ b/dbms/src/Storages/Distributed/DistributedBlockOutputStream.cpp @@ -22,6 +22,7 @@ #include #include #include +#include #include #include #include @@ -190,8 +191,7 @@ void DistributedBlockOutputStream::waitForJobs() ThreadPool::Job DistributedBlockOutputStream::runWritingJob(DistributedBlockOutputStream::JobReplica & job, const Block & current_block) { - auto memory_tracker = current_memory_tracker; - return [this, memory_tracker, &job, ¤t_block]() + return ThreadFactory(false, "DistrOutStrProc").newJob([this, &job, ¤t_block] { ++job.blocks_started; @@ -203,12 +203,6 @@ ThreadPool::Job DistributedBlockOutputStream::runWritingJob(DistributedBlockOutp job.max_elapsed_time_for_block_ms = std::max(job.max_elapsed_time_for_block_ms, elapsed_time_for_block_ms); }); - if (!current_memory_tracker) - { - current_memory_tracker = memory_tracker; - setThreadName("DistrOutStrProc"); - } - const auto & shard_info = cluster->getShardsInfo()[job.shard_index]; size_t num_shards = cluster->getShardsInfo().size(); auto & shard_job = per_shard_jobs[job.shard_index]; @@ -295,7 +289,7 @@ ThreadPool::Job DistributedBlockOutputStream::runWritingJob(DistributedBlockOutp job.blocks_written += 1; job.rows_written += shard_block.rows(); - }; + }); }