diff --git a/paddle/fluid/framework/boxps_trainer.cc b/paddle/fluid/framework/boxps_trainer.cc index bc9c138d02988..fb79516f23cbb 100644 --- a/paddle/fluid/framework/boxps_trainer.cc +++ b/paddle/fluid/framework/boxps_trainer.cc @@ -139,18 +139,20 @@ void BoxPSTrainer::InitTrainerEnv(const ProgramDesc& main_program, } } + std::set async_param_name; if (async_mode_) { - dense_table_->Init(*root_scope_, *param_need_sync_.get(), + async_param_name = dense_table_->Init(*root_scope_, *param_need_sync_.get(), persistable_vars_); } for (int i = 0; i < thread_num_; ++i) { auto this_worker = std::dynamic_pointer_cast(workers_[i]); this_worker->SetRootScope(root_scope_); - this_worker->CreateDeviceResource(main_program); if (async_mode_) { - this_worker->SetDenseTable(dense_table_.get()); + this_worker->SetDenseTable(dense_table_.get()); + this_worker->SetAsyncParamName(async_param_name); } + this_worker->CreateDeviceResource(main_program); // CopyParameters(*root_scope_, i); } } diff --git a/paddle/fluid/framework/boxps_worker.cc b/paddle/fluid/framework/boxps_worker.cc index 20ff19379a480..3bdb430e7e6a6 100644 --- a/paddle/fluid/framework/boxps_worker.cc +++ b/paddle/fluid/framework/boxps_worker.cc @@ -36,13 +36,20 @@ namespace framework { BoxPSAsynDenseTable::BoxPSAsynDenseTable(const int device_num) : device_num_(device_num) { - device_grads_.resize(device_num); + int buffer_size = device_num * 4; // magic number + device_grads_.resize(buffer_size); + buffer_poll_.reset(new PSBufferQueue(buffer_size)); + for (int i= 0; i < buffer_size; i++) { + buffer_poll_->Send(&device_grads_[i]); + } + VLOG(0) << "BoxPSAsynDenseTable init finish "; } BoxPSAsynDenseTable::~BoxPSAsynDenseTable() {} -void BoxPSAsynDenseTable::Init( +std::set BoxPSAsynDenseTable::Init( const Scope& root_scope, const std::vector& param_need_sync, const std::vector& persistable_vars) { + std::set async_param_name; root_scope_ = const_cast(&root_scope); VLOG(0) << "Begin Init For Aysnc Optimize"; for (const auto& e : param_need_sync) { @@ -52,23 +59,49 @@ void BoxPSAsynDenseTable::Init( async_param_list_.push_back(e); async_param_list_.push_back(e + "_moment1_0"); async_param_list_.push_back(e + "_moment2_0"); + async_param_name.insert(e); + async_param_name.insert(e + "@GRAD"); } } - ps_.resize(async_param_list_.size()); + original_ps_.resize(async_param_list_.size()); VLOG(0) << "async_param_list_.size(): " << async_param_list_.size(); std::sort( async_param_list_.begin(), async_param_list_ .end()); // xx_param.b_0, xx_param_moment1_0, xx_param_moment2_0 - for (size_t i = 0; i < async_param_list_.size(); ++i) { + for (size_t i = 0; i < async_param_list_.size(); i+=3) { + const LoDTensor& root_tensor = + root_scope.FindVar(async_param_list_[i])->Get(); + total_param_len_ += root_tensor.numel(); + } + VLOG(0) << "alloc param length dense table:" << total_param_len_; + + ps_.mutable_data({total_param_len_, 1}, platform::CPUPlace()); + mom1_.mutable_data({total_param_len_, 1}, platform::CPUPlace()); + mom2_.mutable_data({total_param_len_, 1}, platform::CPUPlace()); + for (size_t i = 0; i < device_grads_.size(); ++i) { + device_grads_[i].mutable_data({static_cast(total_param_len_), 1}, + platform::CPUPlace()); + } + + int64_t offset = 0; + VLOG(0) << " param size is " << async_param_list_.size(); + for (size_t i = 0; i < async_param_list_.size(); i++) { VLOG(0) << "begin to copy " << async_param_list_[i]; const LoDTensor& root_tensor = root_scope.FindVar(async_param_list_[i])->Get(); - VLOG(0) << "its size is " << root_tensor.numel(); - async_param_size_.push_back(root_tensor.numel()); - ps_[i].mutable_data({root_tensor.numel(), 1}, platform::CPUPlace()); + auto dim = root_tensor.dims(); + size_t len = root_tensor.numel(); + if (i % 3 == 0) { + original_ps_[i].ShareDataWith(ps_.Slice(offset, offset + len)).Resize(dim); + } else if (i % 3 == 1) { + original_ps_[i].ShareDataWith(mom1_.Slice(offset, offset + len)).Resize(dim); + } else { + original_ps_[i].ShareDataWith(mom2_.Slice(offset, offset + len)).Resize(dim); + offset += len; + } TensorCopy(*static_cast(&root_tensor), platform::CPUPlace(), - static_cast(&(ps_[i]))); + static_cast(&(original_ps_[i]))); } // Copy global lr for async mode @@ -90,9 +123,23 @@ void BoxPSAsynDenseTable::Init( } } VLOG(0) << "base lr is " << base_lr_; - ps_buffer_.reset(new PSBufferQueue(8 * 3)); // magic number - + ps_buffer_.reset(new PSBufferQueue(device_num_ * 3)); // magic number + all_lr_.resize(total_param_len_); + auto box_ptr = BoxWrapper::GetInstance(); + std::map lr_map = box_ptr->GetLRMap(); + int lr_index = 0; + for (size_t i = 0; i < async_param_list_.size() / 3; ++i) { + float learning_rate = base_lr_; + if (lr_map.find(async_param_list_[i * 3]) != lr_map.end()) { + learning_rate = lr_map[async_param_list_[i * 3]]; + } + for (int j=0; j < original_ps_[i * 3].numel(); j++) { + all_lr_[lr_index++] = learning_rate; + } + } + InitThreadGroup(); update_thread_ = new std::thread(&BoxPSAsynDenseTable::AsyncUpdate, this); + return async_param_name; } void BoxPSAsynDenseTable::Finalize(void) { @@ -101,26 +148,81 @@ void BoxPSAsynDenseTable::Finalize(void) { } ps_buffer_->Close(); update_thread_->join(); + buffer_poll_->Close(); for (size_t i = 0; i < async_param_list_.size(); ++i) { VLOG(0) << "begin to copy back" << async_param_list_[i]; auto* root_tensor = root_scope_->Var(async_param_list_[i])->GetMutable(); - TensorCopySync(*static_cast(&ps_[i]), platform::CPUPlace(), - root_tensor); + TensorCopySync(*static_cast(&original_ps_[i]), + platform::CPUPlace(), root_tensor); } ps_buffer_ = nullptr; + buffer_poll_ = nullptr; delete update_thread_; update_thread_ = nullptr; } +void BoxPSAsynDenseTable::ThreadUpdate(int thread_id, + std::vector & grad, size_t merge_num) { + float* grad_data = grad[0]->mutable_data(platform::CPUPlace()); + float* param_data = + ps_.mutable_data(platform::CPUPlace()); + float* mom1_data = mom1_.mutable_data(platform::CPUPlace()); + float* mom2_data = mom2_.mutable_data(platform::CPUPlace()); + // merge grad + const size_t start = thread_start_index_[thread_id]; + const size_t end = thread_end_index_[thread_id]; + if (merge_num == 2) { + LoDTensor* grad_tensor_1 = grad[1]; + float* grad_tensor_1_data = + grad_tensor_1->mutable_data(platform::CPUPlace()); + for (size_t j = start; j < end; ++j) { + grad_data[j] = (grad_data[j] + grad_tensor_1_data[j]) / 2; + } + } else if (merge_num == 3) { + LoDTensor* grad_tensor_1 = grad[1]; + LoDTensor* grad_tensor_2 = grad[2]; + float* grad_tensor_1_data = + grad_tensor_1->mutable_data(platform::CPUPlace()); + float* grad_tensor_2_data = + grad_tensor_2->mutable_data(platform::CPUPlace()); + for (size_t j = start; j < end; ++j) { + grad_data[j] = (grad_data[j] + grad_tensor_1_data[j] + grad_tensor_2_data[j]) / 3; + } + } else if (merge_num == 4) { + LoDTensor* grad_tensor_1 = grad[1]; + LoDTensor* grad_tensor_2 = grad[2]; + LoDTensor* grad_tensor_3 = grad[3]; + float* grad_tensor_1_data = + grad_tensor_1->mutable_data(platform::CPUPlace()); + float* grad_tensor_2_data = + grad_tensor_2->mutable_data(platform::CPUPlace()); + float* grad_tensor_3_data = + grad_tensor_3->mutable_data(platform::CPUPlace()); + for (size_t j = start; j < end; ++j) { + grad_data[j] = (grad_data[j] + grad_tensor_1_data[j] + + grad_tensor_2_data[j]+ grad_tensor_3_data[j]) / 4; + } + } + + for (size_t j = start; j < end; ++j) { + mom1_data[j] = 0.99 * mom1_data[j] + + 0.01 * grad_data[j]; // magic beta and episilon + mom2_data[j] = + 0.9999 * mom2_data[j] + 0.0001 * grad_data[j] * grad_data[j]; + param_data[j] -= + all_lr_[j] * (mom1_data[j] / (sqrt(mom2_data[j]) + 1e-8)); + } + return; +} + void BoxPSAsynDenseTable::AsyncUpdate() { VLOG(0) << "Begin AsyncUpdate"; - std::vector*> grad(4, nullptr); // max package + std::vector grad(4, nullptr); // max package auto box_ptr = BoxWrapper::GetInstance(); - std::map lr_map = box_ptr->GetLRMap(); while (ps_buffer_->Receive(&grad[0])) { size_t merge_num = ps_buffer_->Size() + 1; @@ -131,111 +233,61 @@ void BoxPSAsynDenseTable::AsyncUpdate() { ps_buffer_->Receive(&grad[i]); } AutoWRLock ps_lock(&ps_lock_); - // VLOG(0) << "AsyncUpdate recevie grads, and begin to update param, merge " - // << merge_num; - for (size_t i = 0; i < async_param_list_.size() / 3; ++i) { - LoDTensor* param_tensor = &ps_[i * 3]; - LoDTensor* mom1_tensor = &ps_[i * 3 + 1]; - LoDTensor* mom2_tensor = &ps_[i * 3 + 2]; - LoDTensor* grad_tensor = &(*grad[0])[i]; - auto len = async_param_size_[i * 3]; - float* grad_data = grad_tensor->mutable_data(platform::CPUPlace()); - float* param_data = - param_tensor->mutable_data(platform::CPUPlace()); - float* mom1_data = mom1_tensor->mutable_data(platform::CPUPlace()); - float* mom2_data = mom2_tensor->mutable_data(platform::CPUPlace()); - - // merge grad - for (size_t k = 1; k < merge_num; ++k) { - LoDTensor* other_grad_tensor = &(*grad[k])[i]; - float* other_grad_data = - other_grad_tensor->mutable_data(platform::CPUPlace()); - for (size_t j = 0; j < len; ++j) { - grad_data[j] += other_grad_data[j]; - } - } - if (merge_num > 1) { - for (size_t j = 0; j < len; ++j) { - grad_data[j] /= merge_num; - } - } - // float tmp = param_data[0]; - float learning_rate = base_lr_; - if (lr_map.find(async_param_list_[i * 3]) != lr_map.end()) { - learning_rate = lr_map[async_param_list_[i * 3]]; - } - // VLOG(0) << "learning rate for " << async_param_list_[i * 3] << " is " - // << learning_rate; - for (size_t j = 0; j < len; ++j) { - mom1_data[j] = 0.99 * mom1_data[j] + - 0.01 * grad_data[j]; // magic beta and episilon - mom2_data[j] = - 0.9999 * mom2_data[j] + 0.0001 * grad_data[j] * grad_data[j]; - param_data[j] -= - learning_rate * (mom1_data[j] / (sqrt(mom2_data[j]) + 1e-8)); - } - // VLOG(0) << "update dense for " << async_param_list_[i*3] << ", param[" - // << tmp << "] - 0.000005 * [" << mom1_data[0] << "] / [" << mom1_data[1] - // << "] = [" << param_data[0] << "]"; + std::vector> wait_futures; + for (int64_t i = 0; i < thread_num_; ++i) { + wait_futures.emplace_back(thread_pool->Run([this, i , &grad, merge_num]() { + ThreadUpdate(i, grad, merge_num); + })); } - } - VLOG(0) << "Quit AsyncUpdate"; -} -void BoxPSAsynDenseTable::ReShape(const platform::Place& place) { - int device_id = boost::get(place).GetDeviceId(); - auto& grad = device_grads_[device_id]; - grad.resize(async_param_size_.size() / 3); - for (size_t i = 0; i < async_param_size_.size(); ++i) { - if (i % 3 != 0) { - continue; + for (int64_t i = 0; i < thread_num_; ++i) { + wait_futures[i].get(); + } + for (size_t i = 0; i < merge_num; ++i) { + buffer_poll_->Send(grad[i]); } - grad[i / 3].mutable_data( - {static_cast(async_param_size_[i]), 1}, place); } + + VLOG(0) << "Quit AsyncUpdate"; } // async void BoxPSAsynDenseTable::PullDense(const platform::Place& place, - const Scope& scope) { + Tensor * tensor) { // while(ps_buffer_->Size() != 0) {//Size have lock, may have perf problem. // And will hang when the lock was removed // ; // } AutoRDLock ps_lock(&ps_lock_); - for (size_t i = 0; i < async_param_list_.size(); ++i) { - if (i % 3 != 0) { - continue; - } - const std::string& param_name = async_param_list_[i]; - Variable* var = scope.FindVar(param_name); - LoDTensor* tensor = var->GetMutable(); - TensorCopy(*static_cast(&ps_[i]), place, + TensorCopy(*static_cast(&ps_), place, static_cast(tensor)); - - // float *p = (*ps_)[i].mutable_data(platform::CPUPlace()); - // VLOG(0) << "pull dense for " << (*async_param_list_)[i] << ", and the - // first ele is " << p[0]; - } } void BoxPSAsynDenseTable::PushDense(const platform::Place& place, - const Scope& scope) { - int device_id = boost::get(place).GetDeviceId(); - auto& grad = device_grads_[device_id]; - for (size_t i = 0; i < async_param_list_.size(); ++i) { - if (i % 3 != 0) { - continue; + Tensor * tensor) { + LoDTensor * grad = nullptr; + buffer_poll_->Receive(&grad); + TensorCopy(*static_cast(tensor), platform::CPUPlace(), + static_cast(grad)); + ps_buffer_->Send(grad); +} + +void BoxPSAsynDenseTable::InitThreadGroup() { + thread_num_ = 32; + thread_start_index_.resize(thread_num_, 0); + thread_end_index_.resize(thread_num_, 0); + size_t prefix_sum = 0; + size_t thread_update_avg_len = total_param_len_ / thread_num_; + int unalloc_len = total_param_len_ % thread_num_; + for (int i = 0; i < thread_num_; i++) { + thread_start_index_[i] = prefix_sum; + if (i < unalloc_len) { + prefix_sum += thread_update_avg_len + 1; + } else { + prefix_sum += thread_update_avg_len; + } + thread_end_index_[i] = prefix_sum; } - // VLOG(0) << "push dense for " << (*async_param_list_)[i] << "@GRAD"; - std::string grad_name = async_param_list_[i] + "@GRAD"; - Variable* var = scope.FindVar(grad_name); - CHECK(var != nullptr) << "var[" << grad_name << "] not found"; - LoDTensor* tensor = var->GetMutable(); - // VLOG(0) << "the first element of grad_name is: " << tmp; - TensorCopy(*static_cast(tensor), platform::CPUPlace(), - static_cast(&grad[i / 3])); - } - ps_buffer_->Send(&grad); + thread_pool.reset(new paddle::framework::ThreadPool(thread_num_)); } static const int DenseKStepNode = 1; @@ -248,8 +300,8 @@ void BoxPSWorker::Initialize(const TrainerDesc& desc) { void BoxPSWorker::SetDenseTable(BoxPSAsynDenseTable* dense) { dense_table_ = dense; - dense_table_->ReShape(place_); } + int BoxPSWorker::CheckNeedParam(VarDesc* var) { if (!var->Persistable()) { return 0; @@ -277,6 +329,7 @@ int BoxPSWorker::CheckNeedParam(VarDesc* var) { } return 0; } + int64_t BoxPSWorker::AllocParamTensor(int64_t* pad_len) { auto& block = program_->Block(0); // init var and copy persistable @@ -317,6 +370,30 @@ int64_t BoxPSWorker::AllocParamTensor(int64_t* pad_len) { param_sync_.mutable_data({all_sync_param_len, 1}, place_); return total_param_len; } + +int64_t BoxPSWorker::AllocParamTensorAsync() { + auto& block = program_->Block(0); + // init var and copy persistable + int64_t total_param_len = 0; + for (auto& var : block.AllVars()) { + std::string name = var->Name(); + if (!var->Persistable() || async_param_name_.find(name) == async_param_name_.end()) { + continue; + } + const LoDTensor& root_tensor = root_scope_->FindVar(name)->Get(); + int64_t numel = root_tensor.numel(); + total_param_len += numel; + } + + VLOG(2) << "param length:" << total_param_len + << "param grad length:" << total_param_len + << ", device num:" << device_num_; + + param_async_.mutable_data({total_param_len, 1}, place_); + grad_async_.mutable_data({total_param_len, 1}, place_); + return total_param_len; +} + void BoxPSWorker::CreateDeviceResource(const ProgramDesc& main_prog) { program_.reset(new ProgramDesc(main_prog)); for (auto& op_desc : program_->Block(0).AllOps()) { @@ -325,18 +402,39 @@ void BoxPSWorker::CreateDeviceResource(const ProgramDesc& main_prog) { int64_t pad_len = 0; if (sync_mode_ > 0) { - AllocParamTensor(&pad_len); + AllocParamTensor(&pad_len); + } else if (dense_table_) { + AllocParamTensorAsync(); } + auto& block = program_->Block(0); thread_scope_ = &(root_scope_->NewScope()); int64_t offset = 0; + int64_t grad_offset = 0; + // make param and param@GRAD in same order + std::vector sorted_var = block.AllVars(); + std::sort(sorted_var.begin(), sorted_var.end(), [] + (const VarDesc * var1, const VarDesc * var2) { + return var1->Name() < var2->Name(); + }); // init var and copy persistable - for (auto& var : block.AllVars()) { + for (auto& var : sorted_var) { std::string name = var->Name(); if (!var->Persistable()) { - auto* ptr = thread_scope_->Var(name); - InitializeVariable(ptr, var->GetType()); + if (dense_table_ && async_param_name_.find(name) != async_param_name_.end()) { + // parm@GRAD can not find in root_scope_ use parm length replace + const LoDTensor& root_tensor = + root_scope_->FindVar(name.substr(0, name.length() - 5))->Get(); + LoDTensor* gpu_tensor = thread_scope_->Var(name)->GetMutable(); + auto dim = root_tensor.dims(); + size_t len = root_tensor.numel(); + gpu_tensor->ShareDataWith(grad_async_.Slice(grad_offset, grad_offset + len)).Resize(dim); + grad_offset += len; + } else { + auto* ptr = thread_scope_->Var(name); + InitializeVariable(ptr, var->GetType()); + } } else { const LoDTensor& root_tensor = root_scope_->FindVar(name)->Get(); @@ -349,13 +447,23 @@ void BoxPSWorker::CreateDeviceResource(const ProgramDesc& main_prog) { .Resize(dim); offset += len; } + } else if (dense_table_) { + if (async_param_name_.find(name) != async_param_name_.end()) { + auto dim = root_tensor.dims(); + size_t len = root_tensor.numel(); + gpu_tensor->ShareDataWith(param_async_.Slice(offset, offset + len)).Resize(dim); + offset += len; + } } TensorCopy(*static_cast(&root_tensor), place_, static_cast(gpu_tensor)); } } if (sync_mode_ > 0) { - CHECK(offset <= (param_sync_.numel() - pad_len)); + CHECK(offset <= (param_sync_.numel() - pad_len)); + } else if (dense_table_) { + CHECK(offset <= param_async_.numel()); + CHECK(grad_offset <= grad_async_.numel()); } } void BoxPSWorker::SyncParam(void) { @@ -435,14 +543,17 @@ void BoxPSWorker::TrainFiles() { VLOG(2) << "[" << device_id_ << "]begin running ops, batch size:" << batch_size << ", batch id=" << step; + if (dense_table_) { - dense_table_->PullDense(place_, *thread_scope_); + dense_table_->PullDense(place_, ¶m_async_); } + for (auto& op : ops_) { op->Run(*thread_scope_, place_); } + if (dense_table_) { - dense_table_->PushDense(place_, *thread_scope_); + dense_table_->PushDense(place_, &grad_async_); } else if (sync_mode_ == DenseKStepNode || sync_mode_ == DenseKStepALL) { if (step > param_sync_step_) { step = 0; diff --git a/paddle/fluid/framework/device_worker.h b/paddle/fluid/framework/device_worker.h index 5cf3b7597060d..f6f00aeb599cc 100644 --- a/paddle/fluid/framework/device_worker.h +++ b/paddle/fluid/framework/device_worker.h @@ -35,6 +35,7 @@ limitations under the License. */ #include "paddle/fluid/framework/reader.h" #include "paddle/fluid/framework/trainer_desc.pb.h" #include "paddle/fluid/framework/variable_helper.h" +#include "paddle/fluid/framework/threadpool.h" #include "paddle/fluid/operators/reader/blocking_queue.h" #include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/port.h" @@ -583,33 +584,42 @@ class SectionWorker : public DeviceWorker { #ifdef PADDLE_WITH_BOX_PS class BoxPSAsynDenseTable { - typedef operators::reader::BlockingQueue*> + typedef operators::reader::BlockingQueue PSBufferQueue; public: explicit BoxPSAsynDenseTable(const int device_num); ~BoxPSAsynDenseTable(); - void Init(const Scope& root_scope, + std::set Init(const Scope& root_scope, const std::vector& param_need_sync, const std::vector& persistable_vars); void Finalize(void); - - // async - void ReShape(const platform::Place& place); - void PullDense(const platform::Place& place, const Scope& scope); - void PushDense(const platform::Place& place, const Scope& scope); + void PullDense(const platform::Place& place, Tensor * tensor); + void PushDense(const platform::Place& place, Tensor * tensor); + void InitThreadGroup(); + void ThreadUpdate(int thread_id, + std::vector & grad, size_t merge_num); void AsyncUpdate(); private: int device_num_ = 0; - std::vector> device_grads_; + std::vector device_grads_; std::vector async_param_list_; - std::vector ps_; - std::vector async_param_size_; + std::vector original_ps_; + LoDTensor ps_; + LoDTensor mom1_; + LoDTensor mom2_; + + std::vector all_lr_; + std::shared_ptr buffer_poll_ = nullptr; std::shared_ptr ps_buffer_ = nullptr; Scope* root_scope_ = nullptr; - + int64_t total_param_len_ = 0; + std::vector thread_start_index_; + std::vector thread_end_index_; + std::shared_ptr thread_pool = nullptr; + int thread_num_ = 0; RWLock ps_lock_; std::thread* update_thread_ = nullptr; float base_lr_ = -1; @@ -639,11 +649,13 @@ class BoxPSWorker : public DeviceWorker { void SetParamSyncStep(int step) { param_sync_step_ = step; } void SetDenseSyncMode(int mode) { sync_mode_ = mode; } void SetOneRing(bool one_ring) { one_ring_ = one_ring; } + void SetAsyncParamName(const std::set & async_param_name) {async_param_name_ = async_param_name;} protected: int PackBatchTask(void); int CheckNeedParam(VarDesc* var); int64_t AllocParamTensor(int64_t* pad_len); + int64_t AllocParamTensorAsync(); void SyncParam(void); protected: @@ -656,7 +668,10 @@ class BoxPSWorker : public DeviceWorker { // dense async table BoxPSAsynDenseTable* dense_table_ = nullptr; + Tensor param_async_; + Tensor grad_async_; Tensor param_sync_; + std::set async_param_name_; int param_sync_step_ = 0; int sync_mode_ = 0; bool one_ring_ = false;