Skip to content

Commit

Permalink
[Kernel] Optimize p_norm gpu (#69660)
Browse files Browse the repository at this point in the history
* optimize p_norm gpu impl

* upload missing code
  • Loading branch information
HydrogenSulfate authored Nov 25, 2024
1 parent 3e56bff commit a291887
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 7 deletions.
36 changes: 36 additions & 0 deletions paddle/phi/kernels/funcs/p_norm_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,4 +60,40 @@ __device__ __forceinline__ float inline_pow(float base, float exponent) {
__device__ __forceinline__ double inline_pow(double base, double exponent) {
return pow(base, exponent);
}

__device__ __forceinline__ dtype::float16 inline_fabs(dtype::float16 x) {
return static_cast<dtype::float16>(fabs(static_cast<float>(x)));
}
__device__ __forceinline__ dtype::bfloat16 inline_fabs(dtype::bfloat16 x) {
return static_cast<dtype::bfloat16>(fabs(static_cast<float>(x)));
}
__device__ __forceinline__ float inline_fabs(float x) { return fabs(x); }
__device__ __forceinline__ double inline_fabs(double x) { return fabs(x); }

__device__ __forceinline__ dtype::float16 inline_square(dtype::float16 x) {
return static_cast<dtype::float16>(static_cast<float>(x) *
static_cast<float>(x));
}
__device__ __forceinline__ dtype::bfloat16 inline_square(dtype::bfloat16 x) {
return static_cast<dtype::bfloat16>(static_cast<float>(x) *
static_cast<float>(x));
}
__device__ __forceinline__ float inline_square(float x) { return x * x; }
__device__ __forceinline__ double inline_square(double x) { return x * x; }

__device__ __forceinline__ dtype::float16 inline_fabs_cubic(dtype::float16 x) {
return static_cast<dtype::float16>(fabs(
static_cast<float>(x) * static_cast<float>(x) * static_cast<float>(x)));
}
__device__ __forceinline__ dtype::bfloat16 inline_fabs_cubic(
dtype::bfloat16 x) {
return static_cast<dtype::bfloat16>(fabs(
static_cast<float>(x) * static_cast<float>(x) * static_cast<float>(x)));
}
__device__ __forceinline__ float inline_fabs_cubic(float x) {
return fabs(x * x * x);
}
__device__ __forceinline__ double inline_fabs_cubic(double x) {
return fabs(x * x * x);
}
} // namespace phi
56 changes: 49 additions & 7 deletions paddle/phi/kernels/gpu/p_norm_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,30 @@ struct UnsignedPowFunctor {
float porder;
};

template <typename T>
struct FabsFunctor {
HOSTDEVICE explicit inline FabsFunctor() = default;
HOSTDEVICE inline T operator()(const T x) const {
return static_cast<T>(inline_fabs(x));
}
};

template <typename T>
struct SquareFunctor {
HOSTDEVICE explicit inline SquareFunctor() = default;
HOSTDEVICE inline T operator()(const T x) const {
return static_cast<T>(inline_square(x));
}
};

template <typename T>
struct FabsCubicFunctor {
HOSTDEVICE explicit inline FabsCubicFunctor() = default;
HOSTDEVICE inline T operator()(const T x) const {
return static_cast<T>(inline_fabs_cubic(x));
}
};

template <typename T, typename Context>
void PNormKernel(const Context& dev_ctx,
const DenseTensor& x,
Expand Down Expand Up @@ -84,14 +108,32 @@ void PNormKernel(const Context& dev_ctx,
phi::funcs::ReduceKernel<T, T, kps::MinFunctor, AbsFunctor<T>>(
dev_ctx, *in_x, out_norm, AbsFunctor<T>(), reduce_axis);
} else {
phi::funcs::ReduceKernel<T, T, kps::AddFunctor, UnsignedPowFunctor<T>>(
dev_ctx, *in_x, out_norm, UnsignedPowFunctor<T>(porder), reduce_axis);
if (porder == 1.0) {
// fast 1-norm
phi::funcs::ReduceKernel<T, T, kps::AddFunctor, FabsFunctor<T>>(
dev_ctx, *in_x, out_norm, FabsFunctor<T>(), reduce_axis);
} else if (porder == 2.0) {
// fast 2-norm
phi::funcs::ReduceKernel<T, T, kps::AddFunctor, SquareFunctor<T>>(
dev_ctx, *in_x, out_norm, SquareFunctor<T>(), reduce_axis);
} else if (porder == 3.0) {
// fast 3-norm
phi::funcs::ReduceKernel<T, T, kps::AddFunctor, FabsCubicFunctor<T>>(
dev_ctx, *in_x, out_norm, FabsCubicFunctor<T>(), reduce_axis);
} else {
// vanilla norm
phi::funcs::ReduceKernel<T, T, kps::AddFunctor, UnsignedPowFunctor<T>>(
dev_ctx, *in_x, out_norm, UnsignedPowFunctor<T>(porder), reduce_axis);
}

const DenseTensor* tmp_norm = out_norm;
std::vector<const DenseTensor*> ins = {tmp_norm};
std::vector<DenseTensor*> outs = {out_norm};
phi::funcs::ElementwiseKernel<T>(
dev_ctx, ins, &outs, UnsignedPowFunctor<T>(1. / porder));
if (porder != 1.0) {
// save computation when porder is 1.0
const DenseTensor* tmp_norm = out_norm;
std::vector<const DenseTensor*> ins = {tmp_norm};
std::vector<DenseTensor*> outs = {out_norm};
phi::funcs::ElementwiseKernel<T>(
dev_ctx, ins, &outs, UnsignedPowFunctor<T>(1. / porder));
}
}
}
} // namespace phi
Expand Down

0 comments on commit a291887

Please sign in to comment.