diff --git a/CMakeLists.txt b/CMakeLists.txt index e8e1d769131e7d..a66a057622203a 100755 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -51,6 +51,11 @@ message(STATUS "C compiler: ${CMAKE_C_COMPILER}, version: " "${CMAKE_C_COMPILER_ID} ${CMAKE_C_COMPILER_VERSION}") message(STATUS "AR tools: ${CMAKE_AR}") +# MUSL build turn off warnings +if(WITH_MUSL) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-error=deprecated-declarations -Wno-deprecated-declarations -Wno-error=pessimizing-move -Wno-error=deprecated-copy") +endif() + if(WIN32) option(MSVC_STATIC_CRT "use static C Runtime library by default" ON) diff --git a/cmake/external/lite.cmake b/cmake/external/lite.cmake index a39bb3b6995578..274511e3d39df8 100644 --- a/cmake/external/lite.cmake +++ b/cmake/external/lite.cmake @@ -132,7 +132,11 @@ if (NOT LITE_SOURCE_DIR OR NOT LITE_BINARY_DIR) endif() if (WITH_ARM) - set(LITE_OUTPUT_BIN_DIR inference_lite_lib.armlinux.armv8) + if(LITE_WITH_XPU) + set(LITE_OUTPUT_BIN_DIR inference_lite_lib.armlinux.armv8.xpu) + else() + set(LITE_OUTPUT_BIN_DIR inference_lite_lib.armlinux.armv8) + endif() else() set(LITE_OUTPUT_BIN_DIR inference_lite_lib) endif() diff --git a/cmake/external/xpu.cmake b/cmake/external/xpu.cmake index c9cf2572d1d5c4..75e0eb2e275c31 100644 --- a/cmake/external/xpu.cmake +++ b/cmake/external/xpu.cmake @@ -4,7 +4,7 @@ endif() INCLUDE(ExternalProject) SET(XPU_PROJECT "extern_xpu") -SET(XPU_URL "https://baidu-kunlun-public.su.bcebos.com/paddle_depence/xpu_2020_12_07_cdfbf0c.tar.gz" CACHE STRING "" FORCE) +SET(XPU_URL "https://baidu-kunlun-public.su.bcebos.com/paddle_depence/xpu_2020_12_11.tar.gz" CACHE STRING "" FORCE) SET(XPU_SOURCE_DIR "${THIRD_PARTY_PATH}/xpu") SET(XPU_DOWNLOAD_DIR "${XPU_SOURCE_DIR}/src/${XPU_PROJECT}") SET(XPU_INSTALL_DIR "${THIRD_PARTY_PATH}/install/xpu") diff --git a/paddle/fluid/distributed/CMakeLists.txt b/paddle/fluid/distributed/CMakeLists.txt index ee9037dec1a5d0..e99b8b76534369 100644 --- a/paddle/fluid/distributed/CMakeLists.txt +++ b/paddle/fluid/distributed/CMakeLists.txt @@ -14,3 +14,17 @@ endif() add_subdirectory(table) add_subdirectory(test) + +# open it until CI support brpc +return() + +add_subdirectory(service) + +get_property(RPC_DEPS GLOBAL PROPERTY RPC_DEPS) + +set_source_files_properties(fleet.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) +cc_library(fleet + SRCS fleet.cc + DEPS framework_proto ps_framework_proto ps_service variable_helper scope op_registry fs shell ${RPC_DEPS}) + +target_link_libraries(fleet z) diff --git a/paddle/fluid/distributed/fleet.cc b/paddle/fluid/distributed/fleet.cc new file mode 100644 index 00000000000000..92211a72e748eb --- /dev/null +++ b/paddle/fluid/distributed/fleet.cc @@ -0,0 +1,585 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/distributed/fleet.h" +#include +#include +#include "paddle/fluid/distributed/service/communicator.h" +#include "paddle/fluid/distributed/table/table.h" +#include "paddle/fluid/framework/channel.h" +#include "paddle/fluid/framework/data_feed.h" +#include "paddle/fluid/framework/io/fs.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/scope.h" + +namespace paddle { +namespace distributed { + +using framework::LoDTensor; +using framework::ProgramDesc; +using framework::VarDesc; +using framework::Variable; + +const uint32_t MAX_FEASIGN_NUM = 1024 * 100 * 100; +std::shared_ptr FleetWrapper::s_instance_ = NULL; +bool FleetWrapper::is_initialized_ = false; + +std::shared_ptr FleetWrapper::pserver_ptr_ = NULL; + +void FleetWrapper::SetClient2ClientConfig(int request_timeout_ms, + int connect_timeout_ms, + int max_retry) { + client2client_request_timeout_ms_ = request_timeout_ms; + client2client_connect_timeout_ms_ = connect_timeout_ms; + client2client_max_retry_ = max_retry; +} + +void FleetWrapper::LoadSparseOnServer(const std::string& path, + const std::string& meta, + uint32_t table_id) { + VLOG(3) << "load sparse table " << table_id << " with " << path << " meta " + << meta; + pserver_ptr_->_server_ptr->table(table_id)->load(path, meta); +} + +void FleetWrapper::InitServer(const std::string& dist_desc, + const std::vector& host_sign_list, + int index) { + if (!is_initialized_) { + VLOG(3) << "Going to init server"; + pserver_ptr_ = std::shared_ptr( + new paddle::distributed::PSCore()); + pserver_ptr_->init_server(dist_desc, &host_sign_list, host_sign_list.size(), + index); + is_initialized_ = true; + } else { + VLOG(3) << "Server can be initialized only once"; + } +} + +// void FleetWrapper::InitWorker( +// const std::string& dist_desc, const std::vector& +// host_sign_list, Scope* scope, const RpcCtxMap& send_ctx, const +// std::unordered_map>& +// dense_varnames, +// const std::map& envs, int node_num, int index) +// { +// if (!is_initialized_) { +// VLOG(3) << "Going to init worker"; + +// Communicator::InitInstance( +// send_ctx, dense_varnames, dist_desc, host_sign_list, scope, envs); + +// pserver_ptr_ = std::shared_ptr( +// new paddle::distributed::PSCore()); +// pserver_ptr_->init_worker(dist_desc, _regions, +// const_cast(host_sign_list.data()), +// node_num, index); +// is_initialized_ = true; +// } else { +// VLOG(3) << "Worker can be initialized only once"; +// } +// } + +void FleetWrapper::InitWorker( + const std::string& dist_desc, + const std::vector& host_sign_list, Scope* scope, + const RpcCtxMap& send_ctx, + const std::unordered_map>& + dense_varnames, + const std::map& envs, int node_num, int index) { + if (!is_initialized_) { + VLOG(3) << "Going to init worker"; + + Communicator::InitInstance( + send_ctx, dense_varnames, dist_desc, host_sign_list, scope, envs); + + pserver_ptr_ = std::shared_ptr( + new paddle::distributed::PSCore()); + pserver_ptr_->init_worker(dist_desc, _regions, &host_sign_list, node_num, + index); + is_initialized_ = true; + } else { + VLOG(3) << "Worker can be initialized only once"; + } +} + +void FleetWrapper::StopServer() { + VLOG(3) << "Going to stop server"; + auto* communicator = Communicator::GetInstance(); + auto status = communicator->_worker_ptr->stop_server(); + status.wait(); +} + +void FleetWrapper::FinalizeWorker() { + VLOG(3) << "Going to finalize worker"; + pserver_ptr_->finalize_worker(); +} + +void FleetWrapper::BarrierWithTable(uint32_t barrier_type) { + VLOG(3) << "Going to Barrier worker"; + auto* communicator = Communicator::GetInstance(); + communicator->BarrierWithTable(barrier_type); +} + +uint64_t FleetWrapper::RunServer(const std::string& ip, uint32_t port) { + VLOG(3) << "Going to run server with ip " << ip << " port " << port; + auto ret = pserver_ptr_->run_server(ip, port); + return ret; +} + +std::vector FleetWrapper::GetClientsInfo() { + VLOG(3) << "Going to get client info"; + return pserver_ptr_->get_client_info(); + return std::vector(); +} + +void FleetWrapper::CreateClient2ClientConnection() { + VLOG(3) << "Going to create client2client connection"; + pserver_ptr_->create_client2client_connection( + client2client_request_timeout_ms_, client2client_connect_timeout_ms_, + client2client_max_retry_); +} + +std::future FleetWrapper::PullSparseVarsAsync( + const Scope& scope, const uint64_t table_id, + const std::vector& var_names, std::vector* fea_keys, + std::vector>* fea_values, int fea_value_dim) { + fea_keys->clear(); + fea_keys->resize(0); + fea_keys->reserve(MAX_FEASIGN_NUM); + for (auto name : var_names) { + Variable* var = scope.FindVar(name); + if (var == nullptr) { + continue; + } + LoDTensor* tensor = var->GetMutable(); + CHECK(tensor != nullptr) << "tensor of var " << name << " is null"; + int64_t* ids = tensor->data(); + size_t len = tensor->numel(); + for (auto i = 0u; i < len; ++i) { + if (ids[i] == 0u) { + continue; + } + fea_keys->push_back(static_cast(ids[i])); + } + } + fea_values->resize(fea_keys->size() + 1); + for (auto& t : *fea_values) { + t.resize(fea_value_dim); + } + std::vector pull_result_ptr; + for (auto& t : *fea_values) { + pull_result_ptr.push_back(t.data()); + } + return pserver_ptr_->_worker_ptr->pull_sparse( + pull_result_ptr.data(), table_id, fea_keys->data(), fea_keys->size()); +} + +void FleetWrapper::PullSparseVarsSync( + const Scope& scope, const uint64_t table_id, + const std::vector& var_names, std::vector* fea_keys, + std::vector>* fea_values, int fea_value_dim, + const std::vector& var_emb_names) { + std::vector> pull_sparse_status; + pull_sparse_status.resize(0); + fea_keys->clear(); + fea_keys->resize(0); + fea_keys->reserve(MAX_FEASIGN_NUM); + for (size_t var_index = 0; var_index < var_names.size(); ++var_index) { + const std::string& name = var_names[var_index]; + Variable* var = scope.FindVar(name); + if (var == nullptr) { + continue; + } + LoDTensor* tensor = var->GetMutable(); + CHECK(tensor != nullptr) << "tensor of var " << name << " is null"; + int64_t* ids = tensor->data(); + size_t len = tensor->numel(); + + // skip slots which do not have embedding + const std::string& emb_name = var_emb_names[var_index]; + Variable* emb_var = scope.FindVar(emb_name); + if (emb_var == nullptr) { + continue; + } + + for (auto i = 0u; i < len; ++i) { + if (ids[i] == 0u) { + continue; + } + fea_keys->push_back(static_cast(ids[i])); + } + } + fea_values->resize(fea_keys->size() + 1); + for (auto& t : *fea_values) { + t.resize(fea_value_dim); + } + std::vector pull_result_ptr; + for (auto& t : *fea_values) { + pull_result_ptr.push_back(t.data()); + } + auto status = pserver_ptr_->_worker_ptr->pull_sparse( + pull_result_ptr.data(), table_id, fea_keys->data(), fea_keys->size()); + pull_sparse_status.push_back(std::move(status)); + for (auto& t : pull_sparse_status) { + t.wait(); + auto status = t.get(); + if (status != 0) { + LOG(ERROR) << "fleet pull sparse failed, status[" << status << "]"; + sleep(sleep_seconds_before_fail_exit_); + exit(-1); + } + } +} + +void FleetWrapper::PullSparseToTensorSync(const uint64_t table_id, int fea_dim, + uint64_t padding_id, + platform::Place place, + std::vector* inputs, + std::vector* outputs) { + std::vector fea_keys; + std::vector pull_result_ptr; + fea_keys.reserve(MAX_FEASIGN_NUM / 100); + pull_result_ptr.reserve(MAX_FEASIGN_NUM / 100); + std::vector init_value(fea_dim, 0); + framework::LoDTensor* output = nullptr; + float* output_data = nullptr; + size_t output_index = -1; + size_t output_len = 0; + for (size_t index = 0; index < inputs->size(); ++index) { + const framework::LoDTensor* tensor = inputs->at(index); + const int64_t* ids = tensor->data(); + size_t len = tensor->numel(); + for (size_t i = 0; i < len; ++i, output_len += fea_dim) { + if (!output || output_len == size_t(output->numel())) { + ++output_index; + CHECK(output_index < outputs->size()); // NOLINT + output = outputs->at(output_index); + output->set_lod(tensor->lod()); + output_data = output->mutable_data(place); + output_len = 0; + CHECK(output->numel() % fea_dim == 0); // NOLINT + CHECK(output_data != nullptr); // NOLINT + } + uint64_t real_id = static_cast(ids[i]); + if (real_id == padding_id) { + memcpy(output_data + output_len, init_value.data(), + sizeof(float) * fea_dim); + continue; + } + fea_keys.push_back(real_id); + pull_result_ptr.push_back(output_data + output_len); + } + } + auto* communicator = Communicator::GetInstance(); + auto status = communicator->_worker_ptr->pull_sparse( + pull_result_ptr.data(), table_id, fea_keys.data(), fea_keys.size()); + status.wait(); + auto ret = status.get(); + if (ret != 0) { + LOG(ERROR) << "fleet pull sparse failed, status[" << ret << "]"; + sleep(sleep_seconds_before_fail_exit_); + } +} + +void FleetWrapper::PullDenseVarsAsync( + const Scope& scope, const uint64_t tid, + const std::vector& var_names, + std::vector>* pull_dense_status, bool in_cpu) { + auto& regions = _regions[tid]; + regions.clear(); + regions.resize(var_names.size()); + for (auto i = 0u; i < var_names.size(); ++i) { + std::string varname = var_names[i]; + if (!in_cpu) { + varname = var_names[i] + "pin"; + } + Variable* var = scope.FindVar(varname); + LoDTensor* tensor = var->GetMutable(); + float* w = tensor->data(); + paddle::distributed::Region reg(w, tensor->numel()); + regions[i] = std::move(reg); + } + auto status = pserver_ptr_->_worker_ptr->pull_dense(regions.data(), + regions.size(), tid); + pull_dense_status->push_back(std::move(status)); +} + +void FleetWrapper::PullDenseVarsSync( + const Scope& scope, const uint64_t tid, + const std::vector& var_names) { + auto& regions = _regions[tid]; + regions.clear(); + regions.reserve(var_names.size()); + for (auto& t : var_names) { + Variable* var = scope.FindVar(t); + LoDTensor* tensor = var->GetMutable(); + float* w = tensor->data(); + paddle::distributed::Region reg(w, tensor->numel()); + regions.emplace_back(std::move(reg)); + } + auto* communicator = Communicator::GetInstance(); + auto status = communicator->_worker_ptr->pull_dense(regions.data(), + regions.size(), tid); + status.wait(); +} + +void FleetWrapper::PushDenseParamSync( + const Scope& scope, const uint64_t table_id, + const std::vector& var_names) { + auto place = platform::CPUPlace(); + std::vector regions; + for (auto& t : var_names) { + Variable* var = scope.FindVar(t); + CHECK(var != nullptr) << "var[" << t << "] not found"; + LoDTensor* tensor = var->GetMutable(); + float* g = tensor->mutable_data(place); + paddle::distributed::Region reg(g, tensor->numel()); + regions.emplace_back(std::move(reg)); + } + auto* communicator = Communicator::GetInstance(); + auto push_status = communicator->_worker_ptr->push_dense_param( + regions.data(), regions.size(), table_id); + push_status.wait(); + auto status = push_status.get(); + CHECK(status == 0) << "push dense param failed, status[" << status << "]"; +} + +void FleetWrapper::PushDenseVarsSync( + Scope* scope, const uint64_t table_id, + const std::vector& var_names) {} + +void FleetWrapper::PushDenseVarsAsync( + const Scope& scope, const uint64_t table_id, + const std::vector& var_names, + std::vector>* push_sparse_status, float scale_datanorm, + int batch_size) { + auto* communicator = Communicator::GetInstance(); + PADDLE_ENFORCE_EQ( + communicator->Check(table_id), true, + platform::errors::InvalidArgument( + "can not find table: %s, please check your config", table_id)); + communicator->Send(var_names, scope); +} + +void FleetWrapper::PushSparseVarsAsync( + const Scope& scope, const uint64_t table_id, + const std::string& grad_varname, + std::vector>* push_sparse_status) { + std::vector varnames; + varnames.push_back(grad_varname); + + auto* communicator = Communicator::GetInstance(); + PADDLE_ENFORCE_EQ( + communicator->Check(table_id), true, + platform::errors::InvalidArgument( + "can not find table: %s, please check your config", table_id)); + communicator->Send(varnames, scope); +} + +void FleetWrapper::PushSparseVarsWithLabelAsync( + const Scope& scope, const uint64_t table_id, + const std::vector& fea_keys, const std::vector& fea_labels, + const std::vector& sparse_key_names, + const std::vector& sparse_grad_names, const int emb_dim, + std::vector>* push_values, + std::vector>* push_sparse_status, const int batch_size, + const bool use_cvm, const bool dump_slot, + std::vector* sparse_push_keys, const bool no_cvm) { + // not support + return; +} + +void FleetWrapper::PushSparseFromTensorWithLabelAsync( + const Scope& scope, const uint64_t table_id, int fea_dim, + uint64_t padding_id, bool scale_sparse, const std::string& accesor, + const std::string& click_name, platform::Place place, + const std::vector& input_names, + std::vector* inputs, + std::vector* outputs) { + // not support + return; +} + +void FleetWrapper::LoadModel(const std::string& path, const int mode) { + auto ret = pserver_ptr_->_worker_ptr->load(path, std::to_string(mode)); + ret.wait(); + if (ret.get() != 0) { + LOG(ERROR) << "load model from path:" << path << " failed"; + sleep(sleep_seconds_before_fail_exit_); + exit(-1); + } +} + +void FleetWrapper::LoadModelOneTable(const uint64_t table_id, + const std::string& path, const int mode) { + auto ret = + pserver_ptr_->_worker_ptr->load(table_id, path, std::to_string(mode)); + ret.wait(); + if (ret.get() != 0) { + LOG(ERROR) << "load model of table id: " << table_id + << ", from path: " << path << " failed"; + } +} + +void FleetWrapper::SaveModel(const std::string& path, const int mode) { + auto* communicator = Communicator::GetInstance(); + auto ret = communicator->_worker_ptr->save(path, std::to_string(mode)); + ret.wait(); + int32_t feasign_cnt = ret.get(); + if (feasign_cnt == -1) { + LOG(ERROR) << "save model failed"; + sleep(sleep_seconds_before_fail_exit_); + exit(-1); + } +} + +void FleetWrapper::SaveModelOneTable(const uint64_t table_id, + const std::string& path, const int mode) { + auto* communicator = Communicator::GetInstance(); + auto ret = + communicator->_worker_ptr->save(table_id, path, std::to_string(mode)); + ret.wait(); + if (ret.get() != 0) { + LOG(ERROR) << "save model of table id: " << table_id + << ", to path: " << path << " failed"; + } +} + +void FleetWrapper::PrintTableStat(const uint64_t table_id) { + auto* communicator = Communicator::GetInstance(); + auto ret = communicator->_worker_ptr->print_table_stat(table_id); + ret.wait(); + int32_t err_code = ret.get(); + if (err_code == -1) { + LOG(ERROR) << "print table stat failed"; + } +} + +void FleetWrapper::ShrinkSparseTable(int table_id) { + auto ret = pserver_ptr_->_worker_ptr->shrink(table_id); + ret.wait(); +} + +void FleetWrapper::ClearModel() { + auto ret = pserver_ptr_->_worker_ptr->clear(); + ret.wait(); +} + +void FleetWrapper::ClearOneTable(const uint64_t table_id) { + auto ret = pserver_ptr_->_worker_ptr->clear(table_id); + ret.wait(); +} + +void FleetWrapper::ShrinkDenseTable(int table_id, Scope* scope, + std::vector var_list, + float decay, int emb_dim) { + std::vector regions; + for (std::string& name : var_list) { + if (name.find("batch_sum") != std::string::npos) { + Variable* var = scope->FindVar(name); + CHECK(var != nullptr) << "var[" << name << "] not found"; + VLOG(0) << "prepare shrink dense batch_sum"; + LoDTensor* tensor = var->GetMutable(); + float* g = tensor->data(); + + // show_batch_sum += N * log(decay) + std::string size_name = name; + size_name.replace(size_name.find("batch_sum"), size_name.length(), + "batch_size"); + Variable* var_size = scope->FindVar(size_name); + CHECK(var_size != nullptr) << "var[" << size_name << "] not found"; + VLOG(3) << "shrink dense batch_sum: " << name << ", " << size_name; + float* g_size = var_size->GetMutable()->data(); + + for (int k = 0; k < tensor->numel(); k += emb_dim) { + g[k] = g[k] + g_size[k] * log(decay); + } + paddle::distributed::Region reg(g, tensor->numel()); + regions.emplace_back(std::move(reg)); + } else { + Variable* var = scope->FindVar(name); + CHECK(var != nullptr) << "var[" << name << "] not found"; + LoDTensor* tensor = var->GetMutable(); + float* g = tensor->data(); + paddle::distributed::Region reg(g, tensor->numel()); + regions.emplace_back(std::move(reg)); + } + } + auto push_status = pserver_ptr_->_worker_ptr->push_dense_param( + regions.data(), regions.size(), table_id); + push_status.wait(); + auto status = push_status.get(); + if (status != 0) { + // PADDLE_THORW(platform::errors::Fatal( + // "push shrink dense param failed, status is [%d].", status)); + sleep(sleep_seconds_before_fail_exit_); + exit(-1); + } +} + +void FleetWrapper::ClientFlush() { + auto ret = pserver_ptr_->_worker_ptr->flush(); + ret.wait(); +} + +int FleetWrapper::RegisterClientToClientMsgHandler(int msg_type, + MsgHandlerFunc handler) { + VLOG(3) << "calling FleetWrapper::RegisterClientToClientMsgHandler"; + VLOG(3) << "pserver_ptr_=" << pserver_ptr_; + VLOG(3) << "_worker_ptr=" << pserver_ptr_->_worker_ptr; + return pserver_ptr_->_worker_ptr->registe_client2client_msg_handler(msg_type, + handler); +} + +std::future FleetWrapper::SendClientToClientMsg( + int msg_type, int to_client_id, const std::string& msg) { + return pserver_ptr_->_worker_ptr->send_client2client_msg(msg_type, + to_client_id, msg); +} + +std::default_random_engine& FleetWrapper::LocalRandomEngine() { + struct engine_wrapper_t { + std::default_random_engine engine; + + engine_wrapper_t() { + struct timespec tp; + clock_gettime(CLOCK_REALTIME, &tp); + double cur_time = tp.tv_sec + tp.tv_nsec * 1e-9; + static std::atomic x(0); + std::seed_seq sseq = {x++, x++, x++, (uint64_t)(cur_time * 1000)}; + engine.seed(sseq); + } + }; + thread_local engine_wrapper_t r; + return r.engine; +} + +size_t FleetWrapper::GetAbsoluteSum(size_t start, size_t end, size_t level, + const framework::LoD& lod) { + if (level >= lod.size() - 1) { + return end - start; + } + size_t ret = 0; + for (size_t i = start; i < end - 1; ++i) { + size_t pos1 = lod[level][i]; + size_t pos2 = lod[level][i + 1]; + ret += GetAbsoluteSum(pos1, pos2, level + 1, lod); + } + return ret; +} + +} // end namespace distributed +} // end namespace paddle diff --git a/paddle/fluid/distributed/fleet.h b/paddle/fluid/distributed/fleet.h new file mode 100644 index 00000000000000..7f106fafbf2e2e --- /dev/null +++ b/paddle/fluid/distributed/fleet.h @@ -0,0 +1,246 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include "paddle/fluid/distributed/communicator_common.h" +#include "paddle/fluid/distributed/service/service.h" +#include "paddle/fluid/framework/archive.h" +#include "paddle/fluid/framework/io/fs.h" +#include "paddle/fluid/framework/io/shell.h" +#include "paddle/fluid/framework/program_desc.h" +#include "paddle/fluid/framework/scope.h" +#include "paddle/fluid/framework/tensor.h" +#include "paddle/fluid/framework/variable_helper.h" +#include "paddle/fluid/platform/macros.h" // for DISABLE_COPY_AND_ASSIGN + +namespace paddle { +namespace distributed { + +using framework::LoDTensor; +using framework::Scope; +using framework::SelectedRows; +using framework::Variable; + +using RpcCtxMap = std::unordered_map; + +class FleetWrapper { + public: + virtual ~FleetWrapper() {} + FleetWrapper() { + scale_sparse_gradient_with_batch_size_ = true; + // trainer sleep some time for pserver core dump + sleep_seconds_before_fail_exit_ = 300; + // pserver request server timeout ms + client2client_request_timeout_ms_ = 500000; + // pserver connect server timeout_ms + client2client_connect_timeout_ms_ = 10000; + // pserver request max retry + client2client_max_retry_ = 3; + } + + // set client to client communication config + void SetClient2ClientConfig(int request_timeout_ms, int connect_timeout_ms, + int max_retry); + + // Pull sparse variables from server in sync mode + // Param: scope, table_id, var_names, fea_keys, fea_dim, var_emb_names + // Param: fea_values + void PullSparseVarsSync(const Scope& scope, const uint64_t table_id, + const std::vector& var_names, + std::vector* fea_keys, + std::vector>* fea_values, + int fea_dim, + const std::vector& var_emb_names); + + // Pull sparse variables from server in async mode + // Param: scope, table_id, var_names, fea_keys, fea_dim + // Param: fea_values std::future + std::future PullSparseVarsAsync( + const Scope& scope, const uint64_t table_id, + const std::vector& var_names, + std::vector* fea_keys, + std::vector>* fea_values, int fea_dim); + + // Pull sparse variables from server in sync mode + // pull immediately to tensors + void PullSparseToTensorSync(const uint64_t table_id, int fea_dim, + uint64_t padding_id, platform::Place place, + std::vector* inputs, // NOLINT + std::vector* outputs); // NOLINT + + // pull dense variables from server in sync mod + // Param: scope, table_id, var_names + // Param: void + void PullDenseVarsSync(const Scope& scope, const uint64_t table_id, + const std::vector& var_names); + + // pull dense variables from server in async mod + // Param: scope, table_id, var_names + // Param: pull_dense_status + void PullDenseVarsAsync(const Scope& scope, const uint64_t table_id, + const std::vector& var_names, + std::vector>* pull_dense_status, + bool in_cpu); + + // push dense parameters(not gradients) to server in sync mode + void PushDenseParamSync(const Scope& scope, const uint64_t table_id, + const std::vector& var_names); + + void PushDenseVarsAsync(const Scope& scope, const uint64_t table_id, + const std::vector& var_names, + std::vector>* push_sparse_status, + float scale_datanorm, int batch_size); + + // push dense variables to server in sync mode + void PushDenseVarsSync(Scope* scope, const uint64_t table_id, + const std::vector& var_names); + + void PushSparseVarsAsync( + const Scope& scope, const uint64_t table_id, const std::string& grad, + std::vector>* push_sparse_status); + // This is specially designed for click/show stats in server + // Param: scope, table_id, fea_keys, fea_labels, sparse_key_names, + // sparse_grad_names, batch_size, use_cvm, dump_slot + // Param: push_values, push_sparse_status + void PushSparseVarsWithLabelAsync( + const Scope& scope, const uint64_t table_id, + const std::vector& fea_keys, + const std::vector& fea_labels, + const std::vector& sparse_key_names, + const std::vector& sparse_grad_names, const int emb_dim, + std::vector>* push_values, + std::vector>* push_sparse_status, + const int batch_size, const bool use_cvm, const bool dump_slot, + std::vector* sparse_push_keys, const bool no_cvm); + + // Push sparse variables to server in async mode + void PushSparseFromTensorWithLabelAsync( + const Scope& scope, const uint64_t table_id, int fea_dim, + uint64_t padding_id, bool scale_sparse, const std::string& accesor, + const std::string& click_name, platform::Place place, + const std::vector& input_names, + std::vector* inputs, // NOLINT + std::vector* outputs); // NOLINT + + // Push sparse variables to server in Async mode + // Param: scope, table_id, fea_keys, sparse_grad_names + // Param: push_values, push_sparse_status + + // init server + void LoadSparseOnServer(const std::string& path, const std::string& meta, + uint32_t table_id); + // init server + // void InitServer(const std::string& dist_desc, + // const std::vector& host_sign_list, int index); + void InitServer(const std::string& dist_desc, + const std::vector& host_sign_list, int index); + // init trainer + void InitWorker(const std::string& dist_desc, + const std::vector& host_sign_list, Scope* scope, + const RpcCtxMap& send_ctx, + const std::unordered_map>& + dense_varnames, + const std::map& envs, int node_num, + int index); + + // stop server + void StopServer(); + // finalize worker to make worker can be stop + void FinalizeWorker(); + // run server with ip port + uint64_t RunServer(const std::string& ip, uint32_t port); + // get client info + std::vector GetClientsInfo(); + // create client to client connection + void CreateClient2ClientConnection(); + // flush all push requests + void ClientFlush(); + + // barrier with barrier table + void BarrierWithTable(uint32_t barrier_type); + + void PrintTableStat(const uint64_t table_id); + // mode = 0, load all feature + // mode = 1, load delta feature, which means load diff + void LoadModel(const std::string& path, const int mode); + // mode = 0, load all feature + // mode = 1, load delta feature, which means load diff + void LoadModelOneTable(const uint64_t table_id, const std::string& path, + const int mode); + // mode = 0, save all feature + // mode = 1, save delta feature, which means save diff + void SaveModel(const std::string& path, const int mode); + // mode = 0, save all feature + // mode = 1, save delta feature, which means save diff + void SaveModelOneTable(const uint64_t table_id, const std::string& path, + const int mode); + // clear all models, release their memory + void ClearModel(); + // clear one table + void ClearOneTable(const uint64_t table_id); + // shrink sparse table + void ShrinkSparseTable(int table_id); + // shrink dense table + void ShrinkDenseTable(int table_id, Scope* scope, + std::vector var_list, float decay, + int emb_dim); + + typedef std::function MsgHandlerFunc; + // register client to client communication + int RegisterClientToClientMsgHandler(int msg_type, MsgHandlerFunc handler); + // send client to client message + std::future SendClientToClientMsg(int msg_type, int to_client_id, + const std::string& msg); + + // FleetWrapper singleton + static std::shared_ptr GetInstance() { + if (NULL == s_instance_) { + s_instance_.reset(new paddle::distributed::FleetWrapper()); + } + return s_instance_; + } + // this performs better than rand_r, especially large data + std::default_random_engine& LocalRandomEngine(); + + static std::shared_ptr pserver_ptr_; + + private: + static std::shared_ptr s_instance_; + size_t GetAbsoluteSum(size_t start, size_t end, size_t level, + const framework::LoD& lod); + + protected: + static bool is_initialized_; + std::map> _regions; + bool scale_sparse_gradient_with_batch_size_; + int32_t sleep_seconds_before_fail_exit_; + int client2client_request_timeout_ms_; + int client2client_connect_timeout_ms_; + int client2client_max_retry_; + DISABLE_COPY_AND_ASSIGN(FleetWrapper); +}; + +} // end namespace distributed +} // end namespace paddle diff --git a/paddle/fluid/distributed/service/CMakeLists.txt b/paddle/fluid/distributed/service/CMakeLists.txt new file mode 100644 index 00000000000000..0c767ad2b3fa6b --- /dev/null +++ b/paddle/fluid/distributed/service/CMakeLists.txt @@ -0,0 +1,40 @@ +set(BRPC_SRCS ps_client.cc server.cc) +set_source_files_properties(${BRPC_SRCS}) + +set(BRPC_DEPS brpc ssl crypto protobuf gflags glog zlib leveldb snappy gflags glog) + +brpc_library(sendrecv_rpc SRCS + ${BRPC_SRCS} + PROTO sendrecv.proto + DEPS ${BRPC_DEPS} ) + +set_property(GLOBAL PROPERTY RPC_DEPS sendrecv_rpc ${BRPC_DEPS} string_helper) + +get_property(RPC_DEPS GLOBAL PROPERTY RPC_DEPS) + +set_source_files_properties(communicator.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) +set_source_files_properties(service.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) +set_source_files_properties(brpc_ps_server.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) +set_source_files_properties(brpc_ps_client.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) + +set_source_files_properties(brpc_utils.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) +set_source_files_properties(heter_server.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) +set_source_files_properties(heter_client.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) + +set_source_files_properties(client.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) +set_source_files_properties(ps_client.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) +set_source_files_properties(server.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) + + +cc_library(downpour_server SRCS brpc_ps_server.cc DEPS boost eigen3 table ${RPC_DEPS}) +cc_library(downpour_client SRCS brpc_ps_client.cc DEPS boost eigen3 table ${RPC_DEPS}) + +cc_library(client SRCS ps_client.cc DEPS downpour_client boost ${RPC_DEPS}) +cc_library(server SRCS server.cc DEPS downpour_server boost ${RPC_DEPS}) + +cc_library(communicator SRCS communicator.cc DEPS scope client boost table math_function selected_rows_functor ${RPC_DEPS}) +cc_library(ps_service SRCS service.cc DEPS communicator client server boost ${RPC_DEPS}) + +cc_library(brpc_utils SRCS brpc_utils.cc DEPS ${COMMON_DEPS} ${RPC_DEPS}) +cc_library(heter_server SRCS heter_server.cc DEPS brpc_utils ${COMMON_DEPS} ${RPC_DEPS}) +cc_library(heter_client SRCS heter_client.cc DEPS brpc_utils ${COMMON_DEPS} ${RPC_DEPS}) diff --git a/paddle/fluid/distributed/service/brpc_ps_client.cc b/paddle/fluid/distributed/service/brpc_ps_client.cc new file mode 100644 index 00000000000000..bc9d017532dff0 --- /dev/null +++ b/paddle/fluid/distributed/service/brpc_ps_client.cc @@ -0,0 +1,879 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include + +#include "Eigen/Dense" +#include "paddle/fluid/distributed/service/brpc_ps_client.h" +#include "paddle/fluid/distributed/table/table.h" +#include "paddle/fluid/framework/archive.h" + +const static int max_port = 65535; + +DEFINE_int32(pserver_push_dense_merge_limit, 12, + "limit max push_dense local merge requests"); + +DEFINE_int32(pserver_push_sparse_merge_limit, 12, + "limit max push_sparse local merge requests"); + +DEFINE_int32(pserver_pull_dense_limit, 12, + "limit max push_sparse local merge requests"); + +DEFINE_int32(pserver_async_push_dense_interval_ms, 10, + "async push_dense to server interval"); + +DEFINE_int32(pserver_async_push_sparse_interval_ms, 10, + "async push_sparse to server interval"); + +DEFINE_bool(pserver_scale_gradient_by_merge, false, + "scale dense gradient when merged"); + +DEFINE_int32(pserver_communicate_compress_type, 0, + "none:0 snappy:1 gzip:2 zlib:3 lz4:4"); + +DEFINE_int32(pserver_max_async_call_num, 13, + "max task num in async_call_server"); + +DEFINE_int32(pserver_timeout_ms, 500000, "pserver request server timeout_ms"); + +DEFINE_int32(pserver_connect_timeout_ms, 10000, + "pserver connect server timeout_ms"); + +DEFINE_int32(pserver_sparse_merge_thread, 1, "pserver sparse merge thread num"); + +namespace paddle { +namespace distributed { + +inline size_t get_sparse_shard(uint32_t shard_num, uint32_t server_num, + uint64_t key) { + size_t remind = shard_num % server_num; + size_t local_shard_num = + remind == 0 ? shard_num / server_num : shard_num / server_num + 1; + return (key % shard_num) / local_shard_num; +} + +void DownpourPsClientService::service( + ::google::protobuf::RpcController *controller, + const ::paddle::PsRequestMessage *request, + ::paddle::PsResponseMessage *response, ::google::protobuf::Closure *done) { + brpc::ClosureGuard done_guard(done); + int ret = _client->handle_client2client_msg( + request->cmd_id(), request->client_id(), request->data()); + response->set_err_code(0); + response->set_err_msg(""); + if (ret != 0) { + response->set_err_code(-1); + response->set_err_msg("handle_client2client_msg failed"); + } +} + +// 启动client端RpcService 用于数据互发等操作 +int32_t BrpcPsClient::start_client_service() { + if (_service.configure(this, _client_id) != 0) { + LOG(ERROR) + << "service initialize failed, service_name:DownpourPsClientService"; + return -1; + } + _server.AddService(&_service, brpc::SERVER_DOESNT_OWN_SERVICE); + brpc::ServerOptions options; + int start_port = 8500; + options.num_threads = 24; + + if (_server.Start(butil::my_ip_cstr(), brpc::PortRange(start_port, max_port), + &options) != 0) { + LOG(ERROR) << "BrpcPsServer start failed"; + return -1; + } + _env->registe_ps_client(butil::my_ip_cstr(), _server.listen_address().port, + _client_id); + return 0; +} + +int32_t BrpcPsClient::create_client2client_connection( + int pserver_timeout_ms, int pserver_connect_timeout_ms, int max_retry) { + brpc::ChannelOptions options; + options.protocol = "baidu_std"; + options.timeout_ms = pserver_timeout_ms; + options.connection_type = "pooled"; + options.connect_timeout_ms = pserver_connect_timeout_ms; + options.max_retry = max_retry; + + std::vector client_list = _env->get_ps_clients(); + _client_channels.resize(client_list.size()); + std::ostringstream os; + std::string server_ip_port; + for (size_t i = 0; i < client_list.size(); ++i) { + server_ip_port.assign(client_list[i].ip.c_str()); + server_ip_port.append(":"); + server_ip_port.append(std::to_string(client_list[i].port)); + _client_channels[i].reset(new brpc::Channel()); + if (_client_channels[i]->Init(server_ip_port.c_str(), "", &options) != 0) { + LOG(ERROR) << "psclient connect to client:" << server_ip_port + << " Failed!"; + } + os << server_ip_port << ","; + } + LOG(INFO) << "Client connect success:" << os.str(); + return 0; +} + +int32_t BrpcPsClient::initialize() { + _async_call_num = 0; + + brpc::ChannelOptions options; + options.protocol = "baidu_std"; + options.timeout_ms = FLAGS_pserver_timeout_ms; + options.connection_type = "pooled"; + options.connect_timeout_ms = FLAGS_pserver_connect_timeout_ms; + options.max_retry = 3; + + std::ostringstream os; + std::string server_ip_port; + std::string client_ip(butil::my_ip_cstr()); + + // 获取server列表,并连接 + std::vector server_list = _env->get_ps_servers(); + _server_channels.resize(server_list.size()); + for (size_t i = 0; i < server_list.size(); ++i) { + server_ip_port.assign(server_list[i].ip.c_str()); + server_ip_port.append(":"); + server_ip_port.append(std::to_string(server_list[i].port)); + for (size_t j = 0; j < _server_channels[i].size(); ++j) { + _server_channels[i][j].reset(new brpc::Channel()); + if (_server_channels[i][j]->Init(server_ip_port.c_str(), "", &options) != + 0) { + LOG(ERROR) << "psclient connect to server:" << server_ip_port + << " Failed!"; + return -1; + } + } + os << server_ip_port << ","; + } + // 启动client探听接口, 并相互建立连接 + start_client_service(); + + _running = true; + _flushing = false; + return 0; +} + +int DownpourBrpcClosure::check_response(size_t request_idx, int cmd_id) { + if (_cntls[request_idx]->Failed()) { + LOG(ERROR) << "resquest cmd_id:" << cmd_id << " failed, " + "err:" + << _cntls[request_idx]->ErrorText(); + return -1; + } + if (_responses[request_idx].err_code() != 0) { + LOG(ERROR) << "response ret bad, server_idx:" << request_idx + << "cmd_id:" << cmd_id + << " err_code:" << _responses[request_idx].err_code() + << " err_msg:" << _responses[request_idx].err_msg(); + return -1; + } + return 0; +} + +int DownpourBrpcClosure::check_save_response(size_t request_idx, int cmd_id) { + uint32_t feasign_size = 0; + if (_cntls[request_idx]->Failed()) { + LOG(ERROR) << "resquest cmd_id:" << cmd_id << " failed, " + "err:" + << _cntls[request_idx]->ErrorText(); + return -1; + } + feasign_size = _responses[request_idx].err_code(); + if (feasign_size < 0) { + LOG(ERROR) << "response ret bad, server_idx:" << request_idx + << "cmd_id:" << cmd_id + << " err_code:" << _responses[request_idx].err_code() + << " err_msg:" << _responses[request_idx].err_msg(); + return -1; + } + return feasign_size; +} + +std::string DownpourBrpcClosure::get_response(size_t request_idx, int cmd_id) { + std::string data = _responses[request_idx].data(); + return data; +} + +std::future BrpcPsClient::print_table_stat(uint32_t table_id) { + size_t request_call_num = _server_channels.size(); + DownpourBrpcClosure *closure = new DownpourBrpcClosure( + request_call_num, [request_call_num, table_id](void *done) { + int ret = 0; + uint64_t feasign_size = 0; + uint64_t mf_size = 0; + paddle::framework::BinaryArchive ar; + auto *closure = (DownpourBrpcClosure *)done; + for (size_t i = 0; i < request_call_num; ++i) { + if (closure->check_response(i, PS_PRINT_TABLE_STAT) != 0) { + ret = -1; + break; + } + std::string resp = closure->get_response(i, PS_PRINT_TABLE_STAT); + ar.SetReadBuffer(const_cast(resp.c_str()), resp.length(), + nullptr); + + feasign_size += ar.Get(); + mf_size += ar.Get(); + } + closure->set_promise_value(ret); + std::cout << "table id: " << table_id + << ", feasign size: " << feasign_size + << ", mf size: " << mf_size << std::endl; + }); + auto promise = std::make_shared>(); + closure->add_promise(promise); + std::future fut = promise->get_future(); + for (size_t i = 0; i < request_call_num; ++i) { + closure->request(i)->set_cmd_id(PS_PRINT_TABLE_STAT); + closure->request(i)->set_table_id(table_id); + closure->request(i)->set_client_id(_client_id); + PsService_Stub rpc_stub(get_cmd_channel(i)); + closure->cntl(i)->set_timeout_ms( + 10800000); // cmd msg don't limit timeout for save/load + rpc_stub.service(closure->cntl(i), closure->request(i), + closure->response(i), closure); + } + return fut; +} +std::future BrpcPsClient::send_cmd( + uint32_t table_id, int cmd_id, const std::vector ¶ms) { + size_t request_call_num = _server_channels.size(); + DownpourBrpcClosure *closure = new DownpourBrpcClosure( + request_call_num, [request_call_num, cmd_id](void *done) { + int ret = 0; + auto *closure = (DownpourBrpcClosure *)done; + for (size_t i = 0; i < request_call_num; ++i) { + if (closure->check_response(i, cmd_id) != 0) { + ret = -1; + break; + } + } + closure->set_promise_value(ret); + }); + auto promise = std::make_shared>(); + closure->add_promise(promise); + std::future fut = promise->get_future(); + for (size_t i = 0; i < request_call_num; ++i) { + closure->request(i)->set_cmd_id(cmd_id); + closure->request(i)->set_table_id(table_id); + closure->request(i)->set_client_id(_client_id); + for (const auto ¶m : params) { + closure->request(i)->add_params(param); + } + PsService_Stub rpc_stub(get_cmd_channel(i)); + closure->cntl(i)->set_timeout_ms( + 10800000); // cmd msg don't limit timeout for save/load + rpc_stub.service(closure->cntl(i), closure->request(i), + closure->response(i), closure); + } + return fut; +} + +std::future BrpcPsClient::send_save_cmd( + uint32_t table_id, int cmd_id, const std::vector ¶ms) { + size_t request_call_num = _server_channels.size(); + DownpourBrpcClosure *closure = new DownpourBrpcClosure( + request_call_num, [request_call_num, cmd_id](void *done) { + int ret = 0; + uint32_t feasign_size = 0; + auto *closure = (DownpourBrpcClosure *)done; + for (size_t i = 0; i < request_call_num; ++i) { + if (closure->check_save_response(i, cmd_id) < 0) { + ret = -1; + break; + } + feasign_size += closure->check_save_response(i, cmd_id); + } + if (ret == 0) { + closure->set_promise_value(feasign_size); + } else { + closure->set_promise_value(ret); + } + }); + auto promise = std::make_shared>(); + closure->add_promise(promise); + std::future fut = promise->get_future(); + for (size_t i = 0; i < request_call_num; ++i) { + closure->request(i)->set_cmd_id(cmd_id); + closure->request(i)->set_table_id(table_id); + closure->request(i)->set_client_id(_client_id); + for (const auto ¶m : params) { + closure->request(i)->add_params(param); + } + PsService_Stub rpc_stub(get_cmd_channel(i)); + closure->cntl(i)->set_timeout_ms( + 10800000); // cmd msg don't limit timeout for save/load + rpc_stub.service(closure->cntl(i), closure->request(i), + closure->response(i), closure); + } + return fut; +} + +std::future BrpcPsClient::shrink(uint32_t table_id) { + return send_cmd(table_id, PS_SHRINK_TABLE, {std::string("1")}); +} + +std::future BrpcPsClient::load(const std::string &epoch, + const std::string &mode) { + return send_cmd(-1, PS_LOAD_ALL_TABLE, {epoch, mode}); +} +std::future BrpcPsClient::load(uint32_t table_id, + const std::string &epoch, + const std::string &mode) { + return send_cmd(table_id, PS_LOAD_ONE_TABLE, {epoch, mode}); +} + +std::future BrpcPsClient::save(const std::string &epoch, + const std::string &mode) { + return send_save_cmd(-1, PS_SAVE_ALL_TABLE, {epoch, mode}); +} +std::future BrpcPsClient::save(uint32_t table_id, + const std::string &epoch, + const std::string &mode) { + return send_save_cmd(table_id, PS_SAVE_ONE_TABLE, {epoch, mode}); +} + +std::future BrpcPsClient::clear() { + return send_cmd(-1, PS_CLEAR_ALL_TABLE, {}); +} +std::future BrpcPsClient::clear(uint32_t table_id) { + return send_cmd(table_id, PS_CLEAR_ONE_TABLE, {}); +} + +std::future BrpcPsClient::flush() { + _flushing = true; + std::promise promise; + std::future fut = promise.get_future(); + do { + VLOG(3) << "wait _async_call_num:" << _async_call_num; + usleep(100000); // sleep 100ms wait async end + } while (_async_call_num > 0); + promise.set_value(0); + _flushing = false; + return fut; +} + +void BrpcPsClient::finalize_worker() { + flush(); + _running = false; + _server.Stop(1000); + _server.Join(); +} + +std::future BrpcPsClient::stop_server() { + return send_cmd(-1, PS_STOP_SERVER, {}); +} + +std::future BrpcPsClient::start_profiler() { + return send_cmd(-1, PS_START_PROFILER, {}); +} + +std::future BrpcPsClient::stop_profiler() { + return send_cmd(-1, PS_STOP_PROFILER, {}); +} + +std::future BrpcPsClient::barrier(size_t table_id, + uint32_t barrier_type) { + return send_cmd(table_id, PS_BARRIER, {std::to_string(barrier_type)}); +} + +std::future BrpcPsClient::pull_geo_param(size_t table_id, + std::vector *values, + std::vector *keys, + int pserver_idx) { + auto *accessor = table_accessor(table_id); + DownpourBrpcClosure *closure = + new DownpourBrpcClosure(1, [keys, values, accessor](void *done) { + int ret = 0; + auto *closure = (DownpourBrpcClosure *)done; + uint32_t shard_nums; + if (closure->check_response(0, PS_PULL_GEO_PARAM) != 0) { + ret = -1; + } + auto &res_io_buffer = closure->cntl(0)->response_attachment(); + butil::IOBufBytesIterator io_buffer_itr(res_io_buffer); + io_buffer_itr.copy_and_forward((void *)(&shard_nums), sizeof(uint32_t)); + keys->resize(shard_nums); + values->resize(shard_nums * accessor->update_dim()); + io_buffer_itr.copy_and_forward((void *)(keys->data()), + sizeof(uint64_t) * shard_nums); + io_buffer_itr.copy_and_forward((void *)(values->data()), + shard_nums * accessor->update_size()); + closure->set_promise_value(ret); + }); + auto promise = std::make_shared>(); + closure->add_promise(promise); + std::future fut = promise->get_future(); + + closure->request(0)->set_cmd_id(PS_PULL_GEO_PARAM); + closure->request(0)->set_table_id(table_id); + closure->request(0)->set_client_id(_client_id); + PsService_Stub rpc_stub(get_cmd_channel(pserver_idx)); + closure->cntl(0)->set_log_id(butil::gettimeofday_ms()); + rpc_stub.service(closure->cntl(0), closure->request(0), closure->response(0), + closure); + return fut; +} + +std::future BrpcPsClient::push_sparse_param( + size_t table_id, const uint64_t *keys, const float **update_values, + size_t num, void *done) { + auto *accessor = table_accessor(table_id); + // 发送RPC请求 + DownpourBrpcClosure *closure = reinterpret_cast(done); + auto promise = std::make_shared>(); + closure->add_promise(promise); + std::future fut = promise->get_future(); + size_t request_call_num = _server_channels.size(); + std::vector> ids; + std::vector> value_ptrs; + ids.resize(request_call_num); + value_ptrs.resize(request_call_num); + for (size_t i = 0; i < num; ++i) { + size_t pserver_idx = keys[i] % request_call_num; + ids[pserver_idx].push_back(keys[i]); + value_ptrs[pserver_idx].push_back(update_values[i]); + } + for (size_t shard_idx = 0; shard_idx < request_call_num; ++shard_idx) { + auto kvs = ids[shard_idx]; + auto value_ptr = value_ptrs[shard_idx]; + size_t kv_size = kvs.size(); + uint32_t value_size = accessor->update_size(); + // 发送RPC请求 + auto *push_request = closure->request(shard_idx); + push_request->set_cmd_id(PS_PUSH_SPARSE_PARAM); + push_request->set_table_id(table_id); + push_request->set_client_id(_client_id); + push_request->add_params((char *)&kv_size, sizeof(uint32_t)); + auto *push_data = push_request->mutable_data(); + push_data->resize(kv_size * (sizeof(uint64_t) + accessor->update_size())); + char *push_data_ptr = const_cast(push_data->data()); + memcpy(push_data_ptr, kvs.data(), kv_size * sizeof(uint64_t)); + push_data_ptr += kv_size * sizeof(uint64_t); + for (int i = 0; i < kv_size; ++i) { + memcpy(push_data_ptr, value_ptr[i], accessor->update_size()); + push_data_ptr += accessor->update_size(); + } + PsService_Stub rpc_stub(get_sparse_channel(shard_idx)); + closure->cntl(shard_idx)->set_request_compress_type( + (brpc::CompressType)FLAGS_pserver_communicate_compress_type); + rpc_stub.service(closure->cntl(shard_idx), closure->request(shard_idx), + closure->response(shard_idx), closure); + } + return fut; +} + +std::future BrpcPsClient::pull_dense(Region *regions, + size_t region_num, + size_t table_id) { + auto *accessor = table_accessor(table_id); + size_t request_call_num = _server_channels.size(); + uint32_t num_per_shard = + dense_dim_per_shard(accessor->fea_dim(), request_call_num); + // callback 将各shard结果,顺序填入region + DownpourBrpcClosure *closure = new DownpourBrpcClosure( + request_call_num, [request_call_num, num_per_shard, regions, region_num, + accessor](void *done) { + int ret = 0; + size_t region_idx = 0; // 当前填充的region偏移 + size_t region_data_idx = 0; // 当前填充的region内data偏移 + auto *closure = (DownpourBrpcClosure *)done; + size_t shard_data_size = num_per_shard * accessor->select_size(); + for (size_t i = 0; i < request_call_num; ++i) { + if (closure->check_response(i, PS_PULL_DENSE_TABLE) != 0) { + ret = -1; + break; + } + auto &res_io_buffer = closure->cntl(i)->response_attachment(); + + butil::IOBufBytesIterator io_buffer_itr(res_io_buffer); + size_t shard_buffer_remain = res_io_buffer.size(); + if (shard_buffer_remain != shard_data_size) { + LOG(ERROR) << "expect res_size:" << shard_data_size + << ", but size:" << shard_buffer_remain + << ", ignore this response"; + ret = -1; + break; + } + while (shard_buffer_remain > 0 && region_idx < region_num) { + auto ®ion = regions[region_idx]; + if (region.size - region_data_idx >= shard_buffer_remain) { + // region待填充空间 >= 分片buffer数据, 直接拷贝置入 + io_buffer_itr.copy_and_forward( + (void *)(region.data + region_data_idx), shard_buffer_remain); + region_data_idx += shard_buffer_remain; + shard_buffer_remain = 0; + } else if (region.size - region_data_idx == 0) { + // region填满,切换到下一个region + ++region_idx; + region_data_idx = 0; + } else { + // region不足以容纳所有数据,则能放多少 拷贝多少 + io_buffer_itr.copy_and_forward( + (void *)(region.data + region_data_idx), + region.size - region_data_idx); + shard_buffer_remain -= (region.size - region_data_idx); + ++region_idx; + region_data_idx = 0; + } + } + } + closure->set_promise_value(ret); + }); + auto promise = std::make_shared>(); + closure->add_promise(promise); + std::future fut = promise->get_future(); + for (size_t i = 0; i < request_call_num; ++i) { + closure->request(i)->set_cmd_id(PS_PULL_DENSE_TABLE); + closure->request(i)->set_table_id(table_id); + closure->request(i)->set_client_id(_client_id); + closure->request(i)->add_params((char *)&num_per_shard, + sizeof(num_per_shard)); + PsService_Stub rpc_stub(get_dense_channel(i)); + rpc_stub.service(closure->cntl(i), closure->request(i), + closure->response(i), closure); + } + return fut; +} + +std::future BrpcPsClient::push_dense_param(const Region *regions, + size_t region_num, + size_t table_id) { + auto *accessor = table_accessor(table_id); + size_t request_call_num = _server_channels.size(); + // 1.拆分Region数据到shard中,后续多shard并行拷贝数据 + std::vector> regions_partition(request_call_num); + uint32_t num_per_shard = + dense_dim_per_shard(accessor->fea_dim(), request_call_num); + size_t shard_data_size = num_per_shard * accessor->update_size(); + size_t current_region_idx = 0; + size_t current_region_data_idx = 0; + for (size_t i = 0; i < request_call_num; ++i) { + size_t shard_data_remain_size = shard_data_size; + while (shard_data_remain_size > 0 && current_region_idx < region_num) { + const auto ®ion = regions[current_region_idx]; + size_t region_remain_size = region.size - current_region_data_idx; + if (shard_data_remain_size >= region_remain_size) { + regions_partition[i].push_back( + Region(region.data + current_region_data_idx, region_remain_size)); + ++current_region_idx; + current_region_data_idx = 0; + shard_data_remain_size -= region_remain_size; + } else { + regions_partition[i].push_back(Region( + region.data + current_region_data_idx, shard_data_remain_size)); + current_region_data_idx += shard_data_remain_size; + shard_data_remain_size = 0; + } + } + } + + DownpourBrpcClosure *closure = + new DownpourBrpcClosure(request_call_num, [request_call_num](void *done) { + int ret = 0; + auto *closure = (DownpourBrpcClosure *)done; + for (size_t i = 0; i < request_call_num; ++i) { + if (closure->check_response(i, PS_PUSH_DENSE_PARAM) != 0) { + ret = -1; + break; + } + } + closure->set_promise_value(ret); + }); + auto promise = std::make_shared>(); + closure->add_promise(promise); + std::future fut = promise->get_future(); + static const int REGION_ASSIGN_BUFFER_SIZE = 1024 * 10; + static char region_assign_buffer[REGION_ASSIGN_BUFFER_SIZE]; //用于数据补齐 + //开始多shard并行拷贝&请求 + for (size_t i = 0; i < request_call_num; ++i) { + closure->request(i)->set_cmd_id(PS_PUSH_DENSE_PARAM); + closure->request(i)->set_table_id(table_id); + closure->request(i)->set_client_id(_client_id); + auto &request_buffer = closure->cntl(i)->request_attachment(); + request_buffer.append((void *)&num_per_shard, sizeof(uint32_t)); + auto ®ion_list = regions_partition[i]; + size_t fill_remain_size = shard_data_size; + for (auto ®ion : region_list) { + fill_remain_size -= region.size; + request_buffer.append((void *)region.data, region.size); + } + //保证各分片数据对齐 + while (fill_remain_size > 0) { + size_t fill_num = fill_remain_size > REGION_ASSIGN_BUFFER_SIZE + ? REGION_ASSIGN_BUFFER_SIZE + : fill_remain_size; + request_buffer.append((void *)region_assign_buffer, fill_num); + fill_remain_size -= fill_num; + } + PsService_Stub rpc_stub(get_dense_channel(i)); + rpc_stub.service(closure->cntl(i), closure->request(i), + closure->response(i), closure); + } + return fut; +} + +std::future BrpcPsClient::push_sparse_raw_gradient( + size_t table_id, const uint64_t *keys, const float **update_values, + size_t num, void *done) { + auto *accessor = table_accessor(table_id); + //发送RPC请求 + DownpourBrpcClosure *closure = reinterpret_cast(done); + auto promise = std::make_shared>(); + closure->add_promise(promise); + std::future fut = promise->get_future(); + + size_t request_call_num = _server_channels.size(); + std::vector> ids; + std::vector> value_ptrs; + ids.resize(request_call_num); + value_ptrs.resize(request_call_num); + + for (size_t i = 0; i < num; ++i) { + size_t pserver_idx = keys[i] % request_call_num; + ids[pserver_idx].push_back(keys[i]); + value_ptrs[pserver_idx].push_back(update_values[i]); + } + + for (size_t shard_idx = 0; shard_idx < request_call_num; ++shard_idx) { + auto kvs = ids[shard_idx]; + auto value_ptr = value_ptrs[shard_idx]; + + size_t kv_size = kvs.size(); + uint32_t value_size = accessor->update_size(); + + // 发送RPC请求 + auto *push_request = closure->request(shard_idx); + push_request->set_cmd_id(PS_PUSH_SPARSE_TABLE); + push_request->set_table_id(table_id); + push_request->set_client_id(_client_id); + push_request->add_params((char *)&kv_size, sizeof(uint32_t)); + auto *push_data = push_request->mutable_data(); + push_data->resize(kv_size * (sizeof(uint64_t) + accessor->update_size())); + char *push_data_ptr = const_cast(push_data->data()); + memcpy(push_data_ptr, kvs.data(), kv_size * sizeof(uint64_t)); + push_data_ptr += kv_size * sizeof(uint64_t); + + for (int i = 0; i < kv_size; ++i) { + memcpy(push_data_ptr, value_ptr[i], accessor->update_size()); + push_data_ptr += accessor->update_size(); + } + PsService_Stub rpc_stub(get_sparse_channel(shard_idx)); + closure->cntl(shard_idx)->set_request_compress_type( + (brpc::CompressType)FLAGS_pserver_communicate_compress_type); + rpc_stub.service(closure->cntl(shard_idx), closure->request(shard_idx), + closure->response(shard_idx), closure); + } + return fut; +} + +std::future BrpcPsClient::push_dense_raw_gradient( + int table_id, float *total_send_data, size_t total_send_data_size, + void *done) { + size_t request_call_num = _server_channels.size(); + DownpourBrpcClosure *closure = reinterpret_cast(done); + auto promise = std::make_shared>(); + closure->add_promise(promise); + std::future fut = promise->get_future(); + auto *accessor = table_accessor(table_id); + uint32_t num_per_shard = + dense_dim_per_shard(accessor->fea_dim(), request_call_num); + for (size_t i = 0; i < request_call_num; ++i) { + closure->request(i)->set_cmd_id(PS_PUSH_DENSE_TABLE); + closure->request(i)->set_table_id(table_id); + closure->request(i)->set_client_id(_client_id); + auto *push_data = closure->request(i)->mutable_data(); + push_data->clear(); + push_data->resize(sizeof(uint32_t) + num_per_shard * sizeof(float)); + char *push_data_ptr = const_cast(push_data->data()); + memcpy(push_data_ptr, &num_per_shard, sizeof(uint32_t)); + memcpy(push_data_ptr + sizeof(uint32_t), + total_send_data + i * num_per_shard, num_per_shard * sizeof(float)); + VLOG(1) << "push_dense_raw_gradient finish memcpy"; + // closure->cntl(i)->set_request_compress_type( + // (brpc::CompressType)FLAGS_pserver_communicate_compress_type); + PsService_Stub rpc_stub(get_dense_channel(i)); + VLOG(1) << "push_dense_raw_gradient get_dense_channel " << i; + rpc_stub.service(closure->cntl(i), closure->request(i), + closure->response(i), closure); + VLOG(1) << "push_dense_raw_gradient async service " << i; + } + return fut; +} + +std::future BrpcPsClient::pull_sparse(float **select_values, + size_t table_id, + const uint64_t *keys, + size_t num) { + size_t request_call_num = _server_channels.size(); + + auto shard_sorted_kvs = std::make_shared< + std::vector>>>(); + shard_sorted_kvs->resize(request_call_num); + + for (size_t i = 0; i < num; ++i) { + size_t shard_id = keys[i] % request_call_num; + shard_sorted_kvs->at(shard_id).push_back({keys[i], select_values[i]}); + } + + auto *accessor = table_accessor(table_id); + size_t value_size = accessor->select_size(); + + DownpourBrpcClosure *closure = new DownpourBrpcClosure( + request_call_num, [shard_sorted_kvs, value_size](void *done) { + int ret = 0; + auto *closure = (DownpourBrpcClosure *)done; + for (size_t i = 0; i < ids.size(); ++i) { + if (closure->check_response(i, PS_PULL_SPARSE_TABLE) != 0) { + ret = -1; + break; + } + + auto &request_kvs = shard_sorted_kvs->at(i); + auto &res_io_buffer = closure->cntl(i)->response_attachment(); + butil::IOBufBytesIterator io_buffer_itr(res_io_buffer); + uint64_t last_key = UINT64_MAX; + float *last_value_data = NULL; + + for (size_t kv_idx = 0; kv_idx < request_kvs.size(); ++kv_idx) { + auto *kv_pair = &(request_kvs[kv_idx]); + if (kv_pair->first == last_key) { + memcpy((void *)kv_pair->second, (void *)last_value_data, + value_size); + } else { + last_key = kv_pair->first; + last_value_data = kv_pair->second; + if (value_size != + io_buffer_itr.copy_and_forward((void *)(last_value_data), + value_size)) { + LOG(WARNING) << "res data is lack or not in format"; + ret = -1; + break; + } + } + } + } + closure->set_promise_value(ret); + }); + + auto promise = std::make_shared>(); + closure->add_promise(promise); + std::future fut = promise->get_future(); + + for (size_t i = 0; i < request_call_num; ++i) { + auto &sorted_kvs = shard_sorted_kvs->at(i); + std::sort(sorted_kvs.begin(), sorted_kvs.end(), + [](const std::pair &k1, + const std::pair &k2) { + return k1.first < k2.first; + }); + + uint64_t last_key = UINT64_MAX; + uint32_t kv_request_count = 0; + size_t sorted_kv_size = sorted_kvs.size(); + auto &request_buffer = closure->cntl(i)->request_attachment(); + for (size_t kv_idx = 0; kv_idx < sorted_kv_size; ++kv_idx) { + ++kv_request_count; + last_key = sorted_kvs[kv_idx].first; + request_buffer.append((void *)&last_key, sizeof(uint64_t)); + while (kv_idx < sorted_kv_size - 1 && + last_key == sorted_kvs[kv_idx + 1].first) { + ++kv_idx; + } + } + + if (kv_request_count == 0) { + closure->Run(); + } else { + closure->request(i)->set_cmd_id(PS_PULL_SPARSE_TABLE); + closure->request(i)->set_table_id(table_id); + closure->request(i)->set_client_id(_client_id); + closure->request(i)->add_params((char *)&kv_request_count, + sizeof(uint32_t)); + PsService_Stub rpc_stub(get_cmd_channel(i)); + closure->cntl(i)->set_log_id(butil::gettimeofday_ms()); + rpc_stub.service(closure->cntl(i), closure->request(i), + closure->response(i), closure); + } + } + return fut; +} + +std::future BrpcPsClient::send_client2client_msg( + int msg_type, int to_client_id, const std::string &msg) { + auto promise = std::make_shared>(); + std::future fut = promise->get_future(); + if (to_client_id >= _client_channels.size()) { + LOG(FATAL) << "to_client_id is out of range clients, which size is " + << _client_channels.size(); + promise->set_value(-1); + return fut; + } + auto *closure = new DownpourBrpcClosure(1, [msg_type](void *done) { + auto *closure = (DownpourBrpcClosure *)done; + int32_t ret = closure->check_response(0, msg_type + 1000); + closure->set_promise_value(ret); + }); + closure->add_promise(promise); + closure->request(0)->set_cmd_id(msg_type); + closure->request(0)->set_client_id(_client_id); + closure->request(0)->set_data(msg); + PsService_Stub rpc_stub(_client_channels[to_client_id].get()); + rpc_stub.service(closure->cntl(0), closure->request(0), closure->response(0), + closure); + return fut; +} + +std::future BrpcPsClient::push_sparse_raw_gradient_partial( + size_t table_id, const uint64_t *keys, const float **update_values, + uint32_t num, void *done, int pserver_idx) { + auto *accessor = table_accessor(table_id); + size_t value_size = accessor->update_size(); + DownpourBrpcClosure *closure = reinterpret_cast(done); + auto promise = std::make_shared>(); + closure->add_promise(promise); + std::future fut = promise->get_future(); + + // 发送RPC请求 + auto *push_request = closure->request(0); + push_request->set_cmd_id(PS_PUSH_SPARSE_TABLE); + push_request->set_table_id(table_id); + push_request->set_client_id(_client_id); + push_request->add_params((char *)&num, sizeof(uint32_t)); + auto *push_data = push_request->mutable_data(); + push_data->resize(num * (sizeof(uint64_t) + value_size)); + char *push_data_ptr = const_cast(push_data->data()); + memcpy(push_data_ptr, keys, num * sizeof(uint64_t)); + push_data_ptr += num * sizeof(uint64_t); + for (int i = 0; i < num; ++i) { + memcpy(push_data_ptr, update_values[i], value_size); + push_data_ptr += value_size; + } + PsService_Stub rpc_stub(get_sparse_channel(pserver_idx)); + closure->cntl(0)->set_request_compress_type( + (brpc::CompressType)FLAGS_pserver_communicate_compress_type); + rpc_stub.service(closure->cntl(0), closure->request(0), closure->response(0), + closure); + return fut; +} + +} // namespace distributed +} // namespace paddle diff --git a/paddle/fluid/distributed/service/brpc_ps_client.h b/paddle/fluid/distributed/service/brpc_ps_client.h new file mode 100644 index 00000000000000..c0716515150795 --- /dev/null +++ b/paddle/fluid/distributed/service/brpc_ps_client.h @@ -0,0 +1,212 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include + +#include "brpc/channel.h" +#include "brpc/controller.h" +#include "brpc/server.h" +#include "paddle/fluid/distributed/service/ps_client.h" + +namespace paddle { +namespace distributed { + +class DownpourPsClientService : public PsService { + public: + DownpourPsClientService() {} + virtual ~DownpourPsClientService() {} + + virtual int32_t configure(PSClient *client, size_t rank_id) { + _client = client; + _rank = rank_id; + return 0; + } + virtual void service(::google::protobuf::RpcController *controller, + const ::paddle::PsRequestMessage *request, + ::paddle::PsResponseMessage *response, + ::google::protobuf::Closure *done) override; + + protected: + size_t _rank; + PSClient *_client; +}; + +class DownpourBrpcClosure : public PSClientClosure { + public: + DownpourBrpcClosure(size_t num, PSClientCallBack callback) + : PSClientClosure(callback) { + _waiting_num = num; + + _cntls.resize(num); + _requests.resize(num); + _responses.resize(num); + for (size_t i = 0; i < num; ++i) { + _cntls[i].reset(new brpc::Controller()); + } + } + virtual ~DownpourBrpcClosure() {} + virtual void Run() override { + if (_waiting_num.fetch_sub(1) == 1) { + _callback(this); + delete this; + } + } + PsRequestMessage *request(size_t i) { return &_requests[i]; } + PsResponseMessage *response(size_t i) { return &_responses[i]; } + brpc::Controller *cntl(size_t i) { return _cntls[i].get(); } + int check_response(size_t request_idx, int cmd_id); + int check_save_response(size_t request_idx, int cmd_id); + std::string get_response(size_t request_idx, int cmd_id); + + private: + std::atomic _waiting_num; + std::vector _requests; + std::vector _responses; + std::vector> _cntls; +}; + +template +struct array_deleter { + void operator()(T *&x) const { delete[] x; } +}; + +class BrpcPsClient : public PSClient { + public: + BrpcPsClient() {} + virtual ~BrpcPsClient() { + // _running = false; + // try { + // _async_push_dense_thread.join(); + // _async_push_sparse_thread.join(); + //} catch (...) { + //} + } + virtual int32_t create_client2client_connection( + int pserver_timeout_ms, int pserver_connect_timeout_ms, int max_retry); + virtual std::future shrink(uint32_t table_id) override; + virtual std::future load(const std::string &epoch, + const std::string &mode) override; + virtual std::future load(uint32_t table_id, const std::string &epoch, + const std::string &mode) override; + + virtual std::future save(const std::string &epoch, + const std::string &mode) override; + + virtual std::future save(uint32_t table_id, const std::string &epoch, + const std::string &mode) override; + + virtual std::future clear() override; + + virtual std::future clear(uint32_t table_id) override; + + virtual std::future stop_server() override; + + virtual std::future start_profiler() override; + virtual std::future stop_profiler() override; + + virtual void finalize_worker() override; + + virtual std::future pull_dense(Region *regions, size_t region_num, + size_t table_id); + + virtual std::future push_dense_param(const Region *regions, + size_t region_num, + size_t table_id); + + virtual std::future pull_sparse(float **select_values, + size_t table_id, + const uint64_t *keys, size_t num); + + virtual std::future print_table_stat(uint32_t table_id); + + virtual std::future barrier(size_t table_id, uint32_t barrier_type); + + virtual std::future pull_geo_param(size_t table_id, + std::vector *values, + std::vector *keys, + int pserver_idx); + + virtual std::future flush(); + + virtual std::future send_client2client_msg( + int msg_type, int to_client_id, const std::string &msg) override; + + private: + virtual int32_t initialize() override; + + inline uint32_t dense_dim_per_shard(uint32_t dense_dim_total, + uint32_t shard_num) { + return dense_dim_total / shard_num + 1; + } + + std::future send_cmd(uint32_t table_id, int cmd_id, + const std::vector ¶m); + + std::future send_save_cmd(uint32_t table_id, int cmd_id, + const std::vector ¶m); + + inline brpc::Channel *get_sparse_channel(size_t server_id) { + return _server_channels[server_id][0].get(); + } + inline brpc::Channel *get_dense_channel(size_t server_id) { + return _server_channels[server_id][1].get(); + } + inline brpc::Channel *get_cmd_channel(size_t server_id) { + return _server_channels[server_id][2].get(); + } + + bool _running = false; + bool _flushing = false; + std::atomic _async_call_num; //异步请求计数 + + std::vector> + _client_channels; // client2client + std::vector, 3>> + _server_channels; // client2server + virtual std::future push_dense_raw_gradient( + int table_id, float *total_send_data, size_t total_send_data_size, + void *done) override; + + virtual std::future push_sparse_raw_gradient( + size_t table_id, const uint64_t *keys, const float **update_values, + size_t num, void *done) override; + + virtual std::future push_sparse_raw_gradient_partial( + size_t table_id, const uint64_t *keys, const float **update_values, + uint32_t num, void *done, int pserver_idx) override; + + virtual std::future push_sparse_param(size_t table_id, + const uint64_t *keys, + const float **update_values, + size_t num, + void *done) override; + + virtual size_t get_server_nums() { return _server_channels.size(); } + + private: + int32_t start_client_service(); + + float _mae = 0; + float _mse = 0; + uint16_t _push_times = 0; + brpc::Server _server; + DownpourPsClientService _service; + std::atomic_uint grad_num_{0}; +}; +} // namespace distributed +} // namespace paddle diff --git a/paddle/fluid/distributed/service/brpc_ps_server.cc b/paddle/fluid/distributed/service/brpc_ps_server.cc new file mode 100644 index 00000000000000..1386e83447567f --- /dev/null +++ b/paddle/fluid/distributed/service/brpc_ps_server.cc @@ -0,0 +1,530 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/distributed/service/brpc_ps_server.h" +#include // NOLINT +#include "Eigen/Dense" +#include "butil/endpoint.h" +#include "iomanip" +#include "paddle/fluid/distributed/table/table.h" +#include "paddle/fluid/framework/archive.h" +#include "paddle/fluid/platform/profiler.h" + +namespace paddle { +namespace distributed { + +int32_t BrpcPsServer::initialize() { + auto &service_config = _config.downpour_server_param().service_param(); + if (!service_config.has_service_class()) { + LOG(ERROR) << "miss service_class in ServerServiceParameter"; + return -1; + } + auto *service = CREATE_CLASS(PsBaseService, service_config.service_class()); + if (service == NULL) { + LOG(ERROR) << "service is unregistered, service_name:" + << service_config.service_class(); + return -1; + } + + _service.reset(service); + if (service->configure(this) != 0 || service->initialize() != 0) { + LOG(ERROR) << "service initialize failed, service_name:" + << service_config.service_class(); + return -1; + } + if (_server.AddService(service, brpc::SERVER_DOESNT_OWN_SERVICE) != 0) { + LOG(ERROR) << "service add to brpc failed, service:" + << service_config.service_class(); + return -1; + } + return 0; +} + +uint64_t BrpcPsServer::start(const std::string &ip, uint32_t port) { + std::unique_lock lock(mutex_); + + std::string ip_port = ip + ":" + std::to_string(port); + VLOG(3) << "server of rank " << _rank << " starts at " << ip_port; + int num_threads = std::thread::hardware_concurrency(); + brpc::ServerOptions options; + options.num_threads = num_threads; + + if (_server.Start(ip_port.c_str(), &options) != 0) { + LOG(ERROR) << "BrpcPsServer start failed, ip_port=" << ip_port; + return 0; + } + VLOG(0) << "BrpcPsServer::start registe_ps_server"; + _environment->registe_ps_server(ip, port, _rank); + VLOG(0) << "BrpcPsServer::start wait"; + cv_.wait(lock, [&] { return stoped_; }); + + PSHost host; + host.ip = ip; + host.port = port; + host.rank = _rank; + VLOG(0) << "BrpcPsServer::start return host.rank"; + return host.rank; +} + +int32_t BrpcPsServer::port() { return _server.listen_address().port; } + +int32_t PsService::initialize() { + _is_initialize_shard_info = false; + _service_handler_map[PS_STOP_SERVER] = &PsService::stop_server; + _service_handler_map[PS_PULL_DENSE_TABLE] = &PsService::pull_dense; + _service_handler_map[PS_PUSH_DENSE_TABLE] = &PsService::push_dense; + _service_handler_map[PS_PULL_SPARSE_TABLE] = &PsService::pull_sparse; + _service_handler_map[PS_PUSH_SPARSE_TABLE] = &PsService::push_sparse; + _service_handler_map[PS_SAVE_ONE_TABLE] = &PsService::save_one_table; + _service_handler_map[PS_SAVE_ALL_TABLE] = &PsService::save_all_table; + _service_handler_map[PS_SHRINK_TABLE] = &PsService::shrink_table; + _service_handler_map[PS_LOAD_ONE_TABLE] = &PsService::load_one_table; + _service_handler_map[PS_LOAD_ALL_TABLE] = &PsService::load_all_table; + _service_handler_map[PS_CLEAR_ONE_TABLE] = &PsService::clear_one_table; + _service_handler_map[PS_CLEAR_ALL_TABLE] = &PsService::clear_all_table; + _service_handler_map[PS_PUSH_DENSE_PARAM] = &PsService::push_dense_param; + _service_handler_map[PS_PRINT_TABLE_STAT] = &PsService::print_table_stat; + _service_handler_map[PS_PULL_GEO_PARAM] = &PsService::pull_geo_param; + _service_handler_map[PS_PUSH_SPARSE_PARAM] = &PsService::push_sparse_param; + _service_handler_map[PS_BARRIER] = &PsService::barrier; + _service_handler_map[PS_START_PROFILER] = &PsService::start_profiler; + _service_handler_map[PS_STOP_PROFILER] = &PsService::stop_profiler; + + // shard初始化,server启动后才可从env获取到server_list的shard信息 + initialize_shard_info(); + + return 0; +} + +#define CHECK_TABLE_EXIST(table, request, response) \ + if (table == NULL) { \ + std::string err_msg("table not found with table_id:"); \ + err_msg.append(std::to_string(request.table_id())); \ + set_response_code(response, -1, err_msg.c_str()); \ + return -1; \ + } + +int32_t PsService::initialize_shard_info() { + if (!_is_initialize_shard_info) { + std::lock_guard guard(_initialize_shard_mutex); + if (_is_initialize_shard_info) { + return 0; + } + size_t shard_num = _server->environment()->get_ps_servers().size(); + auto &table_map = *(_server->table()); + for (auto itr : table_map) { + itr.second->set_shard(_rank, shard_num); + } + _is_initialize_shard_info = true; + } + return 0; +} + +void PsService::service(google::protobuf::RpcController *cntl_base, + const PsRequestMessage *request, + PsResponseMessage *response, + google::protobuf::Closure *done) { + brpc::ClosureGuard done_guard(done); + std::string log_label("ReceiveCmd-"); + if (!request->has_table_id()) { + set_response_code(*response, -1, "PsRequestMessage.tabel_id is required"); + return; + } + + response->set_err_code(0); + response->set_err_msg(""); + auto *table = _server->table(request->table_id()); + brpc::Controller *cntl = static_cast(cntl_base); + auto itr = _service_handler_map.find(request->cmd_id()); + if (itr == _service_handler_map.end()) { + std::string err_msg( + "undefined cmd_id, should match PsCmdID in ps.proto, cmd_id:"); + err_msg.append(std::to_string(request->cmd_id())); + set_response_code(*response, -1, err_msg.c_str()); + return; + } + serviceHandlerFunc handler_func = itr->second; + int service_ret = (this->*handler_func)(table, *request, *response, cntl); + if (service_ret != 0) { + response->set_err_code(service_ret); + response->set_err_msg("server internal error"); + } +} + +int32_t PsService::pull_dense(Table *table, const PsRequestMessage &request, + PsResponseMessage &response, + brpc::Controller *cntl) { + platform::RecordEvent record_event("PsService->pull_dense"); + CHECK_TABLE_EXIST(table, request, response) + if (request.params_size() < 1) { + set_response_code( + response, -1, + "PsRequestMessage.datas is requeired at least 1 for num of dense"); + return 0; + } + uint32_t num = *(const uint32_t *)request.params(0).c_str(); + if (num < 0) { + set_response_code(response, -1, + "PsRequestMessage.datas[0] is invalid, num must >= 0"); + return 0; + } + + std::vector res_data; + res_data.resize(num * table->value_accesor()->select_size() / sizeof(float)); + table->pull_dense(res_data.data(), num); + + cntl->response_attachment().append((char *)res_data.data(), + res_data.size() * sizeof(float)); + + return 0; +} + +int32_t PsService::push_dense_param(Table *table, + const PsRequestMessage &request, + PsResponseMessage &response, + brpc::Controller *cntl) { + platform::RecordEvent record_event("PsService->push_dense_param"); + CHECK_TABLE_EXIST(table, request, response) + thread_local std::string push_buffer; + auto &req_io_buffer = cntl->request_attachment(); + auto req_buffer_size = req_io_buffer.size(); + if (req_buffer_size < 1) { + set_response_code(response, -1, "req attachment is empty"); + return 0; + } + push_buffer.resize(0); + push_buffer.reserve(req_buffer_size); + const char *data = (const char *)cntl->request_attachment().fetch( + const_cast(push_buffer.data()), req_buffer_size); + + uint32_t num = *(const uint32_t *)data; + + const float *values = (const float *)(data + sizeof(uint32_t)); + if (table->push_dense_param(values, num) != 0) { + set_response_code(response, -1, "push_dense_param failed"); + } + return 0; +} + +int32_t PsService::push_dense(Table *table, const PsRequestMessage &request, + PsResponseMessage &response, + brpc::Controller *cntl) { + platform::RecordEvent record_event("PsService->push_dense"); + CHECK_TABLE_EXIST(table, request, response) + auto req_buffer_size = request.data().size(); + if (req_buffer_size < 1) { + // set_response_code(response, 0, "push dense data is empty"); + return 0; + } + + /* + Push Content: + |--num--|---valuesData---| + |--4B---|----------------| + */ + uint32_t num = *(const uint32_t *)(request.data().data()); + const float *values = + (const float *)(request.data().data() + sizeof(uint32_t)); + if (table->push_dense(values, num) != 0) { + set_response_code(response, -1, "push_dense failed"); + } + + return 0; +} + +int32_t PsService::barrier(Table *table, const PsRequestMessage &request, + PsResponseMessage &response, + brpc::Controller *cntl) { + CHECK_TABLE_EXIST(table, request, response) + + if (request.params_size() < 1) { + set_response_code(response, -1, + "PsRequestMessage.params is requeired at " + "least 1 for num of sparse_key"); + return 0; + } + + auto trainer_id = request.client_id(); + auto barrier_type = request.params(0); + table->barrier(trainer_id, barrier_type); + return 0; +} + +int32_t PsService::push_sparse_param(Table *table, + const PsRequestMessage &request, + PsResponseMessage &response, + brpc::Controller *cntl) { + platform::RecordEvent record_event("PsService->push_sparse_param"); + CHECK_TABLE_EXIST(table, request, response) + auto &push_data = request.data(); + if (push_data.size() < 1) { + // set_response_code(response, 0, "push sparse data is empty"); + return 0; + } + if (request.params_size() < 1) { + set_response_code(response, -1, + "PsRequestMessage.params is requeired at " + "least 1 for num of sparse_key"); + return 0; + } + uint32_t num = *(uint32_t *)(request.params(0).c_str()); + /* + Push Content: + |---keysData---|---valuesData---| + |---8*{num}B---|----------------| + */ + const uint64_t *keys = (const uint64_t *)push_data.data(); + const float *values = + (const float *)(push_data.data() + sizeof(uint64_t) * num); + if (table->push_sparse_param(keys, values, num) != 0) { + set_response_code(response, -1, "push_sparse_param error"); + } + return 0; +} + +int32_t PsService::pull_geo_param(Table *table, const PsRequestMessage &request, + PsResponseMessage &response, + brpc::Controller *cntl) { + platform::RecordEvent record_event("PsService->pull_geo_param"); + CHECK_TABLE_EXIST(table, request, response) + thread_local std::string push_sparse_request_buffer; + + auto trainer_id = request.client_id(); + + std::vector values; + std::vector ids; + table->pull_geo_param(trainer_id, &values, &ids); + + uint32_t num = ids.size(); + cntl->response_attachment().append((char *)(&num), sizeof(uint32_t)); + cntl->response_attachment().append((char *)ids.data(), + ids.size() * sizeof(uint64_t)); + cntl->response_attachment().append((char *)values.data(), + values.size() * sizeof(float)); + return 0; +} + +int32_t PsService::pull_sparse(Table *table, const PsRequestMessage &request, + PsResponseMessage &response, + brpc::Controller *cntl) { + platform::RecordEvent record_event("PsService->pull_sparse"); + CHECK_TABLE_EXIST(table, request, response) + thread_local std::string push_sparse_request_buffer; + auto &req_io_buffer = cntl->request_attachment(); + auto req_buffer_size = req_io_buffer.size(); + if (req_buffer_size < 1) { + set_response_code(response, -1, "req attachment is empty"); + return 0; + } + if (request.params_size() < 1) { + set_response_code(response, -1, + "PsRequestMessage.params is requeired at " + "least 1 for num of sparse_key"); + return 0; + } + uint32_t num = *(uint32_t *)(request.params(0).c_str()); + push_sparse_request_buffer.resize(0); + push_sparse_request_buffer.reserve(req_buffer_size); + const char *data = (const char *)cntl->request_attachment().fetch( + const_cast(push_sparse_request_buffer.data()), req_buffer_size); + /* + Attachment Content: + |---keysData---| + |---8*{num}B---| + */ + const uint64_t *keys = (const uint64_t *)data; + std::vector res_data; + res_data.resize(num * table->value_accesor()->select_size() / sizeof(float)); + table->pull_sparse(res_data.data(), keys, num); + cntl->response_attachment().append((char *)res_data.data(), + res_data.size() * sizeof(float)); + return 0; +} + +int32_t PsService::push_sparse(Table *table, const PsRequestMessage &request, + PsResponseMessage &response, + brpc::Controller *cntl) { + platform::RecordEvent record_event("PsService->push_sparse"); + CHECK_TABLE_EXIST(table, request, response) + auto &push_data = request.data(); + if (push_data.size() < 1) { + // set_response_code(response, 0, "push sparse data is empty"); + return 0; + } + if (request.params_size() < 1) { + set_response_code(response, -1, + "PsRequestMessage.params is requeired at " + "least 1 for num of sparse_key"); + return 0; + } + uint32_t num = *(uint32_t *)(request.params(0).c_str()); + /* + Push Content: + |---keysData---|---valuesData---| + |---8*{num}B---|----------------| + */ + const uint64_t *keys = (const uint64_t *)push_data.data(); + const float *values = + (const float *)(push_data.data() + sizeof(uint64_t) * num); + if (table->push_sparse(keys, values, num) != 0) { + set_response_code(response, -1, "push_sparse error"); + } + return 0; +} + +int32_t PsService::print_table_stat(Table *table, + const PsRequestMessage &request, + PsResponseMessage &response, + brpc::Controller *cntl) { + CHECK_TABLE_EXIST(table, request, response) + std::pair ret = table->print_table_stat(); + paddle::framework::BinaryArchive ar; + ar << ret.first << ret.second; + std::string table_info(ar.Buffer(), ar.Length()); + response.set_data(table_info); + + return 0; +} + +int32_t PsService::load_one_table(Table *table, const PsRequestMessage &request, + PsResponseMessage &response, + brpc::Controller *cntl) { + CHECK_TABLE_EXIST(table, request, response) + if (request.params_size() < 2) { + set_response_code( + response, -1, + "PsRequestMessage.datas is requeired at least 2 for path & load_param"); + return -1; + } + if (table->load(request.params(0), request.params(1)) != 0) { + set_response_code(response, -1, "table load failed"); + return -1; + } + return 0; +} + +int32_t PsService::load_all_table(Table *table, const PsRequestMessage &request, + PsResponseMessage &response, + brpc::Controller *cntl) { + auto &table_map = *(_server->table()); + for (auto &itr : table_map) { + if (load_one_table(itr.second.get(), request, response, cntl) != 0) { + LOG(ERROR) << "load table[" << itr.first << "] failed"; + return -1; + } + } + return 0; +} + +int32_t PsService::save_one_table(Table *table, const PsRequestMessage &request, + PsResponseMessage &response, + brpc::Controller *cntl) { + CHECK_TABLE_EXIST(table, request, response) + if (request.params_size() < 2) { + set_response_code( + response, -1, + "PsRequestMessage.datas is requeired at least 2, path&mode"); + return -1; + } + table->flush(); + + int32_t feasign_size = 0; + feasign_size = table->save(request.params(0), request.params(1)); + if (feasign_size < 0) { + set_response_code(response, -1, "table save failed"); + return -1; + } + return feasign_size; +} + +int32_t PsService::save_all_table(Table *table, const PsRequestMessage &request, + PsResponseMessage &response, + brpc::Controller *cntl) { + auto &table_map = *(_server->table()); + int32_t all_feasign_size = 0; + int32_t feasign_size = 0; + + for (auto &itr : table_map) { + feasign_size = save_one_table(itr.second.get(), request, response, cntl); + if (feasign_size < 0) { + LOG(ERROR) << "save table[" << itr.first << "] failed"; + return -1; + } + } + return 0; +} + +int32_t PsService::shrink_table(Table *table, const PsRequestMessage &request, + PsResponseMessage &response, + brpc::Controller *cntl) { + CHECK_TABLE_EXIST(table, request, response) + table->flush(); + if (table->shrink() != 0) { + set_response_code(response, -1, "table shrink failed"); + } + return 0; +} + +int32_t PsService::clear_one_table(Table *table, + const PsRequestMessage &request, + PsResponseMessage &response, + brpc::Controller *cntl) { + CHECK_TABLE_EXIST(table, request, response) + table->flush(); + table->clear(); + return 0; +} + +int32_t PsService::clear_all_table(Table *table, + const PsRequestMessage &request, + PsResponseMessage &response, + brpc::Controller *cntl) { + auto &table_map = *(_server->table()); + for (auto &itr : table_map) { + if (clear_one_table(itr.second.get(), request, response, cntl) != 0) { + return -1; + } + } + return 0; +} + +int32_t PsService::stop_server(Table *table, const PsRequestMessage &request, + PsResponseMessage &response, + brpc::Controller *cntl) { + auto *p_server = _server; + std::thread t_stop([p_server]() { + p_server->stop(); + LOG(INFO) << "Server Stoped"; + }); + t_stop.detach(); + return 0; +} + +int32_t PsService::stop_profiler(Table *table, const PsRequestMessage &request, + PsResponseMessage &response, + brpc::Controller *cntl) { + platform::DisableProfiler(platform::EventSortingKey::kDefault, + string::Sprintf("server_%s_profile", _rank)); + return 0; +} + +int32_t PsService::start_profiler(Table *table, const PsRequestMessage &request, + PsResponseMessage &response, + brpc::Controller *cntl) { + platform::EnableProfiler(platform::ProfilerState::kCPU); + return 0; +} + +} // namespace distributed +} // namespace paddle diff --git a/paddle/fluid/distributed/service/brpc_ps_server.h b/paddle/fluid/distributed/service/brpc_ps_server.h new file mode 100644 index 00000000000000..0a053848e1eb3c --- /dev/null +++ b/paddle/fluid/distributed/service/brpc_ps_server.h @@ -0,0 +1,153 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "brpc/channel.h" +#include "brpc/controller.h" +#include "brpc/server.h" + +#include +#include +#include "paddle/fluid/distributed/service/server.h" + +namespace paddle { +namespace distributed { + +class BrpcPsServer : public PSServer { + public: + BrpcPsServer() {} + virtual ~BrpcPsServer() {} + virtual uint64_t start(const std::string &ip, uint32_t port); + virtual int32_t stop() { + std::unique_lock lock(mutex_); + stoped_ = true; + cv_.notify_all(); + + _server.Stop(1000); + _server.Join(); + return 0; + } + virtual int32_t port(); + + private: + virtual int32_t initialize(); + + mutable std::mutex mutex_; + std::condition_variable cv_; + bool stoped_ = false; + brpc::Server _server; + std::shared_ptr _service; + std::vector> _pserver_channels; +}; + +class PsService; + +typedef int32_t (PsService::*serviceHandlerFunc)( + Table *table, const PsRequestMessage &request, PsResponseMessage &response, + brpc::Controller *cntl); + +class PsService : public PsBaseService { + public: + virtual int32_t initialize() override; + + virtual void service(::google::protobuf::RpcController *controller, + const ::paddle::PsRequestMessage *request, + ::paddle::PsResponseMessage *response, + ::google::protobuf::Closure *done) override; + + private: + int32_t initialize_shard_info(); + int32_t pull_dense(Table *table, const PsRequestMessage &request, + PsResponseMessage &response, brpc::Controller *cntl); + int32_t push_dense(Table *table, const PsRequestMessage &request, + PsResponseMessage &response, brpc::Controller *cntl); + int32_t push_dense_param(Table *table, const PsRequestMessage &request, + PsResponseMessage &response, brpc::Controller *cntl); + int32_t push_sparse_param(Table *table, const PsRequestMessage &request, + PsResponseMessage &response, + brpc::Controller *cntl); + int32_t pull_sparse(Table *table, const PsRequestMessage &request, + PsResponseMessage &response, brpc::Controller *cntl); + int32_t pull_geo_param(Table *table, const PsRequestMessage &request, + PsResponseMessage &response, brpc::Controller *cntl); + int32_t barrier(Table *table, const PsRequestMessage &request, + PsResponseMessage &response, brpc::Controller *cntl); + int32_t push_sparse(Table *table, const PsRequestMessage &request, + PsResponseMessage &response, brpc::Controller *cntl); + int32_t load_one_table(Table *table, const PsRequestMessage &request, + PsResponseMessage &response, brpc::Controller *cntl); + int32_t load_all_table(Table *table, const PsRequestMessage &request, + PsResponseMessage &response, brpc::Controller *cntl); + int32_t save_one_table(Table *table, const PsRequestMessage &request, + PsResponseMessage &response, brpc::Controller *cntl); + int32_t save_all_table(Table *table, const PsRequestMessage &request, + PsResponseMessage &response, brpc::Controller *cntl); + int32_t shrink_table(Table *table, const PsRequestMessage &request, + PsResponseMessage &response, brpc::Controller *cntl); + int32_t clear_one_table(Table *table, const PsRequestMessage &request, + PsResponseMessage &response, brpc::Controller *cntl); + int32_t clear_all_table(Table *table, const PsRequestMessage &request, + PsResponseMessage &response, brpc::Controller *cntl); + int32_t stop_server(Table *table, const PsRequestMessage &request, + PsResponseMessage &response, brpc::Controller *cntl); + int32_t start_profiler(Table *table, const PsRequestMessage &request, + PsResponseMessage &response, brpc::Controller *cntl); + int32_t stop_profiler(Table *table, const PsRequestMessage &request, + PsResponseMessage &response, brpc::Controller *cntl); + + int32_t print_table_stat(Table *table, const PsRequestMessage &request, + PsResponseMessage &response, brpc::Controller *cntl); + + bool _is_initialize_shard_info; + std::mutex _initialize_shard_mutex; + std::unordered_map _service_handler_map; + std::unordered_map _msg_handler_map; + std::vector _ori_values; +}; + +class DownpourPServerBrpcClosure : public PServerClosure { + public: + DownpourPServerBrpcClosure(size_t num, PServerCallBack callback) + : PServerClosure(callback) { + _waiting_num = num; + _cntls.resize(num); + _requests.resize(num); + _responses.resize(num); + for (size_t i = 0; i < num; ++i) { + _cntls[i].reset(new brpc::Controller()); + } + } + virtual ~DownpourPServerBrpcClosure() {} + + virtual void Run() override { + if (_waiting_num.fetch_sub(1) == 1) { + _callback(this); + delete this; + } + } + PsRequestMessage *request(size_t i) { return &_requests[i]; } + PsResponseMessage *response(size_t i) { return &_responses[i]; } + brpc::Controller *cntl(size_t i) { return _cntls[i].get(); } + int check_response(size_t request_idx, int cmd_id) { return 1; } + int check_save_response(size_t request_idx, int cmd_id) { return 1; } + + private: + std::atomic _waiting_num; + std::vector _requests; + std::vector _responses; + std::vector> _cntls; +}; +} // namespace distributed +} // namespace paddle diff --git a/paddle/fluid/distributed/service/brpc_utils.cc b/paddle/fluid/distributed/service/brpc_utils.cc new file mode 100644 index 00000000000000..abd58bf028c2c1 --- /dev/null +++ b/paddle/fluid/distributed/service/brpc_utils.cc @@ -0,0 +1,314 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/distributed/service/brpc_utils.h" +#include +#include +#include "paddle/fluid/platform/enforce.h" +#include "paddle/fluid/platform/profiler.h" + +namespace paddle { +namespace framework { +class Scope; +class Variable; +} // namespace framework +namespace platform { +class DeviceContext; +} // namespace platform +} // namespace paddle + +namespace paddle { +namespace distributed { + +framework::proto::VarType::Type VarMessageToVarType( + VariableMessage::Type type) { + switch (type) { + case VariableMessage::FP32: + return framework::proto::VarType::FP32; // NOLINT + case VariableMessage::FP64: + return framework::proto::VarType::FP64; // NOLINT + case VariableMessage::INT32: + return framework::proto::VarType::INT32; // NOLINT + case VariableMessage::INT64: + return framework::proto::VarType::INT64; // NOLINT + case VariableMessage::BOOL: + return framework::proto::VarType::BOOL; // NOLINT + default: + PADDLE_THROW(platform::errors::InvalidArgument( + "VarMessageToVarType:Unsupported type %d", type)); + } +} + +void SerializeToMultiVarMsgAndIOBuf( + const std::string& message_name, + const std::vector& send_var_name_val, + const std::vector& recv_var_name_val, + const platform::DeviceContext& ctx, const framework::Scope* scope, + MultiVarMsg* request, butil::IOBuf* iobuf) { + // 1. message_name + request->set_message_name(message_name); + + // 2. var_names + for (auto& send_var_name : send_var_name_val) { + request->add_send_var_names(send_var_name); + } + for (auto& recv_var_name : recv_var_name_val) { + request->add_recv_var_names(recv_var_name); + } + + // 3. VarMessage + for (auto& send_var_name : send_var_name_val) { + auto* send_var_msg = request->add_var_messages(); + butil::IOBuf temp_iobuf; + send_var_msg->set_varname(send_var_name); + + framework::Variable* var = scope->FindVar(send_var_name); + + if (var->IsType()) { + SerializeLodTensor(var, ctx, send_var_msg, &temp_iobuf); + } else if (var->IsType()) { + SerializeSelectedRows(var, ctx, send_var_msg, &temp_iobuf); + } + iobuf->append(temp_iobuf); + } +} + +void SerializeLodTensor(framework::Variable* var, + const platform::DeviceContext& ctx, VarMsg* var_msg, + butil::IOBuf* iobuf) { + auto* tensor = var->GetMutable(); + var_msg->set_type(::paddle::LOD_TENSOR); + const framework::LoD lod = tensor->lod(); + if (lod.size() > 0) { + var_msg->set_lod_level(lod.size()); + for (auto& each : lod) { + VarMsg::LodData* lod_inner = var_msg->add_lod(); + for (auto& d : each) { + lod_inner->add_lod_data(d); + } + } + } + var_msg->set_data_type(static_cast(tensor->type())); + for (auto& dim : framework::vectorize(tensor->dims())) { + var_msg->add_dims(dim); + } + // IO Buffer + if (platform::is_cpu_place(tensor->place())) { + auto data_len = tensor->numel() * framework::SizeOfType(tensor->type()); + iobuf->append(reinterpret_cast(&data_len), 8); + iobuf->append(reinterpret_cast(tensor->data()), + data_len); + } else { +#ifdef PADDLE_WITH_CUDA + char* temp_ptr = + new char[tensor->numel() * framework::SizeOfType(tensor->type())]; + auto stream = + reinterpret_cast(ctx).stream(); + memory::Copy(platform::CPUPlace(), temp_ptr, + BOOST_GET_CONST(platform::CUDAPlace, tensor->place()), + tensor->data(), + tensor->numel() * framework::SizeOfType(tensor->type()), + stream); + auto data_len = tensor->numel() * framework::SizeOfType(tensor->type()); + iobuf->append(reinterpret_cast(&data_len), 8); + iobuf->append(reinterpret_cast(temp_ptr), data_len); + delete[] temp_ptr; +#endif + } +} + +void SerializeSelectedRows(framework::Variable* var, + const platform::DeviceContext& ctx, VarMsg* var_msg, + butil::IOBuf* iobuf) { + framework::SelectedRows* slr = var->GetMutable(); + auto* tensor = slr->mutable_value(); + auto* rows = slr->mutable_rows(); + + var_msg->set_type(::paddle::SELECTED_ROWS); + var_msg->set_slr_height(slr->height()); + + auto* var_data = var_msg->mutable_data(); + var_data->clear(); + var_data->resize(rows->size() * sizeof(int64_t)); + char* data_ptr = const_cast(var_data->data()); + + if (platform::is_cpu_place(tensor->place())) { + memcpy(data_ptr, &(*rows)[0], rows->size() * sizeof(int64_t)); + } else { +#ifdef PADDLE_WITH_CUDA + auto stream = + reinterpret_cast(ctx).stream(); + memory::Copy(platform::CPUPlace(), data_ptr, + BOOST_GET_CONST(platform::CUDAPlace, tensor->place()), + &(*rows)[0], rows->size() * sizeof(int64_t), stream); +#endif + } + var_msg->set_data_type(static_cast(tensor->type())); + for (auto& dim : framework::vectorize(tensor->dims())) { + var_msg->add_dims(dim); + } + + // IO Buffer + if (platform::is_cpu_place(tensor->place())) { + auto data_len = tensor->numel() * framework::SizeOfType(tensor->type()); + iobuf->append(reinterpret_cast(&data_len), 8); + iobuf->append(reinterpret_cast(tensor->data()), + data_len); + } else { +#ifdef PADDLE_WITH_CUDA + char* temp_ptr = + new char[tensor->numel() * framework::SizeOfType(tensor->type())]; + auto stream = + reinterpret_cast(ctx).stream(); + memory::Copy(platform::CPUPlace(), temp_ptr, + BOOST_GET_CONST(platform::CUDAPlace, tensor->place()), + tensor->data(), + tensor->numel() * framework::SizeOfType(tensor->type()), + stream); + auto data_len = tensor->numel() * framework::SizeOfType(tensor->type()); + iobuf->append(reinterpret_cast(&data_len), 8); + iobuf->append(reinterpret_cast(temp_ptr), data_len); + delete[] temp_ptr; +#endif + } +} + +void DeserializeFromMultiVarMsgAndIOBuf(const MultiVarMsg& multi_msg, + const butil::IOBuf* iobuf, + const platform::DeviceContext& ctx, + framework::Scope* scope) { + butil::IOBufBytesIterator io_buffer_itr(*iobuf); + // size_t shard_buffer_remain = res_io_buffer.size(); + for (int recv_var_index = 0; recv_var_index < multi_msg.send_var_names_size(); + ++recv_var_index) { + const auto& msg = multi_msg.var_messages(recv_var_index); + auto* var = scope->Var(msg.varname()); + if (msg.type() == ::paddle::LOD_TENSOR) { + DeserializeLodTensor(var, msg, io_buffer_itr, ctx); + } else if (msg.type() == ::paddle::SELECTED_ROWS) { + DeserializeSelectedRows(var, msg, io_buffer_itr, ctx); + } + } +} + +void DeserializeFromMultiVarMsgAndIOBuf(const MultiVarMsg& multi_msg, + const butil::IOBuf* iobuf, + const platform::DeviceContext& ctx, + const framework::Scope* scope) { + butil::IOBufBytesIterator io_buffer_itr(*iobuf); + // size_t shard_buffer_remain = res_io_buffer.size(); + for (int recv_var_index = 0; recv_var_index < multi_msg.send_var_names_size(); + ++recv_var_index) { + const auto& msg = multi_msg.var_messages(recv_var_index); + auto* var = scope->FindVar(msg.varname()); + PADDLE_ENFORCE_NE(var, nullptr, + platform::errors::InvalidArgument( + "Not find variable %s in scope.", msg.varname())); + if (msg.type() == ::paddle::LOD_TENSOR) { + DeserializeLodTensor(var, msg, io_buffer_itr, ctx); + } else if (msg.type() == ::paddle::SELECTED_ROWS) { + DeserializeSelectedRows(var, msg, io_buffer_itr, ctx); + } + } +} + +void DeserializeLodTensor(framework::Variable* var, const VarMsg& msg, + butil::IOBufBytesIterator& io_buffer_itr, + const platform::DeviceContext& ctx) { + const auto place = ctx.GetPlace(); + framework::LoDTensor* tensor = var->GetMutable(); + std::vector vec_dim; + for (auto& x : msg.dims()) { + vec_dim.push_back(x); + } + tensor->Resize(framework::make_ddim(vec_dim)); + + framework::LoD lod; + for (int i = 0; i < msg.lod_level(); ++i) { + framework::Vector v; + for (int j = 0; j < msg.lod(i).lod_data_size(); ++j) { + v.push_back(msg.lod(i).lod_data(j)); + } + lod.push_back(v); + } + tensor->set_lod(lod); + + void* tensor_data = + tensor->mutable_data(place, VarMessageToVarType(msg.data_type())); + + // IO Buffer + if (platform::is_cpu_place(place)) { + unsigned long data_len; + io_buffer_itr.copy_and_forward((void*)(&data_len), 8); + io_buffer_itr.copy_and_forward(tensor_data, data_len); + } else if (platform::is_gpu_place(place)) { +#ifdef PADDLE_WITH_CUDA + unsigned long data_len; + char* temp_ptr = + new char[tensor->numel() * framework::SizeOfType(tensor->type())]; + io_buffer_itr.copy_and_forward((void*)(&data_len), 8); + io_buffer_itr.copy_and_forward((void*)temp_ptr, data_len); + auto stream = + reinterpret_cast(ctx).stream(); + memory::Copy(BOOST_GET_CONST(platform::CUDAPlace, place), tensor_data, + platform::CPUPlace(), (void*)temp_ptr, + tensor->numel() * framework::SizeOfType(tensor->type()), + stream); + delete[] temp_ptr; +#endif + } +} + +void DeserializeSelectedRows(framework::Variable* var, const VarMsg& msg, + butil::IOBufBytesIterator& io_buffer_itr, + const platform::DeviceContext& ctx) { + const auto place = ctx.GetPlace(); + auto* slr = var->GetMutable(); + framework::Tensor* tensor = slr->mutable_value(); + slr->set_height(msg.slr_height()); + std::vector tmp_rows(msg.slr_height()); + memcpy(&tmp_rows[0], msg.data().data(), msg.slr_height() * sizeof(int64_t)); + slr->set_rows(tmp_rows); + std::vector vec_dim; + for (auto& x : msg.dims()) { + vec_dim.push_back(x); + } + tensor->Resize(framework::make_ddim(vec_dim)); + void* tensor_data = + tensor->mutable_data(place, VarMessageToVarType(msg.data_type())); + // IO Buffer + if (platform::is_cpu_place(place)) { + unsigned long data_len; + io_buffer_itr.copy_and_forward((void*)(&data_len), 8); + io_buffer_itr.copy_and_forward(tensor_data, data_len); + } else if (platform::is_gpu_place(place)) { +#ifdef PADDLE_WITH_CUDA + char* temp_ptr = + new char[tensor->numel() * framework::SizeOfType(tensor->type())]; + unsigned long data_len; + io_buffer_itr.copy_and_forward((void*)(&data_len), 8); + io_buffer_itr.copy_and_forward(temp_ptr, data_len); + auto stream = + reinterpret_cast(ctx).stream(); + memory::Copy(BOOST_GET_CONST(platform::CUDAPlace, place), tensor_data, + platform::CPUPlace(), temp_ptr, + tensor->numel() * framework::SizeOfType(tensor->type()), + stream); + delete[] temp_ptr; +#endif + } +} + +} // namespace distributed +} // namespace paddle diff --git a/paddle/fluid/distributed/service/brpc_utils.h b/paddle/fluid/distributed/service/brpc_utils.h new file mode 100644 index 00000000000000..aa340c58a7b8b0 --- /dev/null +++ b/paddle/fluid/distributed/service/brpc_utils.h @@ -0,0 +1,86 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include + +#include "brpc/channel.h" +#include "paddle/fluid/distributed/service/sendrecv.pb.h" +#include "paddle/fluid/framework/data_type.h" +#include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/framework/scope.h" +#include "paddle/fluid/framework/selected_rows.h" +#include "paddle/fluid/framework/tensor_util.h" +#include "paddle/fluid/framework/var_type.h" +#include "paddle/fluid/platform/port.h" + +namespace grpc { +class ByteBuffer; +} // namespace grpc +namespace paddle { +namespace framework { +class Scope; +class Variable; +} // namespace framework +namespace platform { +class DeviceContext; +} // namespace platform +} // namespace paddle + +namespace paddle { +namespace distributed { + +using MultiVarMsg = ::paddle::MultiVariableMessage; +using VarMsg = ::paddle::VariableMessage; + +void SerializeToMultiVarMsgAndIOBuf( + const std::string& message_name, + const std::vector& send_var_name_val, + const std::vector& recv_var_name_val, + const platform::DeviceContext& ctx, const framework::Scope* scope, + MultiVarMsg* var_msg, butil::IOBuf* iobuf); + +void SerializeLodTensor(framework::Variable* var, + const platform::DeviceContext& ctx, VarMsg* var_msg, + butil::IOBuf* iobuf); + +void SerializeSelectedRows(framework::Variable* var, + const platform::DeviceContext& ctx, VarMsg* request, + butil::IOBuf* iobuf); + +// Deserialize for Server +void DeserializeFromMultiVarMsgAndIOBuf(const MultiVarMsg& multi_msg, + const butil::IOBuf* iobuf, + const platform::DeviceContext& ctx, + framework::Scope* scope); + +// Deserialize for Client +void DeserializeFromMultiVarMsgAndIOBuf(const MultiVarMsg& multi_msg, + const butil::IOBuf* iobuf, + const platform::DeviceContext& ctx, + const framework::Scope* scope); + +void DeserializeLodTensor(framework::Variable* var, const VarMsg& msg, + butil::IOBufBytesIterator& iobuf, + const platform::DeviceContext& ctx); + +void DeserializeSelectedRows(framework::Variable* var, const VarMsg& msg, + butil::IOBufBytesIterator& iobuf, + const platform::DeviceContext& ctx); + +} // namespace distributed +} // namespace paddle diff --git a/paddle/fluid/distributed/service/communicator.cc b/paddle/fluid/distributed/service/communicator.cc new file mode 100644 index 00000000000000..18776a61a5cee7 --- /dev/null +++ b/paddle/fluid/distributed/service/communicator.cc @@ -0,0 +1,1171 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/distributed/service/communicator.h" +#include +#include "paddle/fluid/distributed/table/table.h" + +#include +#include + +#include +#include // NOLINT +#include +#include // NOLINT +#include + +#include "paddle/fluid/framework/eigen.h" +#include "paddle/fluid/framework/selected_rows.h" +#include "paddle/fluid/framework/tensor_util.h" +#include "paddle/fluid/framework/threadpool.h" +#include "paddle/fluid/framework/variable_helper.h" +#include "paddle/fluid/platform/profiler.h" +#include "paddle/fluid/string/printf.h" +#include "paddle/fluid/string/split.h" + +namespace paddle { +namespace distributed { + +using framework::LoDTensor; +using framework::SelectedRows; + +inline double GetCurrentUS() { + struct timeval time; + gettimeofday(&time, NULL); + return 1e+6 * time.tv_sec + time.tv_usec; +} + +Communicator::Communicator() {} + +void Communicator::init_gflag(const std::string &gflags) { + VLOG(0) << "Init With Gflags:" << gflags; + std::vector flags = paddle::string::split_string(gflags); + if (flags.size() < 1) { + flags.push_back("-max_body_size=314217728"); + flags.push_back("-bthread_concurrency=40"); + flags.push_back("-socket_max_unwritten_bytes=2048000000"); + flags.push_back("-max_connection_pool_size=1950"); + } + auto it = flags.begin(); + flags.insert(it, "exe default"); + char *flags_ptr[flags.size()]; + for (size_t i = 0; i < flags.size(); ++i) { + flags_ptr[i] = (char *)(flags[i].c_str()); + } + int params_cnt = flags.size(); + char **params_ptr = &(flags_ptr[0]); + ::google::ParseCommandLineFlags(¶ms_cnt, ¶ms_ptr, true); +} + +std::once_flag Communicator::init_flag_; +std::shared_ptr Communicator::communicator_(nullptr); + +void Communicator::InitBrpcClient( + const std::string &dist_desc, + const std::vector &host_sign_list) { + // not used, just for psclient's init + std::map> + _dense_pull_regions; + for (auto &iter : recv_varname_to_ctx_) { + auto tid = iter.first; + auto var_names = iter.second; + + auto ®ions = _dense_pull_regions[tid]; + regions.reserve(var_names.size()); + for (auto &t : var_names) { + Variable *var = recv_scope_->FindVar(t); + LoDTensor *tensor = var->GetMutable(); + float *w = tensor->data(); + paddle::distributed::Region reg(w, tensor->numel()); + regions.emplace_back(std::move(reg)); + } + } + + if (_worker_ptr.get() == nullptr) { + google::protobuf::TextFormat::ParseFromString(dist_desc, &_ps_param); + init_gflag(_ps_param.init_gflags()); + servers_ = host_sign_list.size(); + _ps_env = paddle::distributed::PaddlePSEnvironment(); + _ps_env.set_ps_servers(&host_sign_list, servers_); + _worker_ptr = std::shared_ptr( + paddle::distributed::PSClientFactory::create(_ps_param)); + _worker_ptr->configure(_ps_param, _dense_pull_regions, _ps_env, + trainer_id_); + } + return; +} + +void Communicator::RpcRecvDense(const std::vector &varnames, + int table_id, Scope *scope) { + platform::RecordEvent record_event("Communicator->RpcRecvDense"); + std::vector regions; + regions.reserve(varnames.size()); + for (auto &t : varnames) { + Variable *var = scope->Var(t); + LoDTensor *tensor = var->GetMutable(); + if (platform::is_gpu_place(tensor->place())) { +#ifdef PADDLE_WITH_CUDA + Variable *temp_var = xpu_temp_scope_->Var(t); + LoDTensor *temp_tensor = temp_var->GetMutable(); + temp_tensor->Resize(tensor->dims()); + float *temp_data = temp_tensor->mutable_data(platform::CPUPlace()); + paddle::distributed::Region reg(temp_data, tensor->numel()); + regions.emplace_back(std::move(reg)); + VLOG(1) << "AsyncCommunicator::RpcRecvDense Var " << t << " table_id " + << table_id << " Temp_data[0] " << temp_data[0] + << " Temp_data[-1] " << temp_data[tensor->numel() - 1]; +#endif + } else { + float *w = tensor->mutable_data(tensor->place()); + paddle::distributed::Region reg(w, tensor->numel()); + regions.emplace_back(std::move(reg)); + } + } + auto status = + _worker_ptr->pull_dense(regions.data(), regions.size(), table_id); + status.wait(); + + for (auto &t : varnames) { + Variable *var = scope->FindVar(t); + LoDTensor *tensor = var->GetMutable(); + VLOG(1) << "AsyncCommunicator::RecvNoBarrier Var " << t << " On gpu? " + << platform::is_gpu_place(tensor->place()); + if (platform::is_gpu_place(tensor->place())) { +#ifdef PADDLE_WITH_CUDA + LoDTensor *temp_tensor = + xpu_temp_scope_->FindVar(t)->GetMutable(); + framework::TensorCopy(*temp_tensor, tensor->place(), tensor); + float *temp_data = temp_tensor->mutable_data(platform::CPUPlace()); + VLOG(1) << "AsyncCommunicator::RpcRecvDense Var " << t << " table_id " + << table_id << " Temp_data[0] " << temp_data[0] + << " Temp_data[-1] " << temp_data[tensor->numel() - 1]; +#endif + } + } + + return; +} + +void Communicator::RpcSendDenseParam(const std::vector &varnames, + int table_id, const Scope &scope) { + platform::RecordEvent record_event("Communicator->RpcSendDenseParam"); + auto place = platform::CPUPlace(); + std::vector regions; + for (auto &t : varnames) { + Variable *var = scope.FindVar(t); + CHECK(var != nullptr) << "var[" << t << "] not found"; + LoDTensor *tensor = var->GetMutable(); + if (platform::is_gpu_place(tensor->place())) { +#ifdef PADDLE_WITH_CUDA + Variable *temp_var = xpu_temp_scope_->Var(t); + LoDTensor *temp_tensor = temp_var->GetMutable(); + temp_tensor->Resize(tensor->dims()); + float *temp_data = temp_tensor->mutable_data(platform::CPUPlace()); + framework::TensorCopy(*tensor, platform::CPUPlace(), temp_tensor); + paddle::distributed::Region reg(temp_data, tensor->numel()); + regions.emplace_back(std::move(reg)); + VLOG(1) << "AsyncCommunicator::RpcSendDenseParam Var " << t + << " table_id " << table_id << " Temp_data[0] " << temp_data[0] + << " Temp_data[-1] " << temp_data[tensor->numel() - 1]; +#endif + } else { + float *w = tensor->mutable_data(place); + paddle::distributed::Region reg(w, tensor->numel()); + regions.emplace_back(std::move(reg)); + VLOG(1) << "AsyncCommunicator::RpcSendDenseParam Var " << t + << " talbe_id " << table_id << " Temp_data[0] " << w[0] + << " Temp_data[-1] " << w[tensor->numel() - 1]; + } + } + auto status = + _worker_ptr->push_dense_param(regions.data(), regions.size(), table_id); + status.wait(); + VLOG(4) << "RPC Send Dense Param " << table_id << " done!"; + return; +} + +void Communicator::RpcSendDense(const CommContext &ctx, const Scope &scope) { + platform::RecordEvent record_event("Communicator->RpcSendDense"); + auto &var_names = ctx.origin_varnames; + auto &table_id = ctx.table_id; + auto dense_data = std::make_shared>(); + size_t request_call_num = _worker_ptr->get_server_nums(); + uint32_t num_per_shard = + dense_dim_per_shard(ctx.height_sections[0], request_call_num); + dense_data->resize(num_per_shard * + request_call_num); // accessor->update_dim() = 1 + float *data = dense_data->data(); + uint32_t pos = 0; + for (size_t i = 0; i < var_names.size(); ++i) { + const LoDTensor tensor = scope.FindVar(var_names[i])->Get(); + size_t count = static_cast(tensor.numel()); + const float *g = tensor.data(); + CHECK(pos + count <= dense_data->size()) + << "invalid dense size, cur pos[" << pos << "]" + << " data_num[" << count << "] size[" << dense_data->size() << "]"; + memcpy(data + pos, g, count * sizeof(float)); + pos += count; + } + + ++_async_call_num; + DownpourBrpcClosure *closure = new DownpourBrpcClosure( + request_call_num, [this, request_call_num](void *done) { + int ret = 0; + auto *closure = (DownpourBrpcClosure *)done; + for (size_t i = 0; i < request_call_num; ++i) { + if (closure->check_response(i, PS_PUSH_DENSE_TABLE) != 0) { + ret = -1; + break; + } + } + closure->set_promise_value(ret); + --_async_call_num; + }); + auto status = _worker_ptr->push_dense_raw_gradient( + table_id, data, dense_data->size(), closure); + status.wait(); + return; +} + +void Communicator::RpcSendSparseParam(const std::string &varname, int table_id, + const Scope &scope) { + platform::RecordEvent record_event("Communicator->RpcSendSparseParam"); + size_t request_call_num = _worker_ptr->get_server_nums(); + std::vector push_g_vec; + + auto *send_var = scope.FindVar(varname); + auto *tensor = send_var->GetMutable(); + auto dim = tensor->dims()[1]; + uint64_t sparse_num = static_cast(tensor->dims()[0]); + std::vector sparse_push_keys(sparse_num); + std::iota(sparse_push_keys.begin(), sparse_push_keys.end(), 0); + push_g_vec.reserve(sparse_num); + + for (auto i = 0; i < static_cast(sparse_push_keys.size()); ++i) { + push_g_vec.push_back(tensor->data() + i * dim); + } + + DownpourBrpcClosure *closure = new DownpourBrpcClosure( + request_call_num, [this, request_call_num](void *done) { + int ret = 0; + auto *closure = (DownpourBrpcClosure *)done; + for (size_t i = 0; i < request_call_num; ++i) { + if (closure->check_response(i, PS_PUSH_SPARSE_PARAM) != 0) { + ret = -1; + break; + } + } + closure->set_promise_value(ret); + }); + auto status = _worker_ptr->push_sparse_param( + table_id, sparse_push_keys.data(), (const float **)push_g_vec.data(), + sparse_push_keys.size(), closure); + status.wait(); + return; +} + +void Communicator::RpcSendSparse(const std::string &var_name, int table_id, + const Scope &scope) { + platform::RecordEvent record_event("Communicator->RpcSendSparse"); + size_t request_call_num = _worker_ptr->get_server_nums(); + std::vector sparse_push_keys; + std::vector push_g_vec; + + auto *send_var = scope.FindVar(var_name); + auto *tensor = send_var->GetMutable(); + auto dim = tensor->value().dims()[1]; + std::transform(tensor->rows().begin(), tensor->rows().end(), + std::back_inserter(sparse_push_keys), + [&](int id) { return static_cast(id); }); + + for (auto i = 0; i < static_cast(sparse_push_keys.size()); ++i) { + push_g_vec.push_back(tensor->mutable_value()->data() + i * dim); + } + + ++_async_call_num; + DownpourBrpcClosure *closure = new DownpourBrpcClosure( + request_call_num, [this, request_call_num](void *done) { + int ret = 0; + auto *closure = (DownpourBrpcClosure *)done; + for (size_t i = 0; i < request_call_num; ++i) { + if (closure->check_response(i, PS_PUSH_SPARSE_TABLE) != 0) { + ret = -1; + break; + } + } + closure->set_promise_value(ret); + --_async_call_num; + }); + auto status = _worker_ptr->push_sparse_raw_gradient( + table_id, sparse_push_keys.data(), (const float **)push_g_vec.data(), + sparse_push_keys.size(), closure); + status.wait(); + return; +} + +void Communicator::RpcRecvSparse(const std::string &varname, int table_id, + Scope *scope) { + platform::RecordEvent record_event("Communicator->RpcRecvSparse"); + auto *send_var = scope->Var(varname); + auto *tensor = send_var->GetMutable(); + auto dim = tensor->dims()[1]; + uint64_t sparse_num = static_cast(tensor->dims()[0]); + + std::vector sparse_push_keys(sparse_num); + std::iota(sparse_push_keys.begin(), sparse_push_keys.end(), 0); + + std::vector push_g_vec; + for (auto i = 0; i < static_cast(sparse_push_keys.size()); ++i) { + push_g_vec.push_back(tensor->data() + i * dim); + } + + auto status = _worker_ptr->pull_sparse((float **)push_g_vec.data(), table_id, + sparse_push_keys.data(), + sparse_push_keys.size()); + status.wait(); + return; +} + +void Communicator::InitParams(const RecvCtxMap &recv_varname_to_ctx) { + if (trainer_id_ == 0) { + for (auto &iter : recv_varname_to_ctx) { + auto &table_id = iter.first; + auto &varnames = iter.second; + RpcSendDenseParam(varnames, table_id, *recv_scope_); + VLOG(1) << "push dense param to table " << table_id + << " from 0' trainer done"; + } + BarrierWithTable(1); + } else { + BarrierWithTable(1); + for (auto &iter : recv_varname_to_ctx) { + auto &table_id = iter.first; + auto &varnames = iter.second; + RpcRecvDense(varnames, table_id, recv_scope_); + VLOG(1) << "pull dense param to table " << table_id + << " from 0' trainer done"; + } + } + BarrierWithTable(1); + return; +} + +void Communicator::RpcProfilerControl() { + if (trainer_id_ == 0) { + if (!do_server_profiler_ && platform::IsProfileEnabled()) { + // send profiler start flag + do_server_profiler_ = true; + auto start_status = _worker_ptr->start_profiler(); + start_status.wait(); + } else if (do_server_profiler_ && !platform::IsProfileEnabled()) { + // send profiler end flag + auto stop_status = _worker_ptr->stop_profiler(); + stop_status.wait(); + do_server_profiler_ = false; + } + } +} + +void AsyncCommunicator::RecvThread() { + if (!independent_recv_) return; + VLOG(3) << "Independent RecvThread Start and Wait"; + + while (running_) { + int grad_num = grad_num_.load(); + if (grad_num > min_send_grad_num_before_recv_) { + RecvByCommunicator(); + grad_num_.store(0); + } else { + std::this_thread::sleep_for(std::chrono::milliseconds(10)); + } + } + VLOG(1) << "communicator stopped, independent recv thread exit"; +} + +void AsyncCommunicator::RecvByCommunicator() { + if (!running_) return; + RecvNoBarrier(); + VLOG(3) << "run recv graph end"; +} + +void AsyncCommunicator::RecvNoBarrier() { + for (auto &iter : recv_varname_to_ctx_) { + auto &table_id = iter.first; + auto &varnames = iter.second; + RpcRecvDense(varnames, table_id, recv_scope_); + } + + for (auto &iter : recv_varname_to_ctx_) { + auto var_names = iter.second; + for (auto &t : var_names) { + Variable *var = recv_scope_->FindVar(t); + LoDTensor *tensor = var->GetMutable(); + VLOG(1) << "AsyncCommunicator::RecvNoBarrier Var " << t << " On gpu? " + << platform::is_gpu_place(tensor->place()); + if (platform::is_gpu_place(tensor->place())) { +#ifdef PADDLE_WITH_CUDA + LoDTensor *temp_tensor = + xpu_temp_scope_->FindVar(t)->GetMutable(); + framework::TensorCopy(*temp_tensor, tensor->place(), tensor); +#endif + } + } + } + + return; +} + +void AsyncCommunicator::SendByCommunicator() { + std::vector> tasks; + tasks.reserve(send_varname_to_ctx_.size()); + + for (auto &iter : send_varname_to_ctx_) { + auto &ctx = iter.second; + + auto send_recv_task = [this, &ctx] { + auto &varnames = ctx.origin_varnames; + auto &table_id = ctx.table_id; + size_t var_nums = varnames.size(); + auto &check_queue = send_varname_to_queue_[varnames[0]]; + std::vector>> vars; + vars.resize(var_nums); + int merged_var_num = 0; + int wait_times = 0; + while (merged_var_num < max_merge_var_num_) { + if (check_queue->Size() == 0) { + VLOG(4) << "wait_times -> " << wait_times; + if (wait_times >= send_wait_times_) { + break; + } + std::this_thread::sleep_for(std::chrono::milliseconds(10)); + wait_times++; + continue; + } else { + wait_times = 0; + for (size_t i = 0; i < var_nums; i++) { + auto &var_name = varnames[i]; + auto &var_queue = send_varname_to_queue_[var_name]; + vars[i].push_back(var_queue->Pop()); + } + merged_var_num++; + } + } + if (merged_var_num == 0) return; + + for (size_t i = 0; i < var_nums; i++) { + auto &var_name = varnames[i]; + MergeVars(var_name, vars[i], send_scope_.get(), 1); + } + + if (ctx.is_sparse) { + PADDLE_ENFORCE_EQ( + varnames.size(), 1, + platform::errors::InvalidArgument( + "sparse variables can only be merged by one variables")); + RpcSendSparse(varnames[0], table_id, *send_scope_); + } else { + RpcSendDense(ctx, *send_scope_); + if (!independent_recv_ && + recv_varname_to_ctx_.find(table_id) != recv_varname_to_ctx_.end()) { + auto recv_varnames = recv_varname_to_ctx_.at(table_id); + RpcRecvDense(recv_varnames, table_id, recv_scope_); + } + } + if (independent_recv_) { + grad_num_.fetch_add(1, std::memory_order_relaxed); + } + }; + tasks.emplace_back(send_threadpool_->enqueue(std::move(send_recv_task))); + } + for (auto &task : tasks) { + task.wait(); + } + return; +} + +void AsyncCommunicator::MainThread() { + VLOG(3) << "AsyncCommunicator MainThread start and wait"; + + while (waiting_ && running_) { + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + VLOG(3) << "wait for running"; + } + + while (running_) { + SendByCommunicator(); + RpcProfilerControl(); + } + VLOG(1) << "communicator stopped, send thread exit"; +} + +void HalfAsyncCommunicator::MainThread() { + VLOG(3) << "HalfAsyncCommunicator MainThread start and wait"; + + while (waiting_ && running_) { + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + VLOG(3) << "wait for running"; + } + + while (running_) { + SendByCommunicator(); + BarrierSend(); + RecvByCommunicator(); + BarrierRecv(); + BarrierWeakUp(); + } + VLOG(1) << "communicator stopped, send thread exit"; +} + +void AsyncCommunicator::InitImpl(const RpcCtxMap &send_varname_to_ctx, + const RecvCtxMap &recv_varname_to_ctx, + Scope *recv_scope) { + send_varname_to_ctx_ = std::move(send_varname_to_ctx); + recv_varname_to_ctx_ = std::move(recv_varname_to_ctx); + recv_scope_ = std::move(recv_scope); + send_scope_.reset(new Scope()); + xpu_temp_scope_.reset(new Scope()); + for (auto &iter : send_varname_to_ctx_) { + auto &ctx = iter.second; + auto &varnames = ctx.origin_varnames; + for (auto &var_name : varnames) { + send_varname_to_queue_[var_name] = + std::make_shared>>( + send_queue_size_); + } + } + send_threadpool_.reset(new ::ThreadPool(thread_pool_size_)); +} + +AsyncCommunicator::~AsyncCommunicator() { + running_ = false; + if (main_thread_) main_thread_->join(); + if (recv_thread_) recv_thread_->join(); +} + +void AsyncCommunicator::Start() { + VLOG(1) << "Communicator start"; + if (!communicator_) { + VLOG(0) << "Communicator is not inited, do nothing"; + } else { + VLOG(1) << "start send thread and recv thread"; + waiting_ = true; + running_ = true; + // flushing_ = false; + BarrierTriggerReset(max_merge_var_num_); + // start send and recv thread + main_thread_.reset( + new std::thread(std::bind(&AsyncCommunicator::MainThread, this))); + if (independent_recv_) { + recv_thread_.reset( + new std::thread(std::bind(&AsyncCommunicator::RecvThread, this))); + } + } +} + +void AsyncCommunicator::Stop() { + VLOG(1) << "Communicator stop"; + running_ = false; + if (!communicator_) { + VLOG(0) << "Communicator is not inited, do nothing"; + } else { + if (recv_thread_) { + VLOG(1) << "stop recv thread"; + recv_thread_->join(); + recv_thread_.reset(nullptr); + } + if (main_thread_) { + VLOG(1) << "stop main thread"; + main_thread_->join(); + main_thread_.reset(nullptr); + } + } + VLOG(1) << "Communicator stop done"; +} + +bool AsyncCommunicator::Check(const std::vector &var_tables) { + PADDLE_ENFORCE_EQ( + var_tables.size(), 1, + platform::errors::InvalidArgument("var_tables.size() == 1 is permitted")); + + auto table_name = var_tables[0]; + if (send_varname_to_ctx_.find(table_name) == send_varname_to_ctx_.end()) + return false; + return true; +} + +bool AsyncCommunicator::Check(const int table_id) { + for (auto &iter : send_varname_to_ctx_) { + auto &ctx = iter.second; + if (ctx.table_id == table_id) return true; + } + return false; +} + +void AsyncCommunicator::Send(const std::vector &var_names, + const framework::Scope &scope) { + waiting_ = false; + for (size_t i = 0; i < var_names.size(); i++) { + auto *var = scope.FindVar(var_names[i]); + auto tmp_grad_var = std::make_shared(); + framework::CopyVariable(*var, tmp_grad_var.get()); + send_varname_to_queue_[var_names[i]]->Push(tmp_grad_var); + } +} + +void HalfAsyncCommunicator::Clean() { + for (auto &iter : send_varname_to_queue_) { + auto &var_name = iter.first; + auto &var_queue = iter.second; + + while (var_queue->Size() > 0) { + var_queue->Pop(); + } + + VLOG(3) << "clean var: " << var_name << " done"; + } +} + +void HalfAsyncCommunicator::BarrierTriggerDecrement() { + barrier_trigger_--; + VLOG(3) << "BarrierTriggerDecrement decrement barrier trigger to " + << barrier_trigger_.load(); +} + +void HalfAsyncCommunicator::BarrierTriggerReset(int initial_val) { + barrier_trigger_.store(initial_val); + + VLOG(3) << "BarrierTriggerReset reset barrier trigger to " + << barrier_trigger_.load(); +} + +void HalfAsyncCommunicator::Barrier() { + barrier_counter_++; + + if (!running_) { + VLOG(3) << "Communicator is not running, release barrier"; + return; + } + + { + std::unique_lock lk(barrier_mutex_); + barrier_cond_.wait(lk, [this] { return (barrier_counter_ == 0); }); + } +} + +int HalfAsyncCommunicator::BatchesCounter() { + while (running_) { + if (barrier_counter_.load() >= barrier_trigger_.load() && + barrier_trigger_.load() != 0) { + break; + } else { + std::this_thread::sleep_for(std::chrono::milliseconds(10)); + } + } + + return barrier_counter_.load(); +} + +void HalfAsyncCommunicator::SendByCommunicator() { + int batches = BatchesCounter(); + VLOG(1) << "HalfAsyncCommunicator::BatchesCounter = " << batches; + if (batches <= 0) return; + + std::vector> tasks; + tasks.reserve(send_varname_to_ctx_.size()); + + for (auto &iter : send_varname_to_ctx_) { + auto &ctx = iter.second; + auto send_recv_task = [this, &ctx, batches] { + auto &varnames = ctx.origin_varnames; + auto &table_id = ctx.table_id; + size_t var_nums = varnames.size(); + + std::vector>> vars; + vars.resize(var_nums); + for (size_t i = 0; i < var_nums; i++) { + auto &var_name = varnames[i]; + auto &var_queue = send_varname_to_queue_[var_name]; + for (int j = 0; j < batches; j++) vars[i].push_back(var_queue->Pop()); + MergeVars(var_name, vars[i], send_scope_.get(), 1); + } + + if (ctx.is_sparse) { + PADDLE_ENFORCE_EQ( + varnames.size(), 1, + platform::errors::InvalidArgument( + "sparse variables can only be merged by one variables")); + RpcSendSparse(varnames[0], table_id, *send_scope_); + } else { + RpcSendDense(ctx, *send_scope_); + } + }; + tasks.emplace_back(send_threadpool_->enqueue(std::move(send_recv_task))); + } + for (auto &task : tasks) { + task.wait(); + } + return; +} + +void HalfAsyncCommunicator::BarrierWeakUp() { + barrier_counter_.store(0); + barrier_cond_.notify_all(); +} + +void SyncCommunicator::BarrierSend() { + if (!running_) return; + BarrierWithTable(0); + VLOG(4) << "BarrierSend with SyncCommunicator"; +} + +void SyncCommunicator::BarrierRecv() { + if (!running_) return; + BarrierWithTable(1); + + VLOG(4) << "BarrierRecv with SyncCommunicator"; +} + +void GeoCommunicator::Send(const std::vector &var_names, + const framework::Scope &scope) { + waiting_ = false; + auto before_send = GetCurrentUS(); + auto table_name = var_names[0]; + + size_t splited_var_nums = + send_varname_to_ctx_[table_name].splited_varnames.size(); + + std::unordered_map> ids_table; + + for (size_t j = 0; j < splited_var_nums; j++) { + ids_table.insert(std::pair>( + send_varname_to_ctx_[table_name].splited_varnames[j], + std::unordered_set())); + } + + auto *var = scope.FindVar(table_name); + + PADDLE_ENFORCE_EQ(var->IsType(), true, + platform::errors::InvalidArgument( + "Only need to send Sparse Grad in Geo mode.")); + auto &rows = var->Get().rows(); + + // insert ids which has not been record + for (size_t j = 0; j < rows.size(); j++) { + auto ep_idx = rows[j] % splited_var_nums; + ids_table.at(send_varname_to_ctx_[table_name].splited_varnames[ep_idx]) + .insert(rows[j]); + } + + for (auto &iter : ids_table) { + auto &key = iter.first; + auto &sparse_ids_set = iter.second; + auto sparse_ids_vec = std::make_shared>(); + sparse_ids_vec->assign(sparse_ids_set.begin(), sparse_ids_set.end()); + sparse_id_queues_.at(key)->Push(sparse_ids_vec); + VLOG(3) << "push " << sparse_ids_vec->size() << " ids to " << key + << "'s queue"; + } + + auto after_send = GetCurrentUS(); + VLOG(2) << "run send op finish. use time " << (after_send - before_send); +} + +void GeoCommunicator::InitImpl(const RpcCtxMap &send_varname_to_ctx, + const RecvCtxMap &recv_varname_to_ctx, + Scope *recv_scope) { + send_varname_to_ctx_ = std::move(send_varname_to_ctx); + recv_varname_to_ctx_ = std::move(recv_varname_to_ctx); + recv_scope_ = std::move(recv_scope); + + PADDLE_ENFORCE_GT( + send_varname_to_ctx.size(), 0, + platform::errors::InvalidArgument("send var contexts can not be zero")); + + for (auto &iter : send_varname_to_ctx_) { + auto &ctx = iter.second; + if (!ctx.is_sparse) continue; + auto &varnames = ctx.origin_varnames; + PADDLE_ENFORCE_EQ( + varnames.size(), 1, + platform::errors::InvalidArgument( + "sparse variables can only be merged by one variables")); + for (auto &splited_var : ctx.splited_varnames) { + parallel_task_nums_ += 1; + sparse_id_queues_.insert( + std::pair>>>>( + splited_var, + std::make_shared< + BlockingQueue>>>( + send_queue_size_))); + } + } + + send_threadpool_.reset(new ::ThreadPool(thread_pool_size_)); + + delta_scope_.reset(new Scope()); + old_scope_.reset(new Scope()); + pserver_scope_.reset(new Scope()); +} + +void GeoCommunicator::InitParams(const RecvCtxMap &recv_varname_to_ctx) { + std::vector> tasks; + tasks.reserve(recv_varname_to_ctx_.size()); + + for (auto &iter : recv_varname_to_ctx_) { + auto &table_id = iter.first; + auto &varnames = iter.second; + + auto recv_task = [this, &table_id, &varnames] { + InitDense(varnames, table_id); + }; + tasks.emplace_back(send_threadpool_->enqueue(std::move(recv_task))); + } + + for (auto &task : tasks) { + task.wait(); + } + + for (auto &iter : send_varname_to_ctx_) { + auto &ctx = iter.second; + if (!ctx.is_sparse) return; + auto &varname = ctx.origin_varnames[0]; + auto &table_id = ctx.table_id; + auto param = varname.substr(0, varname.size() - 5); + InitSparse(param, table_id); + } + return; +} + +void GeoCommunicator::InitDense(std::vector &varnames, + int table_id) { + if (trainer_id_ == 0) { + RpcSendDenseParam(varnames, table_id, *recv_scope_); + BarrierWithTable(1); + VLOG(0) << "push dense param to table " << table_id + << " from 0' trainer done"; + } else { + BarrierWithTable(1); + RpcRecvDense(varnames, table_id, recv_scope_); + VLOG(0) << "push dense param to table " << table_id + << " from 0' trainer done"; + } + + // copy to old_scope + for (auto &t : varnames) { + auto *global_var = recv_scope_->FindVar(t); + global_var->GetMutable(); + auto *old_var = old_scope_->Var(t); + old_var->GetMutable(); + framework::CopyVariable(*global_var, old_var); + } + VLOG(1) << "init dense table " << table_id << " done"; +} + +void GeoCommunicator::SendDense(const CommContext &send_ctx) { + platform::RecordEvent record_event("GeoCommunicator->SendDense"); + auto &var_names = send_ctx.origin_varnames; + auto &table_id = send_ctx.table_id; + for (auto &varname : var_names) { + auto param_name = GradToParam(varname); + auto *var_latest = recv_scope_->FindVar(param_name); + auto *var_timestamp = old_scope_->FindVar(param_name); + + PADDLE_ENFORCE_EQ(var_latest->IsInitialized(), true, + platform::errors::Unavailable( + "%s is not initialized, please check", param_name)); + PADDLE_ENFORCE_EQ(var_timestamp->IsInitialized(), true, + platform::errors::Unavailable( + "%s is not initialized, please check", param_name)); + + auto &t_latest = var_latest->Get(); + auto t_timestamp = var_timestamp->GetMutable(); + + auto cpu_ctx = paddle::platform::CPUDeviceContext(); + auto *var_delta = delta_scope_->Var(varname); + auto *t_delta = var_delta->GetMutable(); + t_delta->mutable_data(t_latest.dims(), cpu_ctx.GetPlace()); + + auto blas = + paddle::operators::math::GetBlas( + cpu_ctx); + blas.VSUB(t_latest.numel(), t_latest.data(), + t_timestamp->data(), t_delta->data()); + + float coefficient = 1.0 / static_cast(trainers_); + blas.SCAL(t_latest.numel(), coefficient, t_delta->data()); + + blas.VADD(t_latest.numel(), t_timestamp->data(), + t_delta->data(), t_timestamp->data()); + } + RpcSendDense(send_ctx, *delta_scope_); + VLOG(1) << "Finish Send Dense " << var_names[0] << ", table_id: " << table_id; + return; +} + +void GeoCommunicator::RecvDense(const CommContext &send_ctx) { + platform::RecordEvent record_event("GeoCommunicator->RecvDense"); + auto &table_id = send_ctx.table_id; + auto &varnames = recv_varname_to_ctx_.at(table_id); + // 1. recv from pserver + RpcRecvDense(varnames, table_id, pserver_scope_.get()); + + // 2.1 pserver - old => delta; 2.2 latest + old => latest 2.3 old => pserver + auto cpu_ctx = paddle::platform::CPUDeviceContext(); + for (auto &varname : varnames) { + auto *var_latest = recv_scope_->FindVar(varname); + auto t_latest = var_latest->GetMutable(); + + auto *var_old = old_scope_->FindVar(varname); + auto t_old = var_old->GetMutable(); + + auto *var_pserver = pserver_scope_->FindVar(varname); + auto t_pserver = var_pserver->Get(); + + auto *var_delta = delta_scope_->Var(varname); + auto *t_delta = var_delta->GetMutable(); + t_delta->mutable_data(t_latest->dims(), cpu_ctx.GetPlace()); + + auto blas = + paddle::operators::math::GetBlas( + cpu_ctx); + blas.VSUB(t_latest->numel(), t_pserver.data(), t_old->data(), + t_delta->data()); + blas.VADD(t_latest->numel(), t_latest->data(), + t_delta->data(), t_latest->data()); + blas.VCOPY(t_latest->numel(), t_pserver.data(), + t_old->data()); + } + VLOG(1) << "Finish Recv Dense " << varnames[0] << ", table_id: " << table_id; + return; +} + +void GeoCommunicator::InitSparse(const std::string &var_name, int table_id) { + VLOG(0) << "Init Sparse " << var_name << " : table " << table_id << " begin."; + if (trainer_id_ == 0) { + RpcSendSparseParam(var_name, table_id, *recv_scope_); + BarrierWithTable(1); + VLOG(0) << "push sparse param to table " << table_id + << " from 0' trainer done"; + } else { + BarrierWithTable(1); + RpcRecvSparse(var_name, table_id, recv_scope_); + VLOG(0) << "push dense param to table " << table_id + << " from 0' trainer done"; + } + + VLOG(0) << "Init Sparse " << var_name << " : table " << table_id << " done."; + auto *global_var = recv_scope_->FindVar(var_name); + auto *var = old_scope_->Var(var_name); + framework::CopyVariable(*global_var, var); + return; +} + +std::vector GeoCommunicator::MergeSparseIds( + const std::string &send_varname) { + size_t merge_num = 0, wait_times = 0; + std::unordered_set sparse_ids; + while (merge_num < static_cast(max_merge_var_num_)) { + VLOG(3) << "Merge Number of " << send_varname << " = " << merge_num; + if (sparse_id_queues_.at(send_varname)->Size() > 0) { + wait_times = 0; + std::shared_ptr> pop_ids = + sparse_id_queues_.at(send_varname)->Pop(); + for (size_t j = 0; j < pop_ids->size(); j++) { + sparse_ids.insert(pop_ids->at(j)); + } + merge_num += 1; + VLOG(3) << "sparse_id_queues_(" << send_varname << ") pushed"; + } else if (sparse_id_queues_.at(send_varname)->Size() == 0) { + VLOG(3) << "wait_times -> " << wait_times; + if (wait_times >= static_cast(send_wait_times_)) { + break; + } + std::this_thread::sleep_for(std::chrono::milliseconds(10)); + wait_times++; + continue; + } + } + std::vector res; + res.assign(sparse_ids.begin(), sparse_ids.end()); + return res; +} + +void GeoCommunicator::SendSparse(const std::string &varname, + std::vector &sparse_ids, int table_id, + int ep_idx) { + platform::RecordEvent record_event("GeoCommunicator->SendSparse"); + std::string param_name = SplitedGradToParam(varname); + VLOG(1) << "In GeoCommunicator::SendSparse(" << varname << " " << param_name + << ", ids.size = " << sparse_ids.size() << ", table_id: " << table_id + << ", ep_idx: " << ep_idx; + + auto *var_latest = recv_scope_->FindVar(param_name); + auto *var_old = old_scope_->FindVar(param_name); + + PADDLE_ENFORCE_EQ(var_latest->IsInitialized(), true, + platform::errors::Unavailable( + "%s is not initialized, please check", param_name)); + PADDLE_ENFORCE_EQ(var_old->IsInitialized(), true, + platform::errors::Unavailable( + "%s is not initialized, please check", param_name)); + + auto &t_latest = var_latest->Get(); + auto *t_old = var_old->GetMutable(); + + auto dims1 = t_latest.dims()[1]; + auto cpu_ctx = paddle::platform::CPUDeviceContext(); + + auto *var_delta = delta_scope_->Var(varname); + auto *t_delta = var_delta->GetMutable(); + auto *var_t_value = t_delta->mutable_value(); + var_t_value->Resize({static_cast(sparse_ids.size()), dims1}); + auto *t_value = var_t_value->mutable_data(cpu_ctx.GetPlace()); + + t_delta->set_rows(sparse_ids); + t_delta->set_height(t_latest.dims()[0]); + + auto blas = + paddle::operators::math::GetBlas( + cpu_ctx); + float coefficient = 1.0 / static_cast(trainers_); + + std::vector push_g_vec; + for (auto j = 0; j < static_cast(sparse_ids.size()); ++j) { + blas.VSUB(dims1, t_latest.data() + sparse_ids[j] * dims1, + t_old->data() + sparse_ids[j] * dims1, + t_value + j * dims1); + blas.SCAL(dims1, coefficient, t_value + j * dims1); + blas.VADD(dims1, t_old->data() + sparse_ids[j] * dims1, + t_value + j * dims1, + t_old->data() + sparse_ids[j] * dims1); + push_g_vec.push_back(t_value + j * dims1); + } + + ++_async_call_num; + DownpourBrpcClosure *closure = new DownpourBrpcClosure(1, [this](void *done) { + int ret = 0; + auto *closure = (DownpourBrpcClosure *)done; + if (closure->check_response(0, PS_PUSH_SPARSE_TABLE) != 0) { + ret = -1; + } + closure->set_promise_value(ret); + --_async_call_num; + }); + auto status = _worker_ptr->push_sparse_raw_gradient_partial( + table_id, (const uint64_t *)sparse_ids.data(), + (const float **)push_g_vec.data(), sparse_ids.size(), closure, ep_idx); + status.wait(); + + VLOG(1) << "Finish Send Sparse " << varname + << ", ids.size = " << sparse_ids.size() << ", table_id: " << table_id; + return; +} + +void GeoCommunicator::RecvSparse(const std::string &varname, int table_id, + int ep_idx) { + platform::RecordEvent record_event("GeoCommunicator->RecvSparse"); + // 1. recv from pserver + std::vector keys; + std::vector values; + auto status = _worker_ptr->pull_geo_param(table_id, &values, &keys, ep_idx); + status.wait(); + + std::string param = SplitedGradToParam(varname); + VLOG(1) << "RecvSparse receive var: " << varname << " " << param << ", " + << table_id << "; ids Size: " << keys.size() + << "; values size: " << values.size(); + + auto *var_latest = recv_scope_->FindVar(param); + auto *var_old = old_scope_->FindVar(param); + + auto *t_latest = var_latest->GetMutable(); + auto *t_old = var_old->GetMutable(); + + auto dims1 = t_latest->dims()[1]; + auto numel = keys.size() * dims1; + + std::vector v_delta; + v_delta.resize(numel); + + auto cpu_ctx = paddle::platform::CPUDeviceContext(); + auto blas = + paddle::operators::math::GetBlas( + cpu_ctx); + + for (auto j = 0; j < static_cast(keys.size()); ++j) { + float *latest_data = t_latest->data() + keys[j] * dims1; + float *old_data = t_old->data() + keys[j] * dims1; + // pserver - old => delta + blas.VSUB(dims1, values.data() + j * dims1, old_data, + v_delta.data() + j * dims1); + // latest + delta => latest + blas.VADD(dims1, latest_data, v_delta.data() + j * dims1, latest_data); + // pserver => old + blas.VCOPY(dims1, values.data() + j * dims1, old_data); + } + VLOG(1) << "Finish Recv Sparse " << param << ", table_id: " << table_id; +} + +void GeoCommunicator::MainThread() { + VLOG(3) << "MainThread start and wait"; + + while (waiting_ && running_) { + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + VLOG(3) << "wait for running"; + } + + while (running_) { + std::vector> tasks; + tasks.reserve(parallel_task_nums_); + + for (auto &iter : send_varname_to_ctx_) { + auto &ctx = iter.second; + auto &varnames = ctx.origin_varnames; + auto &table_id = ctx.table_id; + + if (ctx.is_sparse) { + PADDLE_ENFORCE_EQ( + varnames.size(), 1, + platform::errors::InvalidArgument( + "sparse variables can only be merged by one variables")); + int pserver_num = static_cast(ctx.epmap.size()); + for (int ep_idx = 0; ep_idx < pserver_num; ep_idx++) { + // varname: emb@GRAD, param_name: emb, splited_varname: emb.delta0 + auto send_recv_task = [this, table_id, ep_idx, &ctx] { + auto splited_varname = ctx.splited_varnames[ep_idx]; + auto sparse_ids = MergeSparseIds(splited_varname); + SendSparse(splited_varname, sparse_ids, table_id, ep_idx); + RecvSparse(splited_varname, table_id, ep_idx); + }; + tasks.emplace_back( + send_threadpool_->enqueue(std::move(send_recv_task))); + } + } else { + auto send_recv_task = [this, &ctx] { + SendDense(ctx); + RecvDense(ctx); + }; + tasks.emplace_back( + send_threadpool_->enqueue(std::move(send_recv_task))); + } + } + for (auto &task : tasks) { + task.wait(); + } + } +} + +} // namespace distributed +} // namespace paddle diff --git a/paddle/fluid/distributed/service/communicator.h b/paddle/fluid/distributed/service/communicator.h new file mode 100644 index 00000000000000..a22b006013461c --- /dev/null +++ b/paddle/fluid/distributed/service/communicator.h @@ -0,0 +1,561 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "gflags/gflags.h" +#include "paddle/fluid/distributed/communicator_common.h" +#include "paddle/fluid/framework/scope.h" +#include "paddle/fluid/framework/variable.h" +#include "paddle/fluid/framework/variable_helper.h" +#include "paddle/fluid/operators/math/blas.h" +#include "paddle/fluid/operators/math/math_function.h" +#include "paddle/fluid/operators/math/selected_rows_functor.h" +#include "paddle/fluid/platform/device_context.h" +#include "paddle/fluid/platform/enforce.h" +#include "paddle/fluid/platform/place.h" +#include "paddle/fluid/string/split.h" + +#include "paddle/fluid/distributed/ps.pb.h" +#include "paddle/fluid/distributed/service/brpc_ps_client.h" +#include "paddle/fluid/distributed/service/ps_client.h" + +DECLARE_bool(communicator_is_sgd_optimizer); + +namespace paddle { +namespace distributed { + +using Scope = framework::Scope; +using Variable = framework::Variable; + +template +class BlockingQueue { + public: + explicit BlockingQueue(size_t capacity) : capacity_(capacity) { + PADDLE_ENFORCE_GT(capacity_, 0, + platform::errors::InvalidArgument( + "The capacity must be greater than 0.")); + } + + bool Push(const T &elem) { + { + std::unique_lock lock(mutex_); + cv_.wait(lock, [&] { return queue_.size() < capacity_; }); + queue_.push_back(elem); + } + cv_.notify_one(); + return true; + } + + bool Push(T &&elem) { + { + std::unique_lock lock(mutex_); + cv_.wait(lock, [&] { return queue_.size() < capacity_; }); + queue_.emplace_back(std::move(elem)); + } + cv_.notify_one(); + return true; + } + + T Pop() { + std::unique_lock lock(mutex_); + cv_.wait(lock, [=] { return !queue_.empty(); }); + T rc(std::move(queue_.front())); + queue_.pop_front(); + cv_.notify_one(); + return rc; + } + + size_t Cap() const { + std::lock_guard lock(mutex_); + return capacity_; + } + + size_t Size() const { + std::lock_guard lock(mutex_); + return queue_.size(); + } + + private: + const size_t capacity_; + std::deque queue_; + + mutable std::mutex mutex_; + std::condition_variable cv_; +}; + +template +using EigenVector = framework::EigenVector; + +template +inline void MergeVars(const std::string &var_name, + const std::vector> &vars, + Scope *scope, bool merge_add = true) { + PADDLE_ENFORCE_NE(vars.empty(), true, platform::errors::InvalidArgument( + "vector vars are empty.")); + auto cpu_place = platform::CPUPlace(); + auto &var0 = vars[0]; + auto *out_var = scope->Var(var_name); + + if (var0->IsType()) { + auto dims = var0->Get().dims(); + VLOG(3) << "merge " << var_name << " LoDTensor dims " << dims + << "; merge add: " << merge_add; + // init output tensor + auto *out_t = out_var->GetMutable(); + out_t->mutable_data(dims, cpu_place); + // check the input dims + for (auto &var : vars) { + auto &var_t = var->Get(); + PADDLE_ENFORCE_EQ( + var_t.dims(), dims, + platform::errors::InvalidArgument("vars should have the same dims.")); + } + + // set output tensor to 0. + auto cpu_ctx = paddle::platform::CPUDeviceContext(); + paddle::operators::math::SetConstant + constant_functor; + constant_functor(cpu_ctx, out_t, static_cast(0)); + // sum all vars to out + auto result = EigenVector::Flatten(*out_t); + for (auto &var : vars) { + auto &in_t = var->Get(); + auto in = EigenVector::Flatten(in_t); + result.device(*cpu_ctx.eigen_device()) = result + in; + } + if (!merge_add) { + result.device(*cpu_ctx.eigen_device()) = + result / static_cast(vars.size()); + } + } else if (var0->IsType()) { + auto &slr0 = var0->Get(); + auto *out_slr = out_var->GetMutable(); + out_slr->mutable_rows()->clear(); + out_slr->mutable_value()->mutable_data({{}}, cpu_place); + std::vector inputs; + inputs.reserve(vars.size()); + for (auto &var : vars) { + inputs.push_back(&var->Get()); + } + auto dev_ctx = paddle::platform::CPUDeviceContext(); + if (merge_add) { + paddle::operators::math::scatter::MergeAdd< + paddle::platform::CPUDeviceContext, T> + merge_add; + merge_add(dev_ctx, inputs, out_slr); + } else { + paddle::operators::math::scatter::MergeAverage< + paddle::platform::CPUDeviceContext, T> + merge_average; + merge_average(dev_ctx, inputs, out_slr); + } + + VLOG(3) << "merge " << var_name << " SelectedRows height: " << slr0.height() + << " dims: " << slr0.value().dims() << "; merge add: " << merge_add; + } else { + PADDLE_THROW(platform::errors::InvalidArgument("unsupported var type: %s!", + var0->Type())); + } +} + +using RpcCtxMap = std::unordered_map; +using RecvCtxMap = std::unordered_map>; +using SparseValue = std::unordered_map>; + +class Communicator { + public: + Communicator(); + + explicit Communicator(const std::map &envs_) { + VLOG(0) << "Communicator Init Envs"; + for (auto &iter : envs_) { + envs[iter.first] = iter.second; + VLOG(0) << iter.first << ": " << iter.second; + } + barrier_table_id_ = std::stoi(envs.at("barrier_table_id")); + trainer_id_ = std::stoi(envs.at("trainer_id")); + trainers_ = std::stoi(envs.at("trainers")); + } + + virtual void InitBrpcClient(const std::string &dist_desc, + const std::vector &host_sign_list); + // 1. recv dense param + virtual void RpcRecvDense(const std::vector &varnames, + int table_id, Scope *scope); + // 2. send dense param + virtual void RpcSendDenseParam(const std::vector &varnames, + int table_id, const Scope &scope); + // 3. send dense grad + virtual void RpcSendDense(const CommContext &ctx, const Scope &scope); + // 4. send sparse grad + virtual void RpcSendSparse(const std::string &var_name, int table_id, + const Scope &scope); + // 5. send sparse param + virtual void RpcSendSparseParam(const std::string &varname, int table_id, + const Scope &scope); + // 6. recv sparse param + virtual void RpcRecvSparse(const std::string &varname, int table_id, + Scope *scope); + + virtual ~Communicator() {} + virtual void RpcProfilerControl(); + + virtual void InitParams(const RecvCtxMap &recv_varname_to_ctx); + + virtual void Start() = 0; + + virtual void Stop() = 0; + + virtual bool IsRunning() { return running_; } + + virtual void Clean() {} + + virtual bool Check(const int table_id) = 0; + virtual bool Check(const std::vector &var_tables) = 0; + + virtual void Send(const std::vector &var_names, + const framework::Scope &scope) = 0; + + virtual void RecvNoBarrier() {} + + virtual void Barrier() {} + + virtual void BarrierWithTable(uint32_t barrier_type) { + auto rets = _worker_ptr->barrier(barrier_table_id_, barrier_type); + rets.wait(); + } + + virtual void BarrierTriggerDecrement() {} + + virtual void BarrierTriggerReset(int init_counter) {} + + virtual void InitEnvs() = 0; + + virtual void InitImpl(const RpcCtxMap &send_varname_to_ctx, + const RecvCtxMap &recv_varname_to_ctx, + Scope *recv_scope) {} + + static Communicator *GetInstance() { return communicator_.get(); } + + static std::shared_ptr GetInstantcePtr() { + return communicator_; + } + + template + static Communicator *InitInstance( + const RpcCtxMap &send_ctx, const RecvCtxMap &recv_ctx, + const std::string &dist_desc, + const std::vector &host_sign_list, Scope *recv_scope, + const std::map &envs) { + std::call_once(init_flag_, &Communicator::InitWithRpcCtx, send_ctx, + recv_ctx, dist_desc, host_sign_list, recv_scope, + std::ref(envs)); + return communicator_.get(); + } + + // Init is called by InitInstance. + template + static void InitWithRpcCtx(const RpcCtxMap &send_ctx, + const RecvCtxMap &recv_ctx, + const std::string &dist_desc, + const std::vector &host_sign_list, + Scope *recv_scope, + const std::map &envs) { + if (communicator_.get() == nullptr) { + communicator_.reset(new T(std::ref(envs))); + communicator_->InitEnvs(); + communicator_->InitBrpcClient(dist_desc, host_sign_list); + communicator_->InitImpl(send_ctx, recv_ctx, recv_scope); + } + } + + PSClient *GetPsClient() { return _worker_ptr.get(); } + + std::shared_ptr GetPsClientPtr() { + return _worker_ptr; + } + + std::shared_ptr _worker_ptr; // pointer to worker + + protected: + bool running_ = false; + bool waiting_ = true; + bool flushing_ = false; + bool do_server_profiler_ = false; + static std::shared_ptr communicator_; + static std::once_flag init_flag_; + + std::unordered_map envs; + + // 计算每个shard 对 dense的存储量 + inline uint32_t dense_dim_per_shard(uint32_t dense_dim_total, + uint32_t shard_num) { + return dense_dim_total / shard_num + 1; + } + + void init_gflag(const std::string &gflags); + paddle::distributed::PSParameter _ps_param; + paddle::distributed::PaddlePSEnvironment _ps_env; + int servers_ = 0; + int trainers_; + int trainer_id_ = 0; + int barrier_table_id_ = 0; + RpcCtxMap send_varname_to_ctx_; + RecvCtxMap recv_varname_to_ctx_; + + Scope *recv_scope_; // should be global scope + std::unique_ptr xpu_temp_scope_; + std::atomic _async_call_num{0}; +}; + +class AsyncCommunicator : public Communicator { + public: + AsyncCommunicator() : Communicator() {} + + explicit AsyncCommunicator(const std::map &envs) + : Communicator(envs) {} + + ~AsyncCommunicator(); + + void InitEnvs() { + independent_recv_ = static_cast( + std::stoi(envs.at("communicator_independent_recv_thread"))); + min_send_grad_num_before_recv_ = + std::stoi(envs.at("communicator_min_send_grad_num_before_recv")); + thread_pool_size_ = std::stoi(envs.at("communicator_thread_pool_size")); + max_merge_var_num_ = std::stoi(envs.at("communicator_max_merge_var_num")); + send_wait_times_ = std::stoi(envs.at("communicator_send_wait_times")); + send_queue_size_ = std::stoi(envs.at("communicator_send_queue_size")); + need_global_step_ = + static_cast(std::stoi(envs.at("need_global_step"))); + } + + void Start() override; + + void Stop() override; + + void InitImpl(const RpcCtxMap &send_varname_to_ctx, + const RecvCtxMap &recv_varname_to_ctx, + Scope *recv_scope) override; + + virtual void MainThread(); + virtual void RecvThread(); + + virtual bool Check(const int table_id); + virtual bool Check(const std::vector &var_tables); + + void Send(const std::vector &var_names, + const framework::Scope &scope) override; + + virtual void SendByCommunicator(); + + virtual void SendGlobalStep(int batches) {} + + virtual void RecvByCommunicator(); + + virtual void RecvNoBarrier(); + + virtual int BatchesCounter() { return 1; } + + virtual void BarrierSend() {} + + virtual void BarrierRecv() {} + + virtual void BarrierWeakUp() {} + + protected: + std::unordered_map>>> + send_varname_to_queue_; + std::unique_ptr<::ThreadPool> send_threadpool_{nullptr}; + + int min_send_grad_num_before_recv_; + int thread_pool_size_; + int max_merge_var_num_; + int send_wait_times_; + int send_queue_size_; + bool need_global_step_ = false; + bool independent_recv_ = true; + int parallel_task_nums_ = 0; + + std::unique_ptr main_thread_{nullptr}; + std::unique_ptr recv_thread_{nullptr}; + + std::unique_ptr send_scope_; // an independent scope + std::atomic_uint grad_num_{0}; // the num of gradient sent since last recv +}; + +class HalfAsyncCommunicator : public AsyncCommunicator { + public: + HalfAsyncCommunicator() {} + + explicit HalfAsyncCommunicator(const std::map &envs) + : AsyncCommunicator(envs) {} + + void InitEnvs() { + // enfore to recv after send + independent_recv_ = false; + min_send_grad_num_before_recv_ = 0; + thread_pool_size_ = std::stoi(envs.at("communicator_thread_pool_size")); + max_merge_var_num_ = std::stoi(envs.at("communicator_max_merge_var_num")); + send_wait_times_ = std::stoi(envs.at("communicator_send_wait_times")); + send_queue_size_ = std::stoi(envs.at("communicator_send_queue_size")); + need_global_step_ = + static_cast(std::stoi(envs.at("need_global_step"))); + + VLOG(0) << "HalfAsyncCommunicator Initialized"; + } + + void MainThread() override; + + void SendByCommunicator() override; + + void Clean() override; + + void Barrier() override; + + void BarrierTriggerDecrement() override; + + void BarrierTriggerReset(int initial_val) override; + + int BatchesCounter(); + + void BarrierWeakUp(); + + protected: + // mutex for Wait for barrier + std::mutex barrier_mutex_; + std::condition_variable barrier_cond_; + std::atomic barrier_trigger_{0}; + std::atomic barrier_counter_{0}; +}; + +class SyncCommunicator : public HalfAsyncCommunicator { + public: + SyncCommunicator() : HalfAsyncCommunicator() {} + + explicit SyncCommunicator(const std::map &envs) + : HalfAsyncCommunicator(envs) {} + + void InitEnvs() { + // enfore to recv after send + independent_recv_ = false; + min_send_grad_num_before_recv_ = 0; + max_merge_var_num_ = std::stoi(envs.at("communicator_max_merge_var_num")); + send_wait_times_ = std::stoi(envs.at("communicator_send_wait_times")); + thread_pool_size_ = std::stoi(envs.at("communicator_thread_pool_size")); + send_queue_size_ = std::stoi(envs.at("communicator_send_queue_size")); + need_global_step_ = + static_cast(std::stoi(envs.at("need_global_step"))); + + VLOG(0) << "SyncCommunicator Initialized"; + } + + void BarrierSend(); + + void BarrierRecv(); + + private: + std::vector pserver_endpoints_{}; +}; + +class GeoCommunicator : public AsyncCommunicator { + public: + GeoCommunicator() : AsyncCommunicator() {} + + explicit GeoCommunicator(const std::map &envs) + : AsyncCommunicator(envs) {} + + void InitImpl(const RpcCtxMap &send_varname_to_ctx, + const RecvCtxMap &recv_varname_to_ctx, + Scope *recv_scope) override; + + void InitParams(const RecvCtxMap &recv_varname_to_ctx) override; + void InitDense(std::vector &varnames, int table_id); + void InitSparse(const std::string &var_name, int table_id); + + void SendDense(const CommContext &send_ctx); + void RecvDense(const CommContext &send_ctx); + + std::vector MergeSparseIds(const std::string &varname); + void SendSparse(const std::string &varname, std::vector &sparse_ids, + int table_id, int ep_idx); + void RecvSparse(const std::string &varname, int table_id, int ep_idx); + + void MainThread() override; + + void InitEnvs() { + independent_recv_ = false; + min_send_grad_num_before_recv_ = 0; + send_wait_times_ = std::stoi(envs.at("communicator_send_wait_times")); + thread_pool_size_ = std::stoi(envs.at("communicator_thread_pool_size")); + // id_queue's size + max_merge_var_num_ = std::stoi(envs.at("communicator_max_merge_var_num")); + send_queue_size_ = max_merge_var_num_; + VLOG(0) << "GeoCommunicator Initialized"; + } + + void Send(const std::vector &var_names, + const framework::Scope &scope) override; + + void SendByCommunicator() { return; } + + void SendGlobalStep(int batches) override { return; } + + void RecvByCommunicator() override { return; } + + inline std::string GradToParam(const std::string var_name) { + std::string param_name = var_name.substr(0, var_name.size() - 5); + return param_name; + } + + inline std::string SplitedGradToParam(const std::string delta_name) { + // delta_name: emb.delta0 + auto pos = delta_name.find(".block"); + std::string param_name = delta_name.substr(0, pos); + return param_name; + } + + private: + // parameter for delta calc and send + std::shared_ptr delta_scope_; + // parameter for storage the pserver param after last recv + std::shared_ptr old_scope_; + // parameter on pserver + std::shared_ptr pserver_scope_; + + std::unordered_map< + std::string, + std::shared_ptr>>>> + sparse_id_queues_; +}; + +} // namespace distributed +} // namespace paddle diff --git a/paddle/fluid/distributed/service/env.cc b/paddle/fluid/distributed/service/env.cc new file mode 100644 index 00000000000000..25bc2cc366aaac --- /dev/null +++ b/paddle/fluid/distributed/service/env.cc @@ -0,0 +1,19 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/distributed/service/env.h" + +namespace paddle { +namespace distributed {} // namespace distributed +} // namespace paddle diff --git a/paddle/fluid/distributed/service/env.h b/paddle/fluid/distributed/service/env.h new file mode 100644 index 00000000000000..42f31717f7fba4 --- /dev/null +++ b/paddle/fluid/distributed/service/env.h @@ -0,0 +1,284 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace paddle { +namespace distributed { + +struct PSHost { + std::string ip; + uint32_t port; + uint32_t rank; + + PSHost() = default; + PSHost(const std::string ip, uint32_t port, uint32_t rank) + : ip(ip), port(port), rank(rank) {} + + // |---ip---|---port---|--rank--| + // |-32bit--|--20bit---|--12bit-| + // for pslib + uint64_t serialize_to_uint64() { + uint64_t host_label = 0; + host_label = inet_addr(ip.c_str()); + host_label = host_label << 32; + host_label += (port << 12); + host_label += rank; + return host_label; + } + + void parse_from_uint64(uint64_t host_label) { + static uint64_t rank_label_mask = (1L << 12) - 1; + static uint64_t port_label_mask = (1L << 20) - 1; + rank = host_label & rank_label_mask; + port = (host_label >> 12) & port_label_mask; + uint32_t ip_addr = (host_label >> 32); + ip = inet_ntoa(*(in_addr *)&ip_addr); + } + + std::string to_string() { + std::stringstream s; + s << "host: " << ip; + s << " port: " << port; + s << " rank: " << rank; + s << " uint: " << serialize_to_uint64(); + return s.str(); + } + + // for open source parameter server + std::string serialize_to_string() { + std::stringstream s; + s << ip << ":"; + s << port << ":"; + s << rank; + return s.str(); + } + + void parse_from_string(std::string endpoint) { + std::vector endpoint_info; + string_split(endpoint, ':', &endpoint_info); + ip = endpoint_info[0]; + port = std::stoi(endpoint_info[1]); + rank = std::stoi(endpoint_info[2]); + } + + void string_split(const std::string &str, char sep, + std::vector *pieces, bool ignore_null = true) { + pieces->clear(); + if (str.empty()) { + if (!ignore_null) { + pieces->push_back(str); + } + return; + } + size_t pos = 0; + size_t next = str.find(sep, pos); + while (next != std::string::npos) { + pieces->push_back(str.substr(pos, next - pos)); + pos = next + 1; + next = str.find(sep, pos); + } + if (!str.substr(pos).empty()) { + pieces->push_back(str.substr(pos)); + } + } +}; + +class PSEnvironment { + public: + explicit PSEnvironment() {} + virtual ~PSEnvironment() {} + + virtual int32_t set_ps_servers(uint64_t *host_sign_list, int node_num) { + return 0; + } + virtual int32_t set_ps_servers( + const std::vector *host_endpoint_list, int node_num) { + return 0; + } + + virtual int32_t set_ps_clients(uint64_t *host_sign_list, int node_num) { + return 0; + } + + virtual int32_t set_ps_clients(std::string *host_endpoint_list, + int node_num) { + return 0; + } + virtual uint64_t get_local_host_sign() { return 0; } + virtual std::vector get_ps_servers() const { return _ps_server_list; } + virtual int32_t registe_ps_server(const std::string &ip, uint32_t port, + int32_t rank) { + return registe_ps_host(ip, port, rank, _ps_server_list, + _ps_server_sign_set); + } + + virtual std::vector get_ps_clients() const { return _ps_client_list; } + virtual int32_t registe_ps_client(const std::string &ip, uint32_t port, + int32_t rank) { + return registe_ps_host(ip, port, rank, _ps_client_list, + _ps_client_sign_set); + } + + virtual std::vector get_client_info() { + std::vector client_info; + for (auto &i : _ps_client_sign_set) { + client_info.push_back(i); + } + return client_info; + } + + virtual std::vector get_client_info(bool use_string_endpoint) { + if (use_string_endpoint) { + std::vector client_info; + for (auto &i : _ps_client_list) { + client_info.push_back(i.serialize_to_string()); + } + return client_info; + } + return {}; + } + + protected: + //注册一个host + virtual int32_t registe_ps_host(const std::string &ip, uint32_t port, + int32_t rank, std::vector &host_list, + std::unordered_set &sign_set) { + PSHost host; + host.ip = ip; + host.port = port; + host.rank = rank; + if (sign_set.count(rank) > 0) { + LOG(WARNING) << "ps-host :" << host.ip << ":" << host.port + << ", rank:" << host.rank + << " already register, ignore register"; + } else { + host_list.push_back(host); + sign_set.insert(rank); + } + // if (sign_set.count(host.serialize_to_uint64()) > 0) { + // LOG(WARNING) << "ps-host :" << host.ip << ":" << host.port + // << ", rank:" << host.rank + // << " already register, ignore register"; + // } else { + // host_list.push_back(host); + // sign_set.insert(host.serialize_to_uint64()); + // } + return 0; + } + + std::vector _ps_client_list; + std::unordered_set _ps_client_sign_set; // for unique filter + + std::vector _ps_server_list; + std::unordered_set _ps_server_sign_set; // for unique filter +}; + +class PaddlePSEnvironment : public PSEnvironment { + public: + explicit PaddlePSEnvironment() {} + virtual ~PaddlePSEnvironment() {} + + virtual int32_t set_ps_servers(uint64_t *host_sign_list, int node_num) { + _ps_server_list.clear(); + _ps_server_sign_set.clear(); + for (int i = 0; i < node_num; ++i) { + if (host_sign_list[i] > 0) { + PSHost host; + host.parse_from_uint64(host_sign_list[i]); + _ps_server_list.push_back(host); + _ps_server_sign_set.insert(host.serialize_to_uint64()); + } + } + std::sort( + _ps_server_list.begin(), _ps_server_list.end(), + [](const PSHost &h1, const PSHost &h2) { return h1.rank < h2.rank; }); + return 0; + } + + virtual int32_t set_ps_servers(const std::vector *host_sign_list, + int node_num) { + _ps_server_list.clear(); + _ps_server_sign_set.clear(); + for (int i = 0; i < node_num; ++i) { + if (host_sign_list->at(i) != "") { + PSHost host; + host.parse_from_string(host_sign_list->at(i)); + _ps_server_list.push_back(host); + _ps_server_sign_set.insert(host.rank); + } + } + std::sort( + _ps_server_list.begin(), _ps_server_list.end(), + [](const PSHost &h1, const PSHost &h2) { return h1.rank < h2.rank; }); + return 0; + } + + virtual int32_t set_ps_clients(uint64_t *host_sign_list, int node_num) { + _ps_client_list.clear(); + _ps_client_sign_set.clear(); + for (int i = 0; i < node_num; ++i) { + if (host_sign_list[i] > 0) { + PSHost host; + host.parse_from_uint64(host_sign_list[i]); + _ps_client_list.push_back(host); + _ps_client_sign_set.insert(host.serialize_to_uint64()); + } + } + std::sort( + _ps_client_list.begin(), _ps_client_list.end(), + [](const PSHost &h1, const PSHost &h2) { return h1.rank < h2.rank; }); + return 0; + } + + virtual int32_t set_ps_clients(std::vector *host_sign_list, + int node_num) { + _ps_client_list.clear(); + _ps_client_sign_set.clear(); + for (int i = 0; i < node_num; ++i) { + if (host_sign_list->at(i) != "") { + PSHost host; + host.parse_from_string(host_sign_list->at(i)); + _ps_client_list.push_back(host); + _ps_client_sign_set.insert(host.rank); + } + } + std::sort( + _ps_client_list.begin(), _ps_client_list.end(), + [](const PSHost &h1, const PSHost &h2) { return h1.rank < h2.rank; }); + return 0; + } + + virtual uint64_t get_local_host_sign() { + if (_ps_client_list.size() > 0) { + return _ps_client_list[0].serialize_to_uint64(); + } else { + return 0; + } + } +}; + +} // namespace distributed +} // namespace paddle diff --git a/paddle/fluid/distributed/service/heter_client.cc b/paddle/fluid/distributed/service/heter_client.cc new file mode 100644 index 00000000000000..f4d1f27377f0e6 --- /dev/null +++ b/paddle/fluid/distributed/service/heter_client.cc @@ -0,0 +1,168 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/distributed/service/heter_client.h" +#include +#include +#include "paddle/fluid/framework/channel.h" +#include "paddle/fluid/framework/data_feed.h" +#include "paddle/fluid/framework/device_worker.h" +#include "paddle/fluid/framework/io/fs.h" +#include "paddle/fluid/framework/scope.h" +#include "paddle/fluid/platform/profiler.h" +#include "paddle/fluid/platform/timer.h" + +DECLARE_int32(rpc_deadline); +namespace paddle { +namespace distributed { + +DEFINE_int32(pserver_timeout_ms, 10800000, "pserver request server timeout_ms"); + +std::shared_ptr HeterClient::s_instance_ = NULL; +bool HeterClient::is_initialized_ = false; + +void HeterClient::MainThread() { + while (running_) { + RpcProfilerControl(); + } +} + +void HeterClient::Stop() { + running_ = false; + if (!is_initialized_) { + VLOG(0) << "HeterClient is not inited, do nothing"; + } else { + if (main_thread_) { + auto status = StopHeterWorker(); + status.wait(); + main_thread_->join(); + main_thread_.reset(nullptr); + } + VLOG(1) << "HeterClient Stop Done"; + } +} + +void HeterClient::RpcProfilerControl() { + if (trainer_id_ == 0) { + if (!do_server_profiler_ && platform::IsProfileEnabled()) { + // send profiler start flag + do_server_profiler_ = true; + auto start_status = StartProfiler(); + start_status.wait(); + } else if (do_server_profiler_ && !platform::IsProfileEnabled()) { + // send profiler end flag + auto stop_status = StopProfiler(); + stop_status.wait(); + do_server_profiler_ = false; + } + } +} + +void HeterClient::CreateClient2XpuConnection() { + brpc::ChannelOptions options; + options.protocol = "baidu_std"; + options.connection_type = "single"; + options.timeout_ms = pserver_timeout_ms; + + xpu_channels_.resize(xpu_list_.size()); + for (size_t i = 0; i < xpu_list_.size(); ++i) { + xpu_channels_[i].reset(new brpc::Channel()); + if (xpu_channels_[i]->Init(xpu_list_[i].c_str(), "", &options) != 0) { + VLOG(0) << "HeterServer channel init fail"; + } + } +} + +void HeterClient::SendAndRecvAsync( + const std::vector& ep, const platform::DeviceContext& ctx, + const framework::Scope& scope, const std::string& message_name, + const std::vector& send_var_name, + const std::vector& recv_var_name) { + platform::RecordEvent record_event("HeterClient->SendAndRecvAsync"); + const platform::DeviceContext* p_ctx = &ctx; + const framework::Scope* p_scope = &scope; + const std::string message_name_val = message_name; + const std::vector send_var_name_val = send_var_name; + const std::vector recv_var_name_val = recv_var_name; + + VLOG(3) << "GRPCClient::SendAndRecv Begin, message_name: " + << message_name_val; + // Todo: get correct channel + int num = trainer_id_ % xpu_channels_.size(); + + brpc::Controller cntl; + cntl.set_timeout_ms(pserver_timeout_ms); + distributed::MultiVarMsg request, response; + auto& request_io_buffer = cntl.request_attachment(); + ::paddle::PsService_Stub stub(xpu_channels_[num].get()); + distributed::SerializeToMultiVarMsgAndIOBuf( + message_name_val, send_var_name_val, recv_var_name_val, *p_ctx, p_scope, + &request, &request_io_buffer); + stub.SendAndRecvVariable(&cntl, &request, &response, NULL); + PADDLE_ENFORCE_NE( + cntl.Failed(), true, + platform::errors::Unimplemented( + "HeterClient::SendAndRecv meets brpc error, error message is %s", + cntl.ErrorText())); + VLOG(4) << "call heter_worker success"; + auto& response_io_buffer = cntl.response_attachment(); + distributed::DeserializeFromMultiVarMsgAndIOBuf(response, &response_io_buffer, + ctx, p_scope); +} + +std::future HeterClient::SendCmd( + uint32_t table_id, int cmd_id, const std::vector& params) { + size_t request_call_num = xpu_channels_.size(); + paddle::distributed::DownpourBrpcClosure* closure = + new paddle::distributed::DownpourBrpcClosure( + request_call_num, [request_call_num, cmd_id](void* done) { + int ret = 0; + auto* closure = (paddle::distributed::DownpourBrpcClosure*)done; + for (size_t i = 0; i < request_call_num; ++i) { + if (closure->check_response(i, cmd_id) != 0) { + ret = -1; + break; + } + } + closure->set_promise_value(ret); + }); + auto promise = std::make_shared>(); + closure->add_promise(promise); + std::future fut = promise->get_future(); + for (size_t i = 0; i < request_call_num; ++i) { + closure->request(i)->set_cmd_id(cmd_id); + closure->request(i)->set_table_id(table_id); + closure->request(i)->set_client_id(trainer_id_); + for (const auto& param : params) { + closure->request(i)->add_params(param); + } + ::paddle::PsService_Stub rpc_stub(xpu_channels_[i].get()); + closure->cntl(i)->set_timeout_ms( + pserver_timeout_ms); // cmd msg don't limit timeout for save/load + rpc_stub.service(closure->cntl(i), closure->request(i), + closure->response(i), closure); + } + return fut; +} + +std::future HeterClient::StartProfiler() { + return SendCmd(-1, PS_START_PROFILER, {}); +} + +std::future HeterClient::StopProfiler() { + return SendCmd(-1, PS_STOP_PROFILER, {}); +} + +} // end namespace distributed +} // end namespace paddle diff --git a/paddle/fluid/distributed/service/heter_client.h b/paddle/fluid/distributed/service/heter_client.h new file mode 100644 index 00000000000000..b1c268c3231f92 --- /dev/null +++ b/paddle/fluid/distributed/service/heter_client.h @@ -0,0 +1,127 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include +#include +#include +#include +#include +#include +#include +#include +#include "brpc/channel.h" +#include "brpc/controller.h" +#include "brpc/server.h" +#include "paddle/fluid/distributed/service/brpc_ps_client.h" +#include "paddle/fluid/distributed/service/brpc_utils.h" +#include "paddle/fluid/distributed/service/sendrecv.pb.h" +#include "paddle/fluid/framework/scope.h" +#include "paddle/fluid/framework/tensor.h" +#include "paddle/fluid/framework/variable_helper.h" +#include "paddle/fluid/platform/macros.h" // for DISABLE_COPY_AND_ASSIGN + +namespace paddle { +namespace distributed { + +using MultiVarMsg = ::paddle::MultiVariableMessage; +using VarMsg = ::paddle::VariableMessage; + +typedef std::function HeterRpcCallbackFunc; + +class OnHeterRpcDone : public google::protobuf::Closure { + public: + OnHeterRpcDone(HeterRpcCallbackFunc func) : handler_(func) {} + virtual ~OnHeterRpcDone() {} + void Run() { + std::unique_ptr self_guard(this); + handler_(this); + } + + HeterRpcCallbackFunc handler_; + MultiVariableMessage response; + brpc::Controller cntl; +}; + +class HeterClient { + public: + virtual ~HeterClient() {} + + HeterClient() { + running_ = true; + main_thread_.reset( + new std::thread(std::bind(&HeterClient::MainThread, this))); + } + + void CreateClient2XpuConnection(); + + void SendAndRecvAsync(const std::vector& ep, + const platform::DeviceContext& ctx, + const framework::Scope& scope, + const std::string& message_name, + const std::vector& send_var_name, + const std::vector& recv_var_name); + + // HeterClient singleton + static std::shared_ptr GetInstance( + const std::vector& endpoint, const int& trainer_id) { + if (NULL == s_instance_) { + is_initialized_ = true; + s_instance_.reset(new paddle::distributed::HeterClient()); + std::vector xpu_list = {endpoint}; + s_instance_->SetXpuList(endpoint); + s_instance_->SetTrainerID(trainer_id); + s_instance_->CreateClient2XpuConnection(); + } + return s_instance_; + } + + void Stop(); + + void MainThread(); + + void RpcProfilerControl(); + + std::future SendCmd(uint32_t table_id, int cmd_id, + const std::vector& params); + + std::future StartProfiler(); + std::future StopProfiler(); + std::future StopHeterWorker(); + + std::vector& GetXpuList() { return xpu_list_; } + + void SetXpuList(const std::vector& xpu_list) { + xpu_list_ = xpu_list; + }; + + void SetTrainerID(const int& trainer_id) { trainer_id_ = trainer_id; } + + private: + static std::shared_ptr s_instance_; + + protected: + static bool is_initialized_; + std::unique_ptr main_thread_{nullptr}; + std::vector> xpu_channels_; + DISABLE_COPY_AND_ASSIGN(HeterClient); + std::vector xpu_list_; + + bool running_ = false; + int trainer_id_; + bool do_server_profiler_ = false; +}; + +} // end namespace distributed +} // end namespace paddle diff --git a/paddle/fluid/distributed/service/heter_server.cc b/paddle/fluid/distributed/service/heter_server.cc new file mode 100644 index 00000000000000..d9daf8be1ccb66 --- /dev/null +++ b/paddle/fluid/distributed/service/heter_server.cc @@ -0,0 +1,91 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/distributed/service/heter_server.h" +#include +#include +#include "paddle/fluid/framework/fleet/heter_wrapper.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/scope.h" +#include "paddle/fluid/platform/timer.h" + +namespace paddle { +namespace distributed { + +std::shared_ptr HeterServer::s_instance_ = NULL; + +void HeterServer::RegisterServiceHandler(std::string message_name, + HeterServiceHandler func) { + service_.RegisterServiceHandler(message_name, func); +} + +void HeterServer::StartHeterService() { + server_.AddService(&service_, brpc::SERVER_DOESNT_OWN_SERVICE); + brpc::ServerOptions options; + if (server_.Start(endpoint_.c_str(), &options) != 0) { + VLOG(0) << "heter server start fail"; + } else { + VLOG(0) << "heter server start success! listen on " << endpoint_; + } + + { + std::lock_guard lock(this->mutex_ready_); + ready_ = 1; + } + condition_ready_.notify_all(); + + server_.Join(); +} + +void HeterServer::SetEndPoint(std::string& endpoint) { + endpoint_ = endpoint; + service_.SetEndpoint(endpoint); +} + +void HeterServer::SetFanin(int& fan_in) { service_.SetFanin(fan_in); } + +void HeterServer::WaitServerReady() { + std::unique_lock lock(this->mutex_ready_); + condition_ready_.wait(lock, [=] { return this->ready_ == 1; }); +} + +int32_t HeterService::stop_profiler(const PsRequestMessage& request, + PsResponseMessage& response, + brpc::Controller* cntl) { + platform::DisableProfiler( + platform::EventSortingKey::kDefault, + string::Sprintf("heter_worker_%s_profile", endpoint_)); + return 0; +} + +int32_t HeterService::start_profiler(const PsRequestMessage& request, + PsResponseMessage& response, + brpc::Controller* cntl) { + platform::EnableProfiler(platform::ProfilerState::kAll); + return 0; +} + +int32_t HeterService::stop_heter_worker(const PsRequestMessage& request, + PsResponseMessage& response, + brpc::Controller* cntl) { + auto client_id = request.client_id(); + stop_cpu_worker_set_.insert(client_id); + if (stop_cpu_worker_set_.size() == fan_in_) { + is_exit_ = true; + } + return 0; +} + +} // end namespace distributed +} // end namespace paddle diff --git a/paddle/fluid/distributed/service/heter_server.h b/paddle/fluid/distributed/service/heter_server.h new file mode 100644 index 00000000000000..07fff7adc6e94a --- /dev/null +++ b/paddle/fluid/distributed/service/heter_server.h @@ -0,0 +1,243 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include +#include +#include +#include +#include +#include +#include +#include +#include "brpc/channel.h" +#include "brpc/controller.h" +#include "brpc/server.h" +#include "paddle/fluid/distributed/service/brpc_utils.h" +#include "paddle/fluid/distributed/service/sendrecv.pb.h" +#include "paddle/fluid/framework/executor.h" +#include "paddle/fluid/framework/program_desc.h" +#include "paddle/fluid/framework/scope.h" +#include "paddle/fluid/framework/tensor.h" +#include "paddle/fluid/framework/variable_helper.h" +#include "paddle/fluid/platform/macros.h" // for DISABLE_COPY_AND_ASSIGN +#include "paddle/fluid/platform/profiler.h" + +namespace paddle { +namespace distributed { + +using MultiVarMsg = ::paddle::MultiVariableMessage; +using VarMsg = ::paddle::VariableMessage; + +class HeterService; +typedef int32_t (HeterService::*serviceHandlerFunc)( + const PsRequestMessage& request, PsResponseMessage& response, + brpc::Controller* cntl); + +typedef std::function HeterRpcCallbackFunc; +typedef std::function + HeterServiceHandler; + +class HeterService : public ::paddle::PsService { + public: + HeterService() { + _service_handler_map[PS_STOP_SERVER] = &HeterService::stop_heter_worker; + _service_handler_map[PS_START_PROFILER] = &HeterService::start_profiler; + _service_handler_map[PS_STOP_PROFILER] = &HeterService::stop_profiler; + } + + virtual ~HeterService() {} + + virtual void service(::google::protobuf::RpcController* controller, + const ::paddle::PsRequestMessage* request, + ::paddle::PsResponseMessage* response, + ::google::protobuf::Closure* done) { + brpc::ClosureGuard done_guard(done); + std::string log_label("ReceiveCmd-"); + + response->set_err_code(0); + response->set_err_msg(""); + brpc::Controller* cntl = static_cast(controller); + auto itr = _service_handler_map.find(request->cmd_id()); + if (itr == _service_handler_map.end()) { + std::string err_msg( + "undefined cmd_id, should match PsCmdID in ps.proto, cmd_id:"); + err_msg.append(std::to_string(request->cmd_id())); + return; + } + serviceHandlerFunc handler_func = itr->second; + int service_ret = (this->*handler_func)(*request, *response, cntl); + if (service_ret != 0) { + response->set_err_code(service_ret); + response->set_err_msg("server internal error"); + } + }; + + void SendAndRecvVariable(::google::protobuf::RpcController* controller, + const MultiVarMsg* request, MultiVarMsg* response, + ::google::protobuf::Closure* done) { + brpc::ClosureGuard done_guard(done); + std::string message_name = request->message_name(); + auto itr = handler_map_.find(message_name); + brpc::Controller* cntl = static_cast(controller); + PADDLE_ENFORCE_NE( + itr, handler_map_.end(), + platform::errors::InvalidArgument( + "HeterService::SendAndRecvVariable Get illegal message_name: %s " + "which is not in HeterService::handler_map_", + message_name)); + itr->second(request, response, cntl); + } + + void RegisterServiceHandler(std::string message_name, + HeterServiceHandler func) { + handler_map_[message_name] = func; + } + + void SetEndpoint(const std::string& end_point) { endpoint_ = end_point; } + void SetFanin(const int& fan_in) { fan_in_ = fan_in; } + bool IsExit() { return is_exit_; } + + private: + int32_t stop_profiler(const PsRequestMessage& request, + PsResponseMessage& response, brpc::Controller* cntl); + + int32_t start_profiler(const PsRequestMessage& request, + PsResponseMessage& response, brpc::Controller* cntl); + + int32_t stop_heter_worker(const PsRequestMessage& request, + PsResponseMessage& response, + brpc::Controller* cntl); + + private: + std::string endpoint_; + std::unordered_map handler_map_; + std::unordered_map _service_handler_map; + std::unordered_set stop_cpu_worker_set_; + int fan_in_; + bool is_exit_ = false; +}; + +class HeterServer { + public: + virtual ~HeterServer() {} + + void Stop() { + server_.Stop(1000); + server_.Join(); + } + + bool IsExit() { return service_.IsExit(); } + + HeterServer() {} + + void RegisterServiceHandler(std::string message_name, + HeterServiceHandler func); + + void StartHeterService(); + + void SetEndPoint(std::string& endpoint); + void SetFanin(int& fan_in); + + // HeterWrapper singleton + static std::shared_ptr GetInstance() { + if (NULL == s_instance_) { + s_instance_.reset(new HeterServer()); + } + return s_instance_; + } + + void WaitServerReady(); + + private: + static std::shared_ptr s_instance_; + std::string endpoint_; + + protected: + brpc::Server server_; + HeterService service_; + DISABLE_COPY_AND_ASSIGN(HeterServer); + std::mutex mutex_ready_; + std::condition_variable condition_ready_; + int ready_; +}; + +class HeterRequestHandler { + public: + HeterRequestHandler() + : dev_ctx_(nullptr), + executor_(nullptr), + scope_(nullptr), + program_(nullptr) {} + + virtual ~HeterRequestHandler() {} + + void SetScope(framework::Scope* scope) { scope_ = scope; } + void SetDevCtx(const platform::DeviceContext* dev_ctx) { dev_ctx_ = dev_ctx; } + void SetProgram(framework::ProgramDesc* program) { program_ = program; } + void SetExecutor(framework::Executor* executor) { executor_ = executor; } + + void SetGradToPreparedCtx( + std::unordered_map< + std::string, std::shared_ptr>* g) { + message_to_prepared_ctx_ = g; + } + + virtual int Handle(const MultiVarMsg* request, MultiVarMsg* response, + brpc::Controller* cntl) = 0; + + protected: + const platform::DeviceContext* dev_ctx_; + framework::Executor* executor_; + framework::Scope* scope_; + framework::ProgramDesc* program_; + + std::unordered_map>* + message_to_prepared_ctx_; +}; + +class RequestSendAndRecvHandler final : public HeterRequestHandler { + public: + RequestSendAndRecvHandler() {} + virtual ~RequestSendAndRecvHandler() {} + int Handle(const MultiVarMsg* request, MultiVarMsg* response, + brpc::Controller* cntl) override { + platform::RecordEvent record_event("RequestSendAndRecvHandler->Handle"); + auto& local_scope = scope_->NewScope(); + auto message_name = request->message_name(); + auto& request_io_buffer = cntl->request_attachment(); + distributed::DeserializeFromMultiVarMsgAndIOBuf( + *request, &request_io_buffer, *dev_ctx_, &local_scope); + executor_->RunPreparedContext( + (*message_to_prepared_ctx_)[message_name].get(), &local_scope, false); + + auto response_var_nums = request->recv_var_names_size(); + std::vector response_var_names(response_var_nums), + empty_var_names{}; + + for (int var_idx = 0; var_idx < response_var_nums; ++var_idx) { + response_var_names[var_idx] = request->recv_var_names(var_idx); + } + auto& response_io_buffer = cntl->response_attachment(); + distributed::SerializeToMultiVarMsgAndIOBuf( + message_name, response_var_names, empty_var_names, *dev_ctx_, + &local_scope, response, &response_io_buffer); + scope_->DeleteScope(&local_scope); + return 0; + } +}; + +} // end namespace distributed +} // end namespace paddle diff --git a/paddle/fluid/distributed/service/ps_client.cc b/paddle/fluid/distributed/service/ps_client.cc new file mode 100644 index 00000000000000..dd5fb9c24b32ce --- /dev/null +++ b/paddle/fluid/distributed/service/ps_client.cc @@ -0,0 +1,89 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/distributed/service/ps_client.h" + +#include + +#include "brpc/server.h" +#include "glog/logging.h" +#include "paddle/fluid/distributed/service/brpc_ps_client.h" +#include "paddle/fluid/distributed/table/table.h" + +namespace paddle { +namespace distributed { +REGISTER_CLASS(PSClient, BrpcPsClient); + +int32_t PSClient::configure( + const PSParameter &config, + const std::map> ®ions, + PSEnvironment &env, size_t client_id) { + _env = &env; + _config = config; + _dense_pull_regions = regions; + _client_id = client_id; + _config.mutable_worker_param() + ->mutable_downpour_worker_param() + ->mutable_downpour_table_param() + ->CopyFrom(_config.server_param() + .downpour_server_param() + .downpour_table_param()); + + const auto &work_param = _config.worker_param().downpour_worker_param(); + + for (size_t i = 0; i < work_param.downpour_table_param_size(); ++i) { + auto *accessor = CREATE_CLASS( + ValueAccessor, + work_param.downpour_table_param(i).accessor().accessor_class()); + accessor->configure(work_param.downpour_table_param(i).accessor()); + accessor->initialize(); + _table_accessors[work_param.downpour_table_param(i).table_id()].reset( + accessor); + } + return initialize(); +} + +PSClient *PSClientFactory::create(const PSParameter &ps_config) { + const auto &config = ps_config.server_param(); + if (!config.has_downpour_server_param()) { + LOG(ERROR) << "miss downpour_server_param in ServerParameter"; + return NULL; + } + + if (!config.downpour_server_param().has_service_param()) { + LOG(ERROR) << "miss service_param in ServerParameter.downpour_server_param"; + return NULL; + } + + if (!config.downpour_server_param().service_param().has_client_class()) { + LOG(ERROR) << "miss client_class in " + "ServerParameter.downpour_server_param.service_param"; + return NULL; + } + + const auto &service_param = config.downpour_server_param().service_param(); + PSClient *client = CREATE_CLASS(PSClient, service_param.client_class()); + if (client == NULL) { + LOG(ERROR) << "client is not registered, server_name:" + << service_param.client_class(); + return NULL; + } + + TableManager::instance().initialize(); + LOG(INFO) << "Create PSClient[" << service_param.client_class() + << "] success"; + return client; +} +} // namespace distributed +} // namespace paddle diff --git a/paddle/fluid/distributed/service/ps_client.h b/paddle/fluid/distributed/service/ps_client.h new file mode 100644 index 00000000000000..23b00b3c816088 --- /dev/null +++ b/paddle/fluid/distributed/service/ps_client.h @@ -0,0 +1,208 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include +#include +#include "paddle/fluid/distributed/ps.pb.h" +#include "paddle/fluid/distributed/service/env.h" +#include "paddle/fluid/distributed/service/sendrecv.pb.h" +#include "paddle/fluid/distributed/table/accessor.h" + +namespace paddle { +namespace distributed { + +typedef std::function PSClientCallBack; +class PSClientClosure : public google::protobuf::Closure { + public: + PSClientClosure(PSClientCallBack callback) : _callback(callback) {} + virtual ~PSClientClosure() {} + virtual void set_promise_value(int value) { + for (auto &promise : _promises) { + promise->set_value(value); + } + } + + void add_promise(std::shared_ptr> &promise) { + _promises.push_back(promise); + } + + protected: + PSClientCallBack _callback; + std::vector>> _promises; +}; + +class PSClient { + public: + PSClient() {} + virtual ~PSClient() {} + PSClient(PSClient &&) = delete; + PSClient(const PSClient &) = delete; + + virtual int32_t configure( + const PSParameter &config, + const std::map> + ®ions, + PSEnvironment &_env, size_t client_id) final; + + virtual int32_t create_client2client_connection( + int pserver_timeout_ms, int pserver_connect_timeout_ms, + int max_retry) = 0; + + // 触发table数据退场 + virtual std::future shrink(uint32_t table_id) = 0; + + // 全量table进行数据load + virtual std::future load(const std::string &epoch, + const std::string &mode) = 0; + // 指定table数据load + virtual std::future load(uint32_t table_id, const std::string &epoch, + const std::string &mode) = 0; + // 全量table数据save value_accessor根据mode,可能有不同的save条件 + virtual std::future save(const std::string &epoch, + const std::string &mode) = 0; + // 指定table数据save value_accessor根据mode,可能有不同的save条件 + virtual std::future save(uint32_t table_id, const std::string &epoch, + const std::string &mode) = 0; + + //清空table数据 + virtual std::future clear() = 0; + virtual std::future clear(uint32_t table_id) = 0; + + // pull dense的参数部分,并分块填充到本地网络参数中 + // start和num用于拉取部分参数 + // future结束前keys和values缓冲区不能再次使用 + // client将values按照区块拆包后送交多个sender + // sender聚集同一区块的请求,累计多个填充buffer + // server将参数区块中配置的某一维提取返回 + // 返回数据解包后填充到累计的多个buffer中 + virtual std::future pull_dense(Region *regions, size_t region_num, + size_t table_id) = 0; //保留 + + // firstly push dense param for parameter server + // this is neccessary because dense weight initialized in trainer on cold + // start + virtual std::future push_dense_param(const Region *regions, + size_t region_num, + size_t table_id) = 0; + + // 使用keys进行pull请求,结果填充values + // keys和values的个数均为num个,每个value占用select_size空间 + // future结束前keys和values缓冲区不能再次使用 + // 整合多个线程请求的keys,聚集并分散发送到server + // 返回结果后,遍历buffer并对values赋值 + virtual std::future pull_sparse(float **select_values, + size_t table_id, + const uint64_t *keys, + size_t num) = 0; + + virtual std::future print_table_stat(uint32_t table_id) = 0; + + // 确保所有积攒中的请求都发起发送 + virtual std::future flush() = 0; + // server优雅退出 + virtual std::future stop_server() = 0; + + // server profilera + virtual std::future start_profiler() = 0; + virtual std::future stop_profiler() = 0; + + virtual std::future barrier(size_t table_id, + uint32_t barrier_type) = 0; + + virtual std::future pull_geo_param(size_t table_id, + std::vector *values, + std::vector *keys, + int pserver_idx) = 0; + + virtual void finalize_worker() = 0; + // client to client, 消息发送 + virtual std::future send_client2client_msg(int msg_type, + int to_client_id, + const std::string &msg) { + LOG(FATAL) << "Did not implement"; + std::promise promise; + std::future fut = promise.get_future(); + promise.set_value(-1); + return fut; + } + // client2client消息处理,std::function ret (msg_type, from_client_id, msg) + typedef std::function MsgHandlerFunc; + virtual int registe_client2client_msg_handler(int msg_type, + MsgHandlerFunc handler) { + _msg_handler_map[msg_type] = handler; + return 0; + } + virtual int handle_client2client_msg(int msg_type, int from_client_id, + const std::string &msg) { + auto itr = _msg_handler_map.find(msg_type); + if (itr == _msg_handler_map.end()) { + LOG(WARNING) << "unknown client2client_msg type:" << msg_type; + return -1; + } + return itr->second(msg_type, from_client_id, msg); + } + + virtual ValueAccessor *table_accessor(size_t table_id) { + auto itr = _table_accessors.find(table_id); + if (itr == _table_accessors.end()) { + return NULL; + } + return itr->second.get(); + } + + virtual size_t get_server_nums() = 0; + + virtual std::future push_dense_raw_gradient( + int table_id, float *total_send_data, size_t total_send_data_size, + void *done) = 0; + + virtual std::future push_sparse_raw_gradient( + size_t table_id, const uint64_t *keys, const float **update_values, + size_t num, void *done) = 0; + + virtual std::future push_sparse_raw_gradient_partial( + size_t table_id, const uint64_t *keys, const float **update_values, + uint32_t num, void *done, int pserver_idx) = 0; + + virtual std::future push_sparse_param(size_t table_id, + const uint64_t *keys, + const float **update_values, + size_t num, void *done) = 0; + + protected: + virtual int32_t initialize() = 0; + size_t _client_id; + PSParameter _config; + std::map> + _dense_pull_regions; + PSEnvironment *_env; + std::unordered_map> _table_accessors; + std::unordered_map + _msg_handler_map; //处理client2client消息 +}; +REGISTER_REGISTERER(PSClient); + +class PSClientFactory { + public: + static PSClient *create(const PSParameter &config); +}; +} // namespace distributed +} // namespace paddle diff --git a/paddle/fluid/distributed/service/sendrecv.proto b/paddle/fluid/distributed/service/sendrecv.proto new file mode 100644 index 00000000000000..8f5c8baa2f8242 --- /dev/null +++ b/paddle/fluid/distributed/service/sendrecv.proto @@ -0,0 +1,113 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto2"; +package paddle; +option cc_generic_services = true; +option cc_enable_arenas = true; + +enum PsCmdID { + PS_PULL_DENSE_TABLE = 0; + PS_PUSH_DENSE_TABLE = 1; + PS_PULL_SPARSE_TABLE = 2; + PS_PUSH_SPARSE_TABLE = 3; + PS_SHRINK_TABLE = 4; + PS_SAVE_ONE_TABLE = 5; + PS_SAVE_ALL_TABLE = 6; + PS_LOAD_ONE_TABLE = 7; + PS_LOAD_ALL_TABLE = 8; + PS_CLEAR_ONE_TABLE = 9; + PS_CLEAR_ALL_TABLE = 10; + PS_PUSH_DENSE_PARAM = 11; + PS_STOP_SERVER = 12; + PS_SAVE_ONE_CACHE_TABLE = 13; + PS_GET_CACHE_THRESHOLD = 14; + PS_CACHE_SHUFFLE = 15; + PS_COPY_TABLE = 16; + PS_COPY_TABLE_BY_FEASIGN = 17; + PS_PULL_SPARSE_TABLE_WITH_DEPENDENCY = 18; + PS_PUSH_SPARSE_TABLE_WITH_DEPENDENCY = 19; + PS_PRINT_TABLE_STAT = 20; + PS_SAVE_ONE_TABLE_PREFIX = 21; + PS_SAVE_ONE_TABLE_WITH_WHITELIST = 22; + PS_LOAD_ONE_TABLE_WITH_WHITELIST = 23; + PS_PULL_GEO_PARAM = 24; + PS_BARRIER = 25; + PS_PUSH_SPARSE_PARAM = 26; + PS_START_PROFILER = 27; + PS_STOP_PROFILER = 28; +} + +message PsRequestMessage { + required uint32 cmd_id = 1; + optional uint32 table_id = 2; + repeated bytes params = 3; + optional int32 client_id = 4; + optional bytes data = 5; +}; + +message PsResponseMessage { + required int32 err_code = 1 [ default = 0 ]; + required string err_msg = 2 [ default = "" ]; + optional bytes data = 3; +}; + +enum VarType { + LOD_TENSOR = 0; + SELECTED_ROWS = 1; +} + +message VariableMessage { + enum Type { + // Pod Types + BOOL = 0; + INT16 = 1; + INT32 = 2; + INT64 = 3; + FP16 = 4; + FP32 = 5; + FP64 = 6; + } + + message LodData { repeated int64 lod_data = 1; } + optional string varname = 1; + // TODO(Yancey1989): reference framework::proto::VarDesc::VarType + optional VarType type = 2; + // bool persistable is not needed for sending. + // tensor info: + optional Type data_type = 3; + repeated int64 dims = 4; + + // lod details: + optional int64 lod_level = 5; + repeated LodData lod = 6; + // selected_rows height, aka. original dim0 + optional int64 slr_height = 7; + // tensor data + optional bytes data = 8; +} + +// for SendAndRecv RPC method +message MultiVariableMessage { + // message flags + required string message_name = 1; + repeated string send_var_names = 2; + repeated string recv_var_names = 3; + repeated VariableMessage var_messages = 4; +}; + +service PsService { + rpc service(PsRequestMessage) returns (PsResponseMessage); + rpc SendAndRecvVariable(MultiVariableMessage) returns (MultiVariableMessage); +}; \ No newline at end of file diff --git a/paddle/fluid/distributed/service/server.cc b/paddle/fluid/distributed/service/server.cc new file mode 100644 index 00000000000000..1582b8739c1775 --- /dev/null +++ b/paddle/fluid/distributed/service/server.cc @@ -0,0 +1,87 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/distributed/service/server.h" +#include "glog/logging.h" +#include "paddle/fluid/distributed/service/brpc_ps_server.h" +#include "paddle/fluid/distributed/table/table.h" + +namespace paddle { +namespace distributed { + +REGISTER_CLASS(PSServer, BrpcPsServer); +REGISTER_CLASS(PsBaseService, PsService); + +PSServer *PSServerFactory::create(const PSParameter &ps_config) { + const auto &config = ps_config.server_param(); + + if (!config.has_downpour_server_param()) { + LOG(ERROR) << "miss downpour_server_param in ServerParameter"; + return NULL; + } + + if (!config.downpour_server_param().has_service_param()) { + LOG(ERROR) << "miss service_param in ServerParameter.downpour_server_param"; + return NULL; + } + + if (!config.downpour_server_param().service_param().has_server_class()) { + LOG(ERROR) << "miss server_class in " + "ServerParameter.downpour_server_param.service_param"; + return NULL; + } + + const auto &service_param = config.downpour_server_param().service_param(); + PSServer *server = CREATE_CLASS(PSServer, service_param.server_class()); + if (server == NULL) { + LOG(ERROR) << "server is not registered, server_name:" + << service_param.server_class(); + return NULL; + } + TableManager::instance().initialize(); + return server; +} + +int32_t PSServer::configure(const PSParameter &config, PSEnvironment &env, + size_t server_rank) { + _config = config.server_param(); + _rank = server_rank; + _environment = &env; + _shuffled_ins = + paddle::framework::MakeChannel>(); + const auto &downpour_param = _config.downpour_server_param(); + + uint32_t barrier_table = UINT32_MAX; + + for (size_t i = 0; i < downpour_param.downpour_table_param_size(); ++i) { + auto *table = CREATE_CLASS( + Table, downpour_param.downpour_table_param(i).table_class()); + + if (downpour_param.downpour_table_param(i).table_class() == + "BarrierTable") { + barrier_table = downpour_param.downpour_table_param(i).table_id(); + } + table->initialize(downpour_param.downpour_table_param(i), + config.fs_client_param()); + _table_map[downpour_param.downpour_table_param(i).table_id()].reset(table); + } + + if (barrier_table != UINT32_MAX) { + _table_map[barrier_table]->set_table_map(&_table_map); + } + + return initialize(); +} +} // namespace distributed +} // namespace paddle diff --git a/paddle/fluid/distributed/service/server.h b/paddle/fluid/distributed/service/server.h new file mode 100644 index 00000000000000..4faa0f9db2c4c5 --- /dev/null +++ b/paddle/fluid/distributed/service/server.h @@ -0,0 +1,150 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include +#include +#include "butil/endpoint.h" +#include "google/protobuf/service.h" +#include "paddle/fluid/distributed/common/registerer.h" +#include "paddle/fluid/distributed/ps.pb.h" +#include "paddle/fluid/distributed/service/env.h" +#include "paddle/fluid/distributed/service/sendrecv.pb.h" +#include "paddle/fluid/framework/channel.h" + +namespace paddle { +namespace distributed { + +class Table; + +class PSServer { + public: + PSServer() {} + virtual ~PSServer() {} + PSServer(PSServer &&) = delete; + PSServer(const PSServer &) = delete; + + virtual int32_t configure(const PSParameter &config, PSEnvironment &env, + size_t server_rank) final; + + // return server_ip + virtual std::string ip() { return butil::my_ip_cstr(); } + // return server_port + virtual int32_t port() = 0; + + virtual uint64_t start(const std::string &ip, uint32_t port) = 0; + virtual int32_t stop() = 0; + + inline size_t rank() const { return _rank; } + + inline PSEnvironment *environment() { return _environment; } + + inline const ServerParameter *config() const { return &_config; } + inline Table *table(size_t table_id) { + auto itr = _table_map.find(table_id); + if (itr != _table_map.end()) { + return itr->second.get(); + } + return NULL; + } + + inline std::unordered_map> *table() { + return &_table_map; + } + + typedef std::function MsgHandlerFunc; + virtual int registe_pserver2pserver_msg_handler(int msg_type, + MsgHandlerFunc handler) { + _msg_handler_map[msg_type] = handler; + return 0; + } + + paddle::framework::Channel> _shuffled_ins; + + protected: + virtual int32_t initialize() = 0; + + protected: + size_t _rank; + ServerParameter _config; + PSEnvironment *_environment; + std::unordered_map> _table_map; + std::unordered_map _msg_handler_map; +}; + +REGISTER_REGISTERER(PSServer); + +typedef std::function PServerCallBack; + +class PServerClosure : public google::protobuf::Closure { + public: + PServerClosure(PServerCallBack callback) : _callback(callback) {} + virtual ~PServerClosure() {} + virtual void set_promise_value(int value) { + for (auto &promise : _promises) { + promise->set_value(value); + } + } + void add_promise(std::shared_ptr> &promise) { + _promises.push_back(promise); + } + + protected: + PServerCallBack _callback; + std::vector>> _promises; +}; + +class PsBaseService : public PsService { + public: + PsBaseService() : _rank(0), _server(NULL), _config(NULL) {} + virtual ~PsBaseService() {} + + virtual int32_t configure(PSServer *server) { + _server = server; + _rank = _server->rank(); + _config = _server->config(); + return 0; + } + virtual void service(::google::protobuf::RpcController *controller, + const ::paddle::PsRequestMessage *request, + ::paddle::PsResponseMessage *response, + ::google::protobuf::Closure *done) override = 0; + + virtual void set_response_code(PsResponseMessage &response, int err_code, + const char *err_msg) { + response.set_err_msg(err_msg); + response.set_err_code(err_code); + LOG(WARNING) << "Resonse err_code:" << err_code << " msg:" << err_msg; + } + + virtual int32_t initialize() = 0; + + protected: + size_t _rank; + PSServer *_server; + const ServerParameter *_config; +}; +REGISTER_REGISTERER(PsBaseService); + +class PSServerFactory { + public: + static PSServer *create(const PSParameter &config); +}; +} // namespace distributed +} // namespace paddle diff --git a/paddle/fluid/distributed/service/service.cc b/paddle/fluid/distributed/service/service.cc new file mode 100644 index 00000000000000..40a6d2e1227187 --- /dev/null +++ b/paddle/fluid/distributed/service/service.cc @@ -0,0 +1,129 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/fluid/distributed/service/service.h" + +#include +#include +#include +#include +#include "paddle/fluid/distributed/service/communicator.h" +#include "paddle/fluid/string/string_helper.h" + +using namespace std; + +namespace paddle { +namespace distributed { + +paddle::distributed::PSParameter load_from_prototxt( + const std::string& filename) { + paddle::distributed::PSParameter param; + int file_descriptor = open(filename.c_str(), O_RDONLY); + + if (file_descriptor == -1) { + VLOG(3) << "FATAL: fail to parse " << filename; + exit(-1); + } + + google::protobuf::io::FileInputStream fileInput(file_descriptor); + if (!google::protobuf::TextFormat::Parse(&fileInput, ¶m)) { + VLOG(3) << "FATAL: fail to parse " << filename; + exit(-1); + } + + close(file_descriptor); + return param; +} + +void PSCore::init_gflag(const std::string& gflags) { + LOG(INFO) << "Init With Gflags:" << gflags; + std::vector flags = paddle::string::split_string(gflags); + if (flags.size() < 1) { + flags.push_back("-max_body_size=314217728"); + flags.push_back("-bthread_concurrency=40"); + flags.push_back("-socket_max_unwritten_bytes=2048000000"); + flags.push_back("-max_connection_pool_size=1950"); + } + auto it = flags.begin(); + flags.insert(it, "exe default"); + char* flags_ptr[flags.size()]; + for (size_t i = 0; i < flags.size(); ++i) { + flags_ptr[i] = (char*)(flags[i].c_str()); + } + int params_cnt = flags.size(); + char** params_ptr = &(flags_ptr[0]); + ::google::ParseCommandLineFlags(¶ms_cnt, ¶ms_ptr, true); +} + +int PSCore::init_server(const std::string& dist_desc, + const std::vector* host_sign_list, + int node_num, int index) { + google::protobuf::TextFormat::ParseFromString(dist_desc, &_ps_param); + init_gflag(_ps_param.init_gflags()); + _ps_env = paddle::distributed::PaddlePSEnvironment(); + _ps_env.set_ps_servers(host_sign_list, node_num); + int ret = 0; + _server_ptr = std::shared_ptr( + paddle::distributed::PSServerFactory::create(_ps_param)); + ret = _server_ptr->configure(_ps_param, _ps_env, index); + CHECK(ret == 0) << "failed to configure server"; + return ret; +} + +int PSCore::init_worker( + const std::string& dist_desc, + const std::map>& regions, + const std::vector* host_sign_list, int node_num, int index) { + google::protobuf::TextFormat::ParseFromString(dist_desc, &_ps_param); + init_gflag(_ps_param.init_gflags()); + _ps_env = paddle::distributed::PaddlePSEnvironment(); + _ps_env.set_ps_servers(host_sign_list, node_num); + int ret = 0; + VLOG(1) << "PSCore::init_worker"; + auto* communicator = Communicator::GetInstance(); + ret = communicator->GetPsClient()->configure(_ps_param, regions, _ps_env, + index); + communicator->Start(); + return ret; +} + +std::vector PSCore::get_client_info() { + return _ps_env.get_client_info(); +} + +int PSCore::create_client2client_connection(int pserver_timeout_ms, + int pserver_connect_timeout_ms, + int max_retry) { + int ret = _worker_ptr->create_client2client_connection( + pserver_timeout_ms, pserver_connect_timeout_ms, max_retry); + return ret; +} + +uint64_t PSCore::run_server(const std::string& ip, uint32_t port) { + return _server_ptr->start(ip, port); +} + +int PSCore::finalize_worker() { + _worker_ptr->finalize_worker(); + return 0; +} + +int PSCore::stop_server() { + auto stop_status = _worker_ptr->stop_server(); + stop_status.wait(); + return 0; +} +paddle::distributed::PSParameter* PSCore::get_param() { return &_ps_param; } +} // namespace distributed +} // namespace paddle diff --git a/paddle/fluid/distributed/service/service.h b/paddle/fluid/distributed/service/service.h new file mode 100644 index 00000000000000..97cb864e344bf8 --- /dev/null +++ b/paddle/fluid/distributed/service/service.h @@ -0,0 +1,64 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include +#include + +#include +#include "paddle/fluid/distributed/ps.pb.h" +#include "paddle/fluid/distributed/service/ps_client.h" +#include "paddle/fluid/distributed/service/sendrecv.pb.h" +#include "paddle/fluid/distributed/service/server.h" + +namespace paddle { +namespace distributed { + +class PSCore { + public: + explicit PSCore() {} + virtual ~PSCore() {} + + virtual int init_server(const std::string& dist_desc, + const std::vector* host_sign_list, + int node_num, int index); + virtual int init_worker( + const std::string& dist_desc, + const std::map>& + regions, + const std::vector* host_sign_list, int node_num, int index); + virtual uint64_t run_server(const std::string& ip, uint32_t port); + virtual int stop_server(); + virtual int finalize_worker(); + virtual std::vector get_client_info(); + virtual int create_client2client_connection(int pserver_timeout_ms, + int pserver_connect_timeout_ms, + int max_retry); + std::shared_ptr + _server_ptr; // pointer to server + std::shared_ptr + _worker_ptr; // pointer to worker + virtual paddle::distributed::PSParameter* get_param(); + + private: + void init_gflag(const std::string& gflags); + paddle::distributed::PSParameter _ps_param; + paddle::distributed::PaddlePSEnvironment _ps_env; +}; + +} // namespace distributed +} // namespace paddle diff --git a/paddle/fluid/distributed/test/CMakeLists.txt b/paddle/fluid/distributed/test/CMakeLists.txt index e4cc93c9adf65c..405fe7561115e6 100644 --- a/paddle/fluid/distributed/test/CMakeLists.txt +++ b/paddle/fluid/distributed/test/CMakeLists.txt @@ -16,3 +16,16 @@ cc_test(geo_table_test SRCS geo_table_test.cc DEPS common_table table tensor_acc set_source_files_properties(barrier_table_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) cc_test(barrier_table_test SRCS barrier_table_test.cc DEPS common_table table tensor_accessor ps_framework_proto ${COMMON_DEPS}) + + +# open it until CI support brpc +return() + +set_source_files_properties(brpc_service_dense_sgd_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) +cc_test(brpc_service_dense_sgd_test SRCS brpc_service_dense_sgd_test.cc DEPS scope server client communicator ps_service boost table ps_framework_proto ${COMMON_DEPS}) + +set_source_files_properties(brpc_service_sparse_sgd_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) +cc_test(brpc_service_sparse_sgd_test SRCS brpc_service_sparse_sgd_test.cc DEPS scope server client communicator ps_service boost table ps_framework_proto ${COMMON_DEPS}) + +set_source_files_properties(brpc_utils_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) +cc_test(brpc_utils_test SRCS brpc_utils_test.cc DEPS brpc_utils scope math_function ${COMMON_DEPS} ${RPC_DEPS}) diff --git a/paddle/fluid/distributed/test/brpc_service_dense_sgd_test.cc b/paddle/fluid/distributed/test/brpc_service_dense_sgd_test.cc new file mode 100644 index 00000000000000..3b2f808a2a82d5 --- /dev/null +++ b/paddle/fluid/distributed/test/brpc_service_dense_sgd_test.cc @@ -0,0 +1,272 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include // NOLINT +#include +#include // NOLINT + +#include "google/protobuf/text_format.h" +#include "gtest/gtest.h" +#include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/framework/scope.h" +#include "paddle/fluid/framework/tensor_util.h" +#include "paddle/fluid/framework/variable.h" + +#include "paddle/fluid/operators/math/math_function.h" +#include "paddle/fluid/platform/place.h" +#include "paddle/fluid/string/printf.h" + +#include "paddle/fluid/distributed/ps.pb.h" +#include "paddle/fluid/distributed/service/brpc_ps_client.h" +#include "paddle/fluid/distributed/service/brpc_ps_server.h" +#include "paddle/fluid/distributed/service/env.h" +#include "paddle/fluid/distributed/service/ps_client.h" +#include "paddle/fluid/distributed/service/sendrecv.pb.h" +#include "paddle/fluid/distributed/service/service.h" + +namespace framework = paddle::framework; +namespace platform = paddle::platform; +namespace operators = paddle::operators; +namespace math = paddle::operators::math; +namespace memory = paddle::memory; +namespace distributed = paddle::distributed; + +void CreateVarsOnScope(framework::Scope* scope, platform::CPUPlace* place) { + auto x_var = scope->Var("x"); + x_var->GetMutable(); +} + +void InitTensorsOnClient(framework::Scope* scope, platform::CPUPlace* place, + int64_t rows_numel) { + CreateVarsOnScope(scope, place); + + auto x_var = scope->Var("x")->GetMutable(); + float* x_ptr = + x_var->mutable_data(framework::DDim({1, rows_numel}), *place); + for (int64_t i = 0; i < rows_numel; ++i) x_ptr[i] = 1.0 * (float)i; +} + +void GetDownpourDenseTableProto( + ::paddle::distributed::TableParameter* dense_table_proto) { + dense_table_proto->set_table_id(0); + dense_table_proto->set_table_class("CommonDenseTable"); + dense_table_proto->set_shard_num(256); + dense_table_proto->set_type(::paddle::distributed::PS_DENSE_TABLE); + ::paddle::distributed::TableAccessorParameter* accessor_proto = + dense_table_proto->mutable_accessor(); + ::paddle::distributed::CommonAccessorParameter* common_proto = + dense_table_proto->mutable_common(); + + accessor_proto->set_accessor_class("CommMergeAccessor"); + accessor_proto->set_fea_dim(100); + accessor_proto->set_embedx_dim(1); + + common_proto->set_name("sgd"); + common_proto->set_table_name("MergedDense"); + common_proto->set_trainer_num(1); + common_proto->set_sync(false); + common_proto->add_params("Param"); + common_proto->add_dims(100); + common_proto->add_initializers("fill_constant&1.0"); + common_proto->add_params("LearningRate"); + common_proto->add_dims(1); + common_proto->add_initializers("fill_constant&1.0"); +} + +::paddle::distributed::PSParameter GetServerProto() { + // Generate server proto desc + ::paddle::distributed::PSParameter server_fleet_desc; + ::paddle::distributed::ServerParameter* server_proto = + server_fleet_desc.mutable_server_param(); + ::paddle::distributed::DownpourServerParameter* downpour_server_proto = + server_proto->mutable_downpour_server_param(); + ::paddle::distributed::ServerServiceParameter* server_service_proto = + downpour_server_proto->mutable_service_param(); + server_service_proto->set_service_class("PsService"); + server_service_proto->set_server_class("BrpcPsServer"); + server_service_proto->set_client_class("BrpcPsClient"); + server_service_proto->set_start_server_port(0); + server_service_proto->set_server_thread_num(12); + + ::paddle::distributed::TableParameter* dense_table_proto = + downpour_server_proto->add_downpour_table_param(); + GetDownpourDenseTableProto(dense_table_proto); + return server_fleet_desc; +} + +::paddle::distributed::PSParameter GetWorkerProto() { + ::paddle::distributed::PSParameter worker_fleet_desc; + ::paddle::distributed::WorkerParameter* worker_proto = + worker_fleet_desc.mutable_worker_param(); + + ::paddle::distributed::DownpourWorkerParameter* downpour_worker_proto = + worker_proto->mutable_downpour_worker_param(); + + ::paddle::distributed::TableParameter* worker_dense_table_proto = + downpour_worker_proto->add_downpour_table_param(); + GetDownpourDenseTableProto(worker_dense_table_proto); + + ::paddle::distributed::ServerParameter* server_proto = + worker_fleet_desc.mutable_server_param(); + ::paddle::distributed::DownpourServerParameter* downpour_server_proto = + server_proto->mutable_downpour_server_param(); + ::paddle::distributed::ServerServiceParameter* server_service_proto = + downpour_server_proto->mutable_service_param(); + server_service_proto->set_service_class("PsService"); + server_service_proto->set_server_class("BrpcPsServer"); + server_service_proto->set_client_class("BrpcPsClient"); + server_service_proto->set_start_server_port(0); + server_service_proto->set_server_thread_num(12); + + ::paddle::distributed::TableParameter* server_dense_table_proto = + downpour_server_proto->add_downpour_table_param(); + GetDownpourDenseTableProto(server_dense_table_proto); + + return worker_fleet_desc; +} + +/*-------------------------------------------------------------------------*/ + +std::string ip_ = "127.0.0.1"; +uint32_t port_ = 4214; + +std::vector host_sign_list_; + +std::shared_ptr pserver_ptr_; + +std::shared_ptr worker_ptr_; + +void RunServer() { + ::paddle::distributed::PSParameter server_proto = GetServerProto(); + + auto _ps_env = paddle::distributed::PaddlePSEnvironment(); + LOG(INFO) << "RUN set_ps_servers"; + _ps_env.set_ps_servers(&host_sign_list_, 1); + pserver_ptr_ = std::shared_ptr( + paddle::distributed::PSServerFactory::create(server_proto)); + LOG(INFO) << "RUN configure"; + pserver_ptr_->configure(server_proto, _ps_env, 0); + LOG(INFO) << "RUN start"; + pserver_ptr_->start(ip_, port_); + LOG(INFO) << "End start"; +} + +void RunClient(std::map>& + dense_regions) { + ::paddle::distributed::PSParameter worker_proto = GetWorkerProto(); + paddle::distributed::PaddlePSEnvironment _ps_env; + auto servers_ = host_sign_list_.size(); + _ps_env = paddle::distributed::PaddlePSEnvironment(); + LOG(INFO) << "Run set_ps_servers"; + _ps_env.set_ps_servers(&host_sign_list_, servers_); + LOG(INFO) << "Run Create PSClient"; + worker_ptr_ = std::shared_ptr( + paddle::distributed::PSClientFactory::create(worker_proto)); + LOG(INFO) << "Run configure"; + worker_ptr_->configure(worker_proto, dense_regions, _ps_env, 0); +} + +void RunBrpcPushDense() { + setenv("http_proxy", "", 1); + setenv("https_proxy", "", 1); + auto ph_host = paddle::distributed::PSHost(ip_, port_, 0); + host_sign_list_.push_back(ph_host.serialize_to_string()); + + // Srart Server + std::thread server_thread(RunServer); + sleep(1); + + // Start Client + LOG(INFO) << "Run InitTensorsOnClient"; + framework::Scope client_scope; + platform::CPUPlace place; + InitTensorsOnClient(&client_scope, &place, 100); + std::map> dense_regions; + dense_regions.insert( + std::pair>(0, {})); + auto regions = dense_regions[0]; + framework::Variable* var = client_scope.FindVar("x"); + framework::LoDTensor* tensor = var->GetMutable(); + float* w = tensor->data(); + paddle::distributed::Region reg(w, tensor->numel()); + regions.emplace_back(std::move(reg)); + + LOG(INFO) << "Run RunClient"; + RunClient(dense_regions); + + /*-----------------------Test Server Init----------------------------------*/ + LOG(INFO) << "Run pull_dense_param"; + float* temp = new float[tensor->numel()](); + std::vector temp_region; + paddle::distributed::Region temp_reg(temp, tensor->numel()); + temp_region.emplace_back(std::move(temp_reg)); + auto pull_status = + worker_ptr_->pull_dense(temp_region.data(), temp_region.size(), 0); + pull_status.wait(); + + for (size_t idx = 0; idx < tensor->numel(); ++idx) { + EXPECT_FLOAT_EQ(temp[idx], 1.0); + } + + /*-----------------------Test Push Param----------------------------------*/ + + LOG(INFO) << "Run push_dense_param"; + auto push_status = + worker_ptr_->push_dense_param(regions.data(), regions.size(), 0); + push_status.wait(); + + pull_status = worker_ptr_->pull_dense(regions.data(), regions.size(), 0); + pull_status.wait(); + + for (size_t idx = 0; idx < tensor->numel(); ++idx) { + EXPECT_FLOAT_EQ(w[idx], float(idx)); + } + + /*-----------------------Test Push Grad----------------------------------*/ + + paddle::distributed::DownpourBrpcClosure* closure = + new paddle::distributed::DownpourBrpcClosure(1, [&](void* done) { + int ret = 0; + auto* closure = (paddle::distributed::DownpourBrpcClosure*)done; + for (size_t i = 0; i < 1; ++i) { + if (closure->check_response(i, paddle::PS_PUSH_DENSE_TABLE) != 0) { + ret = -1; + break; + } + } + closure->set_promise_value(ret); + }); + + LOG(INFO) << "Run pull_dense_grad"; + auto push_grad_status = + worker_ptr_->push_dense_raw_gradient(0, temp, tensor->numel(), closure); + push_grad_status.wait(); + + auto pull_update_status = + worker_ptr_->pull_dense(regions.data(), regions.size(), 0); + pull_update_status.wait(); + + for (size_t idx = 0; idx < tensor->numel(); ++idx) { + EXPECT_FLOAT_EQ(w[idx], float(idx) - 1.0); + } + + LOG(INFO) << "Run stop_server"; + worker_ptr_->stop_server(); + LOG(INFO) << "Run finalize_worker"; + worker_ptr_->finalize_worker(); + server_thread.join(); +} + +TEST(RunBrpcPushDense, Run) { RunBrpcPushDense(); } diff --git a/paddle/fluid/distributed/test/brpc_service_sparse_sgd_test.cc b/paddle/fluid/distributed/test/brpc_service_sparse_sgd_test.cc new file mode 100644 index 00000000000000..224b9ba2fc780a --- /dev/null +++ b/paddle/fluid/distributed/test/brpc_service_sparse_sgd_test.cc @@ -0,0 +1,285 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include // NOLINT +#include +#include // NOLINT + +#include "google/protobuf/text_format.h" +#include "gtest/gtest.h" +#include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/framework/scope.h" +#include "paddle/fluid/framework/tensor_util.h" +#include "paddle/fluid/framework/variable.h" + +#include "paddle/fluid/operators/math/math_function.h" +#include "paddle/fluid/platform/place.h" +#include "paddle/fluid/string/printf.h" + +#include "paddle/fluid/distributed/ps.pb.h" +#include "paddle/fluid/distributed/service/brpc_ps_client.h" +#include "paddle/fluid/distributed/service/brpc_ps_server.h" +#include "paddle/fluid/distributed/service/env.h" +#include "paddle/fluid/distributed/service/ps_client.h" +#include "paddle/fluid/distributed/service/sendrecv.pb.h" +#include "paddle/fluid/distributed/service/service.h" + +namespace framework = paddle::framework; +namespace platform = paddle::platform; +namespace operators = paddle::operators; +namespace math = paddle::operators::math; +namespace memory = paddle::memory; +namespace distributed = paddle::distributed; + +void CreateVarsOnScope(framework::Scope* scope, platform::CPUPlace* place) { + auto x_var = scope->Var("x"); + x_var->GetMutable(); +} + +void InitTensorsOnClient(framework::Scope* scope, platform::CPUPlace* place, + int64_t rows_numel) { + CreateVarsOnScope(scope, place); + + auto x_var = scope->Var("x")->GetMutable(); + float* x_ptr = + x_var->mutable_data(framework::DDim({1, rows_numel}), *place); + for (int64_t i = 0; i < rows_numel; ++i) x_ptr[i] = 1.0; +} + +void GetDownpourSparseTableProto( + ::paddle::distributed::TableParameter* sparse_table_proto) { + sparse_table_proto->set_table_id(0); + sparse_table_proto->set_table_class("CommonSparseTable"); + sparse_table_proto->set_shard_num(256); + sparse_table_proto->set_type(::paddle::distributed::PS_SPARSE_TABLE); + ::paddle::distributed::TableAccessorParameter* accessor_proto = + sparse_table_proto->mutable_accessor(); + ::paddle::distributed::CommonAccessorParameter* common_proto = + sparse_table_proto->mutable_common(); + + accessor_proto->set_accessor_class("CommMergeAccessor"); + accessor_proto->set_fea_dim(0); + accessor_proto->set_embedx_dim(10); + + common_proto->set_name("sgd"); + common_proto->set_table_name("MergedDense"); + common_proto->set_trainer_num(1); + common_proto->set_sync(false); + common_proto->add_params("Param"); + common_proto->add_dims(10); + common_proto->add_initializers("uniform_random&0&-1.0&1.0"); + common_proto->add_params("LearningRate"); + common_proto->add_dims(1); + common_proto->add_initializers("fill_constant&1.0"); +} + +::paddle::distributed::PSParameter GetServerProto() { + // Generate server proto desc + ::paddle::distributed::PSParameter server_fleet_desc; + ::paddle::distributed::ServerParameter* server_proto = + server_fleet_desc.mutable_server_param(); + ::paddle::distributed::DownpourServerParameter* downpour_server_proto = + server_proto->mutable_downpour_server_param(); + ::paddle::distributed::ServerServiceParameter* server_service_proto = + downpour_server_proto->mutable_service_param(); + server_service_proto->set_service_class("PsService"); + server_service_proto->set_server_class("BrpcPsServer"); + server_service_proto->set_client_class("BrpcPsClient"); + server_service_proto->set_start_server_port(0); + server_service_proto->set_server_thread_num(12); + + ::paddle::distributed::TableParameter* sparse_table_proto = + downpour_server_proto->add_downpour_table_param(); + GetDownpourSparseTableProto(sparse_table_proto); + return server_fleet_desc; +} + +::paddle::distributed::PSParameter GetWorkerProto() { + ::paddle::distributed::PSParameter worker_fleet_desc; + ::paddle::distributed::WorkerParameter* worker_proto = + worker_fleet_desc.mutable_worker_param(); + + ::paddle::distributed::DownpourWorkerParameter* downpour_worker_proto = + worker_proto->mutable_downpour_worker_param(); + + ::paddle::distributed::TableParameter* worker_sparse_table_proto = + downpour_worker_proto->add_downpour_table_param(); + GetDownpourSparseTableProto(worker_sparse_table_proto); + + ::paddle::distributed::ServerParameter* server_proto = + worker_fleet_desc.mutable_server_param(); + ::paddle::distributed::DownpourServerParameter* downpour_server_proto = + server_proto->mutable_downpour_server_param(); + ::paddle::distributed::ServerServiceParameter* server_service_proto = + downpour_server_proto->mutable_service_param(); + server_service_proto->set_service_class("PsService"); + server_service_proto->set_server_class("BrpcPsServer"); + server_service_proto->set_client_class("BrpcPsClient"); + server_service_proto->set_start_server_port(0); + server_service_proto->set_server_thread_num(12); + + ::paddle::distributed::TableParameter* server_sparse_table_proto = + downpour_server_proto->add_downpour_table_param(); + GetDownpourSparseTableProto(server_sparse_table_proto); + + return worker_fleet_desc; +} + +/*-------------------------------------------------------------------------*/ + +std::string ip_ = "127.0.0.1"; +uint32_t port_ = 4209; + +std::vector host_sign_list_; + +std::shared_ptr pserver_ptr_; + +std::shared_ptr worker_ptr_; + +void RunServer() { + ::paddle::distributed::PSParameter server_proto = GetServerProto(); + + auto _ps_env = paddle::distributed::PaddlePSEnvironment(); + _ps_env.set_ps_servers(&host_sign_list_, 1); + pserver_ptr_ = std::shared_ptr( + paddle::distributed::PSServerFactory::create(server_proto)); + pserver_ptr_->configure(server_proto, _ps_env, 0); + pserver_ptr_->start(ip_, port_); +} + +void RunClient(std::map>& + dense_regions) { + ::paddle::distributed::PSParameter worker_proto = GetWorkerProto(); + paddle::distributed::PaddlePSEnvironment _ps_env; + auto servers_ = host_sign_list_.size(); + _ps_env = paddle::distributed::PaddlePSEnvironment(); + _ps_env.set_ps_servers(&host_sign_list_, servers_); + worker_ptr_ = std::shared_ptr( + paddle::distributed::PSClientFactory::create(worker_proto)); + worker_ptr_->configure(worker_proto, dense_regions, _ps_env, 0); +} + +void RunBrpcPushSparse() { + setenv("http_proxy", "", 1); + setenv("https_proxy", "", 1); + auto ph_host = paddle::distributed::PSHost(ip_, port_, 0); + host_sign_list_.push_back(ph_host.serialize_to_string()); + + // Srart Server + std::thread server_thread(RunServer); + sleep(1); + + // Start Client + framework::Scope client_scope; + platform::CPUPlace place; + InitTensorsOnClient(&client_scope, &place, 100); + std::map> dense_regions; + dense_regions.insert( + std::pair>(0, {})); + auto regions = dense_regions[0]; + framework::Variable* var = client_scope.FindVar("x"); + framework::LoDTensor* tensor = var->GetMutable(); + + RunClient(dense_regions); + std::vector fea_keys(10); + std::vector fea_values(100); + std::vector fea_temp_values(100); + std::vector fea_value_ptr(10); + std::vector fea_temp_value_ptr(10); + + for (size_t idx = 0; idx < fea_keys.size(); ++idx) { + fea_keys[idx] = (uint64_t)idx; + fea_value_ptr[idx] = fea_values.data() + idx * 10; + fea_temp_value_ptr[idx] = fea_temp_values.data() + idx * 10; + } + + /*-----------------------Test Server Init----------------------------------*/ + LOG(INFO) << "Run pull_sparse_param"; + auto pull_status = worker_ptr_->pull_sparse(fea_value_ptr.data(), 0, + fea_keys.data(), fea_keys.size()); + pull_status.wait(); + for (size_t idx = 0; idx < tensor->numel(); ++idx) { + fea_values.data()[idx] *= 2.0; + } + + /*-----------------------Test Push Param----------------------------------*/ + + LOG(INFO) << "Run push_sparse_param"; + paddle::distributed::DownpourBrpcClosure* closure_push_param = + new paddle::distributed::DownpourBrpcClosure(1, [&](void* done) { + int ret = 0; + auto* closure = (paddle::distributed::DownpourBrpcClosure*)done; + for (size_t i = 0; i < 1; ++i) { + if (closure->check_response(i, paddle::PS_PUSH_SPARSE_PARAM) != 0) { + ret = -1; + break; + } + } + closure->set_promise_value(ret); + }); + auto push_status = worker_ptr_->push_sparse_param( + 0, fea_keys.data(), (const float**)fea_value_ptr.data(), fea_keys.size(), + closure_push_param); + push_status.wait(); + + auto pull_param_status = worker_ptr_->pull_sparse( + fea_temp_value_ptr.data(), 0, fea_keys.data(), fea_keys.size()); + pull_param_status.wait(); + + for (size_t idx = 0; idx < tensor->numel(); ++idx) { + EXPECT_FLOAT_EQ(fea_temp_values[idx], fea_values[idx]); + } + + /*-----------------------Test Push Grad----------------------------------*/ + + paddle::distributed::DownpourBrpcClosure* closure_push_grad = + new paddle::distributed::DownpourBrpcClosure(1, [&](void* done) { + int ret = 0; + auto* closure = (paddle::distributed::DownpourBrpcClosure*)done; + for (size_t i = 0; i < 1; ++i) { + if (closure->check_response(i, paddle::PS_PUSH_SPARSE_TABLE) != 0) { + ret = -1; + break; + } + } + closure->set_promise_value(ret); + }); + + LOG(INFO) << "Run pull_sparse_grad"; + std::vector push_g_vec; + for (auto i = 0; i < static_cast(fea_keys.size()); ++i) { + push_g_vec.push_back(tensor->data() + i * 10); + } + auto push_grad_status = worker_ptr_->push_sparse_raw_gradient( + 0, fea_keys.data(), (const float**)push_g_vec.data(), fea_keys.size(), + closure_push_grad); + push_grad_status.wait(); + + auto pull_update_status = worker_ptr_->pull_sparse( + fea_temp_value_ptr.data(), 0, fea_keys.data(), fea_keys.size()); + pull_update_status.wait(); + + for (size_t idx = 0; idx < tensor->numel(); ++idx) { + EXPECT_FLOAT_EQ(fea_temp_values[idx], fea_values[idx] - 1.0); + } + + LOG(INFO) << "Run stop_server"; + worker_ptr_->stop_server(); + LOG(INFO) << "Run finalize_worker"; + worker_ptr_->finalize_worker(); + server_thread.join(); +} + +TEST(RunBrpcPushSparse, Run) { RunBrpcPushSparse(); } diff --git a/paddle/fluid/distributed/test/heter_serde_test.cc b/paddle/fluid/distributed/test/brpc_utils_test.cc similarity index 98% rename from paddle/fluid/distributed/test/heter_serde_test.cc rename to paddle/fluid/distributed/test/brpc_utils_test.cc index 21380921958dbb..ce33cbe6ea3971 100644 --- a/paddle/fluid/distributed/test/heter_serde_test.cc +++ b/paddle/fluid/distributed/test/brpc_utils_test.cc @@ -23,7 +23,7 @@ limitations under the License. */ #include "paddle/fluid/framework/tensor_util.h" #include "paddle/fluid/framework/variable.h" -#include "paddle/fluid/distributed/service/heter_serde.h" +#include "paddle/fluid/distributed/service/brpc_utils.h" #include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/platform/place.h" #include "paddle/fluid/string/printf.h" diff --git a/paddle/fluid/distributed/test/geo_table_test.cc b/paddle/fluid/distributed/test/geo_table_test.cc index fffecbe199e055..5ec1e87dcb6938 100644 --- a/paddle/fluid/distributed/test/geo_table_test.cc +++ b/paddle/fluid/distributed/test/geo_table_test.cc @@ -109,7 +109,7 @@ TEST(SparseGeoTable, SSUM) { auto id = geo_pull_ids[i][j]; for (int k = 0; k < emb_dim; k++) { ASSERT_TRUE(abs(geo_pull_values[i][j * emb_dim + k] - - pull_values[id * emb_dim + k]) < 1e-6); + pull_values[id * emb_dim + k]) < 1e-5); } } } diff --git a/paddle/fluid/distributed/test/sparse_table_test.cc b/paddle/fluid/distributed/test/sparse_table_test.cc index 65439014e8f0e2..6db95c5fac211b 100644 --- a/paddle/fluid/distributed/test/sparse_table_test.cc +++ b/paddle/fluid/distributed/test/sparse_table_test.cc @@ -103,7 +103,7 @@ TEST(CommonSparseTable, SGD) { table->pull_sparse(pull_values.data(), init_keys.data(), init_keys.size()); for (size_t i = 0; i < init_values.size(); ++i) { auto update_val = init_values[i] - 1.0 * total_gradients[i]; - ASSERT_TRUE(abs(update_val - pull_values[i]) < 1e-6); + ASSERT_TRUE(abs(update_val - pull_values[i]) < 1e-5); } } diff --git a/paddle/fluid/framework/data_layout_transform.cc b/paddle/fluid/framework/data_layout_transform.cc index 30464bbca90b87..a42d2913187df5 100644 --- a/paddle/fluid/framework/data_layout_transform.cc +++ b/paddle/fluid/framework/data_layout_transform.cc @@ -181,8 +181,8 @@ void innerTransDataLayoutFromMKLDNN(DataLayout in_layout, DataLayout out_layout, if (in_format != out_format) { void* in_data = GetDataFromTensor(in, in_type); - const std::string key = - platform::CreateKey(in_tz, in_format, out_format, in_type); + std::string key = + platform::CreateKey(*dev_ctx, in_tz, in_format, out_format, in_type); platform::ReorderMKLDNNHandler handler(in_tz, in.type(), in_type, *dev_ctx, cpu_engine, key); diff --git a/paddle/fluid/framework/distributed_strategy.proto b/paddle/fluid/framework/distributed_strategy.proto old mode 100644 new mode 100755 index 9f3af174f60779..914e27d6f1f5e6 --- a/paddle/fluid/framework/distributed_strategy.proto +++ b/paddle/fluid/framework/distributed_strategy.proto @@ -26,6 +26,8 @@ message RecomputeConfig { repeated string checkpoints = 1; } message ShardingConfig { optional float fuse_broadcast_MB = 1 [ default = 32.0 ]; + optional bool hybrid_dp = 2 [ default = false ]; + optional int32 sharding_group_size = 3 [ default = 8 ]; } message AMPConfig { diff --git a/paddle/fluid/framework/fleet/gloo_wrapper.cc b/paddle/fluid/framework/fleet/gloo_wrapper.cc index f4b2d2d7d1881d..8780db89e854a5 100644 --- a/paddle/fluid/framework/fleet/gloo_wrapper.cc +++ b/paddle/fluid/framework/fleet/gloo_wrapper.cc @@ -272,8 +272,7 @@ void GlooWrapper::Init() { attr.iface = iface_; std::shared_ptr file_store = nullptr; std::shared_ptr http_store = nullptr; - auto context = - std::make_shared(rank_, size_); + auto context = std::make_shared(rank_, size_); context->setTimeout(run_timeout_); auto dev = gloo::transport::tcp::CreateDevice(attr); switch (store_type_) { @@ -295,6 +294,7 @@ void GlooWrapper::Init() { http_store->SetTimeoutSeconds(init_timeout_.count()); context->connectFullMesh(*http_store, dev); http_store->Finalize(); + VLOG(3) << "after calling http_store->Finalize."; break; } default: @@ -304,6 +304,7 @@ void GlooWrapper::Init() { context_ = std::move(context); #endif is_initialized_ = true; + VLOG(3) << "gloo initialized done."; } template std::vector GlooWrapper::AllReduce( diff --git a/paddle/fluid/inference/tests/api/lite_resnet50_test.cc b/paddle/fluid/inference/tests/api/lite_resnet50_test.cc index b88f09ae6a6a86..da56a7978a2e48 100644 --- a/paddle/fluid/inference/tests/api/lite_resnet50_test.cc +++ b/paddle/fluid/inference/tests/api/lite_resnet50_test.cc @@ -26,11 +26,7 @@ namespace inference { TEST(AnalysisPredictor, use_gpu) { std::string model_dir = FLAGS_infer_model + "/" + "model"; AnalysisConfig config; -#if defined(PADDLE_WITH_CUDA) config.EnableUseGpu(100, 0); -#elif defined(LITE_SUBGRAPH_WITH_XPU) - config.EnableXpu(100); -#endif config.SetModel(model_dir + "/model", model_dir + "/params"); config.EnableLiteEngine(paddle::AnalysisConfig::Precision::kFloat32, true); @@ -73,6 +69,54 @@ TEST(AnalysisPredictor, use_gpu) { } } +#ifdef LITE_SUBGRAPH_WITH_XPU +TEST(AnalysisPredictor, use_xpu) { + std::string model_dir = FLAGS_infer_model + "/" + "model"; + AnalysisConfig config; + config.EnableLiteEngine(paddle::AnalysisConfig::Precision::kFloat32, true); + config.EnableXpu(100); + config.SetModel(model_dir + "/model", model_dir + "/params"); + + std::vector inputs; + auto predictor = CreatePaddlePredictor(config); + const int batch = 1; + const int channel = 3; + const int height = 318; + const int width = 318; + const int input_num = batch * channel * height * width; + std::vector input(input_num, 1); + + PaddleTensor in; + in.shape = {batch, channel, height, width}; + in.data = + PaddleBuf(static_cast(input.data()), input_num * sizeof(float)); + in.dtype = PaddleDType::FLOAT32; + inputs.emplace_back(in); + + std::vector outputs; + ASSERT_TRUE(predictor->Run(inputs, &outputs)); + + const std::vector truth_values = { + 127.84, 738.088, 1013.22, -438.055, 366.451, 927.585, 736.341, + -633.776, -329.904, -430.149, -633.082, -146.597, -1324.19, -1349.29, + -242.68, 117.541, -801.704, -391.428, -404.756, 453.995, 515.373, + -133.003, 69.3941, 590.056, -1434.66, -1070.81, 307.093, 400.463, + -316.094, -587.089, -161.033, 800.357, -96.4212, 748.706, 868.226, + -447.936, 112.782, 1127.24, 47.4587, 677.698, 593.126, -336.462, + 551.328, 397.816, 78.3572, -715.269, 406.002, 404.149, 246.067, + -8.4649, 131.345, -647.951, + }; + + const size_t expected_size = 1; + EXPECT_EQ(outputs.size(), expected_size); + float* data_o = static_cast(outputs[0].data.data()); + for (size_t j = 0; j < outputs[0].data.length() / sizeof(float); j += 10) { + EXPECT_NEAR((data_o[j] - truth_values[j / 10]) / truth_values[j / 10], 0., + 10e-5); + } +} +#endif + } // namespace inference } // namespace paddle diff --git a/paddle/fluid/operators/affine_channel_op_xpu.cc b/paddle/fluid/operators/affine_channel_op_xpu.cc new file mode 100644 index 00000000000000..db3eedea7ca67b --- /dev/null +++ b/paddle/fluid/operators/affine_channel_op_xpu.cc @@ -0,0 +1,186 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +Indicesou may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef PADDLE_WITH_XPU + +#include +#include +#include +#include "paddle/fluid/framework/data_layout.h" +#include "paddle/fluid/framework/eigen.h" +#include "paddle/fluid/framework/op_registry.h" + +namespace paddle { +namespace operators { + +template +class AffineChannelXPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* x = ctx.Input("X"); + auto* scale = ctx.Input("Scale"); + auto* bias = ctx.Input("Bias"); + + auto* y = ctx.Output("Out"); + y->mutable_data(ctx.GetPlace()); + + const framework::DataLayout layout = + framework::StringToDataLayout(ctx.Attr("data_layout")); + + auto dims = x->dims(); + int N = dims[0]; + int C = layout == framework::DataLayout::kNCHW ? dims[1] + : dims[dims.size() - 1]; + int HxW = x->numel() / N / C; + + auto* scale_d = scale->data(); + auto* bias_d = bias->data(); + + auto* x_d = x->data(); + auto* y_d = y->data(); + auto& dev_ctx = ctx.template device_context(); + std::vector x_shape; + std::vector b_shape; + if (layout == framework::DataLayout::kNCHW) { + x_shape.push_back(N); + x_shape.push_back(C); + x_shape.push_back(HxW); + b_shape.push_back(1); + b_shape.push_back(C); + b_shape.push_back(1); + } else { + x_shape.push_back(N * HxW); + x_shape.push_back(C); + b_shape.push_back(1); + b_shape.push_back(C); + } + int r = 0; + r = xpu::broadcast_mul(dev_ctx.x_context(), x_d, scale_d, y_d, x_shape, + b_shape); + PADDLE_ENFORCE_EQ(r, xpu::Error_t::SUCCESS, + platform::errors::External( + "The broadcast_mul XPU OP return wrong value[%d %s]", + r, XPUAPIErrorMsg[r])); + r = xpu::broadcast_add(dev_ctx.x_context(), y_d, bias_d, y_d, x_shape, + b_shape); + PADDLE_ENFORCE_EQ(r, xpu::Error_t::SUCCESS, + platform::errors::External( + "The broadcast_add XPU OP return wrong value[%d %s]", + r, XPUAPIErrorMsg[r])); + } +}; + +template +class AffineChannelGradXPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* x = ctx.Input("X"); + auto* scale = ctx.Input("Scale"); + auto* dy = ctx.Input(framework::GradVarName("Out")); + + auto* dx = ctx.Output(framework::GradVarName("X")); + auto* dscale = + ctx.Output(framework::GradVarName("Scale")); + auto* dbias = ctx.Output(framework::GradVarName("Bias")); + + const framework::DataLayout layout = + framework::StringToDataLayout(ctx.Attr("data_layout")); + + auto dims = x->dims(); + int N = dims[0]; + int C = layout == framework::DataLayout::kNCHW ? dims[1] + : dims[dims.size() - 1]; + int HxW = x->numel() / N / C; + + auto* dy_d = dy->data(); + auto* scale_d = scale->data(); + + T* dx_d = dx ? dx->mutable_data(ctx.GetPlace()) : nullptr; + T* dscale_d = dscale ? dscale->mutable_data(ctx.GetPlace()) : nullptr; + T* dbias_d = dbias ? dbias->mutable_data(ctx.GetPlace()) : nullptr; + + auto& dev_ctx = ctx.template device_context(); + std::vector x_shape; + std::vector b_shape; + std::vector rdims; + if (layout == framework::DataLayout::kNCHW) { + x_shape.push_back(N); + x_shape.push_back(C); + x_shape.push_back(HxW); + b_shape.push_back(1); + b_shape.push_back(C); + b_shape.push_back(1); + rdims.push_back(0); + rdims.push_back(2); + } else { + x_shape.push_back(N * HxW); + x_shape.push_back(C); + b_shape.push_back(1); + b_shape.push_back(C); + rdims.push_back(0); + } + + int r = 0; + if (dscale_d && dbias_d) { + r = xpu::reduce_sum(dev_ctx.x_context(), dy_d, dbias_d, x_shape, + rdims); + PADDLE_ENFORCE_EQ(r, xpu::Error_t::SUCCESS, + platform::errors::External( + "The reduce_sum XPU OP return wrong value[%d %s]", + r, XPUAPIErrorMsg[r])); + T* tmp = nullptr; + r = xpu_malloc(reinterpret_cast(&tmp), dy->numel() * sizeof(T)); + PADDLE_ENFORCE_EQ(r, xpu::Error_t::SUCCESS, + platform::errors::External("no enough memory in xpu")); + + r = xpu::mul(dev_ctx.x_context(), dy_d, x->data(), tmp, + dy->numel()); + PADDLE_ENFORCE_EQ( + r, xpu::Error_t::SUCCESS, + platform::errors::External("The mul XPU OP return wrong value[%d %s]", + r, XPUAPIErrorMsg[r])); + r = xpu::reduce_sum(dev_ctx.x_context(), tmp, dscale_d, x_shape, + rdims); + PADDLE_ENFORCE_EQ(r, xpu::Error_t::SUCCESS, + platform::errors::External( + "The reduce_sum XPU OP return wrong value[%d %s]", + r, XPUAPIErrorMsg[r])); + if (dev_ctx.x_context()->xpu_stream) { + dev_ctx.Wait(); + } + xpu_free(tmp); + } + if (dx_d) { + r = xpu::broadcast_mul(dev_ctx.x_context(), dy_d, scale_d, dx_d, x_shape, + b_shape); + PADDLE_ENFORCE_EQ( + r, xpu::Error_t::SUCCESS, + platform::errors::External( + "The broadcast_mul XPU OP return wrong value[%d %s]", r, + XPUAPIErrorMsg[r])); + } + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +using XPU = paddle::platform::XPUDeviceContext; + +REGISTER_OP_XPU_KERNEL(affine_channel, ops::AffineChannelXPUKernel); +REGISTER_OP_XPU_KERNEL(affine_channel_grad, + ops::AffineChannelGradXPUKernel); + +#endif diff --git a/paddle/fluid/operators/collective/CMakeLists.txt b/paddle/fluid/operators/collective/CMakeLists.txt index 686b3039d4dea9..395b54c8b6c30a 100644 --- a/paddle/fluid/operators/collective/CMakeLists.txt +++ b/paddle/fluid/operators/collective/CMakeLists.txt @@ -28,11 +28,13 @@ foreach(src ${OPS}) set_source_files_properties(${src} PROPERTIES COMPILE_FLAGS ${COLLECTIVE_COMPILE_FLAGS}) endforeach() -register_operators(EXCLUDES c_gen_nccl_id_op DEPS ${COLLECTIVE_DEPS}) +register_operators(EXCLUDES c_gen_nccl_id_op gen_nccl_id_op DEPS ${COLLECTIVE_DEPS}) if(WITH_NCCL) set(COLLECTIVE_DEPS ${COLLECTIVE_DEPS} nccl_common collective_helper) - op_library(c_gen_nccl_id_op DEPS ${COLLECTIVE_DEPS} nccl_common) + cc_library(gen_nccl_id_op_helper SRCS gen_nccl_id_op_helper.cc) + op_library(c_gen_nccl_id_op DEPS ${COLLECTIVE_DEPS} nccl_common gen_nccl_id_op_helper) + op_library(gen_nccl_id_op DEPS ${COLLECTIVE_DEPS} nccl_common gen_nccl_id_op_helper) endif() if(WITH_GLOO) diff --git a/paddle/fluid/operators/collective/c_gen_nccl_id_op.cc b/paddle/fluid/operators/collective/c_gen_nccl_id_op.cc index ed478b1f0a02cf..93a6b50c4db466 100644 --- a/paddle/fluid/operators/collective/c_gen_nccl_id_op.cc +++ b/paddle/fluid/operators/collective/c_gen_nccl_id_op.cc @@ -21,14 +21,12 @@ limitations under the License. */ #include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/var_type_traits.h" -#include "paddle/fluid/operators/distributed/distributed.h" -#include "paddle/fluid/operators/distributed/request_handler.h" -#include "paddle/fluid/operators/distributed/request_handler_impl.h" -#include "paddle/fluid/operators/distributed/rpc_client.h" #include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/place.h" +#include "paddle/fluid/operators/collective/gen_nccl_id_op_helper.h" + namespace paddle { namespace operators { @@ -42,80 +40,23 @@ class CGenNCCLIdOp : public framework::OperatorBase { void RunImpl(const framework::Scope& scope, const platform::Place& dev_place) const override { - platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); - // put nccl id in CPUPlace - auto& dev_ctx = *pool.Get(platform::CPUPlace()); int rank = Attr("rank"); framework::Scope& local_scope = scope.NewScope(); + std::function func = [&](size_t i) -> std::string { + return Output("Out"); + }; + if (rank == 0) { - GenerateAndSend(&local_scope, dev_ctx); + std::vector endpoint_list = + Attr>("other_endpoints"); + SendBroadCastNCCLID(endpoint_list, 1, func, local_scope); } else { - GetIdByServer(&local_scope, dev_ctx); + std::string endpoint = Attr("endpoint"); + RecvBroadCastNCCLID(endpoint, 1, func, local_scope); } scope.DeleteScope(&local_scope); } - - private: - void GenerateAndSend(framework::Scope* scope, - const platform::DeviceContext& dev_ctx) const { - std::string var_name = Output("Out"); - auto var = scope->FindVar(var_name); - PADDLE_ENFORCE_NOT_NULL( - var, platform::errors::InvalidArgument("Output can not be Null")); - auto id = var->GetMutable(); - PADDLE_ENFORCE_EQ(platform::dynload::ncclGetUniqueId(id), 0, - platform::errors::InvalidArgument( - "ncclGetUniqueId failed with id %s", id)); - - std::vector endpoint_list = - Attr>("other_endpoints"); - distributed::RPCClient* client = - distributed::RPCClient::GetInstance(0); - - for (auto& ep : endpoint_list) { - VLOG(3) << "sending nccl id to " << ep; - client->AsyncSendVar(ep, dev_ctx, *scope, var_name); - } - client->Wait(); - for (auto& ep : endpoint_list) { - client->AsyncSendBatchBarrier(ep); - } - client->Wait(); - VLOG(3) << "sending completed..."; - } - - void GetIdByServer(framework::Scope* scope, - const platform::DeviceContext& dev_ctx) const { - std::string endpoint = Attr("endpoint"); - // NOTE: Can not use unique_ptr here because the default - // deleter will call GRPC Server's base class's dtor and - // that will cause a wired crash. - distributed::RequestSendHandler rpc_h(distributed::DistributedMode::kSync); - std::unique_ptr rpc_service( - new RPCSERVER_T(endpoint, 1)); - - rpc_service->RegisterRPC(distributed::kRequestSend, &rpc_h); - rpc_h.SetRPCServer(rpc_service.get()); - - framework::ProgramDesc empty_program; - framework::Executor executor(dev_ctx.GetPlace()); - rpc_h.SetScope(scope); - rpc_h.SetDevCtx(&dev_ctx); - rpc_h.SetProgram(&empty_program); - rpc_h.SetExecutor(&executor); - - std::thread server_thread( - std::bind(&distributed::RPCServer::StartServer, rpc_service.get())); - - rpc_service->SetCond(distributed::kRequestSend); - VLOG(3) << "start getting nccl id from trainer 0..."; - rpc_service->WaitBarrier(distributed::kRequestSend); - VLOG(3) << "got nccl id and stop server..."; - rpc_service->ShutDown(); - VLOG(3) << "rpc server stopped"; - server_thread.join(); - } }; class CGenNCCLIdOpMaker : public framework::OpProtoAndCheckerMaker { diff --git a/paddle/fluid/operators/collective/gen_nccl_id_op.cc b/paddle/fluid/operators/collective/gen_nccl_id_op.cc new file mode 100644 index 00000000000000..98b1df9efc9038 --- /dev/null +++ b/paddle/fluid/operators/collective/gen_nccl_id_op.cc @@ -0,0 +1,201 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include + +#include "glog/logging.h" +#include "paddle/fluid/framework/executor.h" +#include "paddle/fluid/framework/op_proto_maker.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/framework/program_desc.h" +#include "paddle/fluid/framework/scope.h" +#include "paddle/fluid/framework/var_type_traits.h" +#include "paddle/fluid/platform/device_context.h" +#include "paddle/fluid/platform/enforce.h" +#include "paddle/fluid/platform/place.h" +#include "paddle/fluid/string/split.h" + +#include "paddle/fluid/operators/collective/gen_nccl_id_op_helper.h" + +namespace paddle { +namespace operators { + +class GenNCCLIdOp : public framework::OperatorBase { + public: + GenNCCLIdOp(const std::string& type, const framework::VariableNameMap& inputs, + const framework::VariableNameMap& outputs, + const framework::AttributeMap& attrs) + : OperatorBase(type, inputs, outputs, attrs) {} + + void RunImpl(const framework::Scope& scope, + const platform::Place& dev_place) const override { + std::vector trainers = + Attr>("trainers"); + int trainer_id = Attr("trainer_id"); + std::string endpoint = trainers[trainer_id]; + + PADDLE_ENFORCE_GE(trainer_id, 0, platform::errors::InvalidArgument( + "trainer_id %d is less than 0. Its " + "valid range is [0, trainer_size)")); + PADDLE_ENFORCE_LT( + trainer_id, static_cast(trainers.size()), + platform::errors::OutOfRange("trainer_id %d is out of range. Its valid " + "range is [0, trainer_size)", + trainer_id)); + + int nccl_comm_num = Attr("nccl_comm_num"); + int use_hierarchical_allreduce = Attr("use_hierarchical_allreduce"); + int inter_nranks = Attr("hierarchical_allreduce_inter_nranks"); + int inter_trainer_id = -1; + int exter_trainer_id = -1; + + if (use_hierarchical_allreduce) { + PADDLE_ENFORCE_GT( + trainers.size(), 1, + platform::errors::PreconditionNotMet( + "The number of collective trainers %llu <= 1", trainers.size())); + PADDLE_ENFORCE_GT( + inter_nranks, 1, + platform::errors::PreconditionNotMet( + "inter_nranks %d <= 1 while in hierarchical allreduce mode", + inter_nranks)); + PADDLE_ENFORCE_EQ( + trainers.size() % inter_nranks, 0, + platform::errors::PreconditionNotMet( + "The number of trainers %llu mod inter_nranks %d is not equal 0", + trainers.size(), inter_nranks)); + + inter_trainer_id = trainer_id % inter_nranks; + + if (trainer_id % inter_nranks == 0) { + exter_trainer_id = trainer_id / inter_nranks; + } + } + + std::ostringstream ss; + for (size_t i = 0; i < trainers.size(); i++) { + ss << trainers[i] << ","; + } + + VLOG(1) << "trainer_id:" << trainer_id + << ", use_hierarchical_allreduce:" << use_hierarchical_allreduce + << ", nccl_comm_num:" << nccl_comm_num + << ", inter_nranks:" << inter_nranks + << ", inter_trainer_id:" << inter_trainer_id + << ", exter_trainer_id:" << exter_trainer_id + << ", trainers:" << ss.str(); + + int server_fd = -1; + + /// 1. init flat + std::function func = platform::GetFlatNCCLVarName; + if (trainer_id == 0) { + // server endpoints + std::vector flat_endpoints; + flat_endpoints.insert(flat_endpoints.begin(), trainers.begin() + 1, + trainers.end()); + SendBroadCastNCCLID(flat_endpoints, nccl_comm_num, func, scope); + } else { + server_fd = CreateListenSocket(endpoint); + RecvBroadCastNCCLID(server_fd, endpoint, nccl_comm_num, func, scope); + } + + /// 2. hierarchical inter ncclid + func = platform::GetHierarchicalInterNCCLVarName; + if (inter_trainer_id == 0) { + std::ostringstream ss; + ss << endpoint; + std::vector inter_endpoints; + for (int i = trainer_id + 1; i < trainer_id + inter_nranks && + i < static_cast(trainers.size()); + i++) { + ss << ","; + inter_endpoints.push_back(trainers[i]); + ss << trainers[i]; + } + VLOG(1) << "Hierarchical inter ring endpoints:" << ss.str(); + + SendBroadCastNCCLID(inter_endpoints, nccl_comm_num, func, scope); + } else if (inter_trainer_id > 0) { + VLOG(1) << "Hierarchical inter ring"; + RecvBroadCastNCCLID(server_fd, endpoint, nccl_comm_num, func, scope); + } + + /// 3. hierarchical exter ncclid + func = platform::GetHierarchicalExterNCCLVarName; + if (exter_trainer_id == 0) { + std::ostringstream ss; + std::vector exter_endpoints; + ss << endpoint; + for (size_t i = inter_nranks; i < trainers.size(); i += inter_nranks) { + ss << ","; + exter_endpoints.push_back(trainers[i]); + ss << trainers[i]; + } + VLOG(1) << "Hierarchical exter ring endpoints:" << ss.str(); + + SendBroadCastNCCLID(exter_endpoints, nccl_comm_num, func, scope); + } else if (exter_trainer_id > 0) { + VLOG(1) << "Hierarchical exter ring"; + RecvBroadCastNCCLID(server_fd, endpoint, nccl_comm_num, func, scope); + } + + // close socket server + if (trainer_id != 0) { + CloseSocket(server_fd); + } + } +}; + +class GenNCCLIdOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddOutput("NCCLID", "Raw variable contains a NCCL UniqueId instaces."); + AddComment(R"DOC( +GenNCCLId operator + +For trainer 0: generate a new UniqueId and send it to all the other trainers. +For trainer 1~n: start a gRPC server to get the UniqueId, once got, stop the server. +)DOC"); + AddAttr>( + "trainers", + "['trainer0_ip:port', 'trainer1_ip:port', ...] " + "list of all trainer endpoints") + .SetDefault({}); + AddAttr("trainer_id", + "(int) " + "The index of the trainer in distributed training."); + AddAttr("nccl_comm_num", + "(int default 1) " + "The number of nccl communicator num.") + .SetDefault(1); + AddAttr("use_hierarchical_allreduce", + "(bool default false) " + "Wheter to use hierarchical allreduce.") + .SetDefault(false); + AddAttr("hierarchical_allreduce_inter_nranks", + "(int default 1) " + "Wheter to use hierarchical allreduce.") + .SetDefault(-1); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OPERATOR(gen_nccl_id, ops::GenNCCLIdOp, ops::GenNCCLIdOpMaker); diff --git a/paddle/fluid/operators/collective/gen_nccl_id_op_helper.cc b/paddle/fluid/operators/collective/gen_nccl_id_op_helper.cc new file mode 100644 index 00000000000000..f448084019c605 --- /dev/null +++ b/paddle/fluid/operators/collective/gen_nccl_id_op_helper.cc @@ -0,0 +1,351 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/collective/gen_nccl_id_op_helper.h" + +#include +#include +#include +#include +#include + +#include +#include +#include + +#include "glog/logging.h" +#include "paddle/fluid/framework/executor.h" +#include "paddle/fluid/framework/op_proto_maker.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/framework/program_desc.h" +#include "paddle/fluid/framework/scope.h" +#include "paddle/fluid/framework/var_type_traits.h" +#include "paddle/fluid/platform/device_context.h" +#include "paddle/fluid/platform/enforce.h" +#include "paddle/fluid/platform/place.h" +#include "paddle/fluid/string/split.h" + +namespace paddle { +namespace operators { + +constexpr char COMM_HEAD[] = "_pd_gen_comm_id_"; + +// Check system calls, such as socket, bind. +#define CHECK_SYS_CALL(call, name) \ + do { \ + int retval; \ + CHECK_SYS_CALL_VAL(call, name, retval); \ + } while (false) + +#define CHECK_SYS_CALL_VAL(call, name, retval) \ + do { \ + RETRY_SYS_CALL_VAL(call, name, retval); \ + if (retval == -1) { \ + PADDLE_THROW(platform::errors::Unavailable("Call to %s failed: %s", \ + name, strerror(errno))); \ + } \ + } while (false) + +#define RETRY_SYS_CALL_VAL(call, name, retval) \ + do { \ + retval = (call); \ + if (retval == -1 && \ + (errno == EINTR || errno == EWOULDBLOCK || errno == EAGAIN)) { \ + LOG(WARNING) << "Call " << name << " returned " << strerror(errno) \ + << " retry"; \ + } else { \ + break; \ + } \ + } while (true) + +static int SocketSend(int fd, const char* buffer, int size) { + int offset = 0; + int bytes = 0; + while (offset < size) { + bytes = send(fd, buffer + offset, size - offset, 0); + if (bytes == -1) { + if (errno != EINTR && errno != EWOULDBLOCK && errno != EAGAIN) { + // send failed + return -1; + } else { + bytes = 0; + } + } + offset += bytes; + } + return offset; +} + +static int SocketRecv(int fd, char* buffer, int size) { + int offset = 0; + int bytes = 0; + while (offset < size) { + bytes = recv(fd, buffer + offset, size - offset, 0); + if (bytes == 0) { + // closed by client, maybe probing alive client + return 0; + } + if (bytes == -1) { + if (errno != EINTR && errno != EWOULDBLOCK && errno != EAGAIN) { + return -1; + } else { + bytes = 0; + } + } + offset += bytes; + } + return offset; +} + +static void BindOrConnectFailed(int timeout, int* try_times, int* total_time, + const char* op, const std::string& ep) { + PADDLE_ENFORCE_LT( + *total_time, timeout, + platform::errors::Unavailable("%s addr=%s timeout, failed reason: %s", op, + ep.c_str(), strerror(errno))); + ++(*try_times); + int retry_time = std::min(*try_times * 500, 3000); // max 3 seconds + *total_time += retry_time; + + LOG(WARNING) << op << " addr=" << ep << " failed " << *try_times + << " times with reason: " << strerror(errno) << " retry after " + << retry_time / 1000.0 << " seconds"; + std::this_thread::sleep_for(std::chrono::milliseconds(retry_time)); +} + +int CreateListenSocket(const std::string& ep) { + auto addr = paddle::string::Split(ep, ':'); + PADDLE_ENFORCE_EQ( + addr.size(), 2UL, + platform::errors::InvalidArgument( + "The endpoint should contain host and port, but got %s.", ep)); + std::string host = addr[0]; + int port = std::stoi(addr[1]); + + // creating socket fd + int server_fd = -1; + CHECK_SYS_CALL_VAL(socket(AF_INET, SOCK_STREAM, 0), "socket", server_fd); + + // NOTE. Solutions to `Address already in use`. + // 1. Reuse addr&port. Otherwise, once the server closes the socket + // before client, the server will enter TIME-WAIT status. If we bind port + // again, the error `Address already in use` will appear. + // 2. Or we can close the client first to ensure that the server does + // not enter the TIME-WAIT state. But this is obviously not as convenient + // as the reuse method. + int opt = 1; +#if defined(SO_REUSEPORT) + // since Linux kernel 3.9 + CHECK_SYS_CALL(setsockopt(server_fd, SOL_SOCKET, SO_REUSEADDR | SO_REUSEPORT, + &opt, sizeof(opt)), + "setsockopt"); +#else + CHECK_SYS_CALL( + setsockopt(server_fd, SOL_SOCKET, SO_REUSEADDR, &opt, sizeof(opt)), + "setsockopt"); +#endif + + struct sockaddr_in address; + address.sin_family = AF_INET; + address.sin_addr.s_addr = INADDR_ANY; + address.sin_port = htons(port); + + // TODO(wangxi) Set from env, default 900s=15min + int timeout = 900 * 1000; + int try_times = 0; + int total_time = 0; + while (true) { + int ret_val = -1; + RETRY_SYS_CALL_VAL( + bind(server_fd, (struct sockaddr*)&address, sizeof(address)), "bind", + ret_val); + + if (ret_val == -1) { + BindOrConnectFailed(timeout, &try_times, &total_time, "bind", ep); + continue; + } + break; + } + + CHECK_SYS_CALL(listen(server_fd, 3), "listen"); + LOG(INFO) << "Server listening on: " << ep << " successful."; + return server_fd; +} + +void CloseSocket(int fd) { CHECK_SYS_CALL(close(fd), "close"); } + +static int SocketAccept(int server_fd, const char* head) { + struct sockaddr_in client_addr; + socklen_t addr_length = sizeof(client_addr); + char buffer[1024] = {0}; + int conn = -1; + + while (true) { + CHECK_SYS_CALL_VAL( + accept(server_fd, reinterpret_cast(&client_addr), + &addr_length), + "accept", conn); + + int ret_val = SocketRecv(conn, buffer, strlen(head)); + if (ret_val > 0 && strncmp(buffer, head, strlen(head)) == 0) { + break; // accept client + } else { + VLOG(3) << "socket read failed with ret_val=" << ret_val; + CloseSocket(conn); + } + } + return conn; +} + +static int ConnectAddr(const std::string& ep, const char* head) { + auto addr = paddle::string::Split(ep, ':'); + PADDLE_ENFORCE_EQ( + addr.size(), 2UL, + platform::errors::InvalidArgument( + "The endpoint should contain host and port, but got %s.", ep)); + std::string host = addr[0]; + int port = std::stoi(addr[1]); + + int sock = -1; + CHECK_SYS_CALL_VAL(socket(AF_INET, SOCK_STREAM, 0), "socket", sock); + + struct sockaddr_in server_addr; + memset(&server_addr, 0, sizeof(server_addr)); + server_addr.sin_family = AF_INET; + server_addr.sin_port = htons(port); + + char* ip = NULL; + struct hostent* hp = NULL; + hp = gethostbyname(host.c_str()); + PADDLE_ENFORCE_NOT_NULL(hp, platform::errors::InvalidArgument( + "Fail to get host by name %s.", host)); + + int i = 0; + while (hp->h_addr_list[i] != NULL) { + ip = inet_ntoa(*(struct in_addr*)hp->h_addr_list[i]); + VLOG(3) << "gethostbyname host:" << host << " ->ip: " << ip; + break; + } + + PADDLE_ENFORCE_GT(inet_pton(AF_INET, ip, &server_addr.sin_addr), 0, + platform::errors::Unavailable("Open address %s failed: %s", + ep, strerror(errno))); + + // TODO(wangxi) Set from env, default 900s=15min + int timeout = 900 * 1000; + int try_times = 0; + int total_time = 0; + while (true) { + int ret_val = -1; + RETRY_SYS_CALL_VAL( + connect(sock, (struct sockaddr*)&server_addr, sizeof(server_addr)), + "connect", ret_val); + + if (ret_val == -1) { + BindOrConnectFailed(timeout, &try_times, &total_time, "connect", ep); + continue; + } + + CHECK_SYS_CALL(SocketSend(sock, head, strlen(head)), "send"); + break; + } + return sock; +} + +static void RecvNCCLID(int conn, ncclUniqueId* nccl_id) { + char buffer[1024] = {0}; + static_assert(NCCL_UNIQUE_ID_BYTES <= 1024, + "nccl id bytes must <= buffer size"); + + CHECK_SYS_CALL(SocketRecv(conn, buffer, NCCL_UNIQUE_ID_BYTES), "recv ncc id"); + memcpy(nccl_id, buffer, NCCL_UNIQUE_ID_BYTES); +} + +static void SendNCCLID(int conn, ncclUniqueId* nccl_id) { + char buffer[1024] = {0}; + memcpy(buffer, nccl_id, NCCL_UNIQUE_ID_BYTES); + + CHECK_SYS_CALL(SocketSend(conn, buffer, NCCL_UNIQUE_ID_BYTES), + "send nccl id"); +} + +void SendBroadCastNCCLID(std::vector servers, int nccl_comm_num, + std::function func, + const framework::Scope& scope) { + // connect with server + std::vector connects; + for (auto server : servers) { + VLOG(3) << "connecting endpoint: " << server; + int conn = ConnectAddr(server, COMM_HEAD); + connects.push_back(conn); + } + VLOG(3) << "connecting completed..."; + + for (int i = 0; i < nccl_comm_num; ++i) { + std::string var_name = func(i); + auto var = scope.FindVar(var_name); + PADDLE_ENFORCE_NOT_NULL( + var, platform::errors::NotFound("Variable with name %s is not found", + var_name.c_str())); + auto nccl_id = var->GetMutable(); + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclGetUniqueId(nccl_id)); + + int j = 0; + for (auto conn : connects) { + VLOG(3) << "sending nccl_id_var: " << var_name << " to " << servers[j] + << " nccl_comm_no: " << i; + SendNCCLID(conn, nccl_id); + ++j; + } + VLOG(3) << "sending completed..."; + } + + // close client + for (auto conn : connects) { + CloseSocket(conn); + } +} + +void RecvBroadCastNCCLID(std::string endpoint, int nccl_comm_num, + std::function func, + const framework::Scope& scope) { + int server = CreateListenSocket(endpoint); + RecvBroadCastNCCLID(server, endpoint, nccl_comm_num, func, scope); + CloseSocket(server); +} + +void RecvBroadCastNCCLID(int server_fd, std::string endpoint, int nccl_comm_num, + std::function func, + const framework::Scope& scope) { + int client = SocketAccept(server_fd, COMM_HEAD); + + for (int i = 0; i < nccl_comm_num; ++i) { + std::string var_name = func(i); + auto var = scope.FindVar(var_name); + PADDLE_ENFORCE_NOT_NULL( + var, platform::errors::NotFound("Variable with name %s is not found", + var_name.c_str())); + auto nccl_id = var->GetMutable(); + + VLOG(3) << "trainer: " << endpoint << " receiving nccl_id_var: " << var_name + << " from trainer 0, nccl_comm_no: " << i; + RecvNCCLID(client, nccl_id); + } + VLOG(3) << "receiving completed..."; + CloseSocket(client); +} + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/collective/gen_nccl_id_op_helper.h b/paddle/fluid/operators/collective/gen_nccl_id_op_helper.h new file mode 100644 index 00000000000000..38751805191e3e --- /dev/null +++ b/paddle/fluid/operators/collective/gen_nccl_id_op_helper.h @@ -0,0 +1,48 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include + +namespace paddle { +namespace framework { +class Scope; +} // namespace framework +} // namespace paddle + +namespace paddle { +namespace operators { + +int CreateListenSocket(const std::string& ep); + +void CloseSocket(int fd); + +void SendBroadCastNCCLID(std::vector servers, int nccl_comm_num, + std::function func, + const framework::Scope& scope); + +// server listen on endpoint, then recv nccl id +void RecvBroadCastNCCLID(std::string endpoint, int nccl_comm_num, + std::function func, + const framework::Scope& scope); + +// recv nccl id from socket +void RecvBroadCastNCCLID(int server_fd, std::string endpoint, int nccl_comm_num, + std::function func, + const framework::Scope& scope); +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/distributed_ops/CMakeLists.txt b/paddle/fluid/operators/distributed_ops/CMakeLists.txt index 79f14d75d279d0..ec48a51baa212a 100644 --- a/paddle/fluid/operators/distributed_ops/CMakeLists.txt +++ b/paddle/fluid/operators/distributed_ops/CMakeLists.txt @@ -32,7 +32,6 @@ register_operators(EXCLUDES gen_nccl_id_op DEPS ${DISTRIBUTE_DEPS}) if(WITH_NCCL) set(DISTRIBUTE_DEPS ${DISTRIBUTE_DEPS} nccl_common) - op_library(gen_nccl_id_op DEPS ${DISTRIBUTE_DEPS} nccl_common) endif() set(OPERATOR_DEPS ${OPERATOR_DEPS} ${DISTRIBUTE_DEPS} PARENT_SCOPE) diff --git a/paddle/fluid/operators/elementwise/elementwise_add_op.h b/paddle/fluid/operators/elementwise/elementwise_add_op.h index acda31e0f2309b..0e8d202a9aa384 100644 --- a/paddle/fluid/operators/elementwise/elementwise_add_op.h +++ b/paddle/fluid/operators/elementwise/elementwise_add_op.h @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. */ #pragma once +#include +#include #include "paddle/fluid/operators/elementwise/elementwise_op.h" #include "paddle/fluid/operators/elementwise/elementwise_op_function.cu.h" #include "paddle/fluid/operators/elementwise/elementwise_op_function.h" @@ -116,6 +118,135 @@ elementwise_add_grad(const framework::ExecutionContext &ctx, default_elementwise_add_grad(ctx, x, y, out, dout, dx, dy); } +#ifdef PADDLE_WITH_CUDA +#ifdef __NVCC__ + +template +__global__ void MatrixColReduce(const T *__restrict__ in, T *__restrict__ out, + size_t width, size_t height) { + __shared__ T sdata[BLOCK_H][BLOCK_W + 1]; + size_t idx = threadIdx.x + blockDim.x * blockIdx.x; + size_t width_stride = gridDim.x * blockDim.x; + size_t full_width = (width & (~((uint64_t)(BLOCK_W - 1)))) + + ((width & (BLOCK_W - 1)) ? BLOCK_W : 0); + +#pragma unroll + for (size_t w = idx; w < full_width; w += width_stride) { + sdata[threadIdx.y][threadIdx.x] = 0; + __syncthreads(); + size_t offset = w + threadIdx.y * width; +#pragma unroll + for (size_t h = threadIdx.y; h < height; + h += BLOCK_H) { // block-stride loop across matrix height + sdata[threadIdx.y][threadIdx.x] += + (w < width) ? in[offset] : (static_cast(0)); + offset += width * BLOCK_H; + } + __syncthreads(); + + T val = sdata[threadIdx.x][threadIdx.y]; + for (int i = warpSize >> 1; i > 0; i >>= 1) + val += platform::CudaShuffleXorSync(0xFFFFFFFF, val, i); + + __syncthreads(); + if (threadIdx.x == 0) sdata[0][threadIdx.y] = val; + __syncthreads(); + if ((threadIdx.y == 0) && ((w) < width)) out[w] = sdata[0][threadIdx.x]; + } +} + +template +__global__ void FP16MatrixColReduce( + const paddle::platform::float16 *__restrict__ in, + paddle::platform::float16 *__restrict__ out, size_t width, size_t height) { + constexpr int repeats = BLOCK_H / BLOCK_W; + __shared__ paddle::platform::float16 sdata[BLOCK_H][BLOCK_W + 1]; + size_t idx = threadIdx.x + blockDim.x * blockIdx.x; + size_t width_stride = gridDim.x * blockDim.x; + size_t full_width = (width & (~((uint64_t)(BLOCK_W - 1)))) + + ((width & (BLOCK_W - 1)) ? BLOCK_W : 0); + +#pragma unroll + for (size_t w = idx; w < full_width; w += width_stride) { + for (int r = 0; r < repeats; r++) { + sdata[threadIdx.y + r * BLOCK_W][threadIdx.x] = 0; + } + __syncthreads(); + for (int r = 0; r < repeats; r++) { + size_t offset = w + (r * BLOCK_W + threadIdx.y) * width; +#pragma unroll + for (size_t h = r * BLOCK_H + threadIdx.y; h < height; + h += BLOCK_H) { // block-stride loop across matrix height + sdata[r * BLOCK_W + threadIdx.y][threadIdx.x] += + (w < width) ? in[offset + r * BLOCK_W * width] + : (static_cast(0)); + offset += width * BLOCK_H; + } + } + __syncthreads(); + + paddle::platform::float16 result = + static_cast(0); + for (int r = 0; r < repeats; r++) { + paddle::platform::float16 val = + sdata[threadIdx.x + r * BLOCK_W][threadIdx.y]; + for (int i = warpSize >> 1; i > 0; i >>= 1) + val += platform::CudaShuffleXorSync(0xFFFFFFFF, val, i); + __syncthreads(); + result += val; + } + if (threadIdx.x == 0) sdata[0][threadIdx.y] = result; + __syncthreads(); + if ((threadIdx.y == 0) && ((w) < width)) out[w] = sdata[0][threadIdx.x]; + } +} +#endif +#endif +bool static RunSpecialDims(const framework::DDim &dx_dims, + const framework::DDim &dy_dims, + const framework::DDim &dout_dims, int axis) { + auto smaller_dims = dx_dims; + auto bigger_dims = dy_dims; + auto smaller_dims_size = smaller_dims.size(); + auto bigger_dims_size = bigger_dims.size(); + int smaller_ignore_size = 0; + int bigger_ignore_size = 0; + for (int i = 0; i < smaller_dims_size; i++) { + if (smaller_dims[i] == 1) + smaller_ignore_size++; + else + break; + } + for (int i = 0; i < bigger_dims_size; i++) { + if (bigger_dims[i] == 1) + bigger_ignore_size++; + else + break; + } + + int smaller_real_size = smaller_dims.size() - smaller_ignore_size; + int bigger_real_size = bigger_dims.size() - bigger_ignore_size; + + if (smaller_real_size == bigger_real_size) return false; + + if (bigger_real_size < smaller_real_size) { + smaller_dims = dy_dims; + bigger_dims = dx_dims; + std::swap(smaller_real_size, bigger_real_size); + } + int big_size = bigger_dims.size(); + int small_size = smaller_dims.size(); + for (int i = 1; i <= smaller_real_size; i++) { + if (bigger_dims[big_size - i] != smaller_dims[small_size - i]) return false; + } + + if (axis != -1 && (axis != (bigger_real_size - smaller_real_size))) { + return false; + } + + return true; +} + #ifdef PADDLE_WITH_CUDA // cuda definition template @@ -144,6 +275,63 @@ class ElementwiseAddGradKernel : public ElemwiseGradKernel { // skip out auto *out = dout; +#ifdef PADDLE_WITH_CUDA +#ifdef __NVCC__ + + int axis = ctx.Attr("axis"); + if (ctx.GetPlace() == platform::CUDAPlace() && dx != nullptr && + dy != nullptr && dout != nullptr && dx->numel() != dy->numel() && + RunSpecialDims(dx->dims(), dy->dims(), dout->dims(), axis)) { + auto *dx_data = dx->mutable_data(ctx.GetPlace()); + auto *dy_data = dy->mutable_data(ctx.GetPlace()); + auto *dout_data = dout->data(); + auto stream = ctx.cuda_device_context().stream(); + auto *out_data = dx_data; + int width = dx->numel(); + int height = dout->numel() / width; + if (dx->dims() == dout->dims()) { + width = dy->numel(); + height = dout->numel() / width; + out_data = dy_data; + framework::TensorCopy( + *dout, ctx.GetPlace(), + ctx.template device_context(), dx); + } else { + framework::TensorCopy( + *dout, ctx.GetPlace(), + ctx.template device_context(), dy); + } + + constexpr int block_x = 32; + constexpr int block_y = 32; + dim3 blocks(block_x, block_y); + + int max_physical_threads = + ctx.cuda_device_context().GetMaxPhysicalThreadCount(); + int max_blocks = std::max(max_physical_threads / (block_x * block_y), 1); + int theory_block = (width + blocks.x - 1) / blocks.x; + dim3 grids(std::min(theory_block, max_blocks)); + if (std::is_same::value) { + const paddle::platform::float16 *ptr1 = + reinterpret_cast(dout_data); + paddle::platform::float16 *ptr2 = + reinterpret_cast(out_data); + if (height <= 32) { + FP16MatrixColReduce<32, 32><<>>( + ptr1, ptr2, width, height); + } else { + FP16MatrixColReduce<32, 64><<>>( + ptr1, ptr2, width, height); + } + return; + } + MatrixColReduce<<>>( + dout_data, out_data, width, height); + return; + } + +#endif +#endif // Special case when dy is not needed and dx doesn't reduce if (dx != nullptr && dy == nullptr && dx->dims() == dout->dims()) { VLOG(4) << "Special case when dy is not needed and dx doesn't " diff --git a/paddle/fluid/operators/fused/mkldnn/fusion_gru_mkldnn_op.cc b/paddle/fluid/operators/fused/mkldnn/fusion_gru_mkldnn_op.cc index e51d94e4b1e05a..1eed49de784089 100644 --- a/paddle/fluid/operators/fused/mkldnn/fusion_gru_mkldnn_op.cc +++ b/paddle/fluid/operators/fused/mkldnn/fusion_gru_mkldnn_op.cc @@ -39,20 +39,15 @@ class GRUMKLDNNHandler : public platform::MKLDNNHandlerT { const std::string& unique_name) : platform::MKLDNNHandlerT( dev_ctx, dev_ctx.GetEngine(), cpu_place, - CreateKey(unique_name, MKLDNNGetDataType(), Ti)), + CreateKey(dev_ctx, unique_name, MKLDNNGetDataType(), Ti)), N(N), Ti(Ti), IC(IC), OC(OC) { // Create memory key without Ti because weights, bias and h0 memories // do not depend on Ti size but primitive and input/output memory do - if (platform::MKLDNNDeviceContext::tls().get_cur_mkldnn_session_id() != - platform::MKLDNNDeviceContextThreadLocals::kMKLDNNSessionID_Default) { - memory_key_ = CreateKey(unique_name, MKLDNNGetDataType()); - } else { - memory_key_ = CreateKey(unique_name, MKLDNNGetDataType(), "-t:", - platform::ThreadIDasStr()); - } + memory_key_ = platform::ExtendKeyWithThreadInfoIfNeeded( + dev_ctx, CreateKey(dev_ctx, unique_name, MKLDNNGetDataType())); // Is it int8 kernel const bool is_INT8 = std::is_same::value; diff --git a/paddle/fluid/operators/fused/mkldnn/multi_gru_mkldnn_op.cc b/paddle/fluid/operators/fused/mkldnn/multi_gru_mkldnn_op.cc index b7fd40f78ff9d3..11711bab81735e 100644 --- a/paddle/fluid/operators/fused/mkldnn/multi_gru_mkldnn_op.cc +++ b/paddle/fluid/operators/fused/mkldnn/multi_gru_mkldnn_op.cc @@ -109,13 +109,8 @@ class MultiGRUHandler { const std::string unique_name = ctx.OutputName("Hidden"); // Create memory key without Ti because weights, bias and h0 memories // do not depend on Ti size but primitive and input/output memory do - if (platform::MKLDNNDeviceContext::tls().get_cur_mkldnn_session_id() != - platform::MKLDNNDeviceContextThreadLocals::kMKLDNNSessionID_Default) { - memory_key_ = CreateKey(unique_name, MKLDNNGetDataType()); - } else { - memory_key_ = CreateKey(unique_name, MKLDNNGetDataType(), "-t:", - platform::ThreadIDasStr()); - } + memory_key_ = platform::ExtendKeyWithThreadInfoIfNeeded( + dev_ctx, CreateKey(dev_ctx, unique_name, MKLDNNGetDataType())); key_ = memory_key_; key_.append("T").append(std::to_string(Ti_)); diff --git a/paddle/fluid/operators/layer_norm_op.cu b/paddle/fluid/operators/layer_norm_op.cu index d5a57dd9ddcad9..ad15b18d7feaeb 100644 --- a/paddle/fluid/operators/layer_norm_op.cu +++ b/paddle/fluid/operators/layer_norm_op.cu @@ -109,7 +109,7 @@ struct PairForLayerNormAddFunctor { template __inline__ __device__ T rsqrt(const T val) { - return ::rsqrt(val); + return static_cast(1) / sqrt(val); } template <> @@ -117,10 +117,17 @@ __inline__ __device__ float rsqrt(const float val) { return rsqrtf(val); } +template <> +__inline__ __device__ double rsqrt(const double val) { + return rsqrt(val); +} + +#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__) template <> __inline__ __device__ half rsqrt(const half val) { return hrsqrt(val); } +#endif template __global__ void LayerNormForward(const T *x, const U *scale, const U *bias, @@ -841,6 +848,7 @@ class LayerNormKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext &ctx) const override { + using U = LayerNormParamType; const float epsilon = ctx.Attr("epsilon"); auto *scale = ctx.Input("Scale"); auto *bias = ctx.Input("Bias"); @@ -854,12 +862,10 @@ class LayerNormKernel const auto x_dims = x->dims(); auto *x_data = x->data(); auto *y_data = y->mutable_data(ctx.GetPlace()); - auto *mean_data = mean->mutable_data>(ctx.GetPlace()); - auto *var_data = var->mutable_data>(ctx.GetPlace()); - auto *scale_data = - (scale == nullptr ? nullptr : scale->data>()); - auto *bias_data = - (bias == nullptr ? nullptr : bias->data>()); + auto *mean_data = mean->mutable_data(ctx.GetPlace()); + auto *var_data = var->mutable_data(ctx.GetPlace()); + auto *scale_data = (scale == nullptr ? nullptr : scale->data()); + auto *bias_data = (bias == nullptr ? nullptr : bias->data()); auto matrix_dim = framework::flatten_to_2d(x_dims, begin_norm_axis); int batch_size = static_cast(matrix_dim[0]); @@ -869,7 +875,7 @@ class LayerNormKernel switch (GetDesiredBlockDim(feature_size)) { FIXED_BLOCK_DIM_CASE( - LayerNormForward, + LayerNormForward<<>>( x_data, scale_data, bias_data, y_data, mean_data, var_data, epsilon, feature_size)); diff --git a/paddle/fluid/operators/mkldnn/batch_norm_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/batch_norm_mkldnn_op.cc index 98f368aa7a9085..622d6685dfa718 100644 --- a/paddle/fluid/operators/mkldnn/batch_norm_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/batch_norm_mkldnn_op.cc @@ -48,7 +48,8 @@ class BatchNormMKLDNNHandler : platform::MKLDNNHandlerT( dev_ctx, dev_ctx.GetEngine(), cpu_place, - platform::CreateKey(framework::vectorize(x->dims()), unique_name)) { + platform::CreateKey(dev_ctx, framework::vectorize(x->dims()), + unique_name)) { if (!this->isCached()) { const float epsilon = ctx.Attr("epsilon"); const bool fuse_with_relu = ctx.Attr("fuse_with_relu"); @@ -89,7 +90,7 @@ class BatchNormMKLDNNHandler : platform::MKLDNNHandlerT( dev_ctx, dev_ctx.GetEngine(), cpu_place, - platform::CreateKey(dims, uniq_name)) { + platform::CreateKey(dev_ctx, dims, uniq_name)) { auto diff_dst_md = mkldnn::memory::desc(dims, platform::MKLDNNGetDataType(), diff_fmt); auto src_md = diff --git a/paddle/fluid/operators/mkldnn/concat_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/concat_mkldnn_op.cc index 114daaecb59369..63aa2357beea07 100644 --- a/paddle/fluid/operators/mkldnn/concat_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/concat_mkldnn_op.cc @@ -144,6 +144,7 @@ class ConcatMKLDNNOpKernel : public paddle::framework::OpKernel { platform::errors::InvalidArgument( "The axis is expected to be in range of [%d, %d), but got %d", -rank, rank, concat_axis)); + platform::MKLDNNDeviceContext::tls().log_lib_version(); if (concat_axis < 0) { concat_axis = concat_axis + rank; } @@ -158,9 +159,10 @@ class ConcatMKLDNNOpKernel : public paddle::framework::OpKernel { // If one of the multiple inputs of concat has an input size of 0, the // actual size of the multi_input will change std::string key = platform::CreateKey( - paddle::framework::vectorize(multi_input[0]->dims()), + dev_ctx, paddle::framework::vectorize(multi_input[0]->dims()), multi_input.size(), ctx.OutputName("Out"), dt, - platform::ThreadIDasStr(), dev_ctx.GetKeySuffix()); + platform::ThreadIDasStr()); + key = platform::ExtendKeyWithThreadInfoIfNeeded(dev_ctx, key); const std::string key_prim = key + "@concat_p"; const std::string key_concat_pd = key + "@concat_pd"; diff --git a/paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc index 2e6d809c988790..68fe5828388ee2 100644 --- a/paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc @@ -95,7 +95,7 @@ class ConvMKLDNNHandlerT const std::string& unique_name) : platform::MKLDNNHandlerT( dev_ctx, mkldnn_engine, cpu_place, - platform::CreateKey(framework::vectorize(input->dims()), + platform::CreateKey(dev_ctx, framework::vectorize(input->dims()), unique_name)) { if (!this->isCached()) { PADDLE_ENFORCE_EQ( @@ -521,8 +521,9 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { mkldnn::memory::data_type src_dt = paddle::framework::ToMKLDNNDataType(input->type()); - std::string key = platform::CreateKey( - src_tz, src_dt, ctx.InputName("Input") + ctx.InputName("Filter")); + std::string key = + platform::CreateKey(dev_ctx, src_tz, src_dt, + ctx.InputName("Input") + ctx.InputName("Filter")); const std::string key_conv_pd = key + "@conv_pd"; bool need_s8_to_u8 = false; @@ -537,21 +538,17 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { // This is workaround for hacky implementation // of conv int8 mkl-dnn. Once conv fp32 and conv int8 // are merged/unified, this will disappear - std::string key_tid = ""; - if (platform::MKLDNNDeviceContext::tls().get_cur_mkldnn_session_id() == - platform::MKLDNNDeviceContextThreadLocals::kMKLDNNSessionID_Default) { - key_tid = "-t:" + platform::ThreadIDasStr(); - } - - auto prim_key = key + key_tid + "@conv_p"; - auto dst_key = key + key_tid + "@dst_mem_p"; - auto src_key = key + key_tid + "@src_mem_p"; - auto weights_key = key + key_tid + "@weights_mem_p"; - auto bias_key = key + key_tid + "@bias_mem_p"; - auto user_src_key = key + key_tid + "@user_src_mem_p"; - auto user_residual_key = key + key_tid + "@user_residual_data_mem_p"; - auto src_reorder_key = key + key_tid + "@src_mem_preorder_p"; - auto residual_reorder_key = key + key_tid + "@residual_data_mem_preorder_p"; + auto key_tid = platform::ExtendKeyWithThreadInfoIfNeeded(dev_ctx, key); + + auto prim_key = key_tid + "@conv_p"; + auto dst_key = key_tid + "@dst_mem_p"; + auto src_key = key_tid + "@src_mem_p"; + auto weights_key = key_tid + "@weights_mem_p"; + auto bias_key = key_tid + "@bias_mem_p"; + auto user_src_key = key_tid + "@user_src_mem_p"; + auto user_residual_key = key_tid + "@user_residual_data_mem_p"; + auto src_reorder_key = key_tid + "@src_mem_preorder_p"; + auto residual_reorder_key = key_tid + "@residual_data_mem_preorder_p"; conv_p = std::static_pointer_cast( dev_ctx.GetBlob(prim_key)); @@ -972,10 +969,11 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel { // Get an unique name from "argument" name of "input" and "Filter" variable // as well as attributes of primitive to be created // This name will be used as key when saving info into device context - const std::string key = platform::CreateKey( - src_tz, ctx.InputName("Input") + ctx.InputName("Filter")); + std::string key = platform::CreateKey( + dev_ctx, src_tz, ctx.InputName("Input") + ctx.InputName("Filter")); const std::string key_conv_pd = key + "@fwd_pd"; + key = platform::ExtendKeyWithThreadInfoIfNeeded(dev_ctx, key); std::vector pipeline; // Create user memory descriptors @@ -1090,8 +1088,9 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel { mkldnn::memory::format_tag out_format = weights_tz.size() == 6 ? mkldnn::memory::format_tag::goidhw : mkldnn::memory::format_tag::goihw; - const std::string key = - platform::CreateKey(weights_tz, filter_fmt, out_format, in_type); + std::string key = platform::CreateKey(dev_ctx, weights_tz, filter_fmt, + out_format, in_type); + key = platform::ExtendKeyWithThreadInfoIfNeeded(dev_ctx, key); platform::ReorderMKLDNNHandler handler(weights_tz, filter_grad->type(), in_type, dev_ctx, mkldnn_engine, diff --git a/paddle/fluid/operators/mkldnn/conv_transpose_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/conv_transpose_mkldnn_op.cc index e9f32e7ac25d8e..1eb90451a69529 100644 --- a/paddle/fluid/operators/mkldnn/conv_transpose_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/conv_transpose_mkldnn_op.cc @@ -172,9 +172,8 @@ class ConvTransposeMKLDNNOpKernel : public paddle::framework::OpKernel { auto dst_tz = paddle::framework::vectorize(output->dims()); // Get unique name for storing MKLDNN primitives - const std::string key = - platform::CreateKey(src_tz, ctx.OutputName("Output")); + platform::CreateKey(dev_ctx, src_tz, ctx.OutputName("Output")); std::vector pipeline; diff --git a/paddle/fluid/operators/mkldnn/dequantize_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/dequantize_mkldnn_op.cc index e036fd9aba04b2..8d41b750972352 100644 --- a/paddle/fluid/operators/mkldnn/dequantize_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/dequantize_mkldnn_op.cc @@ -67,8 +67,11 @@ class DeQuantOpKernel : public framework::OpKernel { mkldnn::memory::data_type src_dt = paddle::framework::ToMKLDNNDataType(input->type()); MKLDNNMemoryFormat src_fmt = input->format(); - std::string key = platform::CreateKey(platform::ThreadIDasStr(), src_dt, - src_tz, ctx.OutputName("Output")); + + std::string key = + platform::CreateKey(dev_ctx, src_dt, src_tz, ctx.OutputName("Output")); + key = platform::ExtendKeyWithThreadInfoIfNeeded(dev_ctx, key); + const std::string key_prim = key + "@r"; const std::string key_src_mem = key + "@s"; const std::string key_dst_mem = key + "@d"; diff --git a/paddle/fluid/operators/mkldnn/fc_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/fc_mkldnn_op.cc index 820c46c67d374f..613d193477b60e 100644 --- a/paddle/fluid/operators/mkldnn/fc_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/fc_mkldnn_op.cc @@ -370,8 +370,9 @@ class FCPrimitiveFactory { void CacheWeightsAndBias(const MKLDNNDeviceContext& dev_ctx, const ExecutionContext& ctx) { - const std::string key = - platform::CreateKey(platform::ThreadIDasStr(), dev_ctx.GetKeySuffix()); + std::string key = platform::CreateKey(dev_ctx); + key = platform::ExtendKeyWithThreadInfoIfNeeded(dev_ctx, key); + const std::string weights_key = key + ctx.InputName("W"); const std::string bias_key = key + ctx.InputName("Bias"); dev_ctx.SetBlob(weights_key, weights_); @@ -541,10 +542,11 @@ static void ExecuteFc(const ExecutionContext& ctx, const LoDTensor* input, const Tensor* w, const Tensor* bias, LoDTensor* output, bool fuse_relu, bool force_fp32_output) { auto& dev_ctx = ctx.template device_context(); - const std::string prim_key = platform::CreateKey( - platform::ThreadIDasStr(), dev_ctx.GetKeySuffix(), input->format(), - input->dims()[0], framework::vectorize(w->dims()), - ctx.OutputName("Out")); + std::string prim_key = platform::CreateKey( + dev_ctx, input->format(), input->dims()[0], + framework::vectorize(w->dims()), ctx.OutputName("Out")); + prim_key = platform::ExtendKeyWithThreadInfoIfNeeded(dev_ctx, prim_key); + constexpr bool is_int8 = std::is_same::value || std::is_same::value; bool is_bfloat16 = std::is_same::value; @@ -570,6 +572,7 @@ class FCMKLDNNOpKernel : public framework::OpKernel { PADDLE_ENFORCE_EQ( platform::is_cpu_place(ctx.GetPlace()), true, platform::errors::PreconditionNotMet("FC MKL-DNN must use CPUPlace.")); + platform::MKLDNNDeviceContext::tls().log_lib_version(); auto input = ctx.Input("Input"); auto w = ctx.Input("W"); auto bias = ctx.Input("Bias"); diff --git a/paddle/fluid/operators/mkldnn/layer_norm_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/layer_norm_mkldnn_op.cc index 22261e948aa7b6..65dcb328f20839 100644 --- a/paddle/fluid/operators/mkldnn/layer_norm_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/layer_norm_mkldnn_op.cc @@ -30,7 +30,7 @@ class LayerNormMKLDNNHandler const std::string& uniq_name) : platform::MKLDNNHandlerT( dev_ctx, dev_ctx.GetEngine(), cpu_place, - platform::CreateKey(dims, uniq_name)) { + platform::CreateKey(dev_ctx, dims, uniq_name)) { if (!this->isCached()) { auto md = dnnl::memory::desc(dims, platform::MKLDNNGetDataType(), fmt); if (!is_test) { diff --git a/paddle/fluid/operators/mkldnn/matmul_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/matmul_mkldnn_op.cc index 1f2216cbed2b25..fddc4b4b2e5596 100644 --- a/paddle/fluid/operators/mkldnn/matmul_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/matmul_mkldnn_op.cc @@ -336,9 +336,8 @@ static std::shared_ptr> GetPrimitiveFactory( const auto& out_name = ctx.OutputName("Out"); const auto& dev_ctx = ctx.template device_context(); const auto batch_size = ctx.Input("X")->dims()[0]; - - const std::string key = platform::CreateKey( - platform::ThreadIDasStr(), dev_ctx.GetKeySuffix(), batch_size, out_name); + std::string key = platform::CreateKey(dev_ctx, batch_size, out_name); + key = platform::ExtendKeyWithThreadInfoIfNeeded(dev_ctx, key); auto factory = std::static_pointer_cast>(dev_ctx.GetBlob(key)); @@ -379,6 +378,7 @@ class DNNLMatMulKernel : public framework::OpKernel { platform::errors::Unimplemented( "DNNL matmul doesn't support multiple heads.")); } + platform::MKLDNNDeviceContext::tls().log_lib_version(); ExecuteMatMul(ctx); } }; diff --git a/paddle/fluid/operators/mkldnn/mul_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/mul_mkldnn_op.cc index 258b6971a0d295..46d51606d42da8 100644 --- a/paddle/fluid/operators/mkldnn/mul_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/mul_mkldnn_op.cc @@ -305,9 +305,11 @@ std::shared_ptr> GetPrimitiveFactory( const MKLDNNDeviceContext &dev_ctx, const ExecutionContext &ctx, const Tensor *input_x, const Tensor *input_y, const mkldnn::engine &mkldnn_engine) { - const std::string key = platform::CreateKey( - input_x->type(), framework::vectorize(input_x->dims()), input_y->type(), - framework::vectorize(input_y->dims()), ctx.OutputName("Out")); + std::string key = platform::CreateKey( + dev_ctx, input_x->type(), framework::vectorize(input_x->dims()), + input_y->type(), framework::vectorize(input_y->dims()), + ctx.OutputName("Out")); + key = platform::ExtendKeyWithThreadInfoIfNeeded(dev_ctx, key); auto prim_creator = std::static_pointer_cast>( dev_ctx.GetBlob(key)); @@ -351,6 +353,7 @@ class MulMKLDNNKernel : public framework::OpKernel { PADDLE_ENFORCE_EQ(platform::is_cpu_place(ctx.GetPlace()), true, paddle::platform::errors::PreconditionNotMet( "Operator DNNL Mul must use CPUPlace")); + platform::MKLDNNDeviceContext::tls().log_lib_version(); auto &dev_ctx = ctx.template device_context(); const auto &mkldnn_engine = dev_ctx.GetEngine(); diff --git a/paddle/fluid/operators/mkldnn/pool_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/pool_mkldnn_op.cc index 4e689f5bccf4b4..9488a1a4405a46 100644 --- a/paddle/fluid/operators/mkldnn/pool_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/pool_mkldnn_op.cc @@ -140,7 +140,7 @@ class PoolMKLDNNGradOpKernel : public paddle::framework::OpKernel { // Get an unique name from "argument" name of "Out" variable // This name will be used as key when referring info from device context const std::string key = platform::CreateKey( - diff_src_tz, pooling_type, ksize, strides, paddings, + dev_ctx, diff_src_tz, pooling_type, ksize, strides, paddings, memory::data_type::f32, in_x->format(), ctx.InputName("Out")); platform::PoolingMKLDNNHandler handler( diff --git a/paddle/fluid/operators/mkldnn/quantize_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/quantize_mkldnn_op.cc index 3e04e2dcf00bb1..7a03c6ce86d4bd 100644 --- a/paddle/fluid/operators/mkldnn/quantize_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/quantize_mkldnn_op.cc @@ -64,9 +64,11 @@ class QuantOpKernel : public framework::OpKernel { bool is_negative_input = ctx.Attr("is_negative_input"); bool bfloat16 = ctx.Attr("bfloat16"); - std::string key = platform::CreateKey( - platform::ThreadIDasStr(), src_tz, scale_data, scale_shift, - is_negative_input, ctx.OutputName("Output")); + std::string key = + platform::CreateKey(dev_ctx, src_tz, scale_data, scale_shift, + is_negative_input, ctx.OutputName("Output")); + key = platform::ExtendKeyWithThreadInfoIfNeeded(dev_ctx, key); + const std::string key_prim = key + "@r"; const std::string key_src_mem = key + "@s"; const std::string key_dst_mem = key + "@d"; diff --git a/paddle/fluid/operators/mkldnn/requantize_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/requantize_mkldnn_op.cc index a3b078205e83dd..aa74a45e3a575f 100644 --- a/paddle/fluid/operators/mkldnn/requantize_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/requantize_mkldnn_op.cc @@ -65,9 +65,9 @@ class ReQuantOpKernel : public framework::OpKernel { float reorder_scale = scale_out / scale_in; - std::string key = - platform::CreateKey(platform::ThreadIDasStr(), src_tz, scale_in, - scale_out, ctx.OutputName("Output")); + std::string key = platform::CreateKey(dev_ctx, src_tz, scale_in, scale_out, + ctx.OutputName("Output")); + key = platform::ExtendKeyWithThreadInfoIfNeeded(dev_ctx, key); const std::string key_prim = key + "@r"; const std::string key_src_mem = key + "@s"; const std::string key_dst_mem = key + "@d"; diff --git a/paddle/fluid/operators/mkldnn/softmax_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/softmax_mkldnn_op.cc index 9d9e1e2d8ded51..3eb2e7084a0b07 100644 --- a/paddle/fluid/operators/mkldnn/softmax_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/softmax_mkldnn_op.cc @@ -53,8 +53,8 @@ class SoftmaxMKLDNNHandler mkldnn::softmax_backward>( dev_ctx, mkldnn_engine, cpu_place, // Softmax may be inplace then uniq_name is no longer unique - platform::CreateKey(framework::vectorize(input->dims()), axis, - uniq_name)) { + platform::CreateKey(dev_ctx, framework::vectorize(input->dims()), + axis, uniq_name)) { if (!this->isCached()) { PADDLE_ENFORCE_EQ( input->dims(), output->dims(), @@ -78,7 +78,7 @@ class SoftmaxMKLDNNHandler : platform::MKLDNNHandlerT( dev_ctx, dev_ctx.GetEngine(), cpu_place, - platform::CreateKey(dims, axis, uniq_name)) { + platform::CreateKey(dev_ctx, dims, axis, uniq_name)) { auto data_softmax_md = mkldnn::memory::desc(dims, platform::MKLDNNGetDataType(), fmt); auto diff_softmax_md = diff --git a/paddle/fluid/operators/mkldnn/sum_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/sum_mkldnn_op.cc index e1031c02be3943..2b6f959472491e 100644 --- a/paddle/fluid/operators/mkldnn/sum_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/sum_mkldnn_op.cc @@ -54,7 +54,8 @@ class SumMKLDNNHandler : public platform::MKLDNNHandlerT { : platform::MKLDNNHandlerT( dev_ctx, dev_ctx.GetEngine(), cpu_place, - platform::CreateKey(framework::vectorize(z->dims()), uniq_name)), + platform::CreateKey(dev_ctx, framework::vectorize(z->dims()), + uniq_name)), num_inputs_(0) { for (size_t i = 0; i < in_vars.size(); i++) { srcs_suffix_.push_back(std::string("-") + std::to_string(i)); @@ -184,8 +185,9 @@ class SumMKLDNNOpKernel : public paddle::framework::OpKernel { // For in-place execution which sum does not have we need to fake it // so from oneDNN dst memory we reorder data into input if (in_place) { - const std::string reorder_key = platform::CreateKey( - framework::vectorize(output->dims()), ctx.OutputName("Out") + "-I"); + const std::string reorder_key = + platform::CreateKey(dev_ctx, framework::vectorize(output->dims()), + ctx.OutputName("Out") + "-I"); auto& in_out = in_vars[0]->Get(); auto output_tz = framework::vectorize(output->dims()); diff --git a/paddle/fluid/operators/mkldnn/transpose_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/transpose_mkldnn_op.cc index 28cdd8413ab134..feda5645b4cfa2 100644 --- a/paddle/fluid/operators/mkldnn/transpose_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/transpose_mkldnn_op.cc @@ -48,7 +48,8 @@ class TransposeMKLDNNOpKernel : public paddle::framework::OpKernel { auto nchw_tz = paddle::framework::vectorize(input->dims()); - const std::string key = platform::CreateKey(nchw_tz, ctx.OutputName("Out")); + const std::string key = + platform::CreateKey(dev_ctx, nchw_tz, ctx.OutputName("Out")); platform::TransposeMKLDNNHandler handler(nchw_tz, axis, dev_ctx, mkldnn_engine, key); @@ -103,7 +104,7 @@ class TransposeMKLDNNGradOpKernel : public paddle::framework::OpKernel { auto nchw_tz = paddle::framework::vectorize(out_grad->dims()); const std::string key = platform::CreateKey( - nchw_tz, ctx.OutputName(framework::GradVarName("X"))); + dev_ctx, nchw_tz, ctx.OutputName(framework::GradVarName("X"))); platform::TransposeMKLDNNHandler handler(nchw_tz, reversed_axis, dev_ctx, mkldnn_engine, key); diff --git a/paddle/fluid/operators/roi_align_op_xpu.cc b/paddle/fluid/operators/roi_align_op_xpu.cc index 699cc7b84a4e6d..f35cf06e5f704a 100644 --- a/paddle/fluid/operators/roi_align_op_xpu.cc +++ b/paddle/fluid/operators/roi_align_op_xpu.cc @@ -24,89 +24,202 @@ template class XPUROIAlignOpKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { - auto* in = ctx.Input("X"); - auto* rois = ctx.Input("ROIs"); - auto* out = ctx.Output("Out"); + auto* in = ctx.Input("X"); + auto* rois = ctx.Input("ROIs"); + auto* out = ctx.Output("Out"); + auto pooled_height = ctx.Attr("pooled_height"); auto pooled_width = ctx.Attr("pooled_width"); auto spatial_scale = ctx.Attr("spatial_scale"); auto sampling_ratio = ctx.Attr("sampling_ratio"); - auto& dev_ctx = ctx.template device_context(); + auto in_dims = in->dims(); int batch_size = in_dims[0]; int channels = in_dims[1]; int height = in_dims[2]; int width = in_dims[3]; + int rois_num = rois->dims()[0]; - const T* input_data = in->data(); - framework::Tensor _roi_batch_list; - _roi_batch_list.Resize({rois_num}); - int* rois_lod = _roi_batch_list.mutable_data(ctx.GetPlace()); - int rois_batch_size = 1; + if (rois_num == 0) return; + + Tensor roi_batch_id_list; + roi_batch_id_list.Resize({rois_num}); + auto cplace = platform::CPUPlace(); + int* roi_batch_id_data = roi_batch_id_list.mutable_data(cplace); + auto& dev_ctx = ctx.template device_context(); + auto xplace = BOOST_GET_CONST(platform::XPUPlace, ctx.GetPlace()); + int rois_batch_size = 0; + int* cpu_lod = nullptr; if (ctx.HasInput("RoisNum")) { - auto* rois_num_t = ctx.Input("RoisNum"); + auto* rois_num_t = ctx.Input("RoisNum"); rois_batch_size = rois_num_t->numel(); PADDLE_ENFORCE_EQ( rois_batch_size, batch_size, platform::errors::InvalidArgument( - "The batch size of rois and the batch size of images " - " must be the same. But received the batch size of rois is %d, " - "and the batch size of images is %d", + "The rois_batch_size and imgs " + "batch_size must be the same. But received rois_batch_size = %d, " + "batch_size = %d", rois_batch_size, batch_size)); - auto* rois_num_data = rois_num_t->data(); - rois_lod[0] = 0; - for (int n = 0; n < rois_batch_size; ++n) { - rois_lod[n + 1] = rois_lod[n] + rois_num_data[n]; + + std::vector rois_num_list(rois_batch_size); + memory::Copy(cplace, rois_num_list.data(), xplace, + rois_num_t->data(), sizeof(int) * rois_batch_size); + cpu_lod = new int[rois_batch_size + 1]; + cpu_lod[0] = 0; + for (int i = 0; i < rois_batch_size; i++) { + cpu_lod[i + 1] = cpu_lod[i] + rois_num_list[i]; } } else { - auto _rois_lod = rois->lod().back(); - rois_batch_size = _rois_lod.size() - 1; - for (int n = 0; n < static_cast(_rois_lod.size()); ++n) { - rois_lod[n] = _rois_lod[n]; - } + auto lod = rois->lod(); + PADDLE_ENFORCE_EQ( + lod.empty(), false, + platform::errors::InvalidArgument("Input(ROIs) in ROIAlignOp does " + "not contain LoD information.")); + auto rois_lod = lod.back(); + rois_batch_size = rois_lod.size() - 1; PADDLE_ENFORCE_EQ( rois_batch_size, batch_size, platform::errors::InvalidArgument( - "The rois_batch_size and imgs batch_size of roi_align_xpu OP " - "must " - "be the same. But received rois_batch_size %d , batch_size %d", + "The batch size of rois and batch size " + "of images must be the same. But received rois batch size = %d, " + "and images batch size = %d", rois_batch_size, batch_size)); + int rois_num_with_lod = rois_lod[rois_batch_size]; + PADDLE_ENFORCE_EQ( + rois_num, rois_num_with_lod, + platform::errors::InvalidArgument( + "The actual number of rois and the number of rois " + "provided from Input(RoIsLoD) in RoIAlign must be the same." + " But received actual number of rois is %d, and the number " + "of rois from RoIsLoD is %d", + rois_num, rois_num_with_lod)); + for (int n = 0; n < rois_batch_size; ++n) { + for (size_t i = rois_lod[n]; i < rois_lod[n + 1]; ++i) { + roi_batch_id_data[i] = n; + } + } + cpu_lod = new int[rois_batch_size + 1]; + for (int i = 0; i < rois_batch_size + 1; i++) { + cpu_lod[i] = rois_lod[i]; + } } - int rois_num_with_lod = rois_lod[rois_batch_size]; - PADDLE_ENFORCE_EQ( - rois_num, rois_num_with_lod, - platform::errors::InvalidArgument( - "The rois_num from input and lod of roi_align_xpu OP must be the " - "same. But received input rois_num %d , input lod %d", - rois_num, rois_num_with_lod)); - T* output_data = out->mutable_data(ctx.GetPlace()); - const T* rois_data = rois->data(); - for (int n = 0; n < rois_batch_size; n++) { - int cur_batch_rois_num = rois_lod[n + 1] - rois_lod[n]; - if (cur_batch_rois_num != 0) { - int r = xpu::roi_align( - dev_ctx.x_context(), input_data + n * channels * height * width, - rois_data + rois_lod[n] * 4, cur_batch_rois_num, channels, height, - width, pooled_height, pooled_width, sampling_ratio, spatial_scale, - output_data + - rois_lod[n] * channels * pooled_height * pooled_width); - PADDLE_ENFORCE_EQ( - r, xpu::Error_t::SUCCESS, - platform::errors::External( - "The roi_align XPU OP return wrong value[%d], please check " - "where Baidu Kunlun Card is properly installed.", - r)); + + int* roi_id_data = nullptr; + int r = xpu_malloc(reinterpret_cast(&roi_id_data), + (rois_batch_size + 1) * sizeof(int)); + PADDLE_ENFORCE_EQ(r, xpu::Error_t::SUCCESS, + platform::errors::External("no enough memory in xpu")); + memory::Copy(xplace, roi_id_data, cplace, cpu_lod, + (rois_batch_size + 1) * sizeof(int)); + delete[] cpu_lod; + r = xpu::roi_align( + dev_ctx.x_context(), in->data(), + out->mutable_data(ctx.GetPlace()), rois->data(), roi_id_data, + batch_size, channels, height, width, out->dims()[0], pooled_height, + pooled_width, spatial_scale, sampling_ratio, true); + PADDLE_ENFORCE_EQ(r, xpu::Error_t::SUCCESS, + platform::errors::External( + "The roi_align XPU OP return wrong value[%d %s]", r, + XPUAPIErrorMsg[r])); + if (dev_ctx.x_context()->xpu_stream) { + dev_ctx.Wait(); + } + xpu_free(roi_id_data); + } +}; + +template +class XPUROIAlignGradOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* in = ctx.Input("X"); + auto* rois = ctx.Input("ROIs"); + + auto* out_grad = ctx.Input(framework::GradVarName("Out")); + auto* in_grad = ctx.Output(framework::GradVarName("X")); + + auto pooled_height = ctx.Attr("pooled_height"); + auto pooled_width = ctx.Attr("pooled_width"); + auto spatial_scale = ctx.Attr("spatial_scale"); + auto sampling_ratio = ctx.Attr("sampling_ratio"); + + int rois_num = rois->dims()[0]; + int channels = in->dims()[1]; + int height = in->dims()[2]; + int width = in->dims()[3]; + + if (!in_grad) { + return; + } + Tensor roi_batch_id_list; + roi_batch_id_list.Resize({rois_num}); + auto cplace = platform::CPUPlace(); + + auto& dev_ctx = ctx.template device_context(); + auto xplace = BOOST_GET_CONST(platform::XPUPlace, ctx.GetPlace()); + + int rois_batch_size = 0; + int* cpu_lod = nullptr; + if (ctx.HasInput("RoisNum")) { + auto* rois_num_t = ctx.Input("RoisNum"); + rois_batch_size = rois_num_t->numel(); + std::vector rois_num_list(rois_batch_size); + memory::Copy(cplace, rois_num_list.data(), xplace, + rois_num_t->data(), sizeof(int) * rois_batch_size); + cpu_lod = new int[rois_batch_size + 1]; + cpu_lod[0] = 0; + for (int i = 0; i < rois_batch_size; i++) { + cpu_lod[i + 1] = cpu_lod[i] + rois_num_list[i]; + } + } else { + auto rois_lod = rois->lod().back(); + rois_batch_size = rois_lod.size() - 1; + cpu_lod = new int[rois_batch_size + 1]; + for (int i = 0; i < rois_batch_size + 1; i++) { + cpu_lod[i] = rois_lod[i]; } } + int* roi_id_data = nullptr; + int r = xpu_malloc(reinterpret_cast(&roi_id_data), + (rois_batch_size + 1) * sizeof(int)); + PADDLE_ENFORCE_EQ(r, xpu::Error_t::SUCCESS, + platform::errors::External("no enough memory in xpu")); + memory::Copy(xplace, roi_id_data, cplace, cpu_lod, + (rois_batch_size + 1) * sizeof(int)); + in_grad->mutable_data(ctx.GetPlace()); + + int output_grad_size = out_grad->numel(); + + delete[] cpu_lod; + if (output_grad_size > 0) { + r = xpu::roi_align_grad( + dev_ctx.x_context(), out_grad->data(), in_grad->data(), + rois->data(), roi_id_data, in->dims()[0], channels, height, width, + out_grad->dims()[0], pooled_height, pooled_width, spatial_scale, + sampling_ratio, true); + PADDLE_ENFORCE_EQ( + r, xpu::Error_t::SUCCESS, + platform::errors::External( + "The roi_align_grad XPU OP return wrong value[%d %s]", r, + XPUAPIErrorMsg[r])); + } + if (dev_ctx.x_context()->xpu_stream) { + dev_ctx.Wait(); + } + xpu_free(roi_id_data); } }; } // namespace operators } // namespace paddle + namespace ops = paddle::operators; REGISTER_OP_XPU_KERNEL( roi_align, ops::XPUROIAlignOpKernel); +REGISTER_OP_XPU_KERNEL( + roi_align_grad, + ops::XPUROIAlignGradOpKernel); #endif diff --git a/paddle/fluid/platform/device_context.cc b/paddle/fluid/platform/device_context.cc index f35b64bc0a89e2..beb1db93f483e9 100644 --- a/paddle/fluid/platform/device_context.cc +++ b/paddle/fluid/platform/device_context.cc @@ -472,6 +472,15 @@ MKLDNNDeviceContextThreadLocals::Body::get_cur_paddle_data_layout(void) { return cur_paddle_data_layout; } +void MKLDNNDeviceContextThreadLocals::Body::log_lib_version(void) { + if (!said_once) { + said_once = true; + auto dv = dnnl::version(); + LOG(INFO) << "oneDNN v" << dv->major << "." << dv->minor << "." + << dv->patch; + } +} + void MKLDNNDeviceContext::ResetBlobMap() { std::lock_guard lock(*p_mutex_); if (!block_next_cache_clearing_) { diff --git a/paddle/fluid/platform/device_context.h b/paddle/fluid/platform/device_context.h index 47fd54f96ed028..f0ce89aa5efd86 100644 --- a/paddle/fluid/platform/device_context.h +++ b/paddle/fluid/platform/device_context.h @@ -483,6 +483,7 @@ class MKLDNNDeviceContextThreadLocals { typedef MKLDNNDeviceContextThreadLocals self; struct Body { + bool said_once = false; size_t cur_mkldnn_session_id; // Current data input shape string. // - For fixed-shape, it's a null string in default. @@ -502,6 +503,7 @@ class MKLDNNDeviceContextThreadLocals { void set_cur_input_shape_cache_capacity(int input_shape_cache_capacity); void set_cur_paddle_data_layout(framework::DataLayout dl); framework::DataLayout get_cur_paddle_data_layout(void); + void log_lib_version(void); }; MKLDNNDeviceContextThreadLocals() = default; MKLDNNDeviceContextThreadLocals(const MKLDNNDeviceContextThreadLocals& c) = @@ -549,6 +551,10 @@ class MKLDNNDeviceContext : public CPUDeviceContext { void SetKeySuffix(const std::string& suffix) { key_suffix_ = suffix; } const std::string& GetKeySuffix(void) const { return key_suffix_; } + // Disable adding thread ID to the key + void DisableThreadInfoInKey(void) { key_attach_thread_id_ = false; }; + bool IsThreadIdUsedInKey(void) const { return key_attach_thread_id_; }; + // Prevent next ResetBlobMap() void BlockNextCacheClearing(); @@ -571,6 +577,7 @@ class MKLDNNDeviceContext : public CPUDeviceContext { std::shared_ptr p_mutex_; bool block_next_cache_clearing_ = false; std::string key_suffix_; // Key identifying current Executor + bool key_attach_thread_id_ = true; }; #endif diff --git a/paddle/fluid/platform/mkldnn_helper.h b/paddle/fluid/platform/mkldnn_helper.h index 99044c53d2322a..2de08773df31fa 100644 --- a/paddle/fluid/platform/mkldnn_helper.h +++ b/paddle/fluid/platform/mkldnn_helper.h @@ -431,11 +431,6 @@ inline void AppendKey(std::string* key, const std::vector& dims) { } } -inline unsigned int HashPointer(uintptr_t ptr) { - // Get four less meaningful digits in decimal numerals - return ptr % 1000; -} - // If MKLDNN build and CPU place then register suffix in DeviceContext inline void AttachPointerHashToMKLDNNKey(void* ptr, const platform::Place& place) { @@ -443,20 +438,34 @@ inline void AttachPointerHashToMKLDNNKey(void* ptr, platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); platform::MKLDNNDeviceContext* dev_ctx = (platform::MKLDNNDeviceContext*)pool.Get(place); - dev_ctx->SetKeySuffix("E" + std::to_string(platform::HashPointer( - reinterpret_cast(ptr)))); + dev_ctx->SetKeySuffix("E" + + std::to_string(reinterpret_cast(ptr))); + // When NaiveExecutor/Executor is used no info on thread id is needed in a + // key + dev_ctx->DisableThreadInfoInKey(); } } template -inline std::string CreateKey(ArgTypes&&... args) { +inline std::string CreateKey(const platform::MKLDNNDeviceContext& dev_ctx, + ArgTypes&&... args) { std::string key; key.reserve(64); using expand_type = int[]; expand_type{0, (AppendKey(&key, std::forward(args)), 0)...}; + key += dev_ctx.GetKeySuffix(); return key; } +inline std::string ExtendKeyWithThreadInfoIfNeeded( + const platform::MKLDNNDeviceContext& dev_ctx, const std::string& key) { + return ((dev_ctx.IsThreadIdUsedInKey() == true) && + (platform::MKLDNNDeviceContext::tls().get_cur_mkldnn_session_id() == + platform::MKLDNNDeviceContextThreadLocals::kMKLDNNSessionID_Default)) + ? key + "-t:" + ThreadIDasStr() + : key; +} + inline std::vector> ToMkldnnPadding( const std::vector& paddings) { if (paddings.size() == 6) { diff --git a/paddle/fluid/platform/mkldnn_reuse.h b/paddle/fluid/platform/mkldnn_reuse.h index 6976e55b2305aa..c053815aea796c 100644 --- a/paddle/fluid/platform/mkldnn_reuse.h +++ b/paddle/fluid/platform/mkldnn_reuse.h @@ -43,15 +43,10 @@ class MKLDNNHandlerT { engine_(engine), place_(cpu_place), key_common_(base_key), + key_(platform::ExtendKeyWithThreadInfoIfNeeded(dev_ctx, base_key)), fwd_pd_(nullptr), bwd_pd_(nullptr) { - if (platform::MKLDNNDeviceContext::tls().get_cur_mkldnn_session_id() != - platform::MKLDNNDeviceContextThreadLocals::kMKLDNNSessionID_Default) { - key_ = key_common_; - } else { - key_ = key_common_ + "-t:" + ThreadIDasStr(); - } - key_ += dev_ctx.GetKeySuffix(); + platform::MKLDNNDeviceContext::tls().log_lib_version(); } std::shared_ptr AcquireForwardPrimitive() { @@ -306,8 +301,8 @@ class MKLDNNHandlerT { const MKLDNNDeviceContext& dev_ctx_; mkldnn::engine engine_; platform::Place place_; - std::string key_; std::string key_common_; + std::string key_; std::shared_ptr fwd_pd_; std::shared_ptr bwd_pd_; }; @@ -317,14 +312,11 @@ class MKLDNNHandler { public: MKLDNNHandler(const MKLDNNDeviceContext& dev_ctx, mkldnn::engine engine, const std::string& base_key) - : dev_ctx_(dev_ctx), engine_(engine), key_common_(base_key) { - if (platform::MKLDNNDeviceContext::tls().get_cur_mkldnn_session_id() != - platform::MKLDNNDeviceContextThreadLocals::kMKLDNNSessionID_Default) { - key_ = key_common_; - } else { - key_ = key_common_ + "-t:" + ThreadIDasStr(); - } - key_ += dev_ctx.GetKeySuffix(); + : dev_ctx_(dev_ctx), + engine_(engine), + key_common_(base_key), + key_(platform::ExtendKeyWithThreadInfoIfNeeded(dev_ctx, base_key)) { + platform::MKLDNNDeviceContext::tls().log_lib_version(); } std::shared_ptr AcquireSrcMemory( @@ -508,8 +500,8 @@ class MKLDNNHandler { protected: const MKLDNNDeviceContext& dev_ctx_; mkldnn::engine engine_; - std::string key_; std::string key_common_; + std::string key_; }; template @@ -524,7 +516,7 @@ class BinaryMKLDNNHandler : public platform::MKLDNNHandlerT { : platform::MKLDNNHandlerT( dev_ctx, engine, cpu_place, platform::CreateKey( - framework::vectorize(x->dims()), + dev_ctx, framework::vectorize(x->dims()), uniq_name + (algo == dnnl::algorithm::binary_mul ? "M" : ""))) { // bradcasting combined with in-place may require auto rankdiff = x->dims().size() - y->dims().size(); @@ -627,7 +619,7 @@ class ActivationMKLDNNHandler : platform::MKLDNNHandlerT( dev_ctx, dev_ctx.GetEngine(), cpu_place, - platform::CreateKey(dims, "a", algorithm, unique_name)) { + platform::CreateKey(dev_ctx, dims, "a", algorithm, unique_name)) { auto md = mkldnn::memory::desc(dims, platform::MKLDNNGetDataType(), fmt); this->AcquireForwardPrimitiveDescriptor(mkldnn::prop_kind::forward_training, @@ -645,7 +637,7 @@ class ActivationMKLDNNHandler : platform::MKLDNNHandlerT( dev_ctx, dev_ctx.GetEngine(), cpu_place, - platform::CreateKey(dims, "a", algorithm, unique_name)) { + platform::CreateKey(dev_ctx, dims, "a", algorithm, unique_name)) { auto diff_dst_md = platform::MKLDNNMemDesc( dims, platform::MKLDNNGetDataType(), diff_fmt); auto src_md = @@ -676,7 +668,7 @@ class LRNMKLDNNHandler : platform::MKLDNNHandlerT( dev_ctx, mkldnn_engine, cpu_place, - platform::CreateKey(framework::vectorize(input->dims()), + platform::CreateKey(dev_ctx, framework::vectorize(input->dims()), unique_name)) { if (!this->isCached()) { const int n = ctx.Attr("n"); @@ -712,7 +704,7 @@ class LRNMKLDNNHandler : platform::MKLDNNHandlerT( dev_ctx, dev_ctx.GetEngine(), cpu_place, - platform::CreateKey(dims, unique_name)) { + platform::CreateKey(dev_ctx, dims, unique_name)) { auto src_md = mkldnn::memory::desc(dims, platform::MKLDNNGetDataType(), fmt); auto diff_md = @@ -752,7 +744,7 @@ class PoolingMKLDNNHandler : public MKLDNNHandlerT( dev_ctx, dev_ctx.GetEngine(), cpu_place, - platform::CreateKey(framework::vectorize(input->dims()), + platform::CreateKey(dev_ctx, framework::vectorize(input->dims()), framework::ToMKLDNNDataType(input->type()), unique_name)) { if (!this->isCached()) { @@ -861,7 +853,7 @@ class PoolingMKLDNNHandler : public MKLDNNHandlerT( dev_ctx, dev_ctx.GetEngine(), cpu_place, - platform::CreateKey(diff_src_dims, dt, unique_name)) { + platform::CreateKey(dev_ctx, diff_src_dims, dt, unique_name)) { auto diff_dst_md = mkldnn::memory::desc( diff_dst_dims, platform::MKLDNNGetDataType(), diff_dst_fmt); auto diff_src_md = diff --git a/paddle/scripts/musl_build/Dockerfile b/paddle/scripts/musl_build/Dockerfile index 120b47b21a761f..6621a90802e2b1 100644 --- a/paddle/scripts/musl_build/Dockerfile +++ b/paddle/scripts/musl_build/Dockerfile @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -FROM python:3.7-alpine3.10 +FROM python:3.7-alpine3.11 USER root diff --git a/paddle/scripts/musl_build/package.txt b/paddle/scripts/musl_build/package.txt index ed6796a0d3cc3e..464748419f39f0 100644 --- a/paddle/scripts/musl_build/package.txt +++ b/paddle/scripts/musl_build/package.txt @@ -1,9 +1,9 @@ -linux-headers=4.19.36-r0 -freetype-dev=2.10.0-r1 -libjpeg-turbo-dev=2.0.4-r1 -zlib-dev=1.2.11-r1 -lapack-dev=3.8.0-r1 -openblas-dev=0.3.6-r0 -openssl-dev=1.1.1g-r0 -libuv-dev=1.29.1-r0 +linux-headers +freetype-dev +libjpeg-turbo-dev +zlib-dev +lapack-dev +openblas-dev +openssl-dev +libuv-dev graphviz diff --git a/paddle/scripts/paddle_build.bat b/paddle/scripts/paddle_build.bat index 5ad48734adb482..aee2739b5ab898 100644 --- a/paddle/scripts/paddle_build.bat +++ b/paddle/scripts/paddle_build.bat @@ -424,10 +424,8 @@ test_decoupled_py_reader^|^ test_decoupled_py_reader_data_check^|^ test_eager_deletion_delete_vars^|^ test_eager_deletion_while_op^|^ -test_feed_data_check_shape_type^|^ test_fetch_lod_tensor_array^|^ test_fleet_base_single^|^ -test_fuse_all_reduce_pass^|^ test_fuse_elewise_add_act_pass^|^ test_fuse_optimizer_pass^|^ test_generator_dataloader^|^ @@ -450,7 +448,6 @@ test_imperative_static_runner_while^|^ test_optimizer_in_control_flow^|^ test_fuse_bn_act_pass^|^ test_fuse_bn_add_act_pass^|^ -test_tsm^|^ test_gru_rnn_op^|^ test_rnn_op^|^ test_simple_rnn_op^|^ diff --git a/paddle/scripts/paddle_build.sh b/paddle/scripts/paddle_build.sh index 21eedc6066b49c..e555832ba09365 100755 --- a/paddle/scripts/paddle_build.sh +++ b/paddle/scripts/paddle_build.sh @@ -288,6 +288,7 @@ EOF -DWITH_GLOO=${gloo_flag} \ -DLITE_GIT_TAG=develop \ -DWITH_XPU=${WITH_XPU:-OFF} \ + -DXPU_SDK_ROOT=${XPU_SDK_ROOT:-""} \ -DWITH_LITE=${WITH_LITE:-OFF} \ -DWITH_UNITY_BUILD=${WITH_UNITY_BUILD:-OFF};build_error=$? if [ "$build_error" != 0 ];then diff --git a/python/paddle/distributed/fleet/meta_optimizers/sharding/fp16_helper.py b/python/paddle/distributed/fleet/meta_optimizers/sharding/fp16_helper.py index cf6ab514b0bfe6..03b36262a4fb1e 100644 --- a/python/paddle/distributed/fleet/meta_optimizers/sharding/fp16_helper.py +++ b/python/paddle/distributed/fleet/meta_optimizers/sharding/fp16_helper.py @@ -71,7 +71,11 @@ def remove_cast_op(block, params, segment, offset): return inserted_op_num @staticmethod - def prune_fp16(block, shard, reduced_grads_to_param, nrings): + def prune_fp16(block, shard, reduced_grads_to_param, ring_id): + """ + 1. prune all cast_fp32_to_fp16 ops if the param not belongs to this shard + 2. revise amp inifine grad checking for sharding + """ # remove cast for idx, op in reversed(list(enumerate(block.ops))): if not FP16Utils.is_fp32_cast_op(block, op): @@ -79,9 +83,9 @@ def prune_fp16(block, shard, reduced_grads_to_param, nrings): output_name = op.desc.output_arg_names()[0] param_name = output_name.strip("@GRAD") if param_name not in shard.global_params: - raise ValueError("Input 'X' of check_finite_and_unscale must" - "be grads, but {} is not a grad".format( - input_name)) + raise ValueError("Output 'X' of cast_op must be a grad of" + "model param, but {} is not a grad".format( + output_name)) if output_name in reduced_grads_to_param: continue if shard.has_param(param_name): @@ -137,10 +141,12 @@ def prune_fp16(block, shard, reduced_grads_to_param, nrings): type='c_allreduce_max', inputs={'X': inf_var_fp32}, outputs={'Out': inf_var_fp32}, - attrs={'ring_id': 0, + attrs={'ring_id': ring_id, OP_ROLE_KEY: OpRole.Optimize}) - comm_op_num = insert_sync_comm_ops( - block, update_loss_scaling_op_idx + 3, nrings, [inf_var_fp32]) + + comm_op_num = insert_sync_comm_op(block, update_loss_scaling_op_idx + 3, + ring_id, [inf_var_fp32]) + block._insert_op_without_sync( update_loss_scaling_op_idx + 3 + comm_op_num, type='cast', diff --git a/python/paddle/distributed/fleet/meta_optimizers/sharding/gradient_clip_helper.py b/python/paddle/distributed/fleet/meta_optimizers/sharding/gradient_clip_helper.py index afa46f43fc0fe3..c6aee792fcf745 100644 --- a/python/paddle/distributed/fleet/meta_optimizers/sharding/gradient_clip_helper.py +++ b/python/paddle/distributed/fleet/meta_optimizers/sharding/gradient_clip_helper.py @@ -16,14 +16,19 @@ class GradientClipHelper(object): - def __init__(self): - pass + def __init__(self, sharding_ring_id): + self.sharding_ring_id = sharding_ring_id def _is_gradient_clip_op(self, op): return op.desc.has_attr("op_namescope") \ and op.desc.attr("op_namescope").startswith("/gradient_clip") def prune_gradient_clip(self, block, shard): + """ + prune gradient_clip related ops for params that not belong to cur shard + prune: square, reduce_sum, elementwise_mul + keep: sum, sqrt, elementwise_max, elementwise_div + """ deperated_vars = set() deperate_op_idx = set() for idx, op in enumerate(block.ops): @@ -75,8 +80,10 @@ def prune_gradient_clip(self, block, shard): type='c_allreduce_sum', inputs={'X': sum_res}, outputs={'Out': sum_res}, - attrs={'ring_id': 0, - OP_ROLE_KEY: OpRole.Optimize}) + attrs={ + 'ring_id': self.sharding_ring_id, + OP_ROLE_KEY: OpRole.Optimize + }) block._insert_op_without_sync( idx + 1, type='c_sync_calc_stream', diff --git a/python/paddle/distributed/fleet/meta_optimizers/sharding/prune.py b/python/paddle/distributed/fleet/meta_optimizers/sharding/prune.py index 7348e5f6d1445a..70753b59ccc318 100644 --- a/python/paddle/distributed/fleet/meta_optimizers/sharding/prune.py +++ b/python/paddle/distributed/fleet/meta_optimizers/sharding/prune.py @@ -43,6 +43,7 @@ def get_var_deps(self, var_name): return None def _build_deps(self, ): + for var_name in self._start_vars: self._var_to_use_op[var_name] = [] self._var_to_generate_op[var_name] = [] diff --git a/python/paddle/distributed/fleet/meta_optimizers/sharding/shard.py b/python/paddle/distributed/fleet/meta_optimizers/sharding/shard.py index 27c63fc406fcbf..92e36e0ec1fff3 100644 --- a/python/paddle/distributed/fleet/meta_optimizers/sharding/shard.py +++ b/python/paddle/distributed/fleet/meta_optimizers/sharding/shard.py @@ -124,6 +124,14 @@ def is_opti_var(self, var_name): return True return False + def filter_grads(self, grads): + grads_in_shard = [] + for grad in grads: + param = grad.split("@")[0] + if self.has_param(param): + grads_in_shard.append(grad) + return grads_in_shard + class ProgramSegment(object): def __init__(self, block): diff --git a/python/paddle/distributed/fleet/meta_optimizers/sharding/utils.py b/python/paddle/distributed/fleet/meta_optimizers/sharding/utils.py index b5c34f87cdf225..ad1cd4f60826bb 100755 --- a/python/paddle/distributed/fleet/meta_optimizers/sharding/utils.py +++ b/python/paddle/distributed/fleet/meta_optimizers/sharding/utils.py @@ -78,52 +78,137 @@ def check_broadcast(block): return -def check_allreduce_sum(block): +def check_allreduce_sum(block, shard, dp_ring_id=-1): """ - if a Var is allreduced, the op order should be: - - 0: op that generate Var - - 1: sync_calc - - 2: allreduce_sum op - - 3: sync_comm - - 4: op that use Var + the op order should be: + grad: + - 0: op that generate Var + - 1: sync_calc + - 2: allreduce_sum_sharding + - 3: sync_comm + - 4: allreuce_sum_dp (dp_grads) + - 5: sync_comm (dp_grads) + - 6: op that use Var (dp_grads & sum) """ - var_status = {} - for op in block.ops: + vars_status = {} + dp_grads_status = {} + idx_last_grad_allreduce = -1 + idx_amp_allreduce = -1 + idx_gradient_clip_allreduce = -1 + for idx, op in enumerate(block.ops): if op.type == "c_allreduce_sum": + ring_id = op.desc.attr("ring_id") var_name = op.desc.input_arg_names()[0] - var_status[var_name] = -1 + param = var_name.split("@")[0] + + assert 'sum' in var_name or ("@GRAD" in var_name) + if 'sum' in var_name or (not shard.has_param(param)): + vars_status[var_name] = -1 + else: + dp_grads_status[var_name] = -1 + + if ring_id != 0: + assert shard.has_param(param) + assert ring_id == dp_ring_id + + if "sum" in var_name: + idx_amp_allreduce = idx + elif "@GRAD": + idx_last_grad_allreduce = idx + + if op.type == "c_allreduce_max": + idx_gradient_clip_allreduce = idx for op in block.ops: if op.type == "c_sync_calc_stream": - for var_name in var_status: - if var_name in var_status and var_status[var_name] == 0: - var_status[var_name] = 1 + for var_name in vars_status: + if var_name in vars_status and vars_status[var_name] == 0: + vars_status[var_name] = 1 + for var_name in dp_grads_status: + if var_name in dp_grads_status and dp_grads_status[ + var_name] == 0: + dp_grads_status[var_name] = 1 + elif op.type == "c_allreduce_sum": var_name = op.desc.input_arg_names()[0] - if var_status[var_name] == -1: - raise ValueError("{} is not generated, but you are" - "trying to all-reduce it".format(var_name)) - if var_status[var_name] == 0: - raise ValueError("There should be a sync_calc op " - "after generate Var: {} and before the" - "c_allreduce_sum op".format(var_name)) - assert (var_status[var_name] == 1) - var_status[var_name] = 2 + ring_id = op.desc.attr("ring_id") + if ring_id == 0: + if var_name in vars_status: + _status = vars_status[var_name] + else: + _status = dp_grads_status[var_name] + if _status == -1: + raise ValueError("{} is not generated, but you are" + "trying to all-reduce it".format(var_name)) + if _status == 0: + raise ValueError("There should be a sync_calc op " + "after generate Var: {} and before the" + "c_allreduce_sum op".format(var_name)) + assert (_status == 1) + if var_name in vars_status: + vars_status[var_name] = 2 + else: + dp_grads_status[var_name] = 2 + else: + assert ring_id == dp_ring_id + param = var_name.split("@")[0] + assert shard.has_param(param) + assert dp_grads_status[var_name] == 3 + dp_grads_status[var_name] = 4 + elif op.type == "c_sync_comm_stream": - for var_name in op.desc.input_arg_names(): - if var_name in var_status and var_status[var_name] == 2: - var_status[var_name] = 3 + var_name = op.desc.input_arg_names()[0] + ring_id = op.desc.attr("ring_id") + if ring_id == 0: + for var_name in op.desc.input_arg_names(): + if var_name in vars_status: + assert vars_status[var_name] == 2 + vars_status[var_name] = 3 + elif var_name in dp_grads_status: + assert dp_grads_status[var_name] == 2 + dp_grads_status[var_name] = 3 + else: + for var_name in op.desc.input_arg_names(): + param = var_name.split("@")[0] + assert ring_id == dp_ring_id + assert shard.has_param(param) + assert dp_grads_status[var_name] == 4 + dp_grads_status[var_name] = 5 else: for input_name in op.desc.input_arg_names(): - if input_name in var_status: - if var_status[input_name] != 3: + if input_name in vars_status: + if vars_status[input_name] != 3: raise ValueError("There should be a sync_comm op " "after allreduce the Var: {}".format( - var_name)) + input_name)) + if input_name in dp_grads_status: + if dp_ring_id == -1: + if dp_grads_status[input_name] != 3: + raise ValueError("There should be a sync_comm op " + "after allreduce the Var: {}". + format(input_name)) + else: + if dp_grads_status[input_name] != 5: + raise ValueError( + "The grad in shard should be allreduce and sync" + "twice before usage {}".format(input_name)) + for output_name in op.desc.output_arg_names(): - if output_name in var_status and \ - var_status[output_name] == -1: - var_status[output_name] = 0 + if output_name in vars_status and \ + vars_status[output_name] == -1: + vars_status[output_name] = 0 + if output_name in dp_grads_status and \ + dp_grads_status[output_name] == -1: + dp_grads_status[output_name] = 0 + + # check sharding with amp + if idx_amp_allreduce != -1: + assert idx_amp_allreduce > idx_last_grad_allreduce + + # check sharding with gradient_clip_by_global_norm + if idx_gradient_clip_allreduce != -1: + assert idx_gradient_clip_allreduce > idx_last_grad_allreduce + return @@ -155,20 +240,34 @@ def insert_sync_calc_op(block, insert_idx, calc_dep_vars): return -def insert_sync_comm_ops(block, insert_idx, nrings, comm_dep_vars): +def insert_sync_comm_op(block, insert_idx, ring_id, comm_dep_vars): """ - _insert_sync_comm_ops + insert sync_comm_op for single var """ op_role = get_valid_op_role(block, insert_idx) - for i in range(nrings): - block._insert_op_without_sync( - insert_idx, - type='c_sync_comm_stream', - inputs={'X': comm_dep_vars}, - outputs={'Out': comm_dep_vars}, - attrs={'ring_id': i, - OP_ROLE_KEY: op_role}) - return nrings + block._insert_op_without_sync( + insert_idx, + type='c_sync_comm_stream', + inputs={'X': comm_dep_vars}, + outputs={'Out': comm_dep_vars}, + attrs={'ring_id': ring_id, + OP_ROLE_KEY: op_role}) + return 1 + + +def insert_sync_comm_ops(block, insert_idx, ring_id, comm_dep_vars): + """ + insert sync_comm_op for vars + """ + op_role = get_valid_op_role(block, insert_idx) + block._insert_op_without_sync( + insert_idx, + type='c_sync_comm_stream', + inputs={'X': comm_dep_vars}, + outputs={'Out': comm_dep_vars}, + attrs={'ring_id': int(ring_id), + OP_ROLE_KEY: op_role}) + return 1 def insert_fill_constant_ops(block, insert_idx, fill_constant_vars): @@ -210,13 +309,11 @@ def insert_cast_ops(block, insert_idx, cast_ops): return -def insert_allreduce_ops(block, insert_idx, nrings, allreduce_vars): +def insert_allreduce_ops(block, insert_idx, ring_id, allreduce_vars): """ _add_allreduce_ops """ - ring_id = -1 for var in allreduce_vars: - ring_id = (ring_id + 1) % nrings block._insert_op_without_sync( insert_idx, type='c_allreduce_sum', @@ -224,17 +321,16 @@ def insert_allreduce_ops(block, insert_idx, nrings, allreduce_vars): outputs={'Out': var}, attrs={'ring_id': ring_id, OP_ROLE_KEY: OpRole.Backward}) + return -def insert_broadcast_ops(block, insert_idx, nrings, broadcast2root): +def insert_broadcast_ops(block, insert_idx, ring_id, broadcast2root): """ _add_broadcast_ops """ - ring_id = -1 op_role = get_valid_op_role(block, insert_idx) for broadcast_name, root_device in broadcast2root: - ring_id = (ring_id + 1) % nrings block._insert_op_without_sync( insert_idx, type='c_broadcast', @@ -245,6 +341,7 @@ def insert_broadcast_ops(block, insert_idx, nrings, broadcast2root): 'root': root_device, OP_ROLE_KEY: op_role }) + return diff --git a/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py index a449821f8c2122..a7f704361d31af 100755 --- a/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py @@ -24,7 +24,7 @@ from paddle.distributed.fleet.meta_optimizers.sharding.gradient_clip_helper import GradientClipHelper from paddle.distributed.fleet.meta_optimizers.sharding.prune import ProgramDeps from paddle.distributed.fleet.meta_optimizers.sharding.utils import * - +import logging from functools import reduce __all__ = ["ShardingOptimizer"] @@ -37,6 +37,8 @@ def __init__(self, optimizer): self.meta_optimizers_white_list = [ "RecomputeOptimizer", "AMPOptimizer", + "LarsOptimizer", + "LambOptimizer", ] self.meta_optimizers_black_list = ["GraphExecutionOptimizer", ] self._main_program = None @@ -69,9 +71,14 @@ def minimize_impl(self, startup_program=None, parameter_list=None, no_grad_set=None): - self._nrings = self.user_defined_strategy.nccl_comm_num + # TODO: (JZ-LIANG) support multiple comm in future + # self._nrings = self.user_defined_strategy.nccl_comm_num + self._nrings_sharding = 1 + self._nrings_dp = 1 self._fuse_broadcast_MB = self.user_defined_strategy.sharding_configs[ "fuse_broadcast_MB"] + self.hybrid_dp = self.user_defined_strategy.sharding_configs[ + "hybrid_dp"] if self.inner_opt is None: raise ValueError( @@ -108,28 +115,38 @@ def minimize_impl(self, # check op dependecy check_broadcast(main_block) - check_allreduce_sum(main_block) + check_allreduce_sum(main_block, self._shard, self.dp_ring_id) self._wait() return optimize_ops, params_grads def _set_up(self, params_grads): # step 1: initialize nccl - worker_idx = self.role_maker._worker_index() - endpoints = self.role_maker._get_trainer_endpoints() - current_endpoint = endpoints[worker_idx] + self.global_word_size = self.role_maker._worker_num() + self.global_rank = self.role_maker._worker_index() + self.endpoints = self.role_maker._get_trainer_endpoints() + self.current_endpoint = self.endpoints[self.global_rank] self._collective_helper = CollectiveHelper(self.role_maker, - self._nrings) - for ring_id in range(self._nrings): + self._nrings_sharding) + # config sharding & dp groups + self._init_comm() + # sharding + self._collective_helper._init_communicator( + self._startup_program, self.current_endpoint, + self.sharding_group_endpoints, self.sharding_rank, + self.sharding_ring_id, True) + # dp + if self.hybrid_dp: self._collective_helper._init_communicator( - self._startup_program, current_endpoint, endpoints, worker_idx, - ring_id, None) + self._startup_program, self.current_endpoint, + self.dp_group_endpoints, self.dp_rank, self.dp_ring_id, True) + startup_block = self._startup_program.global_block() startup_block._sync_with_cpp() # step 2: split params self._params = set([x[0].name for x in params_grads]) - self._shard.setup(params_grads, worker_idx, - self.role_maker._worker_num()) + self._shard.setup(params_grads, self.sharding_rank, + self.sharding_group_size) # step 3: get broadcast vars self._broadcast_vars = self._shard.find_broadcast_params( @@ -208,12 +225,18 @@ def _prune_main_program(self, block): """ calculate deps from allredce op to optimize op, remove ops and vars not needed in this worker + + 1. prune regularization (weight decay) + 2. prune cast_fp32_to_fp16; update amp_infine_checking + 3. prune gradient_clip related; update global_norm_sum + 4. prune optimizer op + param + gradient + """ weightdecay_helper = WeightDecayHelper() weightdecay_helper.prune_weight_decay(block, self._shard) FP16Utils.prune_fp16(block, self._shard, self._reduced_grads_to_param, - self._nrings) - gradientclip_helper = GradientClipHelper() + self.sharding_ring_id) + gradientclip_helper = GradientClipHelper(self.sharding_ring_id) gradientclip_helper.prune_gradient_clip(block, self._shard) # build prog deps @@ -226,6 +249,7 @@ def _prune_main_program(self, block): output_name = output_names[0] reduced_grads.append(output_name) + # prune optimizer state and param pruned_opti_vars = [] for var_name in list(block.vars.keys()): if self._shard.is_opti_var(var_name) and \ @@ -273,6 +297,8 @@ def _prune_main_program(self, block): op.desc.set_input('Input', reversed_input_vars) op.desc.set_output('Out', reversed_output_vars) else: + # if all outputs of this op are in _should_removed_var + # _should_removed_var: opt state not cur shard if program_deps.should_remove_op(idx): program_deps.remove_op(idx) @@ -283,16 +309,22 @@ def _add_broadcast_allreduce(self, block): """ _add_broadcast_allreduce """ - ring_id = -1 if len(self._segments) < 1: return - + # sharding if self._segments[-1]._allreduce_vars: + shard_allredue_vars = self._shard.filter_grads(self._segments[-1] + ._allreduce_vars) + if self.hybrid_dp and len(shard_allredue_vars) >= 1: + insert_sync_comm_ops(block, self._segments[-1]._end_idx, + self.dp_ring_id, shard_allredue_vars) + insert_allreduce_ops(block, self._segments[-1]._end_idx, + self.dp_ring_id, shard_allredue_vars) insert_sync_comm_ops(block, self._segments[-1]._end_idx, - self._nrings, + self.sharding_ring_id, self._segments[-1]._allreduce_vars) insert_allreduce_ops(block, self._segments[-1]._end_idx, - self._nrings, + self.sharding_ring_id, self._segments[-1]._allreduce_vars) for idx, segment in reversed(list(enumerate(self._segments))): @@ -331,13 +363,21 @@ def _add_broadcast_allreduce(self, block): segment, 0) # step2: add Sync ops - comm_dep_vars = allreduce_vars + [x[0] for x in broadcast_vars] - if len(comm_dep_vars) > 0: - insert_sync_comm_ops( - block, - segment._end_idx, - self._nrings, - comm_dep_vars, ) + shard_allredue_vars = self._shard.filter_grads(allreduce_vars) + if self.hybrid_dp and len(shard_allredue_vars) >= 1: + insert_sync_comm_ops(block, segment._end_idx, self.dp_ring_id, + shard_allredue_vars) + + broad_cast_vars = [x[0] for x in broadcast_vars] + if len(broad_cast_vars) > 0: + insert_sync_comm_ops(block, segment._end_idx, + self.sharding_ring_id, broad_cast_vars) + else: + comm_dep_vars = allreduce_vars + [x[0] for x in broadcast_vars] + if len(comm_dep_vars) > 0: + insert_sync_comm_ops(block, segment._end_idx, + self.sharding_ring_id, comm_dep_vars) + calc_dep_vars = fill_constant_vars + [ k for k, v in cast_ops.items() ] + self._segments[idx]._allreduce_vars @@ -354,21 +394,27 @@ def _add_broadcast_allreduce(self, block): insert_cast_ops(block, segment._end_idx, cast_ops) # step5: add broadcast ops - insert_broadcast_ops(block, segment._start_idx, self._nrings, - broadcast_vars) - + insert_broadcast_ops(block, segment._start_idx, + self.sharding_ring_id, broadcast_vars) # step6: add all_reduce ops - insert_allreduce_ops(block, segment._start_idx, self._nrings, - allreduce_vars) + # dp + if self.hybrid_dp and len(shard_allredue_vars) >= 1: + insert_allreduce_ops(block, segment._start_idx, self.dp_ring_id, + shard_allredue_vars) + insert_sync_comm_ops(block, segment._start_idx, + self.sharding_ring_id, allreduce_vars) + # sharding + insert_allreduce_ops(block, segment._start_idx, + self.sharding_ring_id, allreduce_vars) block._sync_with_cpp() if self._segments[0]._broadcast_vars: - insert_sync_comm_ops( - block, self._segments[0]._start_idx, self._nrings, - [x[0] for x in self._segments[0]._broadcast_vars]) + broadcast_vars = [x[0] for x in self._segments[0]._broadcast_vars] + insert_sync_comm_ops(block, self._segments[0]._start_idx, + self.sharding_ring_id, broadcast_vars) insert_broadcast_ops(block, self._segments[0]._start_idx, - self._nrings, + self.sharding_ring_id, self._segments[0]._broadcast_vars) fill_constant_vars = [] @@ -409,3 +455,60 @@ def _prune_startup_program(self, block): continue block._remove_var(var_name, sync=False) block._sync_with_cpp() + + def _init_comm(self): + + if self.hybrid_dp: + self.sharding_group_size = self.user_defined_strategy.sharding_configs[ + "sharding_group_size"] + self.sharding_ring_id = 0 + self.sharding_rank = self.global_rank % self.sharding_group_size + + self.dp_group_size = self.global_word_size // self.sharding_group_size + self.dp_rank = self.global_rank // self.sharding_group_size + self.dp_ring_id = self.sharding_rank + 1 + + self.sharding_group_endpoints = [ + ep for idx, ep in enumerate(self.endpoints) + if (idx // self.sharding_group_size) == self.dp_rank + ] + self.dp_group_endpoints = [ + ep for idx, ep in enumerate(self.endpoints) + if (idx % self.sharding_group_size) == self.sharding_rank + ] + assert self.global_word_size > self.sharding_group_size, \ + "global_word_size: {} should be larger than sharding_group_size: {}".format(self.global_word_size, self.sharding_group_size) + assert self.global_word_size % self.sharding_group_size == 0, \ + "global_word_size: {} should be divisible to the sharding_group_size: {}".format(self.global_word_size, self.sharding_group_size) + assert self.dp_group_size * self.sharding_group_size == self.global_word_size, \ + "global_word_size: {} should be equal to the product of sharding_group_size: {} and dp_group_size: {}".format( + self.global_word_size, + self.sharding_group_size, + self.dp_group_size) + + logging.info("Using Sharing&DP mode !") + else: + self.sharding_ring_id = 0 + self.sharding_rank = self.global_rank + self.sharding_group_size = self.role_maker._worker_num() + self.sharding_group_endpoints = self.endpoints + self.dp_ring_id = -1 + self.dp_rank = -1 + self.dp_group_size = None + self.dp_group_endpoints = None + + logging.info("Using Sharing alone mode !") + + logging.info("global word size: {}".format(self.global_word_size)) + logging.info("global rank: {}".format(self.global_rank)) + logging.info("sharding group_size: {}".format(self.sharding_group_size)) + logging.info("sharding rank: {}".format(self.sharding_rank)) + logging.info("dp group size: {}".format(self.dp_group_size)) + logging.info("dp rank: {}".format(self.dp_rank)) + logging.info("current endpoint: {}".format(self.current_endpoint)) + logging.info("sharding group endpoints: {}".format( + self.sharding_group_endpoints)) + logging.info("dp group endpoints: {}".format(self.dp_group_endpoints)) + logging.info("global word endpoints: {}".format(self.endpoints)) + + return diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/utils.py b/python/paddle/fluid/dygraph/dygraph_to_static/utils.py index f3ab02c62f9802..d299e63fd0073e 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/utils.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/utils.py @@ -1028,18 +1028,40 @@ def _build_step_node(self): return step_node def _build_cond_stmt(self, step_node, compare_node): - return gast.Compare( - left=gast.BinOp( + if not isinstance(step_node, (gast.Constant, gast.UnaryOp)): + raise NotImplementedError( + "Dynamic-to-Static only supports the step value is a constant or negative constant in 'for-range' statements, " + "such as '2', '-3'. But received: '{}'. Please fix code to be compatible with Dynamic-to-Static." + .format(ast_to_source_code(step_node).strip())) + + if isinstance(step_node, gast.UnaryOp) or step_node.value < 0: + # eg: + # range(max, min, -2) + # -> + # i > min + return gast.Compare( left=gast.Name( id=self.iter_var_name if self.is_for_range_iter() else self.iter_idx_name, ctx=gast.Load(), annotation=None, type_comment=None), - op=gast.Add(), - right=step_node), - ops=[gast.LtE()], - comparators=[compare_node]) + ops=[gast.Gt()], + comparators=[compare_node]) + else: + # eg: + # range(min, max, 2) + # -> + # i < max + return gast.Compare( + left=gast.Name( + id=self.iter_var_name + if self.is_for_range_iter() else self.iter_idx_name, + ctx=gast.Load(), + annotation=None, + type_comment=None), + ops=[gast.Lt()], + comparators=[compare_node]) def _build_index_increase_node(self, step_node): return gast.AugAssign( diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index 10fb99dd971526..07889ea952b478 100644 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -18,6 +18,7 @@ list(APPEND DIST_TEST_OPS test_parallel_dygraph_transformer) list(APPEND DIST_TEST_OPS test_fleet_pipeline_meta_optimizer) list(APPEND DIST_TEST_OPS test_listen_and_serv_op) list(APPEND DIST_TEST_OPS test_fleet_graph_execution_meta_optimizer) +list(APPEND DIST_TEST_OPS test_gen_nccl_id_op) set(MIXED_DIST_TEST_OPS ${DIST_TEST_OPS}) #remove distribute unittests. list(APPEND MIXED_DIST_TEST_OPS test_dgc_op) @@ -122,17 +123,6 @@ if(NOT WITH_DISTRIBUTE) LIST(REMOVE_ITEM TEST_OPS test_program_code_dist) endif() -if(WITH_MUSL) - # TODO: In the musl docker environment provided by SEC, - # the calculation accuracy of testcase in this unittest - # cannot meet the requirement, error like: - # AssertionError: - # 2.3044646853182973e-07 not less than or equal to 1e-07 - # SEC needs to follow up on this issue, and need to be - # resolved before CI requared - LIST(REMOVE_ITEM TEST_OPS test_sigmoid_focal_loss_op) -endif() - if(WIN32) LIST(REMOVE_ITEM TEST_OPS test_rnn_decode_api) LIST(REMOVE_ITEM TEST_OPS test_complex_matmul) @@ -348,7 +338,7 @@ function(parallel_bash_test_modules TARGET_NAME) endif() endfunction() - +list(REMOVE_ITEM TEST_OPS test_feed_data_check_shape_type) list(REMOVE_ITEM TEST_OPS test_warpctc_op) list(REMOVE_ITEM TEST_OPS test_parallel_executor_crf) list(REMOVE_ITEM TEST_OPS test_parallel_executor_profiler) @@ -375,6 +365,7 @@ list(REMOVE_ITEM TEST_OPS test_basic_gru_api) list(REMOVE_ITEM TEST_OPS test_basic_gru_unit_op) list(REMOVE_ITEM TEST_OPS test_basic_lstm_api) list(REMOVE_ITEM TEST_OPS test_basic_lstm_unit_op) +list(REMOVE_ITEM TEST_OPS test_fuse_all_reduce_pass) list(REMOVE_ITEM TEST_OPS test_fuse_bn_act_pass) list(REMOVE_ITEM TEST_OPS test_fuse_bn_add_act_pass) list(REMOVE_ITEM TEST_OPS test_imperative_static_runner_mnist) @@ -554,12 +545,24 @@ if(WITH_DISTRIBUTE) endif() py_test_modules(test_parallel_executor_crf MODULES test_parallel_executor_crf) -py_test_modules(test_parallel_executor_profiler MODULES test_parallel_executor_profiler) +# Coverage pipeline use cuda 10.1 now, profiler will random hang in cuda 10.1, +# see https://github.com/PaddlePaddle/Paddle/issues/29082 for details. +# We guess there are some bugs in cuda 10.1 or 10.2, +# since this unittest is stable in cuda 11 (py3 pipeline) now. +if(NOT WITH_COVERAGE) + py_test_modules(test_parallel_executor_profiler MODULES test_parallel_executor_profiler) + set_tests_properties(test_parallel_executor_profiler PROPERTIES LABELS "RUN_TYPE=DIST") + set_tests_properties(test_parallel_executor_profiler PROPERTIES TIMEOUT 120) +endif() py_test_modules(test_parallel_executor_transformer MODULES test_parallel_executor_transformer) if(WIN32) py_test_modules(test_parallel_executor_transformer_auto_growth MODULES test_parallel_executor_transformer_auto_growth ENVS FLAGS_allocator_strategy=auto_growth CUDA_VISIBLE_DEVICES=0) + py_test_modules(test_fuse_all_reduce_pass MODULES test_fuse_all_reduce_pass ENVS CUDA_VISIBLE_DEVICES=0) + py_test_modules(test_feed_data_check_shape_type MODULES test_feed_data_check_shape_type ENVS CUDA_VISIBLE_DEVICES=0) else() py_test_modules(test_parallel_executor_transformer_auto_growth MODULES test_parallel_executor_transformer_auto_growth ENVS FLAGS_allocator_strategy=auto_growth) + py_test_modules(test_fuse_all_reduce_pass MODULES test_fuse_all_reduce_pass) + py_test_modules(test_feed_data_check_shape_type MODULES test_feed_data_check_shape_type) endif() py_test_modules(test_data_norm_op MODULES test_data_norm_op) @@ -634,7 +637,6 @@ set_tests_properties(test_parallel_executor_crf test_sync_batch_norm_op test_inp test_parallel_executor_seresnext_base_gpu test_parallel_executor_seresnext_with_reduce_gpu test_parallel_executor_seresnext_with_fuse_all_reduce_gpu - test_parallel_executor_profiler test_parallel_executor_fetch_isolated_var PROPERTIES LABELS "RUN_TYPE=DIST") @@ -722,7 +724,6 @@ set_tests_properties(test_concat_op PROPERTIES TIMEOUT 120) set_tests_properties(test_partial_eager_deletion_transformer PROPERTIES TIMEOUT 120) set_tests_properties(test_parallel_executor_seresnext_with_reduce_gpu PROPERTIES TIMEOUT 120) set_tests_properties(test_dropout_op PROPERTIES TIMEOUT 120) -set_tests_properties(test_parallel_executor_profiler PROPERTIES TIMEOUT 120) set_tests_properties(test_argsort_op PROPERTIES TIMEOUT 120) set_tests_properties(test_sequence_pool PROPERTIES TIMEOUT 120) set_tests_properties(test_gather_nd_op PROPERTIES TIMEOUT 120) diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/CMakeLists.txt b/python/paddle/fluid/tests/unittests/dygraph_to_static/CMakeLists.txt index d2b0d520874721..383ef293139b81 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/CMakeLists.txt @@ -1,27 +1,13 @@ file(GLOB TEST_OPS RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "test_*.py") string(REPLACE ".py" "" TEST_OPS "${TEST_OPS}") -if(WITH_MUSL) - # TODO: In the musl docker environment provided by SEC, - # the test_yolov3 will randomly calculate the result of - # nan, error like: - # AssertionError: - # dygraph_loss: [15742.11914062 9392.61047363] - # static_loss: [nan, nan] - # SEC needs to follow up on this issue, and need to be - # resolved before CI requared - LIST(REMOVE_ITEM TEST_OPS test_yolov3) -endif() - foreach(TEST_OP ${TEST_OPS}) py_test_modules(${TEST_OP} MODULES ${TEST_OP}) endforeach(TEST_OP) set_tests_properties(test_se_resnet PROPERTIES TIMEOUT 900) set_tests_properties(test_tsm PROPERTIES TIMEOUT 900) -if(NOT WITH_MUSL) - set_tests_properties(test_yolov3 PROPERTIES TIMEOUT 900 LABELS "RUN_TYPE=EXCLUSIVE") -endif() +set_tests_properties(test_yolov3 PROPERTIES TIMEOUT 900 LABELS "RUN_TYPE=EXCLUSIVE") set_tests_properties(test_mobile_net PROPERTIES TIMEOUT 120) set_tests_properties(test_seq2seq PROPERTIES TIMEOUT 120) set_tests_properties(test_cycle_gan PROPERTIES TIMEOUT 120) diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_loop.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_loop.py index 2f107e53ab4436..b6aa73d37639b8 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_loop.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_loop.py @@ -94,6 +94,28 @@ def for_loop_dyfunc2(max_len): return ret +def for_loop_dyfunc3(max_len): + ret = fluid.layers.zeros(shape=[1], dtype='float32') + for i in range(1, 10, 2): + fluid.layers.increment(ret, value=2.0, in_place=True) + return ret + + +def for_loop_dyfunc4(max_len): + ret = fluid.layers.zeros(shape=[1], dtype='float32') + for i in range(10, 1, -2): + fluid.layers.increment(ret, value=2.0, in_place=True) + return ret + + +def for_loop_dyfunc_not_support(max_len): + ret = fluid.layers.zeros(shape=[1], dtype='float32') + a = -2 + for i in range(10, 1, a): + fluid.layers.increment(ret, value=2.0, in_place=True) + return ret + + def while_loop_bool_op(x): i = fluid.dygraph.to_variable(x) @@ -333,6 +355,16 @@ def _init_dyfunc(self): self.dyfunc = for_loop_dyfunc2 +class TestTransformForLoop3(TestTransformForLoop): + def _init_dyfunc(self): + self.dyfunc = for_loop_dyfunc3 + + +class TestTransformForLoop4(TestTransformForLoop): + def _init_dyfunc(self): + self.dyfunc = for_loop_dyfunc4 + + class TestClassVarInForLoop(TestTransformForLoop): def _init_dyfunc(self): self.dyfunc = for_loop_class_var @@ -343,5 +375,17 @@ def _init_dyfunc(self): self.dyfunc = var_create_in_for_loop +class TestErrorInForLoop(TestTransformForLoop): + def _init_dyfunc(self): + self.dyfunc = for_loop_dyfunc_not_support + + def test_ast_to_func(self): + with self.assertRaisesRegexp( + NotImplementedError, + "Dynamic-to-Static only supports the step value is a constant or negative constant " + ): + self._run_static() + + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/tsm.yaml b/python/paddle/fluid/tests/unittests/dygraph_to_static/tsm.yaml index 9b682dbd6fb201..ecd320348bb729 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/tsm.yaml +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/tsm.yaml @@ -15,7 +15,7 @@ TRAIN: target_size: 224 num_reader_threads: 12 buf_size: 1024 - batch_size: 4 #128 + batch_size: 2 #128 use_gpu: True num_gpus: 1 #8 filelist: "./data/dataset/kinetics/train.list" @@ -24,7 +24,7 @@ TRAIN: decay_epochs: [40, 60] l2_weight_decay: 1e-4 momentum: 0.9 - total_videos: 8000 #239781 + total_videos: 4000 #239781 VALID: short_size: 256 diff --git a/python/paddle/fluid/tests/unittests/test_deform_conv2d.py b/python/paddle/fluid/tests/unittests/test_deform_conv2d.py index 660625c9bf7561..dc57e87f94022e 100644 --- a/python/paddle/fluid/tests/unittests/test_deform_conv2d.py +++ b/python/paddle/fluid/tests/unittests/test_deform_conv2d.py @@ -22,11 +22,11 @@ class TestDeformConv2D(TestCase): batch_size = 4 - spatial_shape = (16, 16) + spatial_shape = (5, 5) dtype = "float32" def setUp(self): - self.in_channels = 3 + self.in_channels = 2 self.out_channels = 5 self.kernel_size = [3, 3] self.padding = [0, 0] @@ -36,6 +36,8 @@ def setUp(self): self.no_bias = True def prepare(self): + np.random.seed(1) + paddle.seed(1) if isinstance(self.kernel_size, int): filter_shape = (self.kernel_size, ) * 2 else: @@ -182,11 +184,11 @@ def test_identity(self): class TestDeformConv2DFunctional(TestCase): batch_size = 4 - spatial_shape = (16, 16) + spatial_shape = (5, 5) dtype = "float32" def setUp(self): - self.in_channels = 3 + self.in_channels = 2 self.out_channels = 5 self.kernel_size = [3, 3] self.padding = [0, 0] @@ -196,6 +198,8 @@ def setUp(self): self.no_bias = True def prepare(self): + np.random.seed(1) + paddle.seed(1) if isinstance(self.kernel_size, int): filter_shape = (self.kernel_size, ) * 2 else: diff --git a/python/paddle/fluid/tests/unittests/test_dist_base.py b/python/paddle/fluid/tests/unittests/test_dist_base.py index 19d9031573df82..29ac46e81d85db 100644 --- a/python/paddle/fluid/tests/unittests/test_dist_base.py +++ b/python/paddle/fluid/tests/unittests/test_dist_base.py @@ -945,7 +945,7 @@ def _get_nccl2_trainer_cmd(self, model, ep, update_method, trainer_id, tr_cmd += " --use_cuda" env.update({ "FLAGS_selected_gpus": "{}".format(0), - "CUDA_VISIBLE_DEVICES": "{}".format(trainer_id % 2), + "CUDA_VISIBLE_DEVICES": "{}".format(trainer_id), "PADDLE_TRAINERS_NUM": "{}".format(trainer_num), "PADDLE_TRAINER_ID": "{}".format(trainer_id), "PADDLE_TRAINER_ENDPOINTS": self._ps_endpoints, @@ -960,7 +960,7 @@ def _get_nccl2_trainer_cmd(self, model, ep, update_method, trainer_id, if self._pipeline_mode: tr_cmd += " --use_pipeline" if self._mp_mode: - env = {"FLAGS_selected_gpus": "{}".format(trainer_id % 2)} + env = {"FLAGS_selected_gpus": "{}".format(trainer_id)} if self._nccl_comm_num > 1: tr_cmd += " --nccl_comm_num {}".format(self._nccl_comm_num) @@ -992,6 +992,7 @@ def _run_cluster_nccl2(self, model, envs, nccl2_reduce_layer, global DIST_UT_PORT if DIST_UT_PORT == 0: + # NOTE(wangxi). hallreduce test must use 4cards after nccl>=2.7 for i in range(0, 4): self._ps_endpoints += "127.0.0.1:%s," % ( self._find_free_port()) @@ -1110,7 +1111,8 @@ def _get_required_envs(self, check_error_log=False, need_envs={}): required_envs["GLOG_vmodule"] = \ "fused_all_reduce_op_handle=10,all_reduce_op_handle=10,alloc_continuous_space_op=10,fuse_all_reduce_op_pass=10," \ "alloc_continuous_space_for_grad_pass=10,fast_threaded_ssa_graph_executor=10,executor=10,operator=10," \ - "sparse_all_reduce_op_handle=10,gen_nccl_id_op=10,nccl_helper=10,grpc_client=10,grpc_server=10,request_handler_impl=10" + "sparse_all_reduce_op_handle=10,gen_nccl_id_op=10,gen_nccl_id_op_help=10,nccl_helper=10,grpc_client=10," \ + "grpc_server=10,request_handler_impl=10" required_envs["GLOG_logtostderr"] = "1" required_envs.update(need_envs) diff --git a/python/paddle/fluid/tests/unittests/test_dist_mnist_hallreduce.py b/python/paddle/fluid/tests/unittests/test_dist_mnist_hallreduce.py index 356c5573f95308..e1fbbebe171fce 100644 --- a/python/paddle/fluid/tests/unittests/test_dist_mnist_hallreduce.py +++ b/python/paddle/fluid/tests/unittests/test_dist_mnist_hallreduce.py @@ -29,6 +29,7 @@ def _setup_config(self): self._use_reduce = False self._use_reader_alloc = False self._nccl2_mode = True + # NOTE(wangxi). hallreduce test must use 4cards after nccl>=2.7 self._use_hallreduce = True def test_dist_train(self): diff --git a/python/paddle/fluid/tests/unittests/test_elementwise_add_op.py b/python/paddle/fluid/tests/unittests/test_elementwise_add_op.py index c941d7c5f34352..49c2467c9ffeb3 100644 --- a/python/paddle/fluid/tests/unittests/test_elementwise_add_op.py +++ b/python/paddle/fluid/tests/unittests/test_elementwise_add_op.py @@ -351,6 +351,16 @@ def init_axis(self): self.axis = -1 +class TestElementwiseFP16AddOp_commonuse_add1(TestFP16ElementwiseAddOp): + def init_input_output(self): + self.x = np.random.rand(20, 30, 100).astype(self.dtype) + self.y = np.random.rand(1, 1, 100).astype(self.dtype) + self.out = self.x + self.y + + def init_axis(self): + self.axis = -1 + + class TestElementwiseAddOp_commonuse_add2(TestElementwiseAddOp): def init_input_output(self): self.x = np.random.rand(10, 3, 1, 4).astype(self.dtype) @@ -429,4 +439,5 @@ def test_dygraph(self): if __name__ == '__main__': + paddle.enable_static() unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_gen_nccl_id_op.py b/python/paddle/fluid/tests/unittests/test_gen_nccl_id_op.py new file mode 100644 index 00000000000000..bd186e09006d1e --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_gen_nccl_id_op.py @@ -0,0 +1,118 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import os +from launch_function_helper import wait, _find_free_port +from multiprocessing import Pool, Process + +os.environ['GLOG_vmodule'] = str("gen_nccl_id_op*=10") + +import paddle +from paddle.fluid import core + +paddle.enable_static() + + +def run_gen_ncc_id(attr): + nccl_comm_num = attr['nccl_comm_num'] + use_hallreduce = attr['use_hierarchical_allreduce'] + + startup_program = paddle.static.default_startup_program() + main_program = paddle.static.default_main_program() + + with paddle.static.program_guard(main_program, startup_program): + nccl_id_var = startup_program.global_block().create_var( + name="NCCLID", persistable=True, type=core.VarDesc.VarType.RAW) + + for i in range(1, nccl_comm_num): + startup_program.global_block().create_var( + name="NCCLID_{}".format(i), + persistable=True, + type=core.VarDesc.VarType.RAW) + + if use_hallreduce: + for i in range(0, nccl_comm_num): + startup_program.global_block().create_var( + name="Hierarchical_inter_NCCLID_{}".format(i), + persistable=True, + type=core.VarDesc.VarType.RAW) + startup_program.global_block().create_var( + name="Hierarchical_exter_NCCLID_{}".format(i), + persistable=True, + type=core.VarDesc.VarType.RAW) + + startup_program.global_block().append_op( + type="gen_nccl_id", + inputs={}, + outputs={"NCCLID": nccl_id_var}, + attrs=attr) + + place = paddle.CPUPlace() + + exe = paddle.static.Executor(place) + exe.run(startup_program) + + +class TestGenNcclIdOp(unittest.TestCase): + def setUp(self): + try: + self._dist_ut_port_0 = int(os.environ["PADDLE_DIST_UT_PORT"]) + except Exception as e: + self._dist_ut_port_0 = _find_free_port(set()) + + def gen_nccl_id(self, nranks=2): + nccl_comm_num = 1 + if nranks == 2: + use_hallreduce = False + hallreduce_inter_nranks = -1 + elif nranks == 4: + use_hallreduce = True + hallreduce_inter_nranks = 2 + + port = self._dist_ut_port_0 + trainers = [] + for i in range(nranks): + trainers.append('127.0.0.1:{}'.format(port + i)) + + attr = { + "trainers": trainers, + "trainer_id": 0, + "nccl_comm_num": nccl_comm_num, + "use_hierarchical_allreduce": use_hallreduce, + "hierarchical_allreduce_inter_nranks": hallreduce_inter_nranks, + } + + procs = [] + for i in range(nranks): + attr['trainer_id'] = i + p = Process(target=run_gen_ncc_id, args=(attr, )) + p.start() + procs.append(p) + + wait(procs, timeout=120) + + def test_flat(self): + print(">>> test gen flat nccl id") + self.gen_nccl_id(2) + print("<<< end test gen flat nccl id") + + def test_hierarchical(self): + print(">>> test gen hierarchical nccl id") + self.gen_nccl_id(4) + print("<<< end test gen hierarchical nccl id") + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/xpu/test_affine_channel_op_xpu.py b/python/paddle/fluid/tests/unittests/xpu/test_affine_channel_op_xpu.py new file mode 100644 index 00000000000000..3385d671d7332c --- /dev/null +++ b/python/paddle/fluid/tests/unittests/xpu/test_affine_channel_op_xpu.py @@ -0,0 +1,148 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Unit testing for affine_channel_op +""" + +from __future__ import print_function + +import sys +sys.path.append("..") + +import unittest +import numpy as np +from op_test_xpu import XPUOpTest +import paddle +import paddle.fluid.core as core +import paddle.fluid as fluid + + +def affine_channel(x, scale, bias, layout): + C = x.shape[1] if layout == 'NCHW' else x.shape[-1] + if len(x.shape) == 4: + new_shape = (1, C, 1, 1) if layout == 'NCHW' else (1, 1, 1, C) + else: + new_shape = (1, C) + scale = scale.reshape(new_shape) + bias = bias.reshape(new_shape) + return x * scale + bias + + +class TestAffineChannelOp(XPUOpTest): + def setUp(self): + self.op_type = "affine_channel" + self.init_test_case() + + x = np.random.random(self.shape).astype("float32") + scale = np.random.random(self.C).astype("float32") + bias = np.random.random(self.C).astype("float32") + + y = affine_channel(x, scale, bias, self.layout) + + self.inputs = {'X': x, 'Scale': scale, 'Bias': bias} + self.attrs = {'data_layout': self.layout} + self.outputs = {'Out': y} + + def test_check_output(self): + if core.is_compiled_with_xpu(): + paddle.enable_static() + place = paddle.XPUPlace(0) + self.check_output_with_place(place) + + def test_check_grad(self): + if core.is_compiled_with_xpu(): + paddle.enable_static() + place = paddle.XPUPlace(0) + self.check_grad_with_place(place, ['X', 'Scale', 'Bias'], 'Out') + + def test_check_grad_stopgrad_dx(self): + if core.is_compiled_with_xpu(): + paddle.enable_static() + place = paddle.XPUPlace(0) + self.check_grad_with_place( + place, ['Scale', 'Bias'], 'Out', no_grad_set=set('X')) + + def test_check_grad_stopgrad_dscale_dbias(self): + if core.is_compiled_with_xpu(): + paddle.enable_static() + place = paddle.XPUPlace(0) + self.check_grad_with_place( + place, ['X'], 'Out', no_grad_set=set(['Scale', 'Bias'])) + + def init_test_case(self): + self.shape = [2, 100, 3, 3] + self.C = 100 + self.layout = 'NCHW' + + +class TestAffineChannelOpError(unittest.TestCase): + def test_errors(self): + with fluid.program_guard(fluid.Program()): + + def test_x_type(): + input_data = np.random.random(2, 1, 2, 2).astype("float32") + fluid.layers.affine_channel(input_data) + + self.assertRaises(TypeError, test_x_type) + + def test_x_dtype(): + x2 = fluid.layers.data( + name='x2', shape=[None, 1, 2, 2], dtype='int32') + fluid.layers.affine_channel(x2) + + self.assertRaises(TypeError, test_x_dtype) + + def test_scale_type(): + x3 = fluid.layers.data( + name='x3', shape=[None, 1, 2, 2], dtype='float32') + fluid.layers.affine_channel(x3, scale=1) + + self.assertRaises(TypeError, test_scale_type) + + def test_bias_type(): + x4 = fluid.layers.data( + name='x4', shape=[None, 1, 2, 2], dtype='float32') + fluid.layers.affine_channel(x4, bias=1) + + self.assertRaises(TypeError, test_bias_type) + + +class TestAffineChannelNHWC(TestAffineChannelOp): + def init_test_case(self): + self.shape = [2, 3, 3, 100] + self.C = 100 + self.layout = 'NHWC' + + def test_check_grad_stopgrad_dx(self): + return + + def test_check_grad_stopgrad_dscale_dbias(self): + return + + +class TestAffineChannel2D(TestAffineChannelOp): + def init_test_case(self): + self.shape = [2, 100] + self.C = 100 + self.layout = 'NCHW' + + def test_check_grad_stopgrad_dx(self): + return + + def test_check_grad_stopgrad_dscale_dbias(self): + return + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/xpu/test_roi_align_op_xpu.py b/python/paddle/fluid/tests/unittests/xpu/test_roi_align_op_xpu.py index 70f03edb6bac6e..2122223dbec1b4 100644 --- a/python/paddle/fluid/tests/unittests/xpu/test_roi_align_op_xpu.py +++ b/python/paddle/fluid/tests/unittests/xpu/test_roi_align_op_xpu.py @@ -20,13 +20,13 @@ import numpy as np import paddle.fluid.core as core from op_test import OpTest, skip_check_grad_ci +from op_test_xpu import XPUOpTest import paddle import paddle.fluid as fluid from paddle.fluid import Program, program_guard -@skip_check_grad_ci(reason="There is no grad kernel for roi_align_xpu kernel.") -class TestROIAlignOp(OpTest): +class TestROIAlignOp(XPUOpTest): def set_data(self): self.init_test_case() self.make_rois() @@ -59,16 +59,16 @@ def init_test_case(self): self.pooled_width = 2 self.sampling_ratio = -1 - self.x = np.random.random(self.x_dim).astype('float64') + self.x = np.random.random(self.x_dim).astype('float32') def pre_calc(self, x_i, roi_xmin, roi_ymin, roi_bin_grid_h, roi_bin_grid_w, bin_size_h, bin_size_w): count = roi_bin_grid_h * roi_bin_grid_w bilinear_pos = np.zeros( [self.channels, self.pooled_height, self.pooled_width, count, 4], - np.float64) + np.float32) bilinear_w = np.zeros( - [self.pooled_height, self.pooled_width, count, 4], np.float64) + [self.pooled_height, self.pooled_width, count, 4], np.float32) for ph in range(self.pooled_width): for pw in range(self.pooled_height): c = 0 @@ -118,7 +118,7 @@ def pre_calc(self, x_i, roi_xmin, roi_ymin, roi_bin_grid_h, roi_bin_grid_w, def calc_roi_align(self): self.out_data = np.zeros( (self.rois_num, self.channels, self.pooled_height, - self.pooled_width)).astype('float64') + self.pooled_width)).astype('float32') for i in range(self.rois_num): roi = self.rois[i] @@ -166,7 +166,7 @@ def make_rois(self): roi = [bno, x1, y1, x2, y2] rois.append(roi) self.rois_num = len(rois) - self.rois = np.array(rois).astype("float64") + self.rois = np.array(rois).astype("float32") def setUp(self): self.op_type = "roi_align" @@ -178,6 +178,12 @@ def test_check_output(self): place = paddle.XPUPlace(0) self.check_output_with_place(place) + def test_check_grad(self): + if core.is_compiled_with_xpu(): + paddle.enable_static() + place = paddle.XPUPlace(0) + self.check_grad_with_place(place, {'X'}, 'Out') + class TestROIAlignInLodOp(TestROIAlignOp): def set_data(self): diff --git a/python/paddle/hapi/dynamic_flops.py b/python/paddle/hapi/dynamic_flops.py index 382227ea832977..9e2f78b559f186 100644 --- a/python/paddle/hapi/dynamic_flops.py +++ b/python/paddle/hapi/dynamic_flops.py @@ -229,7 +229,7 @@ def add_hooks(m): else: if m_type not in types_collection: print( - "Cannot find suitable count function for {}. Treat it as zero Macs.". + "Cannot find suitable count function for {}. Treat it as zero FLOPs.". format(m_type)) if flops_fn is not None: @@ -256,9 +256,9 @@ def add_hooks(m): continue total_ops += m.total_ops total_params += m.total_params - - total_ops = int(total_ops) - total_params = int(total_params) + if hasattr(m, 'total_ops') and hasattr(m, 'total_params'): + total_ops = int(total_ops) + total_params = int(total_params) if training: model.train() diff --git a/python/paddle/nn/functional/pooling.py b/python/paddle/nn/functional/pooling.py index 50096f89d906a8..f02f673753bd7b 100755 --- a/python/paddle/nn/functional/pooling.py +++ b/python/paddle/nn/functional/pooling.py @@ -995,7 +995,7 @@ def adaptive_avg_pool2d(x, output_size, data_format='NCHW', name=None): # out.shape is [2, 3, 3, 3] """ if not in_dygraph_mode(): - check_variable_and_dtype(x, 'x', ['float32', 'float64'], + check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'], 'adaptive_avg_pool2d') check_type(data_format, 'data_format', str, 'adaptive_avg_pool2d')