diff --git a/paddle/fluid/framework/threadpool.cc b/paddle/fluid/framework/threadpool.cc index 9854d618d2b29e..f26f212d4d5793 100644 --- a/paddle/fluid/framework/threadpool.cc +++ b/paddle/fluid/framework/threadpool.cc @@ -14,8 +14,12 @@ #include "paddle/fluid/framework/threadpool.h" +#include "gflags/gflags.h" #include "paddle/fluid/platform/enforce.h" +DEFINE_int32(io_threadpool_size, 100, + "number of threads used for doing IO, default 100"); + namespace paddle { namespace framework { @@ -91,5 +95,20 @@ void ThreadPool::TaskLoop() { } } +std::unique_ptr ThreadPoolIO::io_threadpool_(nullptr); +std::once_flag ThreadPoolIO::io_init_flag_; + +ThreadPool* ThreadPoolIO::GetInstanceIO() { + std::call_once(io_init_flag_, &ThreadPoolIO::InitIO); + return io_threadpool_.get(); +} + +void ThreadPoolIO::InitIO() { + if (io_threadpool_.get() == nullptr) { + // TODO(typhoonzero1986): make this configurable + io_threadpool_.reset(new ThreadPool(FLAGS_io_threadpool_size)); + } +} + } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/threadpool.h b/paddle/fluid/framework/threadpool.h index f9dce7105e32ff..94111ee335b1a5 100644 --- a/paddle/fluid/framework/threadpool.h +++ b/paddle/fluid/framework/threadpool.h @@ -14,12 +14,12 @@ limitations under the License. */ #pragma once -#include +#include // NOLINT #include -#include -#include +#include // NOLINT +#include // NOLINT #include -#include +#include // NOLINT #include #include "glog/logging.h" #include "paddle/fluid/platform/enforce.h" @@ -28,6 +28,22 @@ limitations under the License. */ namespace paddle { namespace framework { +struct ExceptionHandler { + mutable std::future> future_; + explicit ExceptionHandler( + std::future>&& f) + : future_(std::move(f)) {} + void operator()() const { + auto ex = this->future_.get(); + if (ex != nullptr) { + LOG(FATAL) << "The exception is thrown inside the thread pool. You " + "should use RunAndGetException to handle the exception.\n" + "The default exception handler is LOG(FATAL)." + << ex->what(); + } + } +}; + // ThreadPool maintains a queue of tasks, and runs them using a fixed // number of threads. class ThreadPool { @@ -87,22 +103,6 @@ class ThreadPool { void Wait(); private: - struct ExceptionHandler { - mutable std::future> future_; - explicit ExceptionHandler( - std::future>&& f) - : future_(std::move(f)) {} - void operator()() const { - auto ex = this->future_.get(); - if (ex != nullptr) { - LOG(FATAL) << "The exception is thrown inside the thread pool. You " - "should use RunAndGetException to handle the exception.\n" - "The default exception handler is LOG(FATAL)." - << ex->what(); - } - } - }; - DISABLE_COPY_AND_ASSIGN(ThreadPool); // If the task queue is empty and avaialbe is equal to the number of @@ -135,6 +135,17 @@ class ThreadPool { std::condition_variable completed_; }; +class ThreadPoolIO : ThreadPool { + public: + static ThreadPool* GetInstanceIO(); + static void InitIO(); + + private: + // NOTE: threadpool in base will be inhereted here. + static std::unique_ptr io_threadpool_; + static std::once_flag io_init_flag_; +}; + // Run a function asynchronously. // NOTE: The function must return void. If the function need to return a value, // you can use lambda to capture a value pointer. @@ -143,5 +154,10 @@ std::future Async(Callback callback) { return ThreadPool::GetInstance()->Run(callback); } +template +std::future AsyncIO(Callback callback) { + return ThreadPoolIO::GetInstanceIO()->Run(callback); +} + } // namespace framework } // namespace paddle diff --git a/paddle/fluid/operators/detail/grpc_client.cc b/paddle/fluid/operators/detail/grpc_client.cc index 8bbfd1f1592599..b546aa1d2f8aff 100644 --- a/paddle/fluid/operators/detail/grpc_client.cc +++ b/paddle/fluid/operators/detail/grpc_client.cc @@ -35,7 +35,8 @@ bool RPCClient::AsyncSendVariable(const std::string& ep, const framework::Scope* p_scope = &scope; const auto ch = GetChannel(ep_val); - framework::Async([var_name_val, p_ctx, ep_val, p_scope, time_out, ch, this] { + framework::AsyncIO([var_name_val, p_ctx, ep_val, p_scope, time_out, ch, + this] { auto* var = p_scope->FindVar(var_name_val); ::grpc::ByteBuffer req; @@ -90,7 +91,8 @@ bool RPCClient::AsyncGetVariable(const std::string& ep, const framework::Scope* p_scope = &scope; const auto ch = GetChannel(ep_val); - framework::Async([var_name_val, ep_val, p_scope, p_ctx, time_out, ch, this] { + framework::AsyncIO([var_name_val, ep_val, p_scope, p_ctx, time_out, ch, + this] { // prepare input sendrecv::VariableMessage req; req.set_varname(var_name_val); @@ -133,8 +135,8 @@ bool RPCClient::AsyncPrefetchVariable(const std::string& ep, const framework::Scope* p_scope = &scope; const auto ch = GetChannel(ep_val); - framework::Async([in_var_name_val, out_var_name_val, ep_val, p_scope, p_ctx, - time_out, ch, this] { + framework::AsyncIO([in_var_name_val, out_var_name_val, ep_val, p_scope, p_ctx, + time_out, ch, this] { auto* var = p_scope->FindVar(in_var_name_val); ::grpc::ByteBuffer req; @@ -197,7 +199,7 @@ bool RPCClient::Wait() { std::vector> waits(req_count_); for (int i = 0; i < req_count_; i++) { - waits[i] = framework::Async([i, &a, this] { a[i] = Proceed(); }); + waits[i] = framework::AsyncIO([i, &a, this] { a[i] = Proceed(); }); } for (int i = 0; i < req_count_; i++) { diff --git a/paddle/fluid/operators/detail/grpc_server.cc b/paddle/fluid/operators/detail/grpc_server.cc index d5fc163bc25409..36dad5dd43a6a0 100644 --- a/paddle/fluid/operators/detail/grpc_server.cc +++ b/paddle/fluid/operators/detail/grpc_server.cc @@ -216,10 +216,10 @@ void AsyncGRPCServer::RunSyncUpdate() { std::function prefetch_register = std::bind(&AsyncGRPCServer::TryToRegisterNewPrefetchOne, this); + // TODO(wuyi): Run these "HandleRequest" in thread pool t_send_.reset( new std::thread(std::bind(&AsyncGRPCServer::HandleRequest, this, cq_send_.get(), "cq_send", send_register))); - t_get_.reset( new std::thread(std::bind(&AsyncGRPCServer::HandleRequest, this, cq_get_.get(), "cq_get", get_register))); diff --git a/python/paddle/fluid/tests/book/test_recognize_digits.py b/python/paddle/fluid/tests/book/test_recognize_digits.py index e4997b4069f60f..5ec6890c1b0dab 100644 --- a/python/paddle/fluid/tests/book/test_recognize_digits.py +++ b/python/paddle/fluid/tests/book/test_recognize_digits.py @@ -157,7 +157,6 @@ def train_loop(main_program): for ip in pserver_ips.split(","): eplist.append(':'.join([ip, port])) pserver_endpoints = ",".join(eplist) # ip:port,ip:port... - pserver_endpoints = os.getenv("PSERVERS") trainers = int(os.getenv("TRAINERS")) current_endpoint = os.getenv("POD_IP") + ":" + port trainer_id = int(os.getenv("PADDLE_INIT_TRAINER_ID"))