Skip to content

Commit

Permalink
Cherry pick for fix of operator precision. (#52705)
Browse files Browse the repository at this point in the history
* Fix scale kernel for low precision, cherry pick #50998.

* Fix the FP16 precision problem of add_n. (#50129)

* Change squared_l2_norm to reuse ReduceKernel, and register fp16 and bf16 kernel, which is cherry pick #48315.

* Cherry-pick the fix of MPTypeTrait in KP, which is implemented in #50993.

* Cherry-pick the multi-precision support of AdamW for bf16, #48041.

* Fix compiling error.

* Cherry-pick the fix of CubTensorReduceImpl for bfloat16 in #50993.

* Fix unittest.

---------

Co-authored-by: liuruyan <[email protected]>
  • Loading branch information
Xreki and liuruyan authored Apr 11, 2023
1 parent d12588d commit d1e8b1e
Show file tree
Hide file tree
Showing 11 changed files with 763 additions and 398 deletions.
38 changes: 29 additions & 9 deletions paddle/phi/kernels/funcs/reduce_function.h
Original file line number Diff line number Diff line change
Expand Up @@ -986,14 +986,16 @@ template <typename Tx,
template <typename>
class ReduceOp,
typename TransformOp>
static typename std::enable_if<!std::is_same<Tx, phi::dtype::float16>::value,
void>::type
CubTensorReduceImpl(const Tx* x_data,
Ty* y_data,
const TransformOp& transform,
int reduce_num,
const KPDevice& dev_ctx,
KPStream stream) {
static
typename std::enable_if<!std::is_same<Tx, phi::dtype::float16>::value &&
!std::is_same<Tx, phi::dtype::bfloat16>::value,
void>::type
CubTensorReduceImpl(const Tx* x_data,
Ty* y_data,
const TransformOp& transform,
int reduce_num,
const KPDevice& dev_ctx,
KPStream stream) {
auto reducer = ReduceOp<Ty>();
cub::TransformInputIterator<Ty, TransformOp, const Tx*> trans_x(x_data,
transform);
Expand Down Expand Up @@ -1037,6 +1039,23 @@ CubTensorReduceImpl(const Tx* x_data,
PADDLE_THROW(phi::errors::InvalidArgument(
"Tx should not be float16 when using cub::DeviceReduce::Reduce()."));
}

template <typename Tx,
typename Ty,
template <typename>
class ReduceOp,
typename TransformOp>
static typename std::enable_if<std::is_same<Tx, phi::dtype::bfloat16>::value,
void>::type
CubTensorReduceImpl(const Tx* x_data,
Ty* y_data,
const TransformOp& transform,
int reduce_num,
const KPDevice& dev_ctx,
KPStream stream) {
PADDLE_THROW(phi::errors::InvalidArgument(
"Tx should not be bfloat16 when using cub::DeviceReduce::Reduce()."));
}
#endif // PADDLE_WITH_XPU_KP

template <typename Tx,
Expand Down Expand Up @@ -1081,7 +1100,8 @@ void ReduceKernel(const KPDevice& dev_ctx,

config.SetOutputData(y_data, dev_ctx, &tmp);
constexpr bool kIsTxFP16 = std::is_same<Tx, phi::dtype::float16>::value;
bool use_cub_reduce = config.reduce_num == numel && !kIsTxFP16;
constexpr bool kIsTxBF16 = std::is_same<Tx, phi::dtype::bfloat16>::value;
bool use_cub_reduce = config.reduce_num == numel && !kIsTxFP16 && !kIsTxBF16;
#ifndef PADDLE_WITH_XPU_KP
if (use_cub_reduce) {
if (is_mean) {
Expand Down
19 changes: 11 additions & 8 deletions paddle/phi/kernels/gpu/add_n_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@

#include "paddle/phi/kernels/add_n_kernel.h"

#include "paddle/phi/kernels/impl/add_n_kernel_impl.h"

#include "paddle/fluid/memory/malloc.h"
#include "paddle/fluid/memory/memcpy.h"
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/kernels/impl/add_n_kernel_impl.h"

namespace phi {

Expand All @@ -38,16 +38,18 @@ __global__ void Sum2CUDAKernel(const T *in_0,
template <class T>
__global__ void SumArrayCUDAKernel(
T **in, T *out, int64_t N, size_t in_size, bool read_dst) {
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
int id = blockIdx.x * blockDim.x + threadIdx.x;
while (id < N) {
T total(read_dst ? out[id] : static_cast<T>(0));
MPType total(read_dst ? static_cast<MPType>(out[id])
: static_cast<MPType>(0));
for (int i = 0; i < in_size; ++i) {
const T *tmp = in[i];
if (tmp) {
total += tmp[id];
total += static_cast<MPType>(tmp[id]);
}
}
out[id] = total;
out[id] = static_cast<T>(total);
id += blockDim.x * gridDim.x;
}
}
Expand Down Expand Up @@ -116,11 +118,12 @@ void AddNKernel(const Context &dev_ctx,
int64_t length_0 = in_0.numel();
int64_t length_1 = in_1.numel();
if (length_0 && length_1 && in_0.IsInitialized() && in_1.IsInitialized()) {
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
auto result = EigenVector<T>::Flatten(*out);
auto &place = *dev_ctx.eigen_device();
auto in_0_e = EigenVector<T>::Flatten(in_0);
auto in_1_e = EigenVector<T>::Flatten(in_1);
result.device(place) = in_0_e + in_1_e;
auto in_0_e = EigenVector<T>::Flatten(in_0).template cast<MPType>();
auto in_1_e = EigenVector<T>::Flatten(in_1).template cast<MPType>();
result.device(place) = (in_0_e + in_1_e).template cast<T>();
} else if (length_0 && in_0.IsInitialized()) {
auto result = EigenVector<T>::Flatten(*out);
auto &place = *dev_ctx.eigen_device();
Expand Down
23 changes: 15 additions & 8 deletions paddle/phi/kernels/gpu/scale_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -15,28 +15,30 @@ limitations under the License. */
#include "paddle/phi/kernels/scale_kernel.h"

#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/common/bfloat16.h"
#include "paddle/phi/common/float16.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/elementwise_base.h"

namespace phi {

template <typename InT>
template <typename DataT, typename ParamT>
struct ScaleFunctor {
InT bias;
InT scale;
ParamT bias;
ParamT scale;
bool bias_after_scale;

ScaleFunctor(InT scale_data, InT bias_data, bool is_bias_after_sacle)
ScaleFunctor(ParamT scale_data, ParamT bias_data, bool is_bias_after_sacle)
: bias(bias_data),
scale(scale_data),
bias_after_scale(is_bias_after_sacle) {}

__device__ __forceinline__ InT operator()(const InT x) const {
__device__ __forceinline__ DataT operator()(const DataT x) const {
if (bias_after_scale) {
return scale * x + bias;
return static_cast<DataT>(scale * static_cast<ParamT>(x) + bias);
} else {
return scale * (x + bias);
return static_cast<DataT>(scale * (static_cast<ParamT>(x) + bias));
}
}
};
Expand All @@ -48,16 +50,21 @@ void ScaleKernel(const Context& dev_ctx,
float bias,
bool bias_after_scale,
DenseTensor* out) {
using MT = typename phi::dtype::MPTypeTrait<T>::Type;
std::vector<const DenseTensor*> inputs;
std::vector<DenseTensor*> outputs;
inputs.emplace_back(&x);
outputs.emplace_back(out);
dev_ctx.template Alloc<T>(out);
if (x.numel() <= 0 || (!x.IsInitialized())) {
return;
}
phi::funcs::ElementwiseKernel<T>(
dev_ctx,
inputs,
&outputs,
ScaleFunctor<T>(scale.to<T>(), static_cast<T>(bias), bias_after_scale));
ScaleFunctor<T, MT>(
scale.to<MT>(), static_cast<MT>(bias), bias_after_scale));
}

} // namespace phi
Expand Down
39 changes: 37 additions & 2 deletions paddle/phi/kernels/gpu/squared_l2_norm_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,47 @@
#include "paddle/phi/kernels/squared_l2_norm_grad_kernel.h"

#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/float16.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/squared_l2_norm_grad_kernel_impl.h"
#include "paddle/phi/kernels/funcs/broadcast_function.h"

namespace phi {
/**
* x*y*2.0
*/
template <typename T>
struct DoubleMulFunctor {
__device__ __forceinline__ T operator()(const T a, const T b) const {
return b * a * static_cast<T>(2.0f);
}
};

template <typename T, typename Context>
void SquaredL2NormGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& dout,
DenseTensor* dx) {
dev_ctx.template Alloc<T>(dx);

PADDLE_ENFORCE_EQ(
dout.numel(),
1,
phi::errors::InvalidArgument(
"Input(GRAD@Out) of SquaredL2NormGradOP should be a scalar."));
std::vector<const DenseTensor*> ins{&x, &dout};
std::vector<DenseTensor*> outs{dx};

funcs::BroadcastKernel<ElementwiseType::kBinary, T, T>(
dev_ctx, ins, &outs, -1, phi::DoubleMulFunctor<T>());
}
} // namespace phi

PD_REGISTER_KERNEL(squared_l2_norm_grad,
GPU,
ALL_LAYOUT,
phi::SquaredL2NormGradKernel,
float,
double) {}
double,
phi::dtype::float16,
phi::dtype::bfloat16) {}
31 changes: 28 additions & 3 deletions paddle/phi/kernels/gpu/squared_l2_norm_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,34 @@
#include "paddle/phi/kernels/squared_l2_norm_kernel.h"

#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/bfloat16.h"
#include "paddle/phi/common/float16.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/squared_l2_norm_kernel_impl.h"
#include "paddle/phi/kernels/funcs/reduce_function.h"

PD_REGISTER_KERNEL(
squared_l2_norm, GPU, ALL_LAYOUT, phi::SquaredL2NormKernel, float, double) {
namespace phi {

template <typename T, typename Context>
void SquaredL2NormKernel(const Context& dev_ctx,
const DenseTensor& x,
DenseTensor* out) {
dev_ctx.template Alloc<T>(out);
std::vector<int> origin_reduce_dims;
for (size_t i = 0; i < x.dims().size(); i++) {
origin_reduce_dims.push_back(i);
}
phi::funcs::ReduceKernel<T, T, kps::AddFunctor, kps::SquareFunctor<T, T>>(
dev_ctx, x, out, kps::SquareFunctor<T, T>(), origin_reduce_dims, false);
}

} // namespace phi

PD_REGISTER_KERNEL(squared_l2_norm,
GPU,
ALL_LAYOUT,
phi::SquaredL2NormKernel,
float,
double,
phi::dtype::float16,
phi::dtype::bfloat16) {}
6 changes: 6 additions & 0 deletions paddle/phi/kernels/primitive/compute_primitives.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,12 @@ class MPTypeTrait<phi::dtype::float16> {
using Type = float;
};

template <>
class MPTypeTrait<phi::dtype::bfloat16> {
public:
using Type = float;
};

/**
* @brief Will be used in BlockYReduce, get the index of reduce_num in shared
* memory.
Expand Down
Loading

0 comments on commit d1e8b1e

Please sign in to comment.