Skip to content

Commit

Permalink
[AMP] support master_grad for adam and momentum (#54240)
Browse files Browse the repository at this point in the history
* support master_grad for adam and momentum

Co-authored-by: [email protected] <zhangting2020>
  • Loading branch information
zhangting2020 authored Jun 2, 2023
1 parent cdbf62f commit 703a64a
Show file tree
Hide file tree
Showing 6 changed files with 278 additions and 125 deletions.
13 changes: 6 additions & 7 deletions paddle/fluid/pybind/eager_functions.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1254,24 +1254,23 @@ static PyObject* eager_api_set_master_grads(PyObject* self,
auto tensor_list = CastPyArg2VectorOfTensor(PyTuple_GET_ITEM(args, 0), 0);
for (auto& tensor : tensor_list) {
VLOG(6) << "set master_grad for tensor: " << tensor.name();
PADDLE_ENFORCE_EQ(
egr::egr_utils_api::IsLeafTensor(tensor),
true,
paddle::platform::errors::Fatal("Only leaf Tensor can be set grad."));
if (!egr::egr_utils_api::IsLeafTensor(tensor)) {
continue;
}
paddle::Tensor* grad = egr::EagerUtils::mutable_grad(tensor);
PADDLE_ENFORCE_NE(grad,
nullptr,
paddle::platform::errors::Fatal(
"Detected NULL grad"
"Please check if you have manually cleared"
"the grad inside autograd_meta"));
auto dtype = (*grad).dtype();
if ((*grad).initialized() &&
(dtype == phi::DataType::FLOAT16 || dtype == phi::DataType::BFLOAT16)) {
if ((*grad).initialized() && ((*grad).dtype() == phi::DataType::FLOAT16 ||
(*grad).dtype() == phi::DataType::BFLOAT16)) {
auto master_grad =
paddle::experimental::cast(*grad, phi::DataType::FLOAT32);
grad->set_impl(master_grad.impl());
}
VLOG(6) << "finish setting master_grad for tensor: " << tensor.name();
}
RETURN_PY_NONE
EAGER_CATCH_AND_THROW_RETURN_NULL
Expand Down
230 changes: 158 additions & 72 deletions paddle/phi/kernels/gpu/adam_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@

namespace phi {

template <typename T, typename MT>
template <typename T, typename TG, typename MT>
__global__ void AdamKernelREG(MT beta1,
MT beta2,
MT epsilon,
Expand All @@ -41,7 +41,7 @@ __global__ void AdamKernelREG(MT beta1,
const MT* moment2,
MT* moment2_out,
const MT* lr_,
const T* grad,
const TG* grad,
const T* param,
T* param_out,
const MT* master_param,
Expand Down Expand Up @@ -73,7 +73,7 @@ __global__ void AdamKernelREG(MT beta1,
}
}

template <typename T, typename MT>
template <typename T, typename TG, typename MT>
__global__ void AdamKernelMEM(MT beta1,
MT beta2,
MT epsilon,
Expand All @@ -84,7 +84,7 @@ __global__ void AdamKernelMEM(MT beta1,
const MT* moment2,
MT* moment2_out,
const MT* lr_,
const T* grad,
const TG* grad,
const T* param,
T* param_out,
const MT* master_param,
Expand Down Expand Up @@ -152,6 +152,7 @@ void AdamDenseKernel(const Context& dev_ctx,
DenseTensor* beta2_pow_out,
DenseTensor* master_param_outs) {
using MPDType = typename phi::dtype::MPTypeTrait<T>::Type;
const auto grad_type = grad.dtype();

VLOG(4) << "use_global_beta_pow:" << use_global_beta_pow;

Expand Down Expand Up @@ -212,23 +213,44 @@ void AdamDenseKernel(const Context& dev_ctx,

if (beta1_pow.place() == CPUPlace() && beta2_pow.place() == CPUPlace()) {
// Compute with betapow in REG
AdamKernelREG<T, MPDType><<<blocks, threads, 0, dev_ctx.stream()>>>(
beta1_,
beta2_,
epsilon_,
*beta1_pow.data<MPDType>(),
*beta2_pow.data<MPDType>(),
moment1.data<MPDType>(),
dev_ctx.template Alloc<MPDType>(moment1_out),
moment2.data<MPDType>(),
dev_ctx.template Alloc<MPDType>(moment2_out),
learning_rate.data<MPDType>(),
grad.data<T>(),
param.data<T>(),
dev_ctx.template Alloc<T>(param_out),
master_in_data,
master_out_data,
param.numel());
if (grad_type == phi::DataType::FLOAT32) {
AdamKernelREG<T, float, MPDType>
<<<blocks, threads, 0, dev_ctx.stream()>>>(
beta1_,
beta2_,
epsilon_,
*beta1_pow.data<MPDType>(),
*beta2_pow.data<MPDType>(),
moment1.data<MPDType>(),
dev_ctx.template Alloc<MPDType>(moment1_out),
moment2.data<MPDType>(),
dev_ctx.template Alloc<MPDType>(moment2_out),
learning_rate.data<MPDType>(),
grad.data<float>(),
param.data<T>(),
dev_ctx.template Alloc<T>(param_out),
master_in_data,
master_out_data,
param.numel());
} else {
AdamKernelREG<T, T, MPDType><<<blocks, threads, 0, dev_ctx.stream()>>>(
beta1_,
beta2_,
epsilon_,
*beta1_pow.data<MPDType>(),
*beta2_pow.data<MPDType>(),
moment1.data<MPDType>(),
dev_ctx.template Alloc<MPDType>(moment1_out),
moment2.data<MPDType>(),
dev_ctx.template Alloc<MPDType>(moment2_out),
learning_rate.data<MPDType>(),
grad.data<T>(),
param.data<T>(),
dev_ctx.template Alloc<T>(param_out),
master_in_data,
master_out_data,
param.numel());
}
if (!use_global_beta_pow) {
// Cpu update
dev_ctx.template HostAlloc<MPDType>(beta1_pow_out)[0] =
Expand All @@ -237,23 +259,44 @@ void AdamDenseKernel(const Context& dev_ctx,
beta2_ * beta2_pow.data<MPDType>()[0];
}
} else {
AdamKernelMEM<T, MPDType><<<blocks, threads, 0, dev_ctx.stream()>>>(
beta1_,
beta2_,
epsilon_,
beta1_pow.data<MPDType>(),
beta2_pow.data<MPDType>(),
moment1.data<MPDType>(),
dev_ctx.template Alloc<MPDType>(moment1_out),
moment2.data<MPDType>(),
dev_ctx.template Alloc<MPDType>(moment2_out),
learning_rate.data<MPDType>(),
grad.data<T>(),
param.data<T>(),
dev_ctx.template Alloc<T>(param_out),
master_in_data,
master_out_data,
param.numel());
if (grad_type == phi::DataType::FLOAT32) {
AdamKernelMEM<T, float, MPDType>
<<<blocks, threads, 0, dev_ctx.stream()>>>(
beta1_,
beta2_,
epsilon_,
beta1_pow.data<MPDType>(),
beta2_pow.data<MPDType>(),
moment1.data<MPDType>(),
dev_ctx.template Alloc<MPDType>(moment1_out),
moment2.data<MPDType>(),
dev_ctx.template Alloc<MPDType>(moment2_out),
learning_rate.data<MPDType>(),
grad.data<float>(),
param.data<T>(),
dev_ctx.template Alloc<T>(param_out),
master_in_data,
master_out_data,
param.numel());
} else {
AdamKernelMEM<T, T, MPDType><<<blocks, threads, 0, dev_ctx.stream()>>>(
beta1_,
beta2_,
epsilon_,
beta1_pow.data<MPDType>(),
beta2_pow.data<MPDType>(),
moment1.data<MPDType>(),
dev_ctx.template Alloc<MPDType>(moment1_out),
moment2.data<MPDType>(),
dev_ctx.template Alloc<MPDType>(moment2_out),
learning_rate.data<MPDType>(),
grad.data<T>(),
param.data<T>(),
dev_ctx.template Alloc<T>(param_out),
master_in_data,
master_out_data,
param.numel());
}
if (!use_global_beta_pow) {
// Update with gpu
UpdateBetaPow<MPDType><<<1, 1, 0, dev_ctx.stream()>>>(
Expand Down Expand Up @@ -308,26 +351,48 @@ void MergedAdamKernel(
int threads = 512;
int blocks = (param[idx]->numel() + threads - 1) / threads;

const auto grad_type = grad[idx]->dtype();
if (beta1_pow[idx]->place() == CPUPlace() &&
beta2_pow[idx]->place() == CPUPlace()) {
// Compute with betapow in REG
AdamKernelREG<T, MPDType><<<blocks, threads, 0, dev_ctx.stream()>>>(
beta1_,
beta2_,
epsilon_,
*beta1_pow[idx]->data<MPDType>(),
*beta2_pow[idx]->data<MPDType>(),
moment1[idx]->data<MPDType>(),
dev_ctx.template Alloc<MPDType>(moment1_out[idx]),
moment2[idx]->data<MPDType>(),
dev_ctx.template Alloc<MPDType>(moment2_out[idx]),
learning_rate[idx]->data<MPDType>(),
grad[idx]->data<T>(),
param[idx]->data<T>(),
dev_ctx.template Alloc<T>(param_out[idx]),
master_in_data,
master_out_data,
param[idx]->numel());
if (grad_type == phi::DataType::FLOAT32) {
AdamKernelREG<T, float, MPDType>
<<<blocks, threads, 0, dev_ctx.stream()>>>(
beta1_,
beta2_,
epsilon_,
*beta1_pow[idx]->data<MPDType>(),
*beta2_pow[idx]->data<MPDType>(),
moment1[idx]->data<MPDType>(),
dev_ctx.template Alloc<MPDType>(moment1_out[idx]),
moment2[idx]->data<MPDType>(),
dev_ctx.template Alloc<MPDType>(moment2_out[idx]),
learning_rate[idx]->data<MPDType>(),
grad[idx]->data<float>(),
param[idx]->data<T>(),
dev_ctx.template Alloc<T>(param_out[idx]),
master_in_data,
master_out_data,
param[idx]->numel());
} else {
AdamKernelREG<T, T, MPDType><<<blocks, threads, 0, dev_ctx.stream()>>>(
beta1_,
beta2_,
epsilon_,
*beta1_pow[idx]->data<MPDType>(),
*beta2_pow[idx]->data<MPDType>(),
moment1[idx]->data<MPDType>(),
dev_ctx.template Alloc<MPDType>(moment1_out[idx]),
moment2[idx]->data<MPDType>(),
dev_ctx.template Alloc<MPDType>(moment2_out[idx]),
learning_rate[idx]->data<MPDType>(),
grad[idx]->data<T>(),
param[idx]->data<T>(),
dev_ctx.template Alloc<T>(param_out[idx]),
master_in_data,
master_out_data,
param[idx]->numel());
}
if (!use_global_beta_pow) {
// Cpu update
dev_ctx.template HostAlloc<MPDType>(beta1_pow_out[idx])[0] =
Expand All @@ -336,23 +401,44 @@ void MergedAdamKernel(
beta2_ * beta2_pow[idx]->data<MPDType>()[0];
}
} else {
AdamKernelMEM<T, MPDType><<<blocks, threads, 0, dev_ctx.stream()>>>(
beta1_,
beta2_,
epsilon_,
beta1_pow[idx]->data<MPDType>(),
beta2_pow[idx]->data<MPDType>(),
moment1[idx]->data<MPDType>(),
dev_ctx.template Alloc<MPDType>(moment1_out[idx]),
moment2[idx]->data<MPDType>(),
dev_ctx.template Alloc<MPDType>(moment2_out[idx]),
learning_rate[idx]->data<MPDType>(),
grad[idx]->data<T>(),
param[idx]->data<T>(),
dev_ctx.template Alloc<T>(param_out[idx]),
master_in_data,
master_out_data,
param[idx]->numel());
if (grad_type == phi::DataType::FLOAT32) {
AdamKernelMEM<T, float, MPDType>
<<<blocks, threads, 0, dev_ctx.stream()>>>(
beta1_,
beta2_,
epsilon_,
beta1_pow[idx]->data<MPDType>(),
beta2_pow[idx]->data<MPDType>(),
moment1[idx]->data<MPDType>(),
dev_ctx.template Alloc<MPDType>(moment1_out[idx]),
moment2[idx]->data<MPDType>(),
dev_ctx.template Alloc<MPDType>(moment2_out[idx]),
learning_rate[idx]->data<MPDType>(),
grad[idx]->data<float>(),
param[idx]->data<T>(),
dev_ctx.template Alloc<T>(param_out[idx]),
master_in_data,
master_out_data,
param[idx]->numel());
} else {
AdamKernelMEM<T, T, MPDType><<<blocks, threads, 0, dev_ctx.stream()>>>(
beta1_,
beta2_,
epsilon_,
beta1_pow[idx]->data<MPDType>(),
beta2_pow[idx]->data<MPDType>(),
moment1[idx]->data<MPDType>(),
dev_ctx.template Alloc<MPDType>(moment1_out[idx]),
moment2[idx]->data<MPDType>(),
dev_ctx.template Alloc<MPDType>(moment2_out[idx]),
learning_rate[idx]->data<MPDType>(),
grad[idx]->data<T>(),
param[idx]->data<T>(),
dev_ctx.template Alloc<T>(param_out[idx]),
master_in_data,
master_out_data,
param[idx]->numel());
}
if (!use_global_beta_pow) {
// Update with gpu
UpdateBetaPow<MPDType><<<1, 1, 0, dev_ctx.stream()>>>(
Expand Down
49 changes: 34 additions & 15 deletions paddle/phi/kernels/impl/merged_momentum_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -300,21 +300,40 @@ void MergedMomentumInnerCompute(
} else if (ctx.GetPlace().GetType() == phi::AllocationType::GPU) {
phi::funcs::ForRange<Context> for_range(
static_cast<const Context &>(ctx), params[idx]->numel());
#define PADDLE_LAUNCH_DENSE_MTMOMENTUM_KERNEL(__nesterov, __reg_type) \
phi::DenseMomentumFunctor<T, MT, __reg_type, __nesterov> functor( \
params[idx]->data<T>(), \
grads[idx]->data<T>(), \
velocitys[idx]->data<MT>(), \
lr_temp->data<MPType>(), \
master_in_data, \
static_cast<MT>(mu), \
static_cast<MT>(rescale_grad), \
params[idx]->numel(), \
regularization_coeff, \
params_out[idx]->data<T>(), \
velocitys_out[idx]->data<MT>(), \
master_out_data); \
for_range(functor);
const auto grad_type = grads[idx]->dtype();
#define PADDLE_LAUNCH_DENSE_MTMOMENTUM_KERNEL(__nesterov, __reg_type) \
if (grad_type == phi::DataType::FLOAT32) { \
DenseMomentumFunctor<T, float, MT, __reg_type, __nesterov> functor( \
params[idx]->data<T>(), \
grads[idx]->data<float>(), \
velocitys[idx]->data<MT>(), \
lr_temp->data<MPType>(), \
master_in_data, \
static_cast<MT>(mu), \
static_cast<MT>(rescale_grad), \
params[idx]->numel(), \
regularization_coeff, \
params_out[idx]->data<T>(), \
velocitys_out[idx]->data<MT>(), \
master_out_data); \
for_range(functor); \
} else { \
DenseMomentumFunctor<T, T, MT, __reg_type, __nesterov> functor( \
params[idx]->data<T>(), \
grads[idx]->data<T>(), \
velocitys[idx]->data<MT>(), \
lr_temp->data<MPType>(), \
master_in_data, \
static_cast<MT>(mu), \
static_cast<MT>(rescale_grad), \
params[idx]->numel(), \
regularization_coeff, \
params_out[idx]->data<T>(), \
velocitys_out[idx]->data<MT>(), \
master_out_data); \
for_range(functor); \
}

if (use_nesterov) {
if (regularization_flag == phi::RegularizationType::kL2DECAY) {
PADDLE_LAUNCH_DENSE_MTMOMENTUM_KERNEL(
Expand Down
Loading

0 comments on commit 703a64a

Please sign in to comment.