From 703a64a3990069311c3e5953356fcf40003ecd01 Mon Sep 17 00:00:00 2001 From: Zhang Ting Date: Fri, 2 Jun 2023 11:12:00 +0800 Subject: [PATCH] [AMP] support master_grad for adam and momentum (#54240) * support master_grad for adam and momentum Co-authored-by: zhangting_2017@163.com --- paddle/fluid/pybind/eager_functions.cc | 13 +- paddle/phi/kernels/gpu/adam_kernel.cu | 230 ++++++++++++------ .../phi/kernels/impl/merged_momentum_impl.h | 49 ++-- .../phi/kernels/impl/momentum_kernel_impl.h | 65 +++-- python/paddle/optimizer/optimizer.py | 17 ++ test/amp/test_amp_master_grad.py | 29 ++- 6 files changed, 278 insertions(+), 125 deletions(-) diff --git a/paddle/fluid/pybind/eager_functions.cc b/paddle/fluid/pybind/eager_functions.cc index 3603e569d21db..f877c55de4671 100644 --- a/paddle/fluid/pybind/eager_functions.cc +++ b/paddle/fluid/pybind/eager_functions.cc @@ -1254,10 +1254,9 @@ static PyObject* eager_api_set_master_grads(PyObject* self, auto tensor_list = CastPyArg2VectorOfTensor(PyTuple_GET_ITEM(args, 0), 0); for (auto& tensor : tensor_list) { VLOG(6) << "set master_grad for tensor: " << tensor.name(); - PADDLE_ENFORCE_EQ( - egr::egr_utils_api::IsLeafTensor(tensor), - true, - paddle::platform::errors::Fatal("Only leaf Tensor can be set grad.")); + if (!egr::egr_utils_api::IsLeafTensor(tensor)) { + continue; + } paddle::Tensor* grad = egr::EagerUtils::mutable_grad(tensor); PADDLE_ENFORCE_NE(grad, nullptr, @@ -1265,13 +1264,13 @@ static PyObject* eager_api_set_master_grads(PyObject* self, "Detected NULL grad" "Please check if you have manually cleared" "the grad inside autograd_meta")); - auto dtype = (*grad).dtype(); - if ((*grad).initialized() && - (dtype == phi::DataType::FLOAT16 || dtype == phi::DataType::BFLOAT16)) { + if ((*grad).initialized() && ((*grad).dtype() == phi::DataType::FLOAT16 || + (*grad).dtype() == phi::DataType::BFLOAT16)) { auto master_grad = paddle::experimental::cast(*grad, phi::DataType::FLOAT32); grad->set_impl(master_grad.impl()); } + VLOG(6) << "finish setting master_grad for tensor: " << tensor.name(); } RETURN_PY_NONE EAGER_CATCH_AND_THROW_RETURN_NULL diff --git a/paddle/phi/kernels/gpu/adam_kernel.cu b/paddle/phi/kernels/gpu/adam_kernel.cu index 23dc6cdfd3398..5292d7d29c07b 100644 --- a/paddle/phi/kernels/gpu/adam_kernel.cu +++ b/paddle/phi/kernels/gpu/adam_kernel.cu @@ -30,7 +30,7 @@ namespace phi { -template +template __global__ void AdamKernelREG(MT beta1, MT beta2, MT epsilon, @@ -41,7 +41,7 @@ __global__ void AdamKernelREG(MT beta1, const MT* moment2, MT* moment2_out, const MT* lr_, - const T* grad, + const TG* grad, const T* param, T* param_out, const MT* master_param, @@ -73,7 +73,7 @@ __global__ void AdamKernelREG(MT beta1, } } -template +template __global__ void AdamKernelMEM(MT beta1, MT beta2, MT epsilon, @@ -84,7 +84,7 @@ __global__ void AdamKernelMEM(MT beta1, const MT* moment2, MT* moment2_out, const MT* lr_, - const T* grad, + const TG* grad, const T* param, T* param_out, const MT* master_param, @@ -152,6 +152,7 @@ void AdamDenseKernel(const Context& dev_ctx, DenseTensor* beta2_pow_out, DenseTensor* master_param_outs) { using MPDType = typename phi::dtype::MPTypeTrait::Type; + const auto grad_type = grad.dtype(); VLOG(4) << "use_global_beta_pow:" << use_global_beta_pow; @@ -212,23 +213,44 @@ void AdamDenseKernel(const Context& dev_ctx, if (beta1_pow.place() == CPUPlace() && beta2_pow.place() == CPUPlace()) { // Compute with betapow in REG - AdamKernelREG<<>>( - beta1_, - beta2_, - epsilon_, - *beta1_pow.data(), - *beta2_pow.data(), - moment1.data(), - dev_ctx.template Alloc(moment1_out), - moment2.data(), - dev_ctx.template Alloc(moment2_out), - learning_rate.data(), - grad.data(), - param.data(), - dev_ctx.template Alloc(param_out), - master_in_data, - master_out_data, - param.numel()); + if (grad_type == phi::DataType::FLOAT32) { + AdamKernelREG + <<>>( + beta1_, + beta2_, + epsilon_, + *beta1_pow.data(), + *beta2_pow.data(), + moment1.data(), + dev_ctx.template Alloc(moment1_out), + moment2.data(), + dev_ctx.template Alloc(moment2_out), + learning_rate.data(), + grad.data(), + param.data(), + dev_ctx.template Alloc(param_out), + master_in_data, + master_out_data, + param.numel()); + } else { + AdamKernelREG<<>>( + beta1_, + beta2_, + epsilon_, + *beta1_pow.data(), + *beta2_pow.data(), + moment1.data(), + dev_ctx.template Alloc(moment1_out), + moment2.data(), + dev_ctx.template Alloc(moment2_out), + learning_rate.data(), + grad.data(), + param.data(), + dev_ctx.template Alloc(param_out), + master_in_data, + master_out_data, + param.numel()); + } if (!use_global_beta_pow) { // Cpu update dev_ctx.template HostAlloc(beta1_pow_out)[0] = @@ -237,23 +259,44 @@ void AdamDenseKernel(const Context& dev_ctx, beta2_ * beta2_pow.data()[0]; } } else { - AdamKernelMEM<<>>( - beta1_, - beta2_, - epsilon_, - beta1_pow.data(), - beta2_pow.data(), - moment1.data(), - dev_ctx.template Alloc(moment1_out), - moment2.data(), - dev_ctx.template Alloc(moment2_out), - learning_rate.data(), - grad.data(), - param.data(), - dev_ctx.template Alloc(param_out), - master_in_data, - master_out_data, - param.numel()); + if (grad_type == phi::DataType::FLOAT32) { + AdamKernelMEM + <<>>( + beta1_, + beta2_, + epsilon_, + beta1_pow.data(), + beta2_pow.data(), + moment1.data(), + dev_ctx.template Alloc(moment1_out), + moment2.data(), + dev_ctx.template Alloc(moment2_out), + learning_rate.data(), + grad.data(), + param.data(), + dev_ctx.template Alloc(param_out), + master_in_data, + master_out_data, + param.numel()); + } else { + AdamKernelMEM<<>>( + beta1_, + beta2_, + epsilon_, + beta1_pow.data(), + beta2_pow.data(), + moment1.data(), + dev_ctx.template Alloc(moment1_out), + moment2.data(), + dev_ctx.template Alloc(moment2_out), + learning_rate.data(), + grad.data(), + param.data(), + dev_ctx.template Alloc(param_out), + master_in_data, + master_out_data, + param.numel()); + } if (!use_global_beta_pow) { // Update with gpu UpdateBetaPow<<<1, 1, 0, dev_ctx.stream()>>>( @@ -308,26 +351,48 @@ void MergedAdamKernel( int threads = 512; int blocks = (param[idx]->numel() + threads - 1) / threads; + const auto grad_type = grad[idx]->dtype(); if (beta1_pow[idx]->place() == CPUPlace() && beta2_pow[idx]->place() == CPUPlace()) { // Compute with betapow in REG - AdamKernelREG<<>>( - beta1_, - beta2_, - epsilon_, - *beta1_pow[idx]->data(), - *beta2_pow[idx]->data(), - moment1[idx]->data(), - dev_ctx.template Alloc(moment1_out[idx]), - moment2[idx]->data(), - dev_ctx.template Alloc(moment2_out[idx]), - learning_rate[idx]->data(), - grad[idx]->data(), - param[idx]->data(), - dev_ctx.template Alloc(param_out[idx]), - master_in_data, - master_out_data, - param[idx]->numel()); + if (grad_type == phi::DataType::FLOAT32) { + AdamKernelREG + <<>>( + beta1_, + beta2_, + epsilon_, + *beta1_pow[idx]->data(), + *beta2_pow[idx]->data(), + moment1[idx]->data(), + dev_ctx.template Alloc(moment1_out[idx]), + moment2[idx]->data(), + dev_ctx.template Alloc(moment2_out[idx]), + learning_rate[idx]->data(), + grad[idx]->data(), + param[idx]->data(), + dev_ctx.template Alloc(param_out[idx]), + master_in_data, + master_out_data, + param[idx]->numel()); + } else { + AdamKernelREG<<>>( + beta1_, + beta2_, + epsilon_, + *beta1_pow[idx]->data(), + *beta2_pow[idx]->data(), + moment1[idx]->data(), + dev_ctx.template Alloc(moment1_out[idx]), + moment2[idx]->data(), + dev_ctx.template Alloc(moment2_out[idx]), + learning_rate[idx]->data(), + grad[idx]->data(), + param[idx]->data(), + dev_ctx.template Alloc(param_out[idx]), + master_in_data, + master_out_data, + param[idx]->numel()); + } if (!use_global_beta_pow) { // Cpu update dev_ctx.template HostAlloc(beta1_pow_out[idx])[0] = @@ -336,23 +401,44 @@ void MergedAdamKernel( beta2_ * beta2_pow[idx]->data()[0]; } } else { - AdamKernelMEM<<>>( - beta1_, - beta2_, - epsilon_, - beta1_pow[idx]->data(), - beta2_pow[idx]->data(), - moment1[idx]->data(), - dev_ctx.template Alloc(moment1_out[idx]), - moment2[idx]->data(), - dev_ctx.template Alloc(moment2_out[idx]), - learning_rate[idx]->data(), - grad[idx]->data(), - param[idx]->data(), - dev_ctx.template Alloc(param_out[idx]), - master_in_data, - master_out_data, - param[idx]->numel()); + if (grad_type == phi::DataType::FLOAT32) { + AdamKernelMEM + <<>>( + beta1_, + beta2_, + epsilon_, + beta1_pow[idx]->data(), + beta2_pow[idx]->data(), + moment1[idx]->data(), + dev_ctx.template Alloc(moment1_out[idx]), + moment2[idx]->data(), + dev_ctx.template Alloc(moment2_out[idx]), + learning_rate[idx]->data(), + grad[idx]->data(), + param[idx]->data(), + dev_ctx.template Alloc(param_out[idx]), + master_in_data, + master_out_data, + param[idx]->numel()); + } else { + AdamKernelMEM<<>>( + beta1_, + beta2_, + epsilon_, + beta1_pow[idx]->data(), + beta2_pow[idx]->data(), + moment1[idx]->data(), + dev_ctx.template Alloc(moment1_out[idx]), + moment2[idx]->data(), + dev_ctx.template Alloc(moment2_out[idx]), + learning_rate[idx]->data(), + grad[idx]->data(), + param[idx]->data(), + dev_ctx.template Alloc(param_out[idx]), + master_in_data, + master_out_data, + param[idx]->numel()); + } if (!use_global_beta_pow) { // Update with gpu UpdateBetaPow<<<1, 1, 0, dev_ctx.stream()>>>( diff --git a/paddle/phi/kernels/impl/merged_momentum_impl.h b/paddle/phi/kernels/impl/merged_momentum_impl.h index 40364507e8b2c..cdf90cba70690 100644 --- a/paddle/phi/kernels/impl/merged_momentum_impl.h +++ b/paddle/phi/kernels/impl/merged_momentum_impl.h @@ -300,21 +300,40 @@ void MergedMomentumInnerCompute( } else if (ctx.GetPlace().GetType() == phi::AllocationType::GPU) { phi::funcs::ForRange for_range( static_cast(ctx), params[idx]->numel()); -#define PADDLE_LAUNCH_DENSE_MTMOMENTUM_KERNEL(__nesterov, __reg_type) \ - phi::DenseMomentumFunctor functor( \ - params[idx]->data(), \ - grads[idx]->data(), \ - velocitys[idx]->data(), \ - lr_temp->data(), \ - master_in_data, \ - static_cast(mu), \ - static_cast(rescale_grad), \ - params[idx]->numel(), \ - regularization_coeff, \ - params_out[idx]->data(), \ - velocitys_out[idx]->data(), \ - master_out_data); \ - for_range(functor); + const auto grad_type = grads[idx]->dtype(); +#define PADDLE_LAUNCH_DENSE_MTMOMENTUM_KERNEL(__nesterov, __reg_type) \ + if (grad_type == phi::DataType::FLOAT32) { \ + DenseMomentumFunctor functor( \ + params[idx]->data(), \ + grads[idx]->data(), \ + velocitys[idx]->data(), \ + lr_temp->data(), \ + master_in_data, \ + static_cast(mu), \ + static_cast(rescale_grad), \ + params[idx]->numel(), \ + regularization_coeff, \ + params_out[idx]->data(), \ + velocitys_out[idx]->data(), \ + master_out_data); \ + for_range(functor); \ + } else { \ + DenseMomentumFunctor functor( \ + params[idx]->data(), \ + grads[idx]->data(), \ + velocitys[idx]->data(), \ + lr_temp->data(), \ + master_in_data, \ + static_cast(mu), \ + static_cast(rescale_grad), \ + params[idx]->numel(), \ + regularization_coeff, \ + params_out[idx]->data(), \ + velocitys_out[idx]->data(), \ + master_out_data); \ + for_range(functor); \ + } + if (use_nesterov) { if (regularization_flag == phi::RegularizationType::kL2DECAY) { PADDLE_LAUNCH_DENSE_MTMOMENTUM_KERNEL( diff --git a/paddle/phi/kernels/impl/momentum_kernel_impl.h b/paddle/phi/kernels/impl/momentum_kernel_impl.h index 932cdfe57dade..17917b15951ff 100644 --- a/paddle/phi/kernels/impl/momentum_kernel_impl.h +++ b/paddle/phi/kernels/impl/momentum_kernel_impl.h @@ -104,6 +104,7 @@ class CPUDenseMomentumFunctor { }; template @@ -112,11 +113,11 @@ class DenseMomentumFunctor; // NOTE(dzh) for performance. // avoid if/else in inside kernel, implement GPU UseNesterov/NoNesterov as two // functor. -template -class DenseMomentumFunctor { +template +class DenseMomentumFunctor { private: const T* param_; - const T* grad_; + const TG* grad_; const MT* velocity_; const MultiPrecisionType* lr_; const MT* master_param_; @@ -130,7 +131,7 @@ class DenseMomentumFunctor { public: DenseMomentumFunctor(const T* param, - const T* grad, + const TG* grad, const MT* velocity, const MultiPrecisionType* learning_rate, const MT* master_param, @@ -176,11 +177,11 @@ class DenseMomentumFunctor { } }; -template -class DenseMomentumFunctor { +template +class DenseMomentumFunctor { private: const T* param_; - const T* grad_; + const TG* grad_; const MT* velocity_; const MultiPrecisionType* lr_; const MT* master_param_; @@ -194,7 +195,7 @@ class DenseMomentumFunctor { public: DenseMomentumFunctor(const T* param, - const T* grad, + const TG* grad, const MT* velocity, const MultiPrecisionType* learning_rate, const MT* master_param, @@ -459,21 +460,39 @@ void MomentumDenseImpl(const Context& ctx, velocity_out); } else if (ctx.GetPlace().GetType() == phi::AllocationType::GPU) { funcs::ForRange for_range(ctx, param.numel()); -#define PADDLE_LAUNCH_DENSE_MOMENTUM_KERNEL(__nesterov, __reg_type) \ - DenseMomentumFunctor functor( \ - param.data(), \ - grad.data(), \ - velocity.data(), \ - learning_rate.data>(), \ - master_in_data, \ - mu, \ - rescale_grad, \ - param.numel(), \ - regularization_coeff, \ - ctx.template Alloc(param_out), \ - ctx.template Alloc(velocity_out), \ - master_out_data); \ - for_range(functor); + const auto grad_type = grad.dtype(); +#define PADDLE_LAUNCH_DENSE_MOMENTUM_KERNEL(__nesterov, __reg_type) \ + if (grad_type == phi::DataType::FLOAT32) { \ + DenseMomentumFunctor functor( \ + param.data(), \ + grad.data(), \ + velocity.data(), \ + learning_rate.data>(), \ + master_in_data, \ + mu, \ + rescale_grad, \ + param.numel(), \ + regularization_coeff, \ + ctx.template Alloc(param_out), \ + ctx.template Alloc(velocity_out), \ + master_out_data); \ + for_range(functor); \ + } else { \ + DenseMomentumFunctor functor( \ + param.data(), \ + grad.data(), \ + velocity.data(), \ + learning_rate.data>(), \ + master_in_data, \ + mu, \ + rescale_grad, \ + param.numel(), \ + regularization_coeff, \ + ctx.template Alloc(param_out), \ + ctx.template Alloc(velocity_out), \ + master_out_data); \ + for_range(functor); \ + } if (use_nesterov) { if (regularization_flag == RegularizationType::kL2DECAY) { diff --git a/python/paddle/optimizer/optimizer.py b/python/paddle/optimizer/optimizer.py index 5792497a2a340..392e5310bdabc 100644 --- a/python/paddle/optimizer/optimizer.py +++ b/python/paddle/optimizer/optimizer.py @@ -1265,6 +1265,23 @@ def _create_regularization_of_grad(self, param, grad, regularization=None): ): return grad regularization_term = None + + # when master_grad is true in amp training, grad will be fp32, but param maybe fp16. + # we get master weight when master_grad is true to avoid type mismatch error. + def get_target_param(param, grad): + target_param = param + if param.dtype != grad.dtype: + find_master = ( + self._multi_precision + and self._is_dtype_fp16_or_bf16(param.dtype) + ) + if find_master and len(self._master_weights) != 0: + target_param = self._master_weights[param.name] + else: + target_param = param.astype(grad.dtype) + return target_param + + param = get_target_param(param, grad) if hasattr(param, 'regularizer') and param.regularizer is not None: # Add variable for regularization term in grad block regularization_term = param.regularizer(param, grad, grad.block) diff --git a/test/amp/test_amp_master_grad.py b/test/amp/test_amp_master_grad.py index 6b5aebf35771e..d94923f33f6bf 100644 --- a/test/amp/test_amp_master_grad.py +++ b/test/amp/test_amp_master_grad.py @@ -44,7 +44,7 @@ def check_results( # fp16 calls self.assertEqual(int(op_list['matmul_v2'].split(',')[0]), total_steps) self.assertEqual( - int(op_list['adamw_'].split(',')[0]), + int(op_list['adam_'].split(',')[0]), 2 * (total_steps / accumulate_batchs_num), ) self.assertEqual( @@ -52,14 +52,11 @@ def check_results( total_steps + total_steps * 2, ) - def run_dygraph(self, total_steps, accumulate_batchs_num): - model = SimpleNet(2, 4) - opt = paddle.optimizer.AdamW(parameters=model.parameters()) + def run_dygraph(self, total_steps, accumulate_batchs_num, model, optimizer): model, opt = paddle.amp.decorate( - model, optimizers=opt, level='O2', master_grad=True + model, optimizers=optimizer, level='O2', master_grad=True ) scaler = paddle.amp.GradScaler() - paddle.amp.debugging.enable_operator_stats_collection() for i in range(total_steps): x = np.random.random((2, 2)).astype('float32') @@ -81,16 +78,32 @@ def run_dygraph(self, total_steps, accumulate_batchs_num): op_list = paddle.fluid.core.get_low_precision_op_list() return fp32_grads, op_list - def test_master_grad(self): + def test_adam_master_grad(self): total_steps = 4 accumulate_batchs_num = 2 + model = SimpleNet(2, 4) + opt = paddle.optimizer.Adam(parameters=model.parameters()) fp32_grads, op_list = self.run_dygraph( - total_steps, accumulate_batchs_num + total_steps, accumulate_batchs_num, model, opt ) self.check_results( fp32_grads, op_list, total_steps, accumulate_batchs_num ) + def test_momentum_master_grad(self): + total_steps = 4 + accumulate_batchs_num = 1 + model = SimpleNet(2, 4) + L1Decay = paddle.regularizer.L1Decay(0.0001) + opt = paddle.optimizer.Momentum( + parameters=model.parameters(), weight_decay=L1Decay + ) + fp32_grads, op_list = self.run_dygraph( + total_steps, accumulate_batchs_num, model, opt + ) + for grad in fp32_grads: + self.assertEqual(grad.dtype, paddle.float32) + if __name__ == '__main__': unittest.main()