-
Notifications
You must be signed in to change notification settings - Fork 5.6k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
ee18280
commit bd0bc0a
Showing
9 changed files
with
1,417 additions
and
35 deletions.
There are no files selected for viewing
144 changes: 144 additions & 0 deletions
144
paddle/fluid/operators/collective/global_hierarchy_gather_op.cc
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,144 @@ | ||
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. | ||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
http://www.apache.org/licenses/LICENSE-2.0 | ||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. */ | ||
|
||
#include "paddle/fluid/operators/collective/global_hierarchy_gather_op.h" | ||
|
||
namespace paddle { | ||
namespace operators { | ||
|
||
class GlobalHierarchyGatherOp : public framework::OperatorWithKernel { | ||
public: | ||
using framework::OperatorWithKernel::OperatorWithKernel; | ||
|
||
void InferShape(framework::InferShapeContext* ctx) const override { | ||
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "GlobalHierarchyGather"); | ||
OP_INOUT_CHECK(ctx->HasInput("local_count"), "Input", "local_count", | ||
"GlobalHierarchyGather"); | ||
|
||
OP_INOUT_CHECK(ctx->HasInput("mp_global_count"), "Input", "mp_global_count", | ||
"GlobalHierarchyGather"); | ||
|
||
OP_INOUT_CHECK(ctx->HasInput("dp_global_count"), "Input", "dp_global_count", | ||
"GlobalHierarchyGather"); | ||
|
||
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", | ||
"GlobalHierarchyGather"); | ||
int inside_ring_id = ctx->Attrs().Get<int>("inside_ring_id"); | ||
PADDLE_ENFORCE_GE(inside_ring_id, 0, | ||
platform::errors::InvalidArgument( | ||
"The inside_ring_id (%d) for global hierarchy gather " | ||
"op must be non-negative.", | ||
inside_ring_id)); | ||
|
||
int outside_ring_id = ctx->Attrs().Get<int>("outside_ring_id"); | ||
PADDLE_ENFORCE_GE(outside_ring_id, 0, | ||
platform::errors::InvalidArgument( | ||
"The outside_ring_id (%d) for global hierarchy " | ||
"gather op must be non-negative.", | ||
outside_ring_id)); | ||
|
||
auto input_dims = ctx->GetInputDim("X"); | ||
auto ndim_input = input_dims.size(); | ||
// dim check | ||
PADDLE_ENFORCE_EQ(ndim_input, 2, | ||
platform::errors::InvalidArgument( | ||
"The input tensor's dimension must be 2. " | ||
"But received input's dimension = %d.", | ||
ndim_input)); | ||
framework::DDim out_dims = framework::make_ddim({-1, -1}); | ||
ctx->SetOutputDim("Out", out_dims); | ||
} | ||
|
||
protected: | ||
framework::OpKernelType GetExpectedKernelType( | ||
const framework::ExecutionContext& ctx) const override { | ||
return framework::OpKernelType( | ||
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); | ||
} | ||
}; | ||
|
||
class GlobalHierarchyGatherOpMaker : public framework::OpProtoAndCheckerMaker { | ||
public: | ||
void Make() { | ||
AddInput("X", "(Tensor) tensor send."); | ||
AddInput("local_count", | ||
"(Tensor) Tensor which has n_expert * world_size elements that " | ||
"indicates" | ||
"how many data needed to be received from each expert."); | ||
|
||
AddInput("mp_global_count", | ||
"(Tensor) Tensor which has n_expert * world_size elements that " | ||
"indicates" | ||
"how many data needed to be received from each expert."); | ||
|
||
AddInput("dp_global_count", | ||
"(Tensor) Tensor which has n_expert * world_size elements that " | ||
"indicates" | ||
"how many data needed to be received from each expert."); | ||
|
||
AddOutput("Out", "(Tensor) the result of global_hierarchy_gather."); | ||
AddAttr<int>("inside_ring_id", | ||
"(int default 0) nccl communication ring id.") | ||
.SetDefault(0); | ||
|
||
AddAttr<int>("outside_ring_id", | ||
"(int default 0) nccl communication ring id.") | ||
.SetDefault(0); | ||
AddAttr<bool>( | ||
"use_calc_stream", | ||
"(bool default false) eject CUDA operations to calculation stream.") | ||
.SetDefault(false); | ||
AddComment(R"DOC( | ||
Global Hierarchy Gather Operator | ||
refer to Global Hierarchy Scatter. | ||
)DOC"); | ||
} | ||
}; | ||
|
||
template <typename T> | ||
class GlobalHierarchyGatherOpGradMaker | ||
: public framework::SingleGradOpMaker<T> { | ||
public: | ||
using framework::SingleGradOpMaker<T>::SingleGradOpMaker; | ||
|
||
protected: | ||
void Apply(GradOpPtr<T> retv) const override { | ||
retv->SetType("global_hierarchy_scatter"); | ||
retv->SetInput("X", this->OutputGrad("Out")); | ||
retv->SetInput("local_count", this->Input("local_count")); | ||
retv->SetInput("mp_global_count", this->Input("mp_global_count")); | ||
retv->SetInput("dp_global_count", this->Input("dp_global_count")); | ||
retv->SetInput("outside_ring_id", this->Input("outside_ring_id")); | ||
retv->SetOutput("Out", this->InputGrad("X")); | ||
retv->SetAttrMap(this->Attrs()); | ||
} | ||
}; | ||
|
||
} // namespace operators | ||
} // namespace paddle | ||
|
||
namespace ops = paddle::operators; | ||
namespace plat = paddle::platform; | ||
REGISTER_OPERATOR( | ||
global_hierarchy_gather, ops::GlobalHierarchyGatherOp, | ||
ops::GlobalHierarchyGatherOpMaker, | ||
ops::GlobalHierarchyGatherOpGradMaker<paddle::framework::OpDesc>, | ||
ops::GlobalHierarchyGatherOpGradMaker<paddle::imperative::OpBase>) | ||
|
||
REGISTER_OP_CPU_KERNEL(global_hierarchy_gather, | ||
ops::GlobalHierarchyGatherOpCPUKernel<float>, | ||
ops::GlobalHierarchyGatherOpCPUKernel<double>, | ||
ops::GlobalHierarchyGatherOpCPUKernel<int>, | ||
ops::GlobalHierarchyGatherOpCPUKernel<int64_t>, | ||
ops::GlobalHierarchyGatherOpCPUKernel<plat::float16>); |
204 changes: 204 additions & 0 deletions
204
paddle/fluid/operators/collective/global_hierarchy_gather_op.cu.cc
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,204 @@ | ||
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. | ||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
http://www.apache.org/licenses/LICENSE-2.0 | ||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. */ | ||
|
||
#include "paddle/fluid/operators/collective/global_hierarchy_gather_op.h" | ||
|
||
#if defined(PADDLE_WITH_NCCL) | ||
#include "paddle/fluid/platform/collective_helper.h" | ||
#include "paddle/fluid/platform/device/gpu/nccl_helper.h" | ||
#endif | ||
|
||
namespace paddle { | ||
namespace operators { | ||
using framework::Tensor; | ||
template <typename DeviceContext, typename T> | ||
class GlobalHierarchyGatherOpCUDAKernel : public framework::OpKernel<T> { | ||
public: | ||
void Compute(const framework::ExecutionContext& ctx) const override { | ||
#if defined(PADDLE_WITH_NCCL) | ||
#if NCCL_VERSION_CODE >= 2703 | ||
auto x = ctx.Input<framework::LoDTensor>("X"); | ||
auto local_count = ctx.Input<framework::LoDTensor>("local_count"); | ||
auto mp_global_count = ctx.Input<framework::LoDTensor>("mp_global_count"); | ||
auto dp_global_count = ctx.Input<framework::LoDTensor>("dp_global_count"); | ||
|
||
auto local_count_type = local_count->type(); | ||
if (local_count_type != framework::proto::VarType::INT64) { | ||
PADDLE_THROW(platform::errors::InvalidArgument( | ||
"Please use int64 type in local_count.")); | ||
} | ||
auto out = ctx.Output<framework::LoDTensor>("Out"); | ||
const int64_t* cpu_local_count_data; | ||
const int64_t* cpu_mp_global_count_data; | ||
const int64_t* cpu_dp_global_count_data; | ||
|
||
framework::Tensor cpu_local_count, cpu_mp_global_count, cpu_dp_global_count; | ||
if (platform::is_cpu_place(local_count->place())) { | ||
cpu_local_count_data = local_count->data<int64_t>(); | ||
} else { | ||
framework::TensorCopySync(*local_count, platform::CPUPlace(), | ||
&cpu_local_count); | ||
cpu_local_count_data = cpu_local_count.data<int64_t>(); | ||
} | ||
|
||
if (platform::is_cpu_place(mp_global_count->place())) { | ||
cpu_mp_global_count_data = mp_global_count->data<int64_t>(); | ||
} else { | ||
framework::TensorCopySync(*mp_global_count, platform::CPUPlace(), | ||
&cpu_mp_global_count); | ||
cpu_mp_global_count_data = cpu_mp_global_count.data<int64_t>(); | ||
} | ||
|
||
if (platform::is_cpu_place(dp_global_count->place())) { | ||
cpu_dp_global_count_data = dp_global_count->data<int64_t>(); | ||
} else { | ||
framework::TensorCopySync(*dp_global_count, platform::CPUPlace(), | ||
&cpu_dp_global_count); | ||
cpu_dp_global_count_data = cpu_dp_global_count.data<int64_t>(); | ||
} | ||
|
||
ncclDataType_t dtype = platform::ToNCCLDataType(x->type()); | ||
|
||
int inside_ring_id = ctx.Attr<int>("inside_ring_id"); | ||
PADDLE_ENFORCE_GE( | ||
inside_ring_id, 0, | ||
platform::errors::InvalidArgument("The inside_ring_id (%d) for global " | ||
"gather op must be non-negative.", | ||
inside_ring_id)); | ||
|
||
int outside_ring_id = ctx.Attr<int>("outside_ring_id"); | ||
PADDLE_ENFORCE_GE( | ||
outside_ring_id, 0, | ||
platform::errors::InvalidArgument("The outside_ring_id (%d) for global " | ||
"gather op must be non-negative.", | ||
outside_ring_id)); | ||
|
||
auto place = ctx.GetPlace(); | ||
auto inside_comm = | ||
platform::NCCLCommContext::Instance().Get(inside_ring_id, place); | ||
auto outside_comm = | ||
platform::NCCLCommContext::Instance().Get(outside_ring_id, place); | ||
|
||
cudaStream_t stream = nullptr; | ||
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>(); | ||
|
||
if (ctx.Attr<bool>("use_calc_stream")) { | ||
stream = dev_ctx.stream(); | ||
} else { | ||
stream = inside_comm->stream(); | ||
} | ||
int inside_nranks = inside_comm->nranks(); | ||
int outside_nranks = outside_comm->nranks(); | ||
auto in_feat = x->dims()[1]; | ||
auto n_expert = local_count->dims()[0] / inside_nranks / outside_nranks; | ||
auto inside_all_experts = n_expert * inside_nranks; | ||
|
||
// Step1: outside all_to_all | ||
int dp_fwd_count = 0; | ||
for (auto i = 0; i < cpu_mp_global_count.numel(); ++i) { | ||
dp_fwd_count += cpu_mp_global_count_data[i]; | ||
} | ||
Tensor mp_global_res; | ||
mp_global_res = ctx.AllocateTmpTensor<T, DeviceContext>( | ||
{dp_fwd_count, in_feat}, dev_ctx); | ||
auto outside_send_ptr = 0; | ||
auto outside_recv_ptr = 0; | ||
auto outside_send_buf = x->data<T>(); | ||
auto outside_recv_buf = mp_global_res.mutable_data<T>(place); | ||
|
||
int recv_src = 0; | ||
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupStart()); | ||
for (auto i = 0; i < inside_all_experts; ++i) { | ||
for (auto j = 0; j < outside_nranks; ++j) { | ||
int idx = j + i * outside_nranks; | ||
if (cpu_dp_global_count_data[idx]) { | ||
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclSend( | ||
outside_send_buf + outside_send_ptr * in_feat, | ||
cpu_dp_global_count_data[idx] * in_feat, dtype, j, | ||
outside_comm->comm(), stream)); | ||
outside_send_ptr += cpu_dp_global_count_data[idx]; | ||
} | ||
if (cpu_mp_global_count_data[idx]) { | ||
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclRecv( | ||
outside_recv_buf + outside_recv_ptr * in_feat, | ||
cpu_mp_global_count_data[idx] * in_feat, dtype, | ||
recv_src / inside_all_experts, outside_comm->comm(), stream)); | ||
outside_recv_ptr += cpu_mp_global_count_data[idx]; | ||
} | ||
recv_src++; | ||
} | ||
} | ||
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupEnd()); | ||
|
||
// Step2: inside all_to_all | ||
int mp_fwd_count = 0; | ||
for (auto i = 0; i < cpu_local_count.numel(); ++i) { | ||
mp_fwd_count += cpu_local_count_data[i]; | ||
} | ||
|
||
auto inside_send_ptr = 0; | ||
auto inside_recv_ptr = 0; | ||
auto inside_send_buf = mp_global_res.data<T>(); | ||
auto inside_recv_buf = out->mutable_data<T>({mp_fwd_count, in_feat}, place); | ||
|
||
recv_src = 0; | ||
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupStart()); | ||
for (auto j = 0; j < outside_nranks * n_expert; ++j) { | ||
for (auto k = 0; k < inside_nranks; ++k) { | ||
int idx = k + j * inside_nranks; | ||
if (cpu_mp_global_count_data[idx]) { | ||
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclSend( | ||
inside_send_buf + inside_send_ptr * in_feat, | ||
cpu_mp_global_count_data[idx] * in_feat, dtype, k, | ||
inside_comm->comm(), stream)); | ||
inside_send_ptr += cpu_mp_global_count_data[idx]; | ||
} | ||
if (cpu_local_count_data[idx]) { | ||
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclRecv( | ||
inside_recv_buf + inside_recv_ptr * in_feat, | ||
cpu_local_count_data[idx] * in_feat, dtype, | ||
recv_src / n_expert % inside_nranks, inside_comm->comm(), | ||
stream)); | ||
inside_recv_ptr += cpu_local_count_data[idx]; | ||
} | ||
recv_src++; | ||
} | ||
} | ||
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupEnd()); | ||
|
||
#else | ||
PADDLE_THROW( | ||
platform::errors::Unavailable("NCCL version >= 2.7.3 is needed.")); | ||
#endif | ||
#else | ||
PADDLE_THROW( | ||
platform::errors::Unavailable("PaddlePaddle should compile with GPU.")); | ||
#endif | ||
} | ||
}; | ||
|
||
} // namespace operators | ||
} // namespace paddle | ||
|
||
namespace ops = paddle::operators; | ||
namespace plat = paddle::platform; | ||
using GPUCtx = paddle::platform::CUDADeviceContext; | ||
|
||
REGISTER_OP_CUDA_KERNEL( | ||
global_hierarchy_gather, | ||
ops::GlobalHierarchyGatherOpCUDAKernel<GPUCtx, float>, | ||
ops::GlobalHierarchyGatherOpCUDAKernel<GPUCtx, double>, | ||
ops::GlobalHierarchyGatherOpCUDAKernel<GPUCtx, int>, | ||
ops::GlobalHierarchyGatherOpCUDAKernel<GPUCtx, int64_t>, | ||
ops::GlobalHierarchyGatherOpCUDAKernel<GPUCtx, plat::float16>); |
37 changes: 37 additions & 0 deletions
37
paddle/fluid/operators/collective/global_hierarchy_gather_op.h
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. | ||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
http://www.apache.org/licenses/LICENSE-2.0 | ||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. */ | ||
|
||
#pragma once | ||
#include "paddle/fluid/framework/data_type.h" | ||
#include "paddle/fluid/framework/lod_tensor.h" | ||
#include "paddle/fluid/framework/op_registry.h" | ||
|
||
#if defined(PADDLE_WITH_GLOO) | ||
#include "paddle/fluid/framework/fleet/gloo_wrapper.h" | ||
#endif | ||
|
||
namespace paddle { | ||
namespace operators { | ||
|
||
template <typename T> | ||
class GlobalHierarchyGatherOpCPUKernel : public framework::OpKernel<T> { | ||
public: | ||
void Compute(const framework::ExecutionContext& ctx) const override { | ||
PADDLE_THROW(platform::errors::Unavailable( | ||
"Do not support global hierarchy gather op for cpu kernel now.")); | ||
} | ||
}; | ||
|
||
} // namespace operators | ||
} // namespace paddle |
Oops, something went wrong.