Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use ThreadFactory to auto set thread attributes #2496

Merged
merged 8 commits into from
Aug 2, 2021
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 71 additions & 0 deletions dbms/src/Common/ThreadFactory.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
#pragma once

#include <Common/MemoryTracker.h>
#include <Common/setThreadName.h>
#include <common/ThreadPool.h>
#include <thread>

namespace DB
{

/// ThreadFactory helps to set attributes on new threads or threadpool's jobs.
/// Current supported attributes:
/// 1. MemoryTracker
/// 2. ThreadName
///
/// ThreadFactory should only be constructed on stack.
class ThreadFactory
{
public:
/// force_overwrite_thread_attribute is only used for ThreadPool's jobs.
/// For new threads it is treated as always true.
explicit ThreadFactory(bool force_overwrite_thread_attribute = false, std::string thread_name_ = "")
: force_overwrite(force_overwrite_thread_attribute), thread_name(thread_name_) {}

ThreadFactory(const ThreadFactory &) = delete;
ThreadFactory & operator=(const ThreadFactory &) = delete;

ThreadFactory(ThreadFactory &&) = default;
ThreadFactory & operator=(ThreadFactory &&) = default;

template <typename F, typename ... Args>
std::thread newThread(F && f, Args &&... args)
{
auto memory_tracker = current_memory_tracker;
auto wrapped_func = [memory_tracker, thread_name = thread_name, f = std::move(f)](Args &&... args)
{
setAttributes(memory_tracker, thread_name, true);
return std::invoke(f, std::forward<Args>(args)...);
};
return std::thread(wrapped_func, std::forward<Args>(args)...);
}

template <typename F, typename ... Args>
ThreadPool::Job newJob(F && f, Args &&... args)
{
auto memory_tracker = current_memory_tracker;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Who will provide the current_memory_tracker in current thread?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently only ProcessListElement's ctor.

/// Use std::tuple to workaround the limit on the lambda's init-capture of C++17.
/// See https://stackoverflow.com/questions/47496358/c-lambdas-how-to-capture-variadic-parameter-pack-from-the-upper-scope
return [force_overwrite = force_overwrite, memory_tracker, thread_name = thread_name, f = std::move(f), args = std::make_tuple(std::move(args)...)]
{
setAttributes(memory_tracker, thread_name, force_overwrite);
return std::apply(f, std::move(args));
};
}
private:
static void setAttributes(MemoryTracker * memory_tracker, const std::string & thread_name, bool force_overwrite)
{
if (force_overwrite || !current_memory_tracker)
{
current_memory_tracker = memory_tracker;
if (!thread_name.empty())
setThreadName(thread_name.c_str());
}
}

bool force_overwrite = false;
std::string thread_name;
};

} // namespace DB

9 changes: 4 additions & 5 deletions dbms/src/DataStreams/AsynchronousBlockInputStream.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include <DataStreams/IProfilingBlockInputStream.h>
#include <Common/setThreadName.h>
#include <Common/CurrentMetrics.h>
#include <Common/ThreadFactory.h>
#include <common/ThreadPool.h>
#include <Common/MemoryTracker.h>

Expand Down Expand Up @@ -97,7 +98,7 @@ class AsynchronousBlockInputStream : public IProfilingBlockInputStream
/// If there were no calculations yet, calculate the first block synchronously
if (!started)
{
calculate(current_memory_tracker);
calculate();
started = true;
}
else /// If the calculations are already in progress - wait for the result
Expand All @@ -121,12 +122,12 @@ class AsynchronousBlockInputStream : public IProfilingBlockInputStream
void next()
{
ready.reset();
pool.schedule(std::bind(&AsynchronousBlockInputStream::calculate, this, current_memory_tracker));
pool.schedule(ThreadFactory(false, "AsyncBlockInput").newJob([this] { calculate(); }));
}


/// Calculations that can be performed in a separate thread
void calculate(MemoryTracker * memory_tracker)
void calculate()
{
CurrentMetrics::Increment metric_increment{CurrentMetrics::QueryThread};

Expand All @@ -135,8 +136,6 @@ class AsynchronousBlockInputStream : public IProfilingBlockInputStream
if (first)
{
first = false;
setThreadName("AsyncBlockInput");
current_memory_tracker = memory_tracker;
children.back()->readPrefix();
}

Expand Down
7 changes: 3 additions & 4 deletions dbms/src/DataStreams/CreatingSetsBlockInputStream.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include <Common/FailPoint.h>
#include <Common/ThreadFactory.h>
#include <DataStreams/CreatingSetsBlockInputStream.h>
#include <DataStreams/IBlockOutputStream.h>
#include <DataStreams/materializeBlock.h>
Expand Down Expand Up @@ -104,7 +105,7 @@ void CreatingSetsBlockInputStream::createAll()
{
if (isCancelledOrThrowIfKilled())
return;
workers.push_back(std::thread(&CreatingSetsBlockInputStream::createOne, this, std::ref(elem.second), current_memory_tracker));
workers.emplace_back(ThreadFactory().newThread([this, &subquery = elem.second]{ createOne(subquery); }));
FAIL_POINT_TRIGGER_EXCEPTION(FailPoints::exception_in_creating_set_input_stream);
}
}
Expand All @@ -121,12 +122,10 @@ void CreatingSetsBlockInputStream::createAll()
}
}

void CreatingSetsBlockInputStream::createOne(SubqueryForSet & subquery, MemoryTracker * memory_tracker)
void CreatingSetsBlockInputStream::createOne(SubqueryForSet & subquery)
{
try
{

current_memory_tracker = memory_tracker;
LOG_DEBUG(log,
(subquery.set ? "Creating set. " : "")
<< (subquery.join ? "Creating join. " : "") << (subquery.table ? "Filling temporary table. " : "") << " for task "
Expand Down
2 changes: 1 addition & 1 deletion dbms/src/DataStreams/CreatingSetsBlockInputStream.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ class CreatingSetsBlockInputStream : public IProfilingBlockInputStream
Logger * log = &Logger::get("CreatingSetsBlockInputStream");

void createAll();
void createOne(SubqueryForSet & subquery, MemoryTracker * memory_tracker);
void createOne(SubqueryForSet & subquery);
};

} // namespace DB
5 changes: 3 additions & 2 deletions dbms/src/DataStreams/DedupSortedBlockInputStream.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#include <Common/setThreadName.h>
#include <Common/CurrentMetrics.h>
#include <Common/ThreadFactory.h>

// #define DEDUP_TRACER
#ifndef DEDUP_TRACER
Expand Down Expand Up @@ -36,7 +37,7 @@ DedupSortedBlockInputStream::DedupSortedBlockInputStream(BlockInputStreams & inp
readers.schedule(std::bind(&DedupSortedBlockInputStream::asynFetch, this, i));

LOG_DEBUG(log, "Start deduping in single thread, using priority-queue");
dedup_thread = std::make_unique<std::thread>([this] { asynDedupByQueue(); });
dedup_thread = std::make_unique<std::thread>(ThreadFactory().newThread([this] { asyncDedupByQueue(); }));
}


Expand Down Expand Up @@ -105,7 +106,7 @@ void DedupSortedBlockInputStream::readFromSource(DedupCursors & output, BoundQue
}


void DedupSortedBlockInputStream::asynDedupByQueue()
void DedupSortedBlockInputStream::asyncDedupByQueue()
{
BoundQueue bounds;
DedupCursors cursors(source_blocks.size());
Expand Down
2 changes: 1 addition & 1 deletion dbms/src/DataStreams/DedupSortedBlockInputStream.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ class DedupSortedBlockInputStream : public IProfilingBlockInputStream
}

private:
void asynDedupByQueue();
void asyncDedupByQueue();
void asynFetch(size_t pisition);

void fetchBlock(size_t pisition);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#include <future>
#include <Common/setThreadName.h>
#include <Common/CurrentMetrics.h>
#include <Common/MemoryTracker.h>
#include <Common/ThreadFactory.h>
#include <DataStreams/MergingAggregatedMemoryEfficientBlockInputStream.h>


Expand Down Expand Up @@ -175,14 +175,12 @@ void MergingAggregatedMemoryEfficientBlockInputStream::start()
{
auto & child = children[i];

auto memory_tracker = current_memory_tracker;
reading_pool->schedule([&child, memory_tracker]
{
current_memory_tracker = memory_tracker;
setThreadName("MergeAggReadThr");
CurrentMetrics::Increment metric_increment{CurrentMetrics::QueryThread};
child->readPrefix();
});
reading_pool->schedule(
ThreadFactory(true, "MergeAggReadThr").newJob([&child]
{
CurrentMetrics::Increment metric_increment{CurrentMetrics::QueryThread};
child->readPrefix();
}));
}

reading_pool->wait();
Expand All @@ -196,8 +194,7 @@ void MergingAggregatedMemoryEfficientBlockInputStream::start()
*/

for (size_t i = 0; i < merging_threads; ++i)
pool.schedule(std::bind(&MergingAggregatedMemoryEfficientBlockInputStream::mergeThread,
this, current_memory_tracker));
pool.schedule(ThreadFactory(true, "MergeAggMergThr").newJob([this] { mergeThread(); }));
}
}

Expand Down Expand Up @@ -293,10 +290,8 @@ void MergingAggregatedMemoryEfficientBlockInputStream::finalize()
}


void MergingAggregatedMemoryEfficientBlockInputStream::mergeThread(MemoryTracker * memory_tracker)
void MergingAggregatedMemoryEfficientBlockInputStream::mergeThread()
{
setThreadName("MergeAggMergThr");
current_memory_tracker = memory_tracker;
CurrentMetrics::Increment metric_increment{CurrentMetrics::QueryThread};

try
Expand Down Expand Up @@ -480,14 +475,11 @@ MergingAggregatedMemoryEfficientBlockInputStream::BlocksToMerge MergingAggregate
{
if (need_that_input(input))
{
auto memory_tracker = current_memory_tracker;
reading_pool->schedule([&input, &read_from_input, memory_tracker]
reading_pool->schedule(ThreadFactory(true, "MergeAggReadThr").newJob([&input, &read_from_input]
{
current_memory_tracker = memory_tracker;
setThreadName("MergeAggReadThr");
CurrentMetrics::Increment metric_increment{CurrentMetrics::QueryThread};
read_from_input(input);
});
}));
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@
#include <condition_variable>


class MemoryTracker;

namespace DB
{

Expand Down Expand Up @@ -151,7 +149,7 @@ class MergingAggregatedMemoryEfficientBlockInputStream final : public IProfiling

std::unique_ptr<ParallelMergeData> parallel_merge_data;

void mergeThread(MemoryTracker * memory_tracker);
void mergeThread();

void finalize();
};
Expand Down
6 changes: 3 additions & 3 deletions dbms/src/DataStreams/ParallelInputsProcessor.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include <Common/setThreadName.h>
#include <Common/CurrentMetrics.h>
#include <Common/MemoryTracker.h>
#include <Common/ThreadFactory.h>


/** Allows to process multiple block input streams (sources) in parallel, using specified number of threads.
Expand Down Expand Up @@ -106,7 +107,7 @@ class ParallelInputsProcessor
active_threads = max_threads;
threads.reserve(max_threads);
for (size_t i = 0; i < max_threads; ++i)
threads.emplace_back(std::bind(&ParallelInputsProcessor::thread, this, current_memory_tracker, i));
threads.emplace_back(ThreadFactory(true, "ParalInputsProc").newThread([this, i]{ thread(i); }));
}

/// Ask all sources to stop earlier than they run out.
Expand Down Expand Up @@ -174,9 +175,8 @@ class ParallelInputsProcessor
}
}

void thread(MemoryTracker * memory_tracker, size_t thread_num)
void thread(size_t thread_num)
{
current_memory_tracker = memory_tracker;
std::exception_ptr exception;

setThreadName("ParalInputsProc");
Expand Down
3 changes: 2 additions & 1 deletion dbms/src/DataStreams/SharedQueryBlockInputStream.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include <thread>

#include <Common/ConcurrentBoundedQueue.h>
#include <Common/ThreadFactory.h>
#include <common/logger_useful.h>
#include <Common/typeid_cast.h>

Expand Down Expand Up @@ -56,7 +57,7 @@ class SharedQueryBlockInputStream : public IProfilingBlockInputStream
read_prefixed = true;

/// Start reading thread.
thread = std::thread(&SharedQueryBlockInputStream::fetchBlocks, this);
thread = ThreadFactory().newThread([this] { fetchBlocks(); });
}

void readSuffix() override
Expand Down
24 changes: 10 additions & 14 deletions dbms/src/Interpreters/Aggregator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include <Interpreters/Aggregator.h>
#include <Common/ClickHouseRevision.h>
#include <Common/MemoryTracker.h>
#include <Common/ThreadFactory.h>
#include <Common/typeid_cast.h>
#include <common/demangle.h>

Expand Down Expand Up @@ -1009,9 +1010,8 @@ BlocksList Aggregator::prepareBlocksAndFillTwoLevelImpl(
bool final,
ThreadPool * thread_pool) const
{
auto converter = [&](size_t bucket, MemoryTracker * memory_tracker)
auto converter = [&](size_t bucket)
{
current_memory_tracker = memory_tracker;
return convertOneBucketToBlock(data_variants, method, final, bucket);
};

Expand All @@ -1026,10 +1026,10 @@ BlocksList Aggregator::prepareBlocksAndFillTwoLevelImpl(
if (method.data.impls[bucket].empty())
continue;

tasks[bucket] = std::packaged_task<Block()>(std::bind(converter, bucket, current_memory_tracker));
tasks[bucket] = std::packaged_task<Block()>(std::bind(converter, bucket));

if (thread_pool)
thread_pool->schedule([bucket, &tasks] { tasks[bucket](); });
thread_pool->schedule(ThreadFactory().newJob([bucket, &tasks] { tasks[bucket](); }));
else
tasks[bucket]();
}
Expand Down Expand Up @@ -1439,14 +1439,12 @@ class MergingAndConvertingBlockInputStream : public IProfilingBlockInputStream
if (max_scheduled_bucket_num >= NUM_BUCKETS)
return;

parallel_merge_data->pool.schedule(std::bind(&MergingAndConvertingBlockInputStream::thread, this,
max_scheduled_bucket_num, current_memory_tracker));
parallel_merge_data->pool.schedule(
ThreadFactory(true, "MergingAggregtd").newJob([this]{ thread(max_scheduled_bucket_num); }));
}

void thread(Int32 bucket_num, MemoryTracker * memory_tracker)
void thread(Int32 bucket_num)
{
current_memory_tracker = memory_tracker;
setThreadName("MergingAggregtd");
CurrentMetrics::Increment metric_increment{CurrentMetrics::QueryThread};

try
Expand Down Expand Up @@ -1735,10 +1733,8 @@ void Aggregator::mergeStream(const BlockInputStreamPtr & stream, AggregatedDataV

LOG_TRACE(log, "Merging partially aggregated two-level data.");

auto merge_bucket = [&bucket_to_blocks, &result, this](Int32 bucket, Arena * aggregates_pool, MemoryTracker * memory_tracker)
auto merge_bucket = [&bucket_to_blocks, &result, this](Int32 bucket, Arena * aggregates_pool)
{
current_memory_tracker = memory_tracker;

for (Block & block : bucket_to_blocks[bucket])
{
if (isCancelled())
Expand Down Expand Up @@ -1771,10 +1767,10 @@ void Aggregator::mergeStream(const BlockInputStreamPtr & stream, AggregatedDataV
result.aggregates_pools.push_back(std::make_shared<Arena>());
Arena * aggregates_pool = result.aggregates_pools.back().get();

auto task = std::bind(merge_bucket, bucket, aggregates_pool, current_memory_tracker);
auto task = std::bind(merge_bucket, bucket, aggregates_pool);

if (thread_pool)
thread_pool->schedule(task);
thread_pool->schedule(ThreadFactory().newJob(task));
else
task();
}
Expand Down
Loading