Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【hackathon 4 No53】label_smooth add fp16 support #51493

Merged
merged 14 commits into from
Mar 14, 2023
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 8 additions & 5 deletions paddle/phi/kernels/gpu/label_smooth_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -15,20 +15,22 @@
#include "paddle/phi/kernels/label_smooth_grad_kernel.h"

#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/elementwise_base.h"

namespace phi {
template <typename T>
struct LabelSmoothGradFunctor {
T epsilon;
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
MPType epsilon;

__forceinline__ LabelSmoothGradFunctor(float epsilon_data) {
epsilon = static_cast<T>(epsilon_data);
epsilon = static_cast<MPType>(epsilon_data);
}

__device__ __forceinline__ T operator()(const T x) const {
return static_cast<T>(1 - epsilon) * x;
return static_cast<T>((static_cast<MPType>(1) - epsilon) *
static_cast<MPType>(x));
}
};

Expand All @@ -52,4 +54,5 @@ PD_REGISTER_KERNEL(label_smooth_grad,
ALL_LAYOUT,
phi::LabelSmoothGradKernel,
float,
double) {}
double,
phi::dtype::float16) {}
34 changes: 23 additions & 11 deletions paddle/phi/kernels/gpu/label_smooth_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -17,24 +17,27 @@
#include <vector>

#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/elementwise_base.h"

namespace phi {

template <typename T>
struct LabelSmoothFunctor {
T epsilon;
T label_dim;
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
MPType epsilon;
MPType label_dim;

__forceinline__ LabelSmoothFunctor(float epsilon_data, int label_dim_data) {
epsilon = static_cast<T>(epsilon_data);
label_dim = static_cast<T>(label_dim_data);
epsilon = static_cast<MPType>(epsilon_data);
label_dim = static_cast<MPType>(label_dim_data);
}

__device__ __forceinline__ T operator()(const T x) const {
return (static_cast<T>(1 - epsilon) * x +
static_cast<T>(epsilon / label_dim));
return static_cast<T>(
static_cast<MPType>(static_cast<MPType>(1) - epsilon) *
static_cast<MPType>(x) +
static_cast<MPType>(epsilon / label_dim));
}
};

Expand All @@ -45,10 +48,14 @@ __global__ void LabelSmoothRunDistKernel(const int N,
const T* src,
const T* dist_data,
T* dst) {
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
CUDA_KERNEL_LOOP(idx, N) {
int dist_idx = idx % dist_numel;
dst[idx] = static_cast<T>(1 - epsilon) * src[idx] +
static_cast<T>(epsilon) * dist_data[dist_idx];
dst[idx] =
static_cast<T>((static_cast<MPType>(1) - static_cast<MPType>(epsilon)) *
static_cast<MPType>(src[idx]) +
static_cast<MPType>(epsilon) *
static_cast<MPType>(dist_data[dist_idx]));
}
}

Expand Down Expand Up @@ -83,5 +90,10 @@ void LabelSmoothKernel(const Context& ctx,

} // namespace phi

PD_REGISTER_KERNEL(
label_smooth, GPU, ALL_LAYOUT, phi::LabelSmoothKernel, float, double) {}
PD_REGISTER_KERNEL(label_smooth,
GPU,
ALL_LAYOUT,
phi::LabelSmoothKernel,
float,
double,
phi::dtype::float16) {}
18 changes: 17 additions & 1 deletion python/paddle/fluid/tests/unittests/test_label_smooth_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,17 @@
from op_test import OpTest

import paddle
import paddle.fluid.core as core


class TestLabelSmoothOp(OpTest):
def config(self):
self.op_type = "label_smooth"
self.python_api = paddle.nn.functional.label_smooth
self.init_dtype()
self.epsilon = 0.1
batch_size, self.label_dim = 10, 12
self.label = np.zeros((batch_size, self.label_dim)).astype("float64")
self.label = np.zeros((batch_size, self.label_dim)).astype(self.dtype)
nonzero_index = np.random.randint(self.label_dim, size=(batch_size))
self.label[np.arange(batch_size), nonzero_index] = 1

Expand All @@ -39,13 +41,27 @@ def setUp(self):
self.attrs = {'epsilon': self.epsilon}
self.outputs = {'Out': smoothed_label}

def init_dtype(self):
self.dtype = np.float64

def test_check_output(self):
self.check_output(check_eager=True)

def test_check_grad(self):
self.check_grad(["X"], "Out", check_eager=True)


class TestLabelSmoothFP16OP(TestLabelSmoothOp):
def init_dtype(self):
self.dtype = np.float16

def test_check_output(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
if core.is_float16_supported(place):
self.check_output_with_place(place, check_eager=True)
enkilee marked this conversation as resolved.
Show resolved Hide resolved


class TestLabelSmoothOpWithPriorDist(TestLabelSmoothOp):
def setUp(self):
self.config()
Expand Down
4 changes: 2 additions & 2 deletions python/paddle/nn/functional/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1924,7 +1924,7 @@ def label_smooth(label, prior_dist=None, epsilon=0.1, name=None):
label(Tensor): The input variable containing the label data. The
label data should use one-hot representation. It's
a multidimensional tensor with a shape of
:math:`[N_1, ..., Depth]`, where Depth is class number. The dtype can be "float32" and "float64".
:math:`[N_1, ..., Depth]`, where Depth is class number. The dtype can be "float16" "float32" and "float64".
prior_dist(Tensor, optional): The prior distribution to be used to smooth
labels. If not provided, an uniform distribution
is used. It's a multidimensional tensor with a shape of
Expand Down Expand Up @@ -1966,7 +1966,7 @@ def label_smooth(label, prior_dist=None, epsilon=0.1, name=None):
)

check_variable_and_dtype(
label, 'label', ['float32', 'float64'], 'label_smooth'
label, 'label', ['float16', 'float32', 'float64'], 'label_smooth'
)

helper = LayerHelper("label_smooth", **locals())
Expand Down