Skip to content

Commit

Permalink
[CustomDevice] GetCCLComm add custom device support (#47168)
Browse files Browse the repository at this point in the history
* [CustomDevice] GetCCLComm add custom device support

* update

* update

* update
  • Loading branch information
ronny1996 authored Oct 31, 2022
1 parent 520adc0 commit 34d13d6
Show file tree
Hide file tree
Showing 7 changed files with 96 additions and 31 deletions.
9 changes: 2 additions & 7 deletions paddle/fluid/distributed/collective/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -86,11 +86,6 @@ if(WITH_CUSTOM_DEVICE)
cc_library(
processgroup_custom
SRCS ProcessGroupCustom.cc CustomCCLTools.cc Common.cc
DEPS phi_backends
place
enforce
collective_helper
device_context
phi_api
eager_api)
DEPS processgroup phi_backends place enforce collective_helper
device_context)
endif()
18 changes: 14 additions & 4 deletions paddle/fluid/distributed/collective/ProcessGroupCustom.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
#include "paddle/fluid/memory/malloc.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/phi/api/include/api.h"
#include "paddle/phi/common/place.h"

DECLARE_bool(xccl_blocking_wait);
Expand Down Expand Up @@ -386,15 +385,26 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Barrier(

for (auto& place : places) {
phi::DeviceGuard guard(place);
auto dt = full({1}, 0, phi::DataType::FLOAT32, place);
barrierTensors.push_back(
*std::dynamic_pointer_cast<phi::DenseTensor>(dt.impl()));
phi::DenseTensorMeta meta(phi::DataType::FLOAT32, phi::DDim({1}));
auto allocator = std::unique_ptr<phi::Allocator>(
new paddle::experimental::DefaultAllocator(place));
barrierTensors.emplace_back(allocator.get(), meta);
}
auto task = ProcessGroupCustom::AllReduce(barrierTensors, barrierTensors);
auto xccl_task = dynamic_cast<ProcessGroupCustom::CustomTask*>(task.get());
xccl_task->barrierTensors_ = std::move(barrierTensors);
return task;
}

phi::ccl::CCLComm ProcessGroupCustom::CustomCCLComm(const Place& place) const {
std::vector<Place> places = {place};
const auto& iter = places_to_customcomm_.find(GetKeyFromPlaces(places));
PADDLE_ENFORCE_NE(iter,
places_to_customcomm_.end(),
platform::errors::InvalidArgument(
"Cannot find nccl comm in process group."));
return iter->second[0]->GetCustomCCLComm();
}

} // namespace distributed
} // namespace paddle
2 changes: 2 additions & 0 deletions paddle/fluid/distributed/collective/ProcessGroupCustom.h
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,8 @@ class ProcessGroupCustom : public ProcessGroup {
std::shared_ptr<ProcessGroup::Task> Barrier(
const BarrierOptions& = BarrierOptions()) override;

phi::ccl::CCLComm CustomCCLComm(const Place& place) const;

protected:
virtual std::shared_ptr<ProcessGroupCustom::CustomTask> CreateTask(
std::vector<Place> places,
Expand Down
12 changes: 12 additions & 0 deletions paddle/phi/backends/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -58,3 +58,15 @@ if(WITH_CUSTOM_DEVICE)
SRCS custom/capi_test.cc
DEPS phi_capi)
endif()

set(COMM_UTILS_DEPS processgroup)
if(WITH_NCCL OR WITH_RCCL)
set(COMM_UTILS_DEPS ${PROCESS_GROUP_UTILS_DEPS} processgroup_nccl)
endif()
if(WITH_CUSTOM_DEVICE)
set(COMM_UTILS_DEPS ${PROCESS_GROUP_UTILS_DEPS} processgroup_custom)
endif()
cc_library(
processgroup_comm_utils
SRCS processgroup_comm_utils.cc
DEPS ${COMM_UTILS_DEPS})
65 changes: 65 additions & 0 deletions paddle/phi/backends/processgroup_comm_utils.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/distributed/collective/ProcessGroup.h"
#include "paddle/phi/backends/c_comm_lib.h"
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
#include "paddle/fluid/distributed/collective/ProcessGroupNCCL.h"
#endif
#if defined(PADDLE_WITH_CUSTOM_DEVICE)
#include "paddle/fluid/distributed/collective/ProcessGroupCustom.h"
#endif

namespace phi {
namespace detail {

// FIXME(paddle-dev): Since the singleton of ProcessGroup in fluid is used in
// SyncBN, the fluid symbol will be dependent on external hardware access.
// Here, the part that depends on the fluid symbol is individually encapsulated
// as a temporary function to isolate external symbol dependencies.
// In the future, the dependence on the singleton in fluid in SyncBN needs
// to be removed.
// In principle, the PHI Kernel cannot use the global singleton internally,
// and the required members need to be passed in from the eucalyptus tree.
ccl::CCLComm GetCCLComm(const Place& place, int global_gid) {
paddle::distributed::ProcessGroup* pg = nullptr;
if (paddle::distributed::ProcessGroupMapFromGid::getInstance()->has(
global_gid)) {
pg = paddle::distributed::ProcessGroupMapFromGid::getInstance()->get(
global_gid);
} else {
return nullptr;
}

if (paddle::platform::is_gpu_place(place)) {
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
return static_cast<paddle::distributed::ProcessGroupNCCL*>(pg)->NCCLComm(
place);
#else
return nullptr;
#endif
} else if (paddle::platform::is_custom_place(place)) {
#if defined(PADDLE_WITH_CUSTOM_DEVICE)
return static_cast<paddle::distributed::ProcessGroupCustom*>(pg)
->CustomCCLComm(place);
#else
return nullptr;
#endif
} else {
return nullptr;
}
}

} // namespace detail
} // namespace phi
1 change: 1 addition & 0 deletions paddle/phi/kernels/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ set(COMMON_KERNEL_DEPS ${COMMON_KERNEL_DEPS} processgroup)
if(WITH_NCCL OR WITH_RCCL)
set(COMMON_KERNEL_DEPS ${COMMON_KERNEL_DEPS} processgroup_nccl)
endif()
set(COMMON_KERNEL_DEPS ${COMMON_KERNEL_DEPS} processgroup_comm_utils)

copy_if_different(${kernel_declare_file} ${kernel_declare_file_final})

Expand Down
20 changes: 0 additions & 20 deletions paddle/phi/kernels/gpu/sync_batch_norm_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -18,26 +18,6 @@
#include "paddle/phi/kernels/gpu/sync_batch_norm_utils.h"

namespace phi {
namespace detail {

ccl::CCLComm GetCCLComm(const Place &place, int global_gid) {
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
ncclComm_t comm = nullptr;

if (paddle::distributed::ProcessGroupMapFromGid::getInstance()->has(
global_gid)) {
auto *nccl_pg = static_cast<paddle::distributed::ProcessGroupNCCL *>(
paddle::distributed::ProcessGroupMapFromGid::getInstance()->get(
global_gid));
comm = nccl_pg->NCCLComm(place);
}
return comm;
#else
return nullptr;
#endif
}

} // namespace detail

template <typename T, typename Context>
void SyncBatchNormKernel(const Context &ctx,
Expand Down

0 comments on commit 34d13d6

Please sign in to comment.