diff --git a/cmake/flags.cmake b/cmake/flags.cmake index f90b71f9e60a8..5742a6b602ff3 100644 --- a/cmake/flags.cmake +++ b/cmake/flags.cmake @@ -244,3 +244,7 @@ if(WITH_ROCM) string (REPLACE "-Werror" "-Wno-error" CMAKE_C_FLAGS ${CMAKE_C_FLAGS}) endif() +if(WITH_PSCORE OR WITH_PSLIB) + string (REPLACE "-Wnon-virtual-dtor" "-Wno-non-virtual-dtor" CMAKE_CXX_FLAGS ${CMAKE_CXX_FLAGS}) + string (REPLACE "-Wnon-virtual-dtor" "-Wno-non-virtual-dtor" CMAKE_C_FLAGS ${CMAKE_C_FLAGS}) +endif() diff --git a/paddle/fluid/distributed/collective/CMakeLists.txt b/paddle/fluid/distributed/collective/CMakeLists.txt index 6fb805a72e4de..6d736d5543ce4 100644 --- a/paddle/fluid/distributed/collective/CMakeLists.txt +++ b/paddle/fluid/distributed/collective/CMakeLists.txt @@ -7,14 +7,14 @@ endif() if(WITH_NCCL) cc_library(processgroup_nccl SRCS ProcessGroupNCCL.cc NCCLTools.cc Common.cc DEPS place cuda_stream enforce collective_helper device_context phi phi_api eager_api) - if (WITH_DISTRIBUTE) + if (WITH_DISTRIBUTE AND WITH_PSCORE) cc_library(processgroup_heter SRCS ProcessGroupHeter.cc NCCLTools.cc Common.cc DEPS place cuda_stream enforce collective_helper device_context phi phi_api eager_api) endif() endif() if(WITH_ASCEND_CL) cc_library(processgroup_hccl SRCS ProcessGroupHCCL.cc HCCLTools.cc Common.cc DEPS place npu_stream enforce collective_helper device_context phi phi_api eager_api) - if (WITH_DISTRIBUTE) + if (WITH_DISTRIBUTE AND WITH_PSCORE) cc_library(processgroup_heter SRCS ProcessGroupHeter.cc HCCLTools.cc Common.cc DEPS place npu_stream enforce collective_helper device_context phi phi_api eager_api) endif() endif() diff --git a/paddle/fluid/distributed/collective/ProcessGroup.cc b/paddle/fluid/distributed/collective/ProcessGroup.cc index ab118dadd5d88..6da83a888683b 100644 --- a/paddle/fluid/distributed/collective/ProcessGroup.cc +++ b/paddle/fluid/distributed/collective/ProcessGroup.cc @@ -35,10 +35,10 @@ bool ProcessGroup::Task::Wait(std::chrono::milliseconds timeout) { void ProcessGroup::Task::Synchronize() {} ProcessGroup::ProcessGroup(int rank, int size, int gid) - : rank_(rank), size_(size) { + : rank_(rank), size_(size), gid_(gid) { if (gid != IGNORE_ID) { auto map = ProcessGroupMapFromGid::getInstance(); - map->insert(gid, this); + map->insert(gid_, this); } } diff --git a/paddle/fluid/distributed/collective/ProcessGroup.h b/paddle/fluid/distributed/collective/ProcessGroup.h index c2ad1aa2c93ea..17d021852671e 100644 --- a/paddle/fluid/distributed/collective/ProcessGroup.h +++ b/paddle/fluid/distributed/collective/ProcessGroup.h @@ -93,8 +93,8 @@ class ProcessGroup { } virtual void Broadcast(const phi::DenseTensor* in, phi::DenseTensor* out) { - PADDLE_THROW(platform::errors::InvalidArgument( - "ProcessGroup%s does not support broadcast for static", + PADDLE_THROW(platform::errors::Fatal( + "ProcessGroup%s does not support broadcast for static mode runtime", GetBackendName())); } @@ -148,6 +148,7 @@ class ProcessGroup { protected: const int rank_; const int size_; + const int gid_; }; class ProcessGroupMapFromGid { @@ -158,17 +159,20 @@ class ProcessGroupMapFromGid { } void insert(int gid, ProcessGroup* pg) { + // TODO(sandyhouse): address ut and uncomment the following codes // PADDLE_ENFORCE_EQ(has(gid), false, - // platform::errors::PreconditionNotMet( - // "The process group with id %d does exist.", gid)); + // platform::errors::PreconditionNotMet( + // "The process group with id %d doesnot exist.", + // gid)); map_[gid] = pg; } ProcessGroup* get(int gid) { + // TODO(sandyhouse): address ut and uncomment the following codes // PADDLE_ENFORCE_EQ(has(gid), true, - // platform::errors::PreconditionNotMet( - // "The process group with id %d doesnot exist.", - // gid)); + // platform::errors::PreconditionNotMet( + // "The process group with id %d doesnot exist.", + // gid)); return map_.find(gid)->second; } diff --git a/paddle/fluid/distributed/collective/ProcessGroupHCCL.cc b/paddle/fluid/distributed/collective/ProcessGroupHCCL.cc index b21155e09d06e..55945b5e0e396 100644 --- a/paddle/fluid/distributed/collective/ProcessGroupHCCL.cc +++ b/paddle/fluid/distributed/collective/ProcessGroupHCCL.cc @@ -30,12 +30,6 @@ constexpr int64_t kWaitBlockTImeout = 10; namespace paddle { namespace distributed { -// bool CheckTensorsInNPUPlace(const std::vector& tensors) { -// return std::all_of(tensors.cbegin(), tensors.cend(), [&](const Tensor& t) { -// return t.place() == platform::DeviceType::NPU; -// }); -// } - void SyncDefaultStream( const std::vector& places, std::vector& hcclEvents, // NOLINT diff --git a/paddle/fluid/distributed/collective/ProcessGroupHeter.cc b/paddle/fluid/distributed/collective/ProcessGroupHeter.cc index ffd653042494d..b3c9ddde50116 100644 --- a/paddle/fluid/distributed/collective/ProcessGroupHeter.cc +++ b/paddle/fluid/distributed/collective/ProcessGroupHeter.cc @@ -56,7 +56,8 @@ ProcessGroupHeter::ProcessGroupHeter(const std::shared_ptr& store, local_size_(local_size), gloo_rank_(gloo_rank), gloo_size_(gloo_size), - with_switch_(with_switch) { + with_switch_(with_switch), + switch_endpoint_(switch_endpoint) { #if defined(PADDLE_WITH_NCCL) inner_pg_ = std::make_shared(store, local_rank, local_size, IGNORE_ID); @@ -64,14 +65,10 @@ ProcessGroupHeter::ProcessGroupHeter(const std::shared_ptr& store, inner_pg_ = std::make_shared(store, local_rank, local_size, IGNORE_ID); #else - PADDLE_THROW(platform::errors::InvalidArgument( + PADDLE_THROW(platform::errors::Fatal( "ProcessGroupHeter only supports NCCL and HCCL now."); #endif - if (with_switch_) { - // TODO(sandyhouse) starts a client to connect the cloud switch module - // std::shared_ptr client_ = - // HeterClient::GetInstance({switch_endpoint}, {}, 0); - } else if (local_rank_ == 0) { + if (local_rank_ == 0 && !with_switch_) { auto opts = ProcessGroupGloo::GlooOptions::create(); opts->device = ProcessGroupGloo::createDefaultDevice(); inter_pg_ = std::make_shared(store, gloo_rank_, @@ -79,6 +76,15 @@ ProcessGroupHeter::ProcessGroupHeter(const std::shared_ptr& store, } } +template +static void _do_add(T* dst, T* src, size_t size) { + for (size_t i = 0; i < size; i++) { + *dst += *src; + dst++; + src++; + } +} + std::shared_ptr ProcessGroupHeter::AllReduce( std::vector& tensors, const AllreduceOptions& opts) { #if defined(PADDLE_WITH_NCCL) @@ -93,33 +99,92 @@ std::shared_ptr ProcessGroupHeter::AllReduce( // Step2: copy tensors to CPU if (local_rank_ == 0) { - std::vector cpu_tensors(tensors.size()); + std::vector cpu_tensors; + cpu_tensors.reserve(tensors.size()); for (size_t i = 0; i < tensors.size(); i++) { auto dense_gpu_tensor = std::dynamic_pointer_cast(tensors[i].impl()); - auto dense_cpu_tensor = - std::dynamic_pointer_cast(cpu_tensors[i].impl()); - dense_cpu_tensor->Resize(tensors[i].dims()); + phi::DenseTensorMeta meta = phi::DenseTensorMeta( + dense_gpu_tensor->dtype(), dense_gpu_tensor->dims()); + std::shared_ptr dense_cpu_tensor = + std::make_shared( + std::make_unique( + paddle::platform::CPUPlace()) + .get(), + meta); + dense_cpu_tensor->ResizeAndAllocate(dense_gpu_tensor->dims()); + cpu_tensors[i] = paddle::experimental::Tensor(dense_cpu_tensor); framework::TensorCopySync(*dense_gpu_tensor, platform::CPUPlace(), dense_cpu_tensor.get()); } // Step3: do inter cluster allreduce if (with_switch_) { - // TODO(sandyhouse) send to and recv from switch, and do add + if (local_rank_ == 0) { + HeterClient* client_ = + HeterClient::GetInstance({switch_endpoint_}, {}, 0).get(); + auto dense_cpu_tensor = + std::dynamic_pointer_cast(cpu_tensors[0].impl()); + std::vector send_size; + send_size.push_back(dense_cpu_tensor->numel()); + int ret = client_->Send( + gid_, {dense_cpu_tensor->name()}, send_size, + dense_cpu_tensor->data(), + dense_cpu_tensor->numel() * + framework::DataTypeSize(dense_cpu_tensor->dtype())); + PADDLE_ENFORCE_EQ(ret, 0, platform::errors::PreconditionNotMet( + "Send to the switch module error.")); + phi::DenseTensorMeta meta = phi::DenseTensorMeta( + dense_cpu_tensor->dtype(), dense_cpu_tensor->dims()); + std::shared_ptr dense_cpu_tensor2 = + std::make_shared( + std::make_unique( + paddle::platform::CPUPlace()) + .get(), + meta); + dense_cpu_tensor2->ResizeAndAllocate(dense_cpu_tensor->dims()); + Tensor cpu_tensor_temp = + paddle::experimental::Tensor(dense_cpu_tensor2); + ret = client_->Recv( + gid_, {dense_cpu_tensor->name()}, dense_cpu_tensor2->data(), + dense_cpu_tensor2->numel() * + framework::DataTypeSize(dense_cpu_tensor2->dtype())); + PADDLE_ENFORCE_EQ(ret, 0, platform::errors::PreconditionNotMet( + "Recv from the switch module error.")); + + switch (dense_cpu_tensor->dtype()) { + case DataType::FLOAT32: + _do_add(reinterpret_cast(dense_cpu_tensor->data()), + reinterpret_cast(dense_cpu_tensor2->data()), + dense_cpu_tensor->numel()); + break; + case DataType::FLOAT64: + _do_add( + reinterpret_cast(dense_cpu_tensor->data()), + reinterpret_cast(dense_cpu_tensor2->data()), + dense_cpu_tensor->numel()); + break; + case DataType::INT32: + _do_add(reinterpret_cast(dense_cpu_tensor->data()), + reinterpret_cast(dense_cpu_tensor2->data()), + dense_cpu_tensor->numel()); + break; + default: + PADDLE_THROW(platform::errors::PreconditionNotMet( + "Unsupported data type (%s) to do add.", + framework::DataType2String(dense_cpu_tensor->dtype()))); + } + } } else { auto gloo_task = inter_pg_->AllReduce(cpu_tensors, opts); gloo_task->Wait(); } // Step4: copy cpu tensors to gpu - // TODO(sandyhouse) // copy cpu tensors to gpu for (size_t i = 0; i < tensors.size(); i++) { auto dense_gpu_tensor = std::dynamic_pointer_cast(tensors[i].impl()); auto dense_cpu_tensor = std::dynamic_pointer_cast(cpu_tensors[i].impl()); - // framework::TensorCopySync(*dense_cpu_tensor, tensors[i].place(), - // dense_gpu_tensor.get()); framework::TensorCopySync(*dense_cpu_tensor, dense_cpu_tensor->place(), dense_gpu_tensor.get()); } @@ -147,18 +212,57 @@ std::shared_ptr ProcessGroupHeter::Broadcast( inner_pg_->Broadcast(tensors, b_opts); if (local_rank_ == 0) { - std::vector cpu_tensors(tensors.size()); + std::vector cpu_tensors; + cpu_tensors.reserve(tensors.size()); for (size_t i = 0; i < tensors.size(); i++) { auto dense_gpu_tensor = std::dynamic_pointer_cast(tensors[i].impl()); - auto dense_cpu_tensor = - std::dynamic_pointer_cast(cpu_tensors[i].impl()); - dense_cpu_tensor->Resize(tensors[i].dims()); + phi::DenseTensorMeta meta = phi::DenseTensorMeta( + dense_gpu_tensor->dtype(), dense_gpu_tensor->dims()); + std::shared_ptr dense_cpu_tensor = + std::make_shared( + std::make_unique( + paddle::platform::CPUPlace()) + .get(), + meta); + dense_cpu_tensor->ResizeAndAllocate(dense_gpu_tensor->dims()); + cpu_tensors[i] = paddle::experimental::Tensor(dense_cpu_tensor); framework::TensorCopySync(*dense_gpu_tensor, platform::CPUPlace(), dense_cpu_tensor.get()); } if (with_switch_) { - // TODO(sandyhouse) send to and recv + if (local_rank_ == 0) { + HeterClient* client_ = + HeterClient::GetInstance({switch_endpoint_}, {}, 0).get(); + auto dense_cpu_tensor = + std::dynamic_pointer_cast(cpu_tensors[0].impl()); + if (gloo_rank_ == 0) { + std::vector send_size; + send_size.push_back(dense_cpu_tensor->numel()); + int ret = client_->Send( + gid_, {dense_cpu_tensor->name()}, send_size, + dense_cpu_tensor->data(), + dense_cpu_tensor->numel() * + framework::DataTypeSize(dense_cpu_tensor->dtype())); + PADDLE_ENFORCE_EQ(ret, 0, platform::errors::PreconditionNotMet( + "Send to the switch module error.")); + } else { + int ret = client_->Recv( + gid_, {dense_cpu_tensor->name()}, dense_cpu_tensor->data(), + dense_cpu_tensor->numel() * + framework::DataTypeSize(dense_cpu_tensor->dtype())); + PADDLE_ENFORCE_EQ(ret, 0, + platform::errors::PreconditionNotMet( + "Receive from the switch module error.")); + ret = client_->Recv( + gid_, {dense_cpu_tensor->name()}, dense_cpu_tensor->data(), + dense_cpu_tensor->numel() * + framework::DataTypeSize(dense_cpu_tensor->dtype())); + PADDLE_ENFORCE_EQ(ret, 0, + platform::errors::PreconditionNotMet( + "Receive from the switch module error.")); + } + } } else { auto gloo_task = inter_pg_->Broadcast(cpu_tensors, opts); gloo_task->Wait(); @@ -168,8 +272,6 @@ std::shared_ptr ProcessGroupHeter::Broadcast( std::dynamic_pointer_cast(tensors[i].impl()); auto dense_cpu_tensor = std::dynamic_pointer_cast(cpu_tensors[i].impl()); - // framework::TensorCopySync(*dense_cpu_tensor, tensors[i].place(), - // dense_gpu_tensor.get()); framework::TensorCopySync(*dense_cpu_tensor, dense_cpu_tensor->place(), dense_gpu_tensor.get()); } @@ -185,22 +287,44 @@ void ProcessGroupHeter::Broadcast(const phi::DenseTensor* in, inner_pg_->Broadcast(in, out); if (local_rank_ == 0) { - Tensor cpu_tensor; - auto dense_cpu_tensor = - std::dynamic_pointer_cast(cpu_tensor.impl()); - dense_cpu_tensor->Resize(in->dims()); + phi::DenseTensorMeta meta = phi::DenseTensorMeta(in->dtype(), in->dims()); + std::shared_ptr dense_cpu_tensor = + std::make_shared( + std::make_unique( + paddle::platform::CPUPlace()) + .get(), + meta); + dense_cpu_tensor->ResizeAndAllocate(in->dims()); + Tensor cpu_tensor = paddle::experimental::Tensor(dense_cpu_tensor); framework::TensorCopySync(*in, platform::CPUPlace(), dense_cpu_tensor.get()); if (with_switch_) { - // TODO(sandyhouse) send to and recv + if (local_rank_ == 0) { + HeterClient* client_ = + HeterClient::GetInstance({switch_endpoint_}, {}, 0).get(); + if (gloo_rank_ == 0) { + std::vector send_size; + send_size.push_back(in->numel()); + int ret = client_->Send( + gid_, {in->name()}, send_size, dense_cpu_tensor->data(), + in->numel() * framework::DataTypeSize(in->dtype())); + PADDLE_ENFORCE_EQ(ret, 0, platform::errors::PreconditionNotMet( + "Send to the switch module error.")); + } else { + int ret = + client_->Recv(gid_, {in->name()}, dense_cpu_tensor->data(), + in->numel() * framework::DataTypeSize(in->dtype())); + PADDLE_ENFORCE_EQ(ret, 0, + platform::errors::PreconditionNotMet( + "Receive from the switch module error.")); + } + } } else { std::vector cpu_tensors = {cpu_tensor}; - // auto gloo_task = inter_pg_->Broadcast(cpu_tensors); - // gloo_task->Wait(); - inter_pg_->Broadcast(cpu_tensors); + auto gloo_task = inter_pg_->Broadcast(cpu_tensors); + gloo_task->Wait(); } - framework::TensorCopySync(*dense_cpu_tensor, dense_cpu_tensor->place(), - out); + framework::TensorCopySync(*dense_cpu_tensor, out->place(), out); } inner_pg_->Broadcast(out, out); } diff --git a/paddle/fluid/distributed/collective/ProcessGroupHeter.h b/paddle/fluid/distributed/collective/ProcessGroupHeter.h index 8a26adbea4d78..892dbb9369e8d 100644 --- a/paddle/fluid/distributed/collective/ProcessGroupHeter.h +++ b/paddle/fluid/distributed/collective/ProcessGroupHeter.h @@ -23,7 +23,6 @@ #include "paddle/fluid/distributed/collective/ProcessGroup.h" #include "paddle/fluid/distributed/collective/ProcessGroupGloo.h" -// #include "paddle/fluid/distributed/ps/service/heter_client.h" #include "paddle/fluid/platform/device_context.h" #ifdef PADDLE_WITH_GLOO @@ -48,6 +47,11 @@ #include "paddle/fluid/distributed/collective/ProcessGroupHCCL.h" #endif +#if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE) && \ + (defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_ASCEND_CL)) +#include "paddle/fluid/distributed/ps/service/heter_client.h" +#endif + #include "paddle/fluid/distributed/collective/Common.h" constexpr const char* HETER_BACKEND_NAME = "HETER_BACKEND"; @@ -108,6 +112,7 @@ class ProcessGroupHeter : public ProcessGroup { int gloo_rank_; int gloo_size_; bool with_switch_; + std::string switch_endpoint_; }; } // namespace distributed diff --git a/paddle/fluid/distributed/collective/ProcessGroupNCCL.cc b/paddle/fluid/distributed/collective/ProcessGroupNCCL.cc index 7c0752b5f367c..eeb5e3b397c10 100644 --- a/paddle/fluid/distributed/collective/ProcessGroupNCCL.cc +++ b/paddle/fluid/distributed/collective/ProcessGroupNCCL.cc @@ -226,6 +226,43 @@ std::shared_ptr ProcessGroupNCCL::Collective( return task; } +template +void ProcessGroupNCCL::Collective(const phi::DenseTensor* in, + phi::DenseTensor* out, Fn fn, + CommType op_type) { + std::vector places; + places.push_back(in->place()); + const auto key = GetKeyFromPlaces(places); + + { + std::lock_guard lock(mutex_); + if (places_to_ncclcomm_.find(key) == places_to_ncclcomm_.end()) { + CreateNCCLManagerCache(key, places); + } + } + + auto& nccl_comms = places_to_ncclcomm_[key]; + + SyncDefaultStream(places, places_to_events_[key], places_to_ctx_[key]); + + // construct uninitialize guard for device + platform::CUDADeviceGuard cuda_guard; + + if (FLAGS_use_stream_safe_cuda_allocator) { + cuda_guard.SetDevice(places[0]); + memory::RecordStream(in->Holder(), places_to_ctx_[key][0]->stream()); + } + + { + platform::NCCLGroupGuard nccl_guard; + cuda_guard.SetDevice(places[0]); + const auto& nccl_stream = places_to_ctx_[key][0]->stream(); + fn(in, out, nccl_comms[0]->GetNcclComm(), nccl_stream); + } + + cuda_guard.SetDevice(places[0]); +} + template std::shared_ptr ProcessGroupNCCL::PointToPoint( std::vector& tensors, Fn fn, int dst_rank, CommType op_type) { diff --git a/paddle/fluid/distributed/collective/ProcessGroupNCCL.h b/paddle/fluid/distributed/collective/ProcessGroupNCCL.h index 4ab5374dacaf4..fa73ed195b0c1 100644 --- a/paddle/fluid/distributed/collective/ProcessGroupNCCL.h +++ b/paddle/fluid/distributed/collective/ProcessGroupNCCL.h @@ -146,6 +146,10 @@ class ProcessGroupNCCL : public ProcessGroup { std::vector& outputs, // NOLINT Fn fn, CommType op_type); + template + void Collective(const phi::DenseTensor*, phi::DenseTensor*, Fn fn, + CommType op_type); + template std::shared_ptr PointToPoint( std::vector& tensors, // NOLINT diff --git a/paddle/fluid/operators/collective/c_broadcast_op.cu.cc b/paddle/fluid/operators/collective/c_broadcast_op.cu.cc index 0ad61bb16b51e..7bdf5f0c46ca6 100644 --- a/paddle/fluid/operators/collective/c_broadcast_op.cu.cc +++ b/paddle/fluid/operators/collective/c_broadcast_op.cu.cc @@ -37,7 +37,6 @@ class CBroadcastOpCUDAKernel : public framework::OpKernel { int rid = ctx.Attr("ring_id"); auto place = ctx.GetPlace(); - auto comm = platform::NCCLCommContext::Instance().Get(rid, place); auto map = distributed::ProcessGroupMapFromGid::getInstance(); if (map->has(rid)) { // Use ProcessGroup @@ -46,6 +45,7 @@ class CBroadcastOpCUDAKernel : public framework::OpKernel { return; } + auto comm = platform::NCCLCommContext::Instance().Get(rid, place); gpuStream_t stream = nullptr; if (ctx.Attr("use_calc_stream")) { auto dev_ctx = platform::DeviceContextPool::Instance().Get(place); diff --git a/paddle/fluid/pybind/CMakeLists.txt b/paddle/fluid/pybind/CMakeLists.txt index b190f429410f4..f8e7081de01bd 100644 --- a/paddle/fluid/pybind/CMakeLists.txt +++ b/paddle/fluid/pybind/CMakeLists.txt @@ -91,12 +91,18 @@ if(NOT ON_INFER) set (PYBIND_DEPS ${PYBIND_DEPS} processgroup eager_reducer) if (WITH_NCCL) set (PYBIND_DEPS ${PYBIND_DEPS} processgroup_nccl) + if (WITH_PSCORE) + set (PYBIND_DEPS ${PYBIND_DEPS} processgroup_heter) + endif() endif() if (WITH_GLOO) set (PYBIND_DEPS ${PYBIND_DEPS} processgroup_gloo) endif() if(WITH_ASCEND_CL) set (PYBIND_DEPS ${PYBIND_DEPS} processgroup_hccl) + if (WITH_PSCORE) + set (PYBIND_DEPS ${PYBIND_DEPS} processgroup_heter) + endif() endif() set(PYBIND_SRCS ${PYBIND_SRCS} distributed_py.cc) endif() diff --git a/paddle/fluid/pybind/distributed_py.cc b/paddle/fluid/pybind/distributed_py.cc index 6c74ea2eef4d0..38ed1d4f2bb5d 100644 --- a/paddle/fluid/pybind/distributed_py.cc +++ b/paddle/fluid/pybind/distributed_py.cc @@ -39,6 +39,11 @@ limitations under the License. */ #include "paddle/fluid/distributed/collective/ProcessGroupHCCL.h" #endif +#if defined(PADDLE_WITH_GLOO) && defined(PADDLE_WITH_PSCORE) && \ + (defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_ASCEND_CL)) +#include "paddle/fluid/distributed/collective/ProcessGroupHeter.h" +#endif + #if defined(PADDLE_WITH_GLOO) #include "paddle/fluid/distributed/collective/ProcessGroupGloo.h" #include "paddle/fluid/distributed/store/tcp_store.h" @@ -217,6 +222,21 @@ void BindDistributed(py::module *m) { int>(), py::arg("store"), py::arg("rank"), py::arg("world_size"), py::arg("group_id") = 0, py::call_guard()); + +#if defined(PADDLE_WITH_GLOO) && defined(PADDLE_WITH_PSCORE) && \ + (defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_ASCEND_CL)) + py::class_>( + *m, "ProcessGroupHeter", ProcessGroup) + .def(py::init &, int, int, int, + int, int, int, int, bool, std::string>(), + py::arg("store"), py::arg("rank"), py::arg("world_size"), + py::arg("gid") = 0, py::arg("local_rank") = 0, + py::arg("local_size") = 1, py::arg("gloo_rank") = 0, + py::arg("gloo_size") = 1, py::arg("with_switch") = false, + py::arg("switch_endpoint") = "", + py::call_guard()); +#endif #endif #if defined(PADDLE_WITH_ASCEND_CL) @@ -227,6 +247,21 @@ void BindDistributed(py::module *m) { int>(), py::arg("store"), py::arg("rank"), py::arg("world_size"), py::arg("group_id") = 0, py::call_guard()); + +#if defined(PADDLE_WITH_GLOO) && defined(PADDLE_WITH_PSCORE) && \ + (defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_ASCEND_CL)) + py::class_>( + *m, "ProcessGroupHeter", ProcessGroup) + .def(py::init &, int, int, int, + int, int, int, int, bool, std::string>(), + py::arg("store"), py::arg("rank"), py::arg("world_size"), + py::arg("gid") = 0, py::arg("local_rank") = 0, + py::arg("local_size") = 1, py::arg("gloo_rank") = 0, + py::arg("gloo_rank") = 1, py::arg("with_switch") = false, + py::arg("switch_endpoint") = "", + py::call_guard()); +#endif #endif py::class_