Skip to content

Commit

Permalink
Merge pull request #9578 from typhoonzero/threadpool_for_io
Browse files Browse the repository at this point in the history
Multi stream thread pool
  • Loading branch information
typhoonzero authored Apr 13, 2018
2 parents 2c552d4 + a08bf76 commit 1bdc726
Show file tree
Hide file tree
Showing 5 changed files with 63 additions and 27 deletions.
19 changes: 19 additions & 0 deletions paddle/fluid/framework/threadpool.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,12 @@

#include "paddle/fluid/framework/threadpool.h"

#include "gflags/gflags.h"
#include "paddle/fluid/platform/enforce.h"

DEFINE_int32(io_threadpool_size, 100,
"number of threads used for doing IO, default 100");

namespace paddle {
namespace framework {

Expand Down Expand Up @@ -91,5 +95,20 @@ void ThreadPool::TaskLoop() {
}
}

std::unique_ptr<ThreadPool> ThreadPoolIO::io_threadpool_(nullptr);
std::once_flag ThreadPoolIO::io_init_flag_;

ThreadPool* ThreadPoolIO::GetInstanceIO() {
std::call_once(io_init_flag_, &ThreadPoolIO::InitIO);
return io_threadpool_.get();
}

void ThreadPoolIO::InitIO() {
if (io_threadpool_.get() == nullptr) {
// TODO(typhoonzero1986): make this configurable
io_threadpool_.reset(new ThreadPool(FLAGS_io_threadpool_size));
}
}

} // namespace framework
} // namespace paddle
56 changes: 36 additions & 20 deletions paddle/fluid/framework/threadpool.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,12 @@ limitations under the License. */

#pragma once

#include <condition_variable>
#include <condition_variable> // NOLINT
#include <functional>
#include <future>
#include <mutex>
#include <future> // NOLINT
#include <mutex> // NOLINT
#include <queue>
#include <thread>
#include <thread> // NOLINT
#include <vector>
#include "glog/logging.h"
#include "paddle/fluid/platform/enforce.h"
Expand All @@ -28,6 +28,22 @@ limitations under the License. */
namespace paddle {
namespace framework {

struct ExceptionHandler {
mutable std::future<std::unique_ptr<platform::EnforceNotMet>> future_;
explicit ExceptionHandler(
std::future<std::unique_ptr<platform::EnforceNotMet>>&& f)
: future_(std::move(f)) {}
void operator()() const {
auto ex = this->future_.get();
if (ex != nullptr) {
LOG(FATAL) << "The exception is thrown inside the thread pool. You "
"should use RunAndGetException to handle the exception.\n"
"The default exception handler is LOG(FATAL)."
<< ex->what();
}
}
};

// ThreadPool maintains a queue of tasks, and runs them using a fixed
// number of threads.
class ThreadPool {
Expand Down Expand Up @@ -87,22 +103,6 @@ class ThreadPool {
void Wait();

private:
struct ExceptionHandler {
mutable std::future<std::unique_ptr<platform::EnforceNotMet>> future_;
explicit ExceptionHandler(
std::future<std::unique_ptr<platform::EnforceNotMet>>&& f)
: future_(std::move(f)) {}
void operator()() const {
auto ex = this->future_.get();
if (ex != nullptr) {
LOG(FATAL) << "The exception is thrown inside the thread pool. You "
"should use RunAndGetException to handle the exception.\n"
"The default exception handler is LOG(FATAL)."
<< ex->what();
}
}
};

DISABLE_COPY_AND_ASSIGN(ThreadPool);

// If the task queue is empty and avaialbe is equal to the number of
Expand Down Expand Up @@ -135,6 +135,17 @@ class ThreadPool {
std::condition_variable completed_;
};

class ThreadPoolIO : ThreadPool {
public:
static ThreadPool* GetInstanceIO();
static void InitIO();

private:
// NOTE: threadpool in base will be inhereted here.
static std::unique_ptr<ThreadPool> io_threadpool_;
static std::once_flag io_init_flag_;
};

// Run a function asynchronously.
// NOTE: The function must return void. If the function need to return a value,
// you can use lambda to capture a value pointer.
Expand All @@ -143,5 +154,10 @@ std::future<void> Async(Callback callback) {
return ThreadPool::GetInstance()->Run(callback);
}

template <typename Callback>
std::future<void> AsyncIO(Callback callback) {
return ThreadPoolIO::GetInstanceIO()->Run(callback);
}

} // namespace framework
} // namespace paddle
12 changes: 7 additions & 5 deletions paddle/fluid/operators/detail/grpc_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ bool RPCClient::AsyncSendVariable(const std::string& ep,
const framework::Scope* p_scope = &scope;
const auto ch = GetChannel(ep_val);

framework::Async([var_name_val, p_ctx, ep_val, p_scope, time_out, ch, this] {
framework::AsyncIO([var_name_val, p_ctx, ep_val, p_scope, time_out, ch,
this] {
auto* var = p_scope->FindVar(var_name_val);

::grpc::ByteBuffer req;
Expand Down Expand Up @@ -89,7 +90,8 @@ bool RPCClient::AsyncGetVariable(const std::string& ep,
const framework::Scope* p_scope = &scope;
const auto ch = GetChannel(ep_val);

framework::Async([var_name_val, ep_val, p_scope, p_ctx, time_out, ch, this] {
framework::AsyncIO([var_name_val, ep_val, p_scope, p_ctx, time_out, ch,
this] {
// prepare input
sendrecv::VariableMessage req;
req.set_varname(var_name_val);
Expand Down Expand Up @@ -132,8 +134,8 @@ bool RPCClient::AsyncPrefetchVariable(const std::string& ep,
const framework::Scope* p_scope = &scope;
const auto ch = GetChannel(ep_val);

framework::Async([in_var_name_val, out_var_name_val, ep_val, p_scope, p_ctx,
time_out, ch, this] {
framework::AsyncIO([in_var_name_val, out_var_name_val, ep_val, p_scope, p_ctx,
time_out, ch, this] {
auto* var = p_scope->FindVar(in_var_name_val);

::grpc::ByteBuffer req;
Expand Down Expand Up @@ -196,7 +198,7 @@ bool RPCClient::Wait() {
std::vector<std::future<void>> waits(req_count_);

for (int i = 0; i < req_count_; i++) {
waits[i] = framework::Async([i, &a, this] { a[i] = Proceed(); });
waits[i] = framework::AsyncIO([i, &a, this] { a[i] = Proceed(); });
}

for (int i = 0; i < req_count_; i++) {
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/operators/detail/grpc_server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -217,10 +217,10 @@ void AsyncGRPCServer::RunSyncUpdate() {
std::function<void()> prefetch_register =
std::bind(&AsyncGRPCServer::TryToRegisterNewPrefetchOne, this);

// TODO(wuyi): Run these "HandleRequest" in thread pool
t_send_.reset(
new std::thread(std::bind(&AsyncGRPCServer::HandleRequest, this,
cq_send_.get(), "cq_send", send_register)));

t_get_.reset(
new std::thread(std::bind(&AsyncGRPCServer::HandleRequest, this,
cq_get_.get(), "cq_get", get_register)));
Expand Down
1 change: 0 additions & 1 deletion python/paddle/fluid/tests/book/test_recognize_digits.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,6 @@ def train_loop(main_program):
for ip in pserver_ips.split(","):
eplist.append(':'.join([ip, port]))
pserver_endpoints = ",".join(eplist) # ip:port,ip:port...
pserver_endpoints = os.getenv("PSERVERS")
trainers = int(os.getenv("TRAINERS"))
current_endpoint = os.getenv("POD_IP") + ":" + port
trainer_id = int(os.getenv("PADDLE_INIT_TRAINER_ID"))
Expand Down

0 comments on commit 1bdc726

Please sign in to comment.