Skip to content

Commit

Permalink
Fix
Browse files Browse the repository at this point in the history
  • Loading branch information
co63oc committed Aug 27, 2024
1 parent a25a2d5 commit 12155d1
Show file tree
Hide file tree
Showing 12 changed files with 187 additions and 109 deletions.
10 changes: 0 additions & 10 deletions paddle/fluid/operators/collective/mp_allreduce_sum_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -82,13 +82,3 @@ REGISTER_OPERATOR(mp_allreduce_sum,
ops::MpAllReduceSumOpGradMaker<paddle::imperative::OpBase>,
ops::MpAllReduceSumOpMaker,
ops::MpAllReduceSumInplaceInferer);

PD_REGISTER_STRUCT_KERNEL(mp_allreduce_sum,
CPU,
ALL_LAYOUT,
ops::MpAllReduceSumCPUKernel,
float,
double,
int,
int64_t,
phi::dtype::float16) {}
39 changes: 0 additions & 39 deletions paddle/fluid/operators/collective/mp_allreduce_sum_op.cu.cc

This file was deleted.

29 changes: 0 additions & 29 deletions paddle/fluid/operators/collective/mp_allreduce_sum_op.kps

This file was deleted.

31 changes: 0 additions & 31 deletions paddle/fluid/operators/collective/mp_allreduce_sum_op_xpu.cc

This file was deleted.

2 changes: 2 additions & 0 deletions paddle/fluid/pir/dialect/op_generator/ops_api_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,8 @@
'nce',
'lrn',
'max_pool2d_v2',
'mp_allreduce_sum',
'mp_allreduce_sum_',
'partial_sum',
'pull_gpups_sparse',
'pull_gpups_sparse_',
Expand Down
37 changes: 37 additions & 0 deletions paddle/phi/kernels/cpu/mp_allreduce_sum_kernel.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/all_reduce_kernel.h"

namespace phi {
template <typename T, typename Context>
void MpAllReduceSumKernel(const Context& dev_ctx,
const DenseTensor& x,
int ring_id UNUSED,
bool use_calc_stream UNUSED,
DenseTensor* out) {
AllReduceKernel<T, Context>(
dev_ctx, x, static_cast<int>(ReduceType::kRedSum), out);
}
} // namespace phi
PD_REGISTER_KERNEL(mp_allreduce_sum,
CPU,
ALL_LAYOUT,
phi::MpAllReduceSumKernel,
float,
double,
int,
int64_t,
phi::dtype::float16) {}
52 changes: 52 additions & 0 deletions paddle/phi/kernels/gpu/mp_allreduce_sum_kernel.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/all_reduce_kernel.h"

namespace phi {
template <typename T, typename Context>
void MpAllReduceSumKernel(const Context& dev_ctx,
const DenseTensor& x,
int ring_id UNUSED,
bool use_calc_stream UNUSED,
DenseTensor* out) {
AllReduceKernel<T, Context>(
dev_ctx, x, static_cast<int>(ReduceType::kRedSum), out);
}
} // namespace phi

#if (NCCL_VERSION_CODE >= 21000 && CUDA_VERSION >= 11000) || \
defined(PADDLE_WITH_HIP)
PD_REGISTER_KERNEL(mp_allreduce_sum,
GPU,
ALL_LAYOUT,
phi::MpAllReduceSumKernel,
float,
double,
int,
int64_t,
phi::dtype::bfloat16,
phi::dtype::float16) {}
#else
PD_REGISTER_KERNEL(mp_allreduce_sum,
GPU,
ALL_LAYOUT,
phi::MpAllReduceSumKernel,
float,
double,
int,
int64_t,
phi::dtype::float16) {}
#endif
36 changes: 36 additions & 0 deletions paddle/phi/kernels/kps/mp_allreduce_sum_kernel.kps
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
#ifdef PADDLE_WITH_XPU_KP

// Please do not modify the following code
#if defined(__CUDA_ARCH__)
#undef __CUDA_ARCH__
#endif

#if defined(__CUDACC__)
#undef __CUDACC__
#endif

#if defined(__CUDA__)
#undef __CUDA__
#endif

#if defined(__NVCC__)
#undef __NVCC__
#endif

#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/all_reduce_kernel.h"

namespace phi {
template <typename T, typename Context>
void MpAllReduceSumKernel(const Context& dev_ctx,
const DenseTensor& x,
int ring_id UNUSED,
bool use_calc_stream UNUSED,
DenseTensor* out) {
AllReduceKernel<T, Context>(
dev_ctx, x, static_cast<int>(ReduceType::kRedSum), out);
}
} // namespace phi
PD_REGISTER_KERNEL(
mp_allreduce_sum, KPS, ALL_LAYOUT, phi::MpAllReduceSumKernel, float) {}
#endif
36 changes: 36 additions & 0 deletions paddle/phi/kernels/xpu/mp_allreduce_sum_kernel.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/all_reduce_kernel.h"

namespace phi {
template <typename T, typename Context>
void MpAllReduceSumKernel(const Context& dev_ctx,
const DenseTensor& x,
int ring_id UNUSED,
bool use_calc_stream UNUSED,
DenseTensor* out) {
AllReduceKernel<T, Context>(
dev_ctx, x, static_cast<int>(ReduceType::kRedSum), out);
}
} // namespace phi

PD_REGISTER_KERNEL(mp_allreduce_sum,
XPU,
ALL_LAYOUT,
phi::MpAllReduceSumKernel,
float,
int,
phi::dtype::float16) {}
6 changes: 6 additions & 0 deletions paddle/phi/ops/yaml/inconsistent/static_backward.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,12 @@
func : minimum_grad
composite : minimum_grad(x, y, out_grad, x_grad, y_grad)

- backward_op : mp_allreduce_sum_grad
forward : mp_allreduce_sum(Tensor x, int ring_id = 0, bool use_calc_stream = false) -> Tensor(out)
args : (Tensor out_grad, int ring_id = 0, bool use_calc_stream = false)
output : Tensor(x_grad)
invoke : c_identity(out_grad, ring_id, false, false)

- backward_op : multiply_double_grad
forward : multiply_grad (Tensor x, Tensor y, Tensor grad_out, int axis = -1) -> Tensor(grad_x), Tensor(grad_y)
args : (Tensor x, Tensor y, Tensor grad_out, Tensor grad_x_grad, Tensor grad_y_grad, int axis = -1)
Expand Down
12 changes: 12 additions & 0 deletions paddle/phi/ops/yaml/inconsistent/static_ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -576,6 +576,18 @@
backward : minimum_grad
interfaces : paddle::dialect::InferSymbolicShapeInterface

- op : mp_allreduce_sum
args : (Tensor x, int ring_id = 0, bool use_calc_stream = false)
output : Tensor(out)
infer_meta :
func : AllReduceInferMeta
param: [x]
kernel :
func : mp_allreduce_sum
param: [x, ring_id, use_calc_stream]
backward: mp_allreduce_sum_grad
inplace: (x -> out)

- op : multiply
args : (Tensor x, Tensor y)
output : Tensor
Expand Down
6 changes: 6 additions & 0 deletions paddle/phi/ops/yaml/op_compat.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4431,6 +4431,12 @@
outputs:
out: Out

- op: mp_allreduce_sum
inputs :
x : X
outputs :
out: Out

- op: nce
backward: nce_grad
inputs:
Expand Down

0 comments on commit 12155d1

Please sign in to comment.