Skip to content

Commit

Permalink
【hackathon 4 No53】label_smooth add fp16 support (#51493)
Browse files Browse the repository at this point in the history
  • Loading branch information
enkilee authored Mar 14, 2023
1 parent 775fb43 commit 9cd99f7
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 19 deletions.
13 changes: 8 additions & 5 deletions paddle/phi/kernels/gpu/label_smooth_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -15,20 +15,22 @@
#include "paddle/phi/kernels/label_smooth_grad_kernel.h"

#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/elementwise_base.h"

namespace phi {
template <typename T>
struct LabelSmoothGradFunctor {
T epsilon;
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
MPType epsilon;

__forceinline__ LabelSmoothGradFunctor(float epsilon_data) {
epsilon = static_cast<T>(epsilon_data);
epsilon = static_cast<MPType>(epsilon_data);
}

__device__ __forceinline__ T operator()(const T x) const {
return static_cast<T>(1 - epsilon) * x;
return static_cast<T>((static_cast<MPType>(1) - epsilon) *
static_cast<MPType>(x));
}
};

Expand All @@ -52,4 +54,5 @@ PD_REGISTER_KERNEL(label_smooth_grad,
ALL_LAYOUT,
phi::LabelSmoothGradKernel,
float,
double) {}
double,
phi::dtype::float16) {}
34 changes: 23 additions & 11 deletions paddle/phi/kernels/gpu/label_smooth_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -17,24 +17,27 @@
#include <vector>

#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/elementwise_base.h"

namespace phi {

template <typename T>
struct LabelSmoothFunctor {
T epsilon;
T label_dim;
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
MPType epsilon;
MPType label_dim;

__forceinline__ LabelSmoothFunctor(float epsilon_data, int label_dim_data) {
epsilon = static_cast<T>(epsilon_data);
label_dim = static_cast<T>(label_dim_data);
epsilon = static_cast<MPType>(epsilon_data);
label_dim = static_cast<MPType>(label_dim_data);
}

__device__ __forceinline__ T operator()(const T x) const {
return (static_cast<T>(1 - epsilon) * x +
static_cast<T>(epsilon / label_dim));
return static_cast<T>(
static_cast<MPType>(static_cast<MPType>(1) - epsilon) *
static_cast<MPType>(x) +
static_cast<MPType>(epsilon / label_dim));
}
};

Expand All @@ -45,10 +48,14 @@ __global__ void LabelSmoothRunDistKernel(const int N,
const T* src,
const T* dist_data,
T* dst) {
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
CUDA_KERNEL_LOOP(idx, N) {
int dist_idx = idx % dist_numel;
dst[idx] = static_cast<T>(1 - epsilon) * src[idx] +
static_cast<T>(epsilon) * dist_data[dist_idx];
dst[idx] =
static_cast<T>((static_cast<MPType>(1) - static_cast<MPType>(epsilon)) *
static_cast<MPType>(src[idx]) +
static_cast<MPType>(epsilon) *
static_cast<MPType>(dist_data[dist_idx]));
}
}

Expand Down Expand Up @@ -83,5 +90,10 @@ void LabelSmoothKernel(const Context& ctx,

} // namespace phi

PD_REGISTER_KERNEL(
label_smooth, GPU, ALL_LAYOUT, phi::LabelSmoothKernel, float, double) {}
PD_REGISTER_KERNEL(label_smooth,
GPU,
ALL_LAYOUT,
phi::LabelSmoothKernel,
float,
double,
phi::dtype::float16) {}
11 changes: 10 additions & 1 deletion python/paddle/fluid/tests/unittests/test_label_smooth_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,10 @@ class TestLabelSmoothOp(OpTest):
def config(self):
self.op_type = "label_smooth"
self.python_api = paddle.nn.functional.label_smooth
self.init_dtype()
self.epsilon = 0.1
batch_size, self.label_dim = 10, 12
self.label = np.zeros((batch_size, self.label_dim)).astype("float64")
self.label = np.zeros((batch_size, self.label_dim)).astype(self.dtype)
nonzero_index = np.random.randint(self.label_dim, size=(batch_size))
self.label[np.arange(batch_size), nonzero_index] = 1

Expand All @@ -39,13 +40,21 @@ def setUp(self):
self.attrs = {'epsilon': self.epsilon}
self.outputs = {'Out': smoothed_label}

def init_dtype(self):
self.dtype = np.float64

def test_check_output(self):
self.check_output(check_eager=True)

def test_check_grad(self):
self.check_grad(["X"], "Out", check_eager=True)


class TestLabelSmoothFP16OP(TestLabelSmoothOp):
def init_dtype(self):
self.dtype = np.float16


class TestLabelSmoothOpWithPriorDist(TestLabelSmoothOp):
def setUp(self):
self.config()
Expand Down
4 changes: 2 additions & 2 deletions python/paddle/nn/functional/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1923,7 +1923,7 @@ def label_smooth(label, prior_dist=None, epsilon=0.1, name=None):
label(Tensor): The input variable containing the label data. The
label data should use one-hot representation. It's
a multidimensional tensor with a shape of
:math:`[N_1, ..., Depth]`, where Depth is class number. The dtype can be "float32" and "float64".
:math:`[N_1, ..., Depth]`, where Depth is class number. The dtype can be "float16" "float32" and "float64".
prior_dist(Tensor, optional): The prior distribution to be used to smooth
labels. If not provided, an uniform distribution
is used. It's a multidimensional tensor with a shape of
Expand Down Expand Up @@ -1965,7 +1965,7 @@ def label_smooth(label, prior_dist=None, epsilon=0.1, name=None):
)

check_variable_and_dtype(
label, 'label', ['float32', 'float64'], 'label_smooth'
label, 'label', ['float16', 'float32', 'float64'], 'label_smooth'
)

helper = LayerHelper("label_smooth", **locals())
Expand Down

0 comments on commit 9cd99f7

Please sign in to comment.