From 4e1fc952256a283262c00203be50cd814a48083c Mon Sep 17 00:00:00 2001 From: JamesLim-sy Date: Thu, 14 Oct 2021 13:24:11 +0000 Subject: [PATCH 01/12] first commit --- .../operators/optimizers/lars_momentum_op.cu | 351 +++++++++--------- 1 file changed, 167 insertions(+), 184 deletions(-) diff --git a/paddle/fluid/operators/optimizers/lars_momentum_op.cu b/paddle/fluid/operators/optimizers/lars_momentum_op.cu index b640e62221f77..361e6ce97a0ff 100644 --- a/paddle/fluid/operators/optimizers/lars_momentum_op.cu +++ b/paddle/fluid/operators/optimizers/lars_momentum_op.cu @@ -84,32 +84,20 @@ class LarsThreadConfig { template __device__ inline void VectorizeLarsUpdate( - const T* __restrict__ grad, const MT* __restrict__ param, - const MT* __restrict__ velocity, T* __restrict__ param_out, - MT* __restrict__ velocity_out, const MT mu, MT local_lr, + const T* __restrict__ grad, MT* __restrict__ param, + MT* __restrict__ velocity, const MT mu, MT local_lr, const MT lars_weight_decay, const MT rescale_grad, const int tid, - const int grid_stride, const int numel, - MT* __restrict__ master_param_out = nullptr) { + const int grid_stride, const int numel) { using VecType = paddle::platform::AlignedVector; using VecMType = paddle::platform::AlignedVector; int main = numel >> (VecSize >> 1); int tail_offset = main * VecSize; const VecType* __restrict__ grad_vec = reinterpret_cast(grad); - const VecMType* __restrict__ param_vec = - reinterpret_cast(param); - const VecMType* __restrict__ velocity_vec = - reinterpret_cast(velocity); - VecType* param_out_vec = reinterpret_cast(param_out); - VecMType* velocity_out_vec = reinterpret_cast(velocity_out); - - VecMType* master_param_out_vec; - if (IsAmp) { - master_param_out_vec = reinterpret_cast(master_param_out); - } + VecMType* __restrict__ velocity_vec = reinterpret_cast(velocity); + VecMType* __restrict__ param_vec = reinterpret_cast(param); for (int i = tid; i < main; i += grid_stride) { - VecType param_out_tmp; VecMType velocity_tmp, param_tmp; VecType grad_data = grad_vec[i]; VecMType param_data = param_vec[i]; @@ -121,13 +109,9 @@ __device__ inline void VectorizeLarsUpdate( Fma(velocity_data[j], mu, local_lr * Fma(lars_weight_decay, param_data[j], grad_val)); param_tmp[j] = param_data[j] - velocity_tmp[j]; - param_out_tmp[j] = static_cast(param_tmp[j]); - } - param_out_vec[i] = param_out_tmp; - velocity_out_vec[i] = velocity_tmp; - if (IsAmp) { - master_param_out_vec[i] = param_tmp; } + param_vec[i] = param_tmp; + velocity_vec[i] = velocity_tmp; } for (int i = tid + tail_offset; i < numel; i += grid_stride) { @@ -136,11 +120,8 @@ __device__ inline void VectorizeLarsUpdate( MT velocity_tmp = Fma(velocity[i], mu, local_lr * Fma(lars_weight_decay, param_val, grad_val)); MT param_tmp = param_val - velocity_tmp; - param_out[i] = static_cast(param_tmp); - velocity_out[i] = velocity_tmp; - if (IsAmp) { - master_param_out[i] = param_tmp; - } + param[i] = param_tmp; + velocity[i] = velocity_tmp; } } @@ -158,10 +139,10 @@ template __global__ void L2NormKernel( #endif const T* __restrict__ p_data, const T* __restrict__ g_data, - MT* __restrict__ p_buffer, MT* __restrict__ g_buffer, const int64_t numel, - const int repeat_times, const MT rescale_grad, const int thresh = 0, - MT* __restrict__ p_n = nullptr, MT* __restrict__ g_n = nullptr) { - __shared__ MT s_buffer[2]; + MT* __restrict__ p_buffer, MT* __restrict__ g_buffer, MT s_buffer[], + const int64_t numel, const int repeat_times, const MT rescale_grad, + const int thresh = 0, MT* __restrict__ p_n = nullptr, + MT* __restrict__ g_n = nullptr) { int tid = threadIdx.x + blockDim.x * blockIdx.x; int grid_stride = LARS_BLOCK_SIZE * gridDim.x; const MT rescale_pow = rescale_grad * rescale_grad; @@ -236,15 +217,13 @@ __global__ void L2NormKernel( template __forceinline__ __device__ void MomentumUpdate( - const T* __restrict__ param, const T* __restrict__ grad, - const MT* __restrict__ velocity, T* param_out, MT* velocity_out, - const MT* __restrict__ master_param, MT* __restrict__ master_param_out, - const MT* __restrict__ learning_rate, const MT mu, - const MT lars_weight_decay, const MT lars_coeff, const MT epsilon, - const MT rescale_grad, const MT param_norm, const MT grad_norm, - const int tid, const int grid_stride, const int64_t numel, - const bool is_amp) { - const MT lr = learning_rate[0]; + T* __restrict__ param, const T* __restrict__ grad, + MT* __restrict__ velocity, MT* __restrict__ master_param, + const MT* __restrict__ learn_rate, const MT mu, const MT lars_weight_decay, + const MT lars_coeff, const MT epsilon, const MT rescale_grad, + const MT param_norm, const MT grad_norm, const int tid, + const int grid_stride, const int64_t numel, const bool is_amp) { + const MT lr = learn_rate[0]; MT local_lr = lr; if (lars_weight_decay > static_cast(0)) { local_lr = lr * lars_coeff * param_norm / @@ -252,109 +231,129 @@ __forceinline__ __device__ void MomentumUpdate( } if (is_amp) { VectorizeLarsUpdate( - grad, master_param, velocity, param_out, velocity_out, mu, local_lr, - lars_weight_decay, rescale_grad, tid, grid_stride, numel, - master_param_out); + grad, master_param, velocity, mu, local_lr, lars_weight_decay, + rescale_grad, tid, grid_stride, numel); } else { if (std::is_same::value || std::is_same::value) { /* TODO(limingshu): pointer cast may damage memory accessing for fp16 */ VectorizeLarsUpdate( - grad, reinterpret_cast(param), velocity, param_out, - velocity_out, mu, local_lr, lars_weight_decay, rescale_grad, tid, - grid_stride, numel); + grad, reinterpret_cast(param), velocity, mu, local_lr, + lars_weight_decay, rescale_grad, tid, grid_stride, numel); } else { VectorizeLarsUpdate( - grad, reinterpret_cast(param), velocity, param_out, - velocity_out, mu, local_lr, lars_weight_decay, rescale_grad, tid, - grid_stride, numel); + grad, reinterpret_cast(param), velocity, mu, local_lr, + lars_weight_decay, rescale_grad, tid, grid_stride, numel); } } } #if CUDA_VERSION >= 11000 -template -struct LarsParamWarpper { - int64_t numel_arr[LARS_MAX_MERGED_OPS]; - int repeat_arr[LARS_MAX_MERGED_OPS]; - const T* __restrict__ g_arr[LARS_MAX_MERGED_OPS]; - const MT* __restrict__ lr_arr[LARS_MAX_MERGED_OPS]; - T* __restrict__ p_out_arr[LARS_MAX_MERGED_OPS]; - MT* __restrict__ v_out_arr[LARS_MAX_MERGED_OPS]; - MT* __restrict__ master_p_out_arr[LARS_MAX_MERGED_OPS]; - MT weight_decay_arr[LARS_MAX_MERGED_OPS]; + +template +struct MergedLarsMasterParam { + DEVICE inline MT* GetMasterParam(size_t) const { return nullptr; } + constexpr void SetMasterParam(size_t, MT*) {} }; -template -__global__ void MergedMomentumLarsKernel(LarsParamWarpper lars_warpper, - MT* __restrict__ p_buffer, - MT* __restrict__ g_buffer, - const int op_num, const MT mu, - const MT lars_coeff, const MT epsilon, - const MT rescale_grad, - const bool is_amp) { +template +struct MergedLarsMasterParam { + MT* __restrict__ master_params[kOpNum]; + + DEVICE inline MT* __restrict__ GetMasterParam(size_t idx) const { + return master_params[idx]; + } + void SetMasterParam(size_t idx, MT* p) { master_params[idx] = p; } +}; + +template ::value ? 85 : 90> +struct LarsParamWarpper : public MergedLarsMasterParam { + static constexpr int kNum = kOpNum; + + int64_t numel_arr[kOpNum]; + int repeat_arr[kOpNum]; + const T* __restrict__ g_arr[kOpNum]; + T* __restrict__ p_arr[kOpNum]; + MT* __restrict__ v_arr[kOpNum]; + MT weight_decay_arr[kOpNum]; +}; + +template +__global__ void MergedMomentumLarsKernel( + LarsWarpperType lars_warpper, MT* __restrict__ p_buffer, + MT* __restrict__ g_buffer, const MT* __restrict__ lr, const int op_num, + const MT mu, const MT lars_coeff, const MT epsilon, const MT rescale_grad, + const bool is_amp) { + __shared__ MT s_buffer[2]; int grid_stride = gridDim.x * LARS_BLOCK_SIZE; int tid = threadIdx.x + blockIdx.x * blockDim.x; const cooperative_groups::grid_group cg = cooperative_groups::this_grid(); - for (int i = 0; i < op_num; ++i) { + + for (int i = 0; i < lars_warpper.kNum; ++i) { + if (i > op_num) break; int numel = lars_warpper.numel_arr[i]; MT param_norm = static_cast(0); MT grad_norm = static_cast(0); - L2NormKernel(&cg, lars_warpper.p_out_arr[i], lars_warpper.g_arr[i], - p_buffer, g_buffer, numel, lars_warpper.repeat_arr[i], - rescale_grad, 0, ¶m_norm, &grad_norm); - MomentumUpdate( - lars_warpper.p_out_arr[i], lars_warpper.g_arr[i], - lars_warpper.v_out_arr[i], lars_warpper.p_out_arr[i], - lars_warpper.v_out_arr[i], lars_warpper.master_p_out_arr[i], - lars_warpper.master_p_out_arr[i], lars_warpper.lr_arr[i], mu, - lars_warpper.weight_decay_arr[i], lars_coeff, epsilon, rescale_grad, - param_norm, grad_norm, tid, grid_stride, numel, is_amp); + L2NormKernel(&cg, lars_warpper.p_arr[i], lars_warpper.g_arr[i], + p_buffer, g_buffer, s_buffer, numel, + lars_warpper.repeat_arr[i], rescale_grad, 0, + ¶m_norm, &grad_norm); + MomentumUpdate(lars_warpper.p_arr[i], lars_warpper.g_arr[i], + lars_warpper.v_arr[i], lars_warpper.GetMasterParam(i), + lr, mu, lars_warpper.weight_decay_arr[i], lars_coeff, + epsilon, rescale_grad, param_norm, grad_norm, tid, + grid_stride, numel, is_amp); } } #endif template __global__ void MomentumLarsKernel( - const T* __restrict__ param, const T* __restrict__ grad, - const MT* __restrict__ velocity, T* param_out, MT* velocity_out, - const MT* __restrict__ master_param, MT* __restrict__ master_param_out, - const MT* __restrict__ learning_rate, MT* __restrict__ p_buffer, + T* __restrict__ param, const T* __restrict__ grad, + MT* __restrict__ velocity, MT* __restrict__ master_param, + const MT* __restrict__ learn_rate, MT* __restrict__ p_buffer, MT* __restrict__ g_buffer, const MT mu, const MT lars_coeff, const MT lars_weight_decay, const MT epsilon, const MT rescale_grad, const int repeat_times, const int thresh, const int64_t numel, const bool is_amp) { + __shared__ MT s_buffer[2]; int tid = threadIdx.x + blockIdx.x * blockDim.x; int grid_stride = gridDim.x * LARS_BLOCK_SIZE; #if CUDA_VERSION >= 11000 const cooperative_groups::grid_group cg = cooperative_groups::this_grid(); MT param_norm = static_cast(0); MT grad_norm = static_cast(0); - L2NormKernel(&cg, param, grad, p_buffer, g_buffer, numel, repeat_times, - rescale_grad, gridDim.x, ¶m_norm, &grad_norm); + L2NormKernel(&cg, param, grad, p_buffer, g_buffer, s_buffer, numel, + repeat_times, rescale_grad, gridDim.x, ¶m_norm, + &grad_norm); #else const MT rescale_grad_pow = rescale_grad * rescale_grad; MT param_part_norm = threadIdx.x < thresh ? p_buffer[threadIdx.x] : 0; MT grad_part_norm = threadIdx.x < thresh ? g_buffer[threadIdx.x] : 0; + MT tmp0 = math::blockReduceSum(param_part_norm, FINAL_MASK); + MT tmp1 = math::blockReduceSum(param_part_norm, FINAL_MASK); + if (threadIdx.x == 0) { + s_buffer[0] = tmp0; + s_buffer[1] = tmp1; + } __syncthreads(); - MT param_norm = Sqrt(math::blockReduceSum(param_part_norm, FINAL_MASK)); - MT grad_norm = Sqrt(rescale_grad_pow * - math::blockReduceSum(grad_part_norm, FINAL_MASK)); + MT param_norm = Sqrt(s_buffer[0]); + MT grad_norm = Sqrt(rescale_pow * s_buffer[1]); #endif - MomentumUpdate(param, grad, velocity, param_out, velocity_out, - master_param, master_param_out, learning_rate, mu, + MomentumUpdate(param, grad, velocity, master_param, learn_rate, mu, lars_weight_decay, lars_coeff, epsilon, rescale_grad, param_norm, grad_norm, tid, grid_stride, numel, is_amp); } template inline void SeparatedLarsMomentumOpCUDAKernel( - const platform::CUDADeviceContext& cuda_ctx, const T* param_data, - T* param_out_data, const MT* velocity_data, MT* velocity_out_data, - const T* grad_data, const MT* lr, MT* p_buffer, MT* g_buffer, const MT mu, - const MT lars_coeff, const MT weight_decay, const MT epsilon, - const MT rescale_grad, const int64_t numel, const MT* master_param_data, - MT* master_out_data, const bool is_amp) { + const platform::CUDADeviceContext& cuda_ctx, T* param_data, + MT* velocity_data, const T* grad_data, const MT* lr, MT* p_buffer, + MT* g_buffer, const MT mu, const MT lars_coeff, const MT weight_decay, + const MT epsilon, const MT rescale_grad, const int64_t numel, + const MT* master_param_data, const bool is_amp) { LarsThreadConfig lars_thread_config(numel); L2NormKernel<<>>( @@ -363,9 +362,8 @@ inline void SeparatedLarsMomentumOpCUDAKernel( MomentumLarsKernel<<>>( - param_data, grad_data, velocity_data, param_out_data, velocity_out_data, - master_param_data, master_out_data, lr, p_buffer, g_buffer, mu, - lars_coeff, weight_decay, epsilon, rescale_grad, 0, + param_data, grad_data, velocity_data, master_param_data, lr, p_buffer, + g_buffer, mu, lars_coeff, weight_decay, epsilon, rescale_grad, 0, lars_thread_config.grid_for_norm, numel, is_amp); } @@ -376,7 +374,6 @@ class LarsMomentumOpCUDAKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { int num_blocks_per_sm = 0; - bool multi_precision = ctx.Attr("multi_precision"); auto& cuda_ctx = ctx.template device_context(); int sm_num = cuda_ctx.GetSMCount(); framework::Tensor tmp_buffer_t = @@ -385,6 +382,7 @@ class LarsMomentumOpCUDAKernel : public framework::OpKernel { auto* p_buffer = tmp_buffer_t.mutable_data(ctx.GetPlace()); auto* g_buffer = p_buffer + LARS_BLOCK_SIZE; + bool multi_precision = ctx.Attr("multi_precision"); MT mu = static_cast(ctx.Attr("mu")); MT lars_coeff = static_cast(ctx.Attr("lars_coeff")); MT epsilon = static_cast(ctx.Attr("epsilon")); @@ -400,18 +398,28 @@ class LarsMomentumOpCUDAKernel : public framework::OpKernel { auto master_param = ctx.MultiInput("MasterParam"); auto master_param_out = ctx.MultiOutput("MasterParamOut"); + auto* lr = learning_rate[0]->data(); int op_num = grad.size(); + for (size_t i = 0; i < op_num; ++i) { + PADDLE_ENFORCE_EQ( + param[i], param_out[i], + platform::errors::InvalidArgument( + "Input(Param) and Output(ParamOut) must be the same Tensors.")); + PADDLE_ENFORCE_EQ(velocity[i], velocity_out[i], + platform::errors::InvalidArgument( + "Input(Velocity) and Output(VelocityOut) must be " + "the same Tensors.")); + if (multi_precision) { + PADDLE_ENFORCE_EQ(master_param[i], master_param_out[i], + platform::errors::InvalidArgument( + "Input(MasterParam) and Output(MasterParamOut) " + "must be the same Tensors.")); + } + } #if CUDA_VERSION >= 11000 if (op_num > 1) { LarsParamWarpper lars_warpper; - PADDLE_ENFORCE_LT( - op_num, LARS_MAX_MERGED_OPS, - platform::errors::InvalidArgument( - "The maximum number of merged-ops supported is (%d), but" - "lars op required for trainning this model is (%d)\n", - LARS_MAX_MERGED_OPS, op_num)); - /* Implementation of lars optimizer consists of following two steps: 1. Figure out the L2 norm statistic result of grad data and param data. 2. Update param and velocity with usage of L2 norm statistic result. @@ -420,76 +428,60 @@ class LarsMomentumOpCUDAKernel : public framework::OpKernel { - The thread quantity shall less than pyhsical SM limited threads - Launche as thread-block can synchronizlly execute. */ cudaOccupancyMaxActiveBlocksPerMultiprocessor( - &num_blocks_per_sm, MergedMomentumLarsKernel, LARS_BLOCK_SIZE, - sizeof(MT) << 1); - - size_t total_numel = 0; - for (int i = 0; i < op_num; ++i) { - size_t temp_numel = param[i]->numel(); - total_numel += temp_numel; - lars_warpper.numel_arr[i] = temp_numel; - lars_warpper.g_arr[i] = grad[i]->data(); - lars_warpper.lr_arr[i] = learning_rate[i]->data(); - lars_warpper.p_out_arr[i] = - param_out[i]->mutable_data(ctx.GetPlace()); - lars_warpper.v_out_arr[i] = - velocity_out[i]->mutable_data(ctx.GetPlace()); - lars_warpper.weight_decay_arr[i] = static_cast(weight_decay_arr[i]); - PADDLE_ENFORCE_EQ( - param[i]->data(), lars_warpper.p_out_arr[i], - platform::errors::InvalidArgument( - "Input(Param) and Output(ParamOut) must be the same Tensors.")); - PADDLE_ENFORCE_EQ(velocity[i]->data(), lars_warpper.v_out_arr[i], - platform::errors::InvalidArgument( - "Input(Velocity) and Output(VelocityOut) must be " - "the same Tensors.")); - } - int64_t avg_numel = total_numel / op_num; - LarsThreadConfig lars_thread_config(avg_numel, sm_num, - num_blocks_per_sm); - for (int i = 0; i < op_num; ++i) { - lars_warpper.repeat_arr[i] = - lars_thread_config.GetRepeatTimes(lars_warpper.numel_arr[i]); - } - if (multi_precision) { - for (int i = 0; i < op_num; ++i) { - lars_warpper.master_p_out_arr[i] = - master_param_out[i]->mutable_data(ctx.GetPlace()); - PADDLE_ENFORCE_EQ(master_param[i]->data(), - lars_warpper.master_p_out_arr[i], - platform::errors::InvalidArgument( - "Input(MasterParam) and Output(MasterParamOut) " - "must be the same Tensors.")); + &num_blocks_per_sm, + MergedMomentumLarsKernel, + LARS_BLOCK_SIZE, sizeof(MT) << 1); + + int merge_times = (op_num + lars_warpper.kNum - 1) / lars_warpper.kNum; + for (int j = 0; j < merge_times; ++j) { + size_t total_numel = 0; + int start_idx = j * lars_warpper.kNum; + int loop_num = std::min(lars_warpper.kNum, op_num - start_idx); + + for (int i = 0; i < loop_num; ++i) { + size_t temp_numel = param[start_idx + i]->numel(); + total_numel += temp_numel; + lars_warpper.numel_arr[i] = temp_numel; + lars_warpper.g_arr[i] = grad[start_idx + i]->data(); + lars_warpper.p_arr[i] = param_out[start_idx + i]->data(); + lars_warpper.v_arr[i] = velocity_out[start_idx + i]->data(); + lars_warpper.weight_decay_arr[i] = + static_cast(weight_decay_arr[start_idx + i]); } + int64_t avg_numel = total_numel / loop_num; + LarsThreadConfig lars_thread_config(avg_numel, sm_num, + num_blocks_per_sm); + for (int i = 0; i < loop_num; ++i) { + lars_warpper.repeat_arr[i] = + lars_thread_config.GetRepeatTimes(lars_warpper.numel_arr[i]); + if (multi_precision) { + lars_warpper.SetMasterParam( + i, master_param_out[i]->mutable_data(ctx.GetPlace())); + } + } + void* cuda_param[] = {reinterpret_cast(&lars_warpper), + reinterpret_cast(&p_buffer), + reinterpret_cast(&g_buffer), + reinterpret_cast(&lr), + reinterpret_cast(&loop_num), + reinterpret_cast(&mu), + reinterpret_cast(&lars_coeff), + reinterpret_cast(&epsilon), + reinterpret_cast(&rescale_grad), + reinterpret_cast(&multi_precision)}; + cudaLaunchCooperativeKernel( + reinterpret_cast( + MergedMomentumLarsKernel), + lars_thread_config.grid_for_lars, LARS_BLOCK_SIZE, cuda_param, 0, + cuda_ctx.stream()); } - void* cuda_param[] = {reinterpret_cast(&lars_warpper), - reinterpret_cast(&p_buffer), - reinterpret_cast(&g_buffer), - reinterpret_cast(&op_num), - reinterpret_cast(&mu), - reinterpret_cast(&lars_coeff), - reinterpret_cast(&epsilon), - reinterpret_cast(&rescale_grad), - reinterpret_cast(&multi_precision)}; - // Lanuch all sm theads, and thead of each block synchronizedly cooperate. - cudaLaunchCooperativeKernel( - reinterpret_cast(MergedMomentumLarsKernel), - lars_thread_config.grid_for_lars, LARS_BLOCK_SIZE, cuda_param, 0, - cuda_ctx.stream()); } else { auto* param_data = param[0]->data(); auto* grad_data = grad[0]->data(); - auto* velocity_data = velocity[0]->data(); + auto* velocity_data = velocity_out[0]->data(); auto* lr = learning_rate[0]->data(); - auto* param_out_data = param_out[0]->mutable_data(ctx.GetPlace()); - auto* velocity_out_data = - velocity_out[0]->mutable_data(ctx.GetPlace()); const MT* master_param_data = - multi_precision ? master_param[0]->data() : nullptr; - MT* master_param_out_data = - multi_precision - ? master_param_out[0]->mutable_data(ctx.GetPlace()) - : nullptr; + multi_precision ? master_param_out[0]->data() : nullptr; int64_t numel = param[0]->numel(); MT lars_weight_decay = weight_decay_arr[0]; @@ -501,14 +493,12 @@ class LarsMomentumOpCUDAKernel : public framework::OpKernel { num_blocks_per_sm); int repeat_times = lars_thread_config.GetRepeatTimes(numel); int thresh = 0; + void* cuda_param[] = { reinterpret_cast(¶m_data), reinterpret_cast(&grad_data), reinterpret_cast(&velocity_data), - reinterpret_cast(¶m_out_data), - reinterpret_cast(&velocity_out_data), reinterpret_cast(&master_param_data), - reinterpret_cast(&master_param_out_data), reinterpret_cast(&lr), reinterpret_cast(&p_buffer), reinterpret_cast(&g_buffer), @@ -530,19 +520,12 @@ class LarsMomentumOpCUDAKernel : public framework::OpKernel { #else for (int i = 0; i < op_num; ++i) { const MT* master_param_data = - multi_precision ? master_param[i]->data() : nullptr; - MT* master_param_out_data = - multi_precision - ? master_param_out[i]->mutable_data(ctx.GetPlace()) - : nullptr; + multi_precision ? master_param_out[i]->data() : nullptr; SeparatedLarsMomentumOpCUDAKernel( - cuda_ctx, param[i]->data(), - param_out[i]->mutable_data(ctx.GetPlace()), - velocity[i]->data(), - velocity_out[i]->mutable_data(ctx.GetPlace()), grad[i]->data(), - learning_rate[i]->data(), p_buffer, g_buffer, mu, lars_coeff, + cuda_ctx, param_out[i]->data(), velocity_out[i]->data(), + grad[i]->data(), lr, p_buffer, g_buffer, mu, lars_coeff, weight_decay_arr[i], epsilon, rescale_grad, param[i]->numel(), - master_param_data, master_param_out_data, multi_precision); + master_param_data, multi_precision); } #endif } From 2ffbbba4221e0930ba274ecb9608d8716851c88c Mon Sep 17 00:00:00 2001 From: JamesLim-sy Date: Thu, 14 Oct 2021 13:49:25 +0000 Subject: [PATCH 02/12] revert change about argmuent of l2_norm --- .../operators/optimizers/lars_momentum_op.cu | 21 ++++++++----------- 1 file changed, 9 insertions(+), 12 deletions(-) diff --git a/paddle/fluid/operators/optimizers/lars_momentum_op.cu b/paddle/fluid/operators/optimizers/lars_momentum_op.cu index 361e6ce97a0ff..f6077d0e1086a 100644 --- a/paddle/fluid/operators/optimizers/lars_momentum_op.cu +++ b/paddle/fluid/operators/optimizers/lars_momentum_op.cu @@ -139,10 +139,10 @@ template __global__ void L2NormKernel( #endif const T* __restrict__ p_data, const T* __restrict__ g_data, - MT* __restrict__ p_buffer, MT* __restrict__ g_buffer, MT s_buffer[], - const int64_t numel, const int repeat_times, const MT rescale_grad, - const int thresh = 0, MT* __restrict__ p_n = nullptr, - MT* __restrict__ g_n = nullptr) { + MT* __restrict__ p_buffer, MT* __restrict__ g_buffer, const int64_t numel, + const int repeat_times, const MT rescale_grad, const int thresh = 0, + MT* __restrict__ p_n = nullptr, MT* __restrict__ g_n = nullptr) { + __shared__ MT s_buffer[2]; int tid = threadIdx.x + blockDim.x * blockIdx.x; int grid_stride = LARS_BLOCK_SIZE * gridDim.x; const MT rescale_pow = rescale_grad * rescale_grad; @@ -286,7 +286,6 @@ __global__ void MergedMomentumLarsKernel( MT* __restrict__ g_buffer, const MT* __restrict__ lr, const int op_num, const MT mu, const MT lars_coeff, const MT epsilon, const MT rescale_grad, const bool is_amp) { - __shared__ MT s_buffer[2]; int grid_stride = gridDim.x * LARS_BLOCK_SIZE; int tid = threadIdx.x + blockIdx.x * blockDim.x; const cooperative_groups::grid_group cg = cooperative_groups::this_grid(); @@ -297,9 +296,8 @@ __global__ void MergedMomentumLarsKernel( MT param_norm = static_cast(0); MT grad_norm = static_cast(0); L2NormKernel(&cg, lars_warpper.p_arr[i], lars_warpper.g_arr[i], - p_buffer, g_buffer, s_buffer, numel, - lars_warpper.repeat_arr[i], rescale_grad, 0, - ¶m_norm, &grad_norm); + p_buffer, g_buffer, numel, lars_warpper.repeat_arr[i], + rescale_grad, 0, ¶m_norm, &grad_norm); MomentumUpdate(lars_warpper.p_arr[i], lars_warpper.g_arr[i], lars_warpper.v_arr[i], lars_warpper.GetMasterParam(i), lr, mu, lars_warpper.weight_decay_arr[i], lars_coeff, @@ -318,17 +316,16 @@ __global__ void MomentumLarsKernel( const MT lars_weight_decay, const MT epsilon, const MT rescale_grad, const int repeat_times, const int thresh, const int64_t numel, const bool is_amp) { - __shared__ MT s_buffer[2]; int tid = threadIdx.x + blockIdx.x * blockDim.x; int grid_stride = gridDim.x * LARS_BLOCK_SIZE; #if CUDA_VERSION >= 11000 const cooperative_groups::grid_group cg = cooperative_groups::this_grid(); MT param_norm = static_cast(0); MT grad_norm = static_cast(0); - L2NormKernel(&cg, param, grad, p_buffer, g_buffer, s_buffer, numel, - repeat_times, rescale_grad, gridDim.x, ¶m_norm, - &grad_norm); + L2NormKernel(&cg, param, grad, p_buffer, g_buffer, numel, repeat_times, + rescale_grad, gridDim.x, ¶m_norm, &grad_norm); #else + __shared__ MT s_buffer[2]; const MT rescale_grad_pow = rescale_grad * rescale_grad; MT param_part_norm = threadIdx.x < thresh ? p_buffer[threadIdx.x] : 0; MT grad_part_norm = threadIdx.x < thresh ? g_buffer[threadIdx.x] : 0; From 3da1b1f574d3098986b37cd6251991cd6db3f506 Mon Sep 17 00:00:00 2001 From: JamesLim-sy Date: Thu, 14 Oct 2021 13:56:17 +0000 Subject: [PATCH 03/12] get param addr from param_out tensor --- .../operators/optimizers/lars_momentum_op.cu | 28 ++++++++++--------- 1 file changed, 15 insertions(+), 13 deletions(-) diff --git a/paddle/fluid/operators/optimizers/lars_momentum_op.cu b/paddle/fluid/operators/optimizers/lars_momentum_op.cu index f6077d0e1086a..381d0d8df5e65 100644 --- a/paddle/fluid/operators/optimizers/lars_momentum_op.cu +++ b/paddle/fluid/operators/optimizers/lars_momentum_op.cu @@ -28,8 +28,6 @@ limitations under the License. */ #define LARS_BLOCK_SIZE 512 #endif -#define LARS_MAX_MERGED_OPS 60 - namespace paddle { namespace operators { @@ -314,7 +312,7 @@ __global__ void MomentumLarsKernel( const MT* __restrict__ learn_rate, MT* __restrict__ p_buffer, MT* __restrict__ g_buffer, const MT mu, const MT lars_coeff, const MT lars_weight_decay, const MT epsilon, const MT rescale_grad, - const int repeat_times, const int thresh, const int64_t numel, + const int repeat_times, int thresh, const int64_t numel, const bool is_amp) { int tid = threadIdx.x + blockIdx.x * blockDim.x; int grid_stride = gridDim.x * LARS_BLOCK_SIZE; @@ -326,7 +324,7 @@ __global__ void MomentumLarsKernel( rescale_grad, gridDim.x, ¶m_norm, &grad_norm); #else __shared__ MT s_buffer[2]; - const MT rescale_grad_pow = rescale_grad * rescale_grad; + const MT rescale_pow = rescale_grad * rescale_grad; MT param_part_norm = threadIdx.x < thresh ? p_buffer[threadIdx.x] : 0; MT grad_part_norm = threadIdx.x < thresh ? g_buffer[threadIdx.x] : 0; MT tmp0 = math::blockReduceSum(param_part_norm, FINAL_MASK); @@ -347,10 +345,10 @@ __global__ void MomentumLarsKernel( template inline void SeparatedLarsMomentumOpCUDAKernel( const platform::CUDADeviceContext& cuda_ctx, T* param_data, - MT* velocity_data, const T* grad_data, const MT* lr, MT* p_buffer, - MT* g_buffer, const MT mu, const MT lars_coeff, const MT weight_decay, - const MT epsilon, const MT rescale_grad, const int64_t numel, - const MT* master_param_data, const bool is_amp) { + MT* velocity_data, const T* grad_data, MT* master_param_data, const MT* lr, + MT* p_buffer, MT* g_buffer, const MT mu, const MT lars_coeff, + const MT weight_decay, const MT epsilon, const MT rescale_grad, + const int64_t numel, const bool is_amp) { LarsThreadConfig lars_thread_config(numel); L2NormKernel<<>>( @@ -429,6 +427,8 @@ class LarsMomentumOpCUDAKernel : public framework::OpKernel { MergedMomentumLarsKernel, LARS_BLOCK_SIZE, sizeof(MT) << 1); + VLOG(10) << "Num of ops merged in lars_warpper is " << lars_warpper.kNum; + int merge_times = (op_num + lars_warpper.kNum - 1) / lars_warpper.kNum; for (int j = 0; j < merge_times; ++j) { size_t total_numel = 0; @@ -471,10 +471,12 @@ class LarsMomentumOpCUDAKernel : public framework::OpKernel { MergedMomentumLarsKernel), lars_thread_config.grid_for_lars, LARS_BLOCK_SIZE, cuda_param, 0, cuda_ctx.stream()); + + VLOG(10) << "Lanuched ops number is " << loop_num; } } else { - auto* param_data = param[0]->data(); auto* grad_data = grad[0]->data(); + auto* param_data = param_out[0]->data(); auto* velocity_data = velocity_out[0]->data(); auto* lr = learning_rate[0]->data(); const MT* master_param_data = @@ -516,13 +518,13 @@ class LarsMomentumOpCUDAKernel : public framework::OpKernel { } #else for (int i = 0; i < op_num; ++i) { - const MT* master_param_data = + MT* master_param_data = multi_precision ? master_param_out[i]->data() : nullptr; SeparatedLarsMomentumOpCUDAKernel( cuda_ctx, param_out[i]->data(), velocity_out[i]->data(), - grad[i]->data(), lr, p_buffer, g_buffer, mu, lars_coeff, - weight_decay_arr[i], epsilon, rescale_grad, param[i]->numel(), - master_param_data, multi_precision); + grad[i]->data(), master_param_data, lr, p_buffer, g_buffer, mu, + lars_coeff, weight_decay_arr[i], epsilon, rescale_grad, + param[i]->numel(), multi_precision); } #endif } From fb89aefcc78e628b6caf986d5350d08fd09f6046 Mon Sep 17 00:00:00 2001 From: JamesLim-sy Date: Thu, 14 Oct 2021 17:03:14 +0000 Subject: [PATCH 04/12] shrink the struct size --- paddle/fluid/operators/optimizers/lars_momentum_op.cu | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/paddle/fluid/operators/optimizers/lars_momentum_op.cu b/paddle/fluid/operators/optimizers/lars_momentum_op.cu index 381d0d8df5e65..1f7466f121683 100644 --- a/paddle/fluid/operators/optimizers/lars_momentum_op.cu +++ b/paddle/fluid/operators/optimizers/lars_momentum_op.cu @@ -247,7 +247,6 @@ __forceinline__ __device__ void MomentumUpdate( } #if CUDA_VERSION >= 11000 - template struct MergedLarsMasterParam { DEVICE inline MT* GetMasterParam(size_t) const { return nullptr; } @@ -270,7 +269,7 @@ template { static constexpr int kNum = kOpNum; - int64_t numel_arr[kOpNum]; + int numel_arr[kOpNum]; int repeat_arr[kOpNum]; const T* __restrict__ g_arr[kOpNum]; T* __restrict__ p_arr[kOpNum]; From eec1fc6239e89fad6d2d2c2692a294c9e2c97a4a Mon Sep 17 00:00:00 2001 From: JamesLim-sy Date: Fri, 15 Oct 2021 03:51:16 +0000 Subject: [PATCH 05/12] add test file --- .../operators/optimizers/lars_momentum_op.cc | 5 + .../operators/optimizers/lars_momentum_op.cu | 55 ++--- python/paddle/fluid/optimizer.py | 219 ++++++++++++++---- 3 files changed, 209 insertions(+), 70 deletions(-) diff --git a/paddle/fluid/operators/optimizers/lars_momentum_op.cc b/paddle/fluid/operators/optimizers/lars_momentum_op.cc index 65be35843bdf9..960d43fc0202a 100644 --- a/paddle/fluid/operators/optimizers/lars_momentum_op.cc +++ b/paddle/fluid/operators/optimizers/lars_momentum_op.cc @@ -177,6 +177,11 @@ class LarsMomentumOpMaker : public framework::OpProtoAndCheckerMaker { "(float, default 1.0) Multiply the gradient with `rescale_grad`" "before updating. Often choose to be `1.0/batch_size`.") .SetDefault(1.0f); + AddAttr( + "merge_option", + "(float, default 1.0) Multiply the gradient with `rescale_grad`" + "before updating. Often choose to be `1.0/batch_size`.") + .SetDefault(false); AddComment(R"DOC( Lars Momentum Optimizer. diff --git a/paddle/fluid/operators/optimizers/lars_momentum_op.cu b/paddle/fluid/operators/optimizers/lars_momentum_op.cu index 1f7466f121683..077c45b3eafb5 100644 --- a/paddle/fluid/operators/optimizers/lars_momentum_op.cu +++ b/paddle/fluid/operators/optimizers/lars_momentum_op.cu @@ -82,18 +82,17 @@ class LarsThreadConfig { template __device__ inline void VectorizeLarsUpdate( - const T* __restrict__ grad, MT* __restrict__ param, - MT* __restrict__ velocity, const MT mu, MT local_lr, - const MT lars_weight_decay, const MT rescale_grad, const int tid, - const int grid_stride, const int numel) { + const T* __restrict__ grad, MT* param, MT* velocity, const MT mu, + MT local_lr, const MT lars_weight_decay, const MT rescale_grad, + const int tid, const int grid_stride, const int numel) { using VecType = paddle::platform::AlignedVector; using VecMType = paddle::platform::AlignedVector; int main = numel >> (VecSize >> 1); int tail_offset = main * VecSize; const VecType* __restrict__ grad_vec = reinterpret_cast(grad); - VecMType* __restrict__ velocity_vec = reinterpret_cast(velocity); - VecMType* __restrict__ param_vec = reinterpret_cast(param); + VecMType* velocity_vec = reinterpret_cast(velocity); + VecMType* param_vec = reinterpret_cast(param); for (int i = tid; i < main; i += grid_stride) { VecMType velocity_tmp, param_tmp; @@ -136,10 +135,9 @@ __forceinline__ __device__ void L2NormKernel( template __global__ void L2NormKernel( #endif - const T* __restrict__ p_data, const T* __restrict__ g_data, - MT* __restrict__ p_buffer, MT* __restrict__ g_buffer, const int64_t numel, - const int repeat_times, const MT rescale_grad, const int thresh = 0, - MT* __restrict__ p_n = nullptr, MT* __restrict__ g_n = nullptr) { + T* p_data, const T* __restrict__ g_data, MT* p_buffer, MT* g_buffer, + const int64_t numel, const int repeat_times, const MT rescale_grad, + const int thresh = 0, MT* p_n = nullptr, MT* g_n = nullptr) { __shared__ MT s_buffer[2]; int tid = threadIdx.x + blockDim.x * blockIdx.x; int grid_stride = LARS_BLOCK_SIZE * gridDim.x; @@ -215,8 +213,7 @@ __global__ void L2NormKernel( template __forceinline__ __device__ void MomentumUpdate( - T* __restrict__ param, const T* __restrict__ grad, - MT* __restrict__ velocity, MT* __restrict__ master_param, + T* param, const T* __restrict__ grad, MT* velocity, MT* master_param, const MT* __restrict__ learn_rate, const MT mu, const MT lars_weight_decay, const MT lars_coeff, const MT epsilon, const MT rescale_grad, const MT param_norm, const MT grad_norm, const int tid, @@ -255,9 +252,9 @@ struct MergedLarsMasterParam { template struct MergedLarsMasterParam { - MT* __restrict__ master_params[kOpNum]; + MT* master_params[kOpNum]; - DEVICE inline MT* __restrict__ GetMasterParam(size_t idx) const { + DEVICE inline MT* GetMasterParam(size_t idx) const { return master_params[idx]; } void SetMasterParam(size_t idx, MT* p) { master_params[idx] = p; } @@ -265,24 +262,23 @@ struct MergedLarsMasterParam { template ::value ? 85 : 90> + std::is_same::value ? 80 : 90> struct LarsParamWarpper : public MergedLarsMasterParam { static constexpr int kNum = kOpNum; int numel_arr[kOpNum]; int repeat_arr[kOpNum]; const T* __restrict__ g_arr[kOpNum]; - T* __restrict__ p_arr[kOpNum]; - MT* __restrict__ v_arr[kOpNum]; + T* p_arr[kOpNum]; + MT* v_arr[kOpNum]; MT weight_decay_arr[kOpNum]; }; template __global__ void MergedMomentumLarsKernel( - LarsWarpperType lars_warpper, MT* __restrict__ p_buffer, - MT* __restrict__ g_buffer, const MT* __restrict__ lr, const int op_num, - const MT mu, const MT lars_coeff, const MT epsilon, const MT rescale_grad, - const bool is_amp) { + LarsWarpperType lars_warpper, MT* p_buffer, MT* g_buffer, const MT* lr, + const int op_num, const MT mu, const MT lars_coeff, const MT epsilon, + const MT rescale_grad, const bool is_amp) { int grid_stride = gridDim.x * LARS_BLOCK_SIZE; int tid = threadIdx.x + blockIdx.x * blockDim.x; const cooperative_groups::grid_group cg = cooperative_groups::this_grid(); @@ -306,13 +302,11 @@ __global__ void MergedMomentumLarsKernel( template __global__ void MomentumLarsKernel( - T* __restrict__ param, const T* __restrict__ grad, - MT* __restrict__ velocity, MT* __restrict__ master_param, - const MT* __restrict__ learn_rate, MT* __restrict__ p_buffer, - MT* __restrict__ g_buffer, const MT mu, const MT lars_coeff, - const MT lars_weight_decay, const MT epsilon, const MT rescale_grad, - const int repeat_times, int thresh, const int64_t numel, - const bool is_amp) { + T* param, const T* __restrict__ grad, MT* velocity, MT* master_param, + const MT* __restrict__ learn_rate, MT* p_buffer, MT* g_buffer, const MT mu, + const MT lars_coeff, const MT lars_weight_decay, const MT epsilon, + const MT rescale_grad, const int repeat_times, int thresh, + const int64_t numel, const bool is_amp) { int tid = threadIdx.x + blockIdx.x * blockDim.x; int grid_stride = gridDim.x * LARS_BLOCK_SIZE; #if CUDA_VERSION >= 11000 @@ -375,12 +369,12 @@ class LarsMomentumOpCUDAKernel : public framework::OpKernel { {LARS_BLOCK_SIZE << 1}, cuda_ctx); auto* p_buffer = tmp_buffer_t.mutable_data(ctx.GetPlace()); auto* g_buffer = p_buffer + LARS_BLOCK_SIZE; - bool multi_precision = ctx.Attr("multi_precision"); MT mu = static_cast(ctx.Attr("mu")); MT lars_coeff = static_cast(ctx.Attr("lars_coeff")); MT epsilon = static_cast(ctx.Attr("epsilon")); MT rescale_grad = static_cast(ctx.Attr("rescale_grad")); + bool merge_option = ctx.Attr("merge_option"); auto weight_decay_arr = ctx.Attr>("lars_weight_decay"); auto grad = ctx.MultiInput("Grad"); @@ -412,7 +406,8 @@ class LarsMomentumOpCUDAKernel : public framework::OpKernel { } } #if CUDA_VERSION >= 11000 - if (op_num > 1) { + // if (op_num > 1) { + if (merge_option) { LarsParamWarpper lars_warpper; /* Implementation of lars optimizer consists of following two steps: 1. Figure out the L2 norm statistic result of grad data and param data. diff --git a/python/paddle/fluid/optimizer.py b/python/paddle/fluid/optimizer.py index 228ba08499808..a48559e71a706 100755 --- a/python/paddle/fluid/optimizer.py +++ b/python/paddle/fluid/optimizer.py @@ -1961,6 +1961,7 @@ def __init__(self, exclude_from_weight_decay=None, epsilon=0, multi_precision=False, + merge_option=False, rescale_grad=1.0): assert learning_rate is not None assert momentum is not None @@ -2045,54 +2046,192 @@ def _create_accumulators(self, block, parameters): self._add_accumulator(self._velocity_acc_str, p) def _append_optimize_op(self, block, param_and_grad): - assert isinstance(block, framework.Block) - _lars_weight_decay = self._lars_weight_decay - param_name = param_and_grad[0].name - if len(self._exclude_from_weight_decay) > 0: - for name in self._exclude_from_weight_decay: - if name in param_name: - _lars_weight_decay = 0.0 - break + if (not self._merge_option): + assert isinstance(block, framework.Block) + _lars_weight_decay = self._lars_weight_decay + param_name = param_and_grad[0].name + if len(self._exclude_from_weight_decay) > 0: + for name in self._exclude_from_weight_decay: + if name in param_name: + _lars_weight_decay = 0.0 + break - velocity_acc = self._get_accumulator(self._velocity_acc_str, - param_and_grad[0]) - lr = self._create_param_lr(param_and_grad) + velocity_acc = self._get_accumulator(self._velocity_acc_str, + param_and_grad[0]) + lr = self._create_param_lr(param_and_grad) + + find_master = self._multi_precision and param_and_grad[ + 0].dtype == core.VarDesc.VarType.FP16 + master_weight = (self._master_weights[param_and_grad[0].name] + if find_master else None) + + attrs = { + "mu": self._momentum, + "lars_coeff": self._lars_coeff, + "lars_weight_decay": [_lars_weight_decay], + "multi_precision": find_master, + "rescale_grad": self._rescale_grad + } - find_master = self._multi_precision and param_and_grad[ - 0].dtype == core.VarDesc.VarType.FP16 - master_weight = (self._master_weights[param_and_grad[0].name] - if find_master else None) + inputs = { + "Param": param_and_grad[0], + "Grad": param_and_grad[1], + "Velocity": velocity_acc, + "LearningRate": lr + } - attrs = { - "mu": self._momentum, - "lars_coeff": self._lars_coeff, - "lars_weight_decay": [_lars_weight_decay], - "multi_precision": find_master, - "rescale_grad": self._rescale_grad - } + outputs = { + "ParamOut": param_and_grad[0], + "VelocityOut": velocity_acc + } - inputs = { - "Param": param_and_grad[0], - "Grad": param_and_grad[1], - "Velocity": velocity_acc, - "LearningRate": lr - } + if find_master: + inputs["MasterParam"] = master_weight + outputs["MasterParamOut"] = master_weight - outputs = {"ParamOut": param_and_grad[0], "VelocityOut": velocity_acc} + # create the momentum optimize op + momentum_op = block.append_op( + type=self.type if _lars_weight_decay != 0.0 else 'momentum', + inputs=inputs, + outputs=outputs, + attrs=attrs, + stop_gradient=True) + return momentum_op - if find_master: - inputs["MasterParam"] = master_weight - outputs["MasterParamOut"] = master_weight + else: + attrs = { + "mu": self._momentum, + "lars_coeff": self._lars_coeff, + "rescale_grad": self._rescale_grad, + "multi_precision": False, + "merge_option": True + } - # create the momentum optimize op - momentum_op = block.append_op( - type=self.type if _lars_weight_decay != 0.0 else 'momentum', - inputs=inputs, - outputs=outputs, - attrs=attrs, - stop_gradient=True) + if self._merge_option and not framework.in_dygraph_mode(): + assert isinstance( + param_and_grad, list + ), "Once merging all lars ops, argument `param_and_grad` must be list type." + + lr_array = [] + grad_array = [] + param_array = [] + velocity_array = [] + lars_weight_decay_array = [] + find_master = self._multi_precision and param_and_grad[0][ + 0].dtype == core.VarDesc.VarType.FP16 + master_weight_array = [] if find_master else None + + for param_and_grad_element in param_and_grad: + param_array.append(param_and_grad_element[0]) + grad_array.append(param_and_grad_element[1]) + velocity_array.append( + self._get_accumulator(self._velocity_acc_str, + param_and_grad_element[0])) + lr_array.append( + self._create_param_lr(param_and_grad_element)) + if find_master: + master_weight_array.append(self._master_weights[ + param_and_grad_element[0].name]) + + if len(self._exclude_from_weight_decay) > 0: + _lars_weight_decay = self._lars_weight_decay + for name in self._exclude_from_weight_decay: + if name in param_and_grad_element[0].name: + _lars_weight_decay = 0.0 + break + lars_weight_decay_array.append(_lars_weight_decay) + else: + lars_weight_decay_array.append(self._lars_weight_decay) + + inputs = { + "Param": param_array, + "Grad": grad_array, + "Velocity": velocity_array, + "LearningRate": lr_array + } + outputs = { + "ParamOut": param_array, + "VelocityOut": velocity_array + } + + # param lars_weight_decay combination + attrs["lars_weight_decay"] = lars_weight_decay_array + + if find_master: + attrs["multi_precision"] = True + inputs["MasterParam"] = master_weight_array + outputs["MasterParamOut"] = master_weight_array + + # create the momentum optimize op + lars_momentum_op = block.append_op( + type=self.type, + inputs=inputs, + outputs=outputs, + attrs=attrs, + stop_gradient=True) + return lars_momentum_op - return momentum_op + def _create_optimization_pass(self, parameters_and_grads): + global_block = framework.default_main_program().global_block() + target_block = global_block + current_block = framework.default_main_program().current_block() + + start = len(target_block.ops) + self._update_param_device_map(parameters_and_grads, target_block) + self._create_accumulators( + target_block, + [p[0] for p in parameters_and_grads if p[0].trainable]) + self._create_global_learning_rate() + + if framework.in_dygraph_mode(): + for param_and_grad in parameters_and_grads: + if param_and_grad[0].trainable is True: + self._append_optimize_op(target_block, param_and_grad) + else: + if (not self._merge_option): + for param_and_grad in parameters_and_grads: + with param_and_grad[0].block.program._optimized_guard( + param_and_grad), name_scope("optimizer"): + if param_and_grad[0].trainable is True: + device = self._get_device_for_param(param_and_grad[ + 0].name) + with device_guard(device): + optimize_op = self._append_optimize_op( + target_block, param_and_grad) + else: + normal_parameters_and_grad = [] + multi_precision_parameters_and_grads = [] + has_amp_lars = False + has_lars = False + for param_and_grad in parameters_and_grads: + with param_and_grad[0].block.program._optimized_guard( + param_and_grad), name_scope("optimizer"): + if param_and_grad[0].trainable is True: + device = self._get_device_for_param(param_and_grad[ + 0].name) + if self._multi_precision and param_and_grad[ + 0].dtype == core.VarDesc.VarType.FP16: + has_amp_lars = True + multi_precision_parameters_and_grads.append( + param_and_grad) + else: + has_lars = True + normal_parameters_and_grad.append( + param_and_grad) + with device_guard(device): + if has_amp_lars: + multi_precision_optimize_op = self._append_optimize_op( + target_block, multi_precision_parameters_and_grads) + if has_lars: + normal_optimize_op = self._append_optimize_op( + target_block, normal_parameters_and_grad) + + # Get custom finish ops for subclasses + # FIXME: Need to fix this once we figure out how to handle dependencies + self._finish_update(target_block, parameters_and_grads) + + end = len(target_block.ops) + return target_block._slice_ops(start, end) class AdagradOptimizer(Optimizer): From 7be64342bb8dac53089802804cc88e0f8d05abab Mon Sep 17 00:00:00 2001 From: JamesLim-sy Date: Sun, 17 Oct 2021 08:13:31 +0000 Subject: [PATCH 06/12] fix python codes error --- .../operators/optimizers/lars_momentum_op.cu | 226 ++++++++++-------- python/paddle/fluid/optimizer.py | 13 +- .../test_fleet_lars_meta_optimizer.py | 4 +- 3 files changed, 135 insertions(+), 108 deletions(-) diff --git a/paddle/fluid/operators/optimizers/lars_momentum_op.cu b/paddle/fluid/operators/optimizers/lars_momentum_op.cu index 5474ef9c7dc39..ae224299f5be9 100644 --- a/paddle/fluid/operators/optimizers/lars_momentum_op.cu +++ b/paddle/fluid/operators/optimizers/lars_momentum_op.cu @@ -48,13 +48,13 @@ class LarsThreadConfig { public: int grid_for_lars; #if CUDA_VERSION >= 11000 - public: explicit LarsThreadConfig(int64_t numel, int sm_num, int num_blocks_per_sm) { int grid = (numel + LARS_BLOCK_SIZE - 1) / LARS_BLOCK_SIZE; grid_for_lars = std::min(std::min(sm_num * num_blocks_per_sm, grid), LARS_BLOCK_SIZE); } + #else int grid_for_norm; explicit LarsThreadConfig(const int64_t numel) { @@ -71,10 +71,10 @@ class LarsThreadConfig { template __device__ inline void VectorizeLarsUpdate( - const T* __restrict__ grad, MT* param, MT* velocity, T* param_out, - const MT mu, MT local_lr, const MT lars_weight_decay, const MT rescale_grad, - const int tid, const int grid_stride, const int numel, - MT* master_param_out = nullptr) { + const T* __restrict__ grad, const MT* param, const MT* velocity, + T* param_out, MT* velocity_out, const MT mu, MT local_lr, + const MT lars_weight_decay, const MT rescale_grad, const int tid, + const int grid_stride, const int numel, MT* master_param_out = nullptr) { using VecType = paddle::platform::AlignedVector; using VecMType = paddle::platform::AlignedVector; int main = numel >> (VecSize >> 1); @@ -82,8 +82,9 @@ __device__ inline void VectorizeLarsUpdate( const VecType* grad_vec = reinterpret_cast(grad); const VecMType* param_vec = reinterpret_cast(param); - VecMType* velocity_vec = reinterpret_cast(velocity); + const VecMType* velocity_vec = reinterpret_cast(velocity); VecType* param_out_vec = reinterpret_cast(param_out); + VecMType* velocity_out_vec = reinterpret_cast(velocity_out); VecMType* master_param_out_vec; if (IsAmp) { @@ -106,7 +107,7 @@ __device__ inline void VectorizeLarsUpdate( param_out_tmp[j] = static_cast(param_tmp[j]); } param_out_vec[i] = param_out_tmp; - velocity_vec[i] = velocity_tmp; + velocity_out_vec[i] = velocity_tmp; if (IsAmp) { master_param_out_vec[i] = param_tmp; } @@ -119,7 +120,7 @@ __device__ inline void VectorizeLarsUpdate( param_val, grad_val)); MT param_tmp = param_val - velocity_tmp; param_out[i] = static_cast(param_tmp); - velocity[i] = velocity_tmp; + velocity_out[i] = velocity_tmp; if (IsAmp) { master_param_out[i] = param_tmp; } @@ -142,10 +143,8 @@ __global__ void L2NormKernel( const T* p_data, const T* __restrict__ g_data, MT* __restrict__ p_buffer, MT* __restrict__ g_buffer, const int64_t numel, const MT rescale_grad, MT* __restrict__ p_n = nullptr, MT* __restrict__ g_n = nullptr) { - __shared__ MT s_buffer[2]; int tid = threadIdx.x + blockDim.x * blockIdx.x; int grid_stride = LARS_BLOCK_SIZE * gridDim.x; - const MT rescale_pow = rescale_grad * rescale_grad; MT p_tmp = static_cast(0); MT g_tmp = static_cast(0); @@ -164,6 +163,7 @@ __global__ void L2NormKernel( g_buffer[blockIdx.x] = g_tmp; } #if CUDA_VERSION >= 11000 + __shared__ MT s_buffer[2]; cg->sync(); // Grid sync for writring partial result to gloabl memory MT p_part_sum = threadIdx.x < gridDim.x ? p_buffer[threadIdx.x] : 0; MT g_part_sum = threadIdx.x < gridDim.x ? g_buffer[threadIdx.x] : 0; @@ -175,14 +175,15 @@ __global__ void L2NormKernel( } __syncthreads(); *p_n = Sqrt(s_buffer[0]); - *g_n = Sqrt(rescale_pow * s_buffer[1]); + *g_n = rescale_grad * Sqrt(s_buffer[1]); #endif } template __forceinline__ __device__ void MomentumUpdate( - const T* __restrict__ grad, T* param, MT* velocity, MT* master_param, - T* param_out, const MT* __restrict__ learning_rate, const MT mu, + const T* param, const T* __restrict__ grad, const MT* velocity, + T* param_out, MT* velocity_out, const MT* master_param, + MT* master_param_out, const MT* __restrict__ learning_rate, const MT mu, const MT lars_weight_decay, const MT lars_coeff, const MT epsilon, const MT rescale_grad, const MT param_norm, const MT grad_norm, const int tid, const int grid_stride, const int64_t numel, @@ -195,19 +196,22 @@ __forceinline__ __device__ void MomentumUpdate( } if (is_amp) { VectorizeLarsUpdate( - grad, master_param, velocity, param_out, mu, local_lr, - lars_weight_decay, rescale_grad, tid, grid_stride, numel); + grad, master_param, velocity, param_out, velocity_out, mu, local_lr, + lars_weight_decay, rescale_grad, tid, grid_stride, numel, + master_param_out); } else { if (std::is_same::value || std::is_same::value) { /* TODO(limingshu): pointer cast may damage memory accessing for fp16 */ VectorizeLarsUpdate( - grad, reinterpret_cast(param), velocity, param_out, mu, local_lr, - lars_weight_decay, rescale_grad, tid, grid_stride, numel); + grad, reinterpret_cast(param), velocity, param_out, + velocity_out, mu, local_lr, lars_weight_decay, rescale_grad, tid, + grid_stride, numel); } else { VectorizeLarsUpdate( - grad, reinterpret_cast(param), velocity, param_out, mu, local_lr, - lars_weight_decay, rescale_grad, tid, grid_stride, numel); + grad, reinterpret_cast(param), velocity, param_out, + velocity_out, mu, local_lr, lars_weight_decay, rescale_grad, tid, + grid_stride, numel); } } } @@ -236,15 +240,15 @@ struct LarsParamWarpper : public MergedLarsMasterParam { static constexpr int kNum = kOpNum; int numel_arr[kOpNum]; - const T* __restrict__ g_arr[kOpNum]; const MT* __restrict__ lr_arr[kOpNum]; + const T* __restrict__ g_arr[kOpNum]; T* p_arr[kOpNum]; MT* v_arr[kOpNum]; - MT weight_decay_arr[kOpNum]; + MT weight_decay[kOpNum]; }; -template -__global__ void MergedMomentumLarsKernel(LarsWarpperType lars_warpper, +template +__global__ void MergedMomentumLarsKernel(LarsParamWarpper lars_warpper, MT* __restrict__ p_buffer, MT* __restrict__ g_buffer, const int op_num, const MT mu, @@ -254,31 +258,33 @@ __global__ void MergedMomentumLarsKernel(LarsWarpperType lars_warpper, int grid_stride = gridDim.x * LARS_BLOCK_SIZE; int tid = threadIdx.x + blockIdx.x * blockDim.x; const cooperative_groups::grid_group cg = cooperative_groups::this_grid(); - for (int i = 0; i < lars_warpper.kNum; ++i) { - if (i > op_num) break; + for (int i = 0; i < op_num; ++i) { int numel = lars_warpper.numel_arr[i]; MT param_norm = static_cast(0); MT grad_norm = static_cast(0); L2NormKernel(&cg, lars_warpper.p_arr[i], lars_warpper.g_arr[i], p_buffer, g_buffer, numel, rescale_grad, ¶m_norm, &grad_norm); - MomentumUpdate(lars_warpper.g_arr[i], lars_warpper.p_arr[i], - lars_warpper.v_arr[i], lars_warpper.GetMasterParam(i), - lars_warpper.p_arr[i], lars_warpper.lr_arr[i], mu, - lars_warpper.weight_decay_arr[i], lars_coeff, epsilon, - rescale_grad, param_norm, grad_norm, tid, grid_stride, - numel, is_amp); + MomentumUpdate( + lars_warpper.p_arr[i], lars_warpper.g_arr[i], lars_warpper.v_arr[i], + lars_warpper.p_arr[i], lars_warpper.v_arr[i], + lars_warpper.GetMasterParam(i), lars_warpper.GetMasterParam(i), + lars_warpper.lr_arr[i], mu, lars_warpper.weight_decay[i], lars_coeff, + epsilon, rescale_grad, param_norm, grad_norm, tid, grid_stride, numel, + is_amp); } } #endif template __global__ void MomentumLarsKernel( - const T* __restrict__ grad, T* param, MT* velocity, MT* master_param, - const MT* __restrict__ learning_rate, MT* __restrict__ p_buffer, - MT* __restrict__ g_buffer, const MT mu, const MT lars_coeff, - const MT lars_weight_decay, const MT epsilon, const MT rescale_grad, - const int thresh, const int64_t numel, const bool is_amp) { + const T* param, const T* __restrict__ grad, const MT* velocity, + T* param_out, MT* velocity_out, const MT* master_param, + MT* master_param_out, const MT* __restrict__ learning_rate, + MT* __restrict__ p_buffer, MT* __restrict__ g_buffer, const MT mu, + const MT lars_coeff, const MT lars_weight_decay, const MT epsilon, + const MT rescale_grad, const int thresh, const int64_t numel, + const bool is_amp) { int tid = threadIdx.x + blockIdx.x * blockDim.x; int grid_stride = gridDim.x * LARS_BLOCK_SIZE; #if CUDA_VERSION >= 11000 @@ -288,33 +294,33 @@ __global__ void MomentumLarsKernel( L2NormKernel(&cg, param, grad, p_buffer, g_buffer, numel, rescale_grad, ¶m_norm, &grad_norm); #else - __shared__ MT s_buffer[2]; - const MT rescale_pow = rescale_grad * rescale_grad; - MT param_part_norm = threadIdx.x < thresh ? p_buffer[threadIdx.x] : 0; - MT grad_part_norm = threadIdx.x < thresh ? g_buffer[threadIdx.x] : 0; - MT tmp0 = math::blockReduceSum(param_part_norm, FINAL_MASK); - MT tmp1 = math::blockReduceSum(param_part_norm, FINAL_MASK); + MT p_part_sum = threadIdx.x < thresh ? p_buffer[threadIdx.x] : 0; + MT g_part_sum = threadIdx.x < thresh ? g_buffer[threadIdx.x] : 0; + MT tmp0 = math::blockReduceSum(p_part_sum, FINAL_MASK); + MT tmp1 = math::blockReduceSum(g_part_sum, FINAL_MASK); if (threadIdx.x == 0) { s_buffer[0] = tmp0; s_buffer[1] = tmp1; } __syncthreads(); MT param_norm = Sqrt(s_buffer[0]); - MT grad_norm = Sqrt(rescale_pow * s_buffer[1]); + MT grad_norm = rescale_grad * Sqrt(s_buffer[1]); + #endif - MomentumUpdate(grad, param, velocity, master_param, param /*inplace*/, - learning_rate, mu, lars_weight_decay, lars_coeff, - epsilon, rescale_grad, param_norm, grad_norm, tid, - grid_stride, numel, is_amp); + MomentumUpdate(param, grad, velocity, param_out, velocity_out, + master_param, master_param_out, learning_rate, mu, + lars_weight_decay, lars_coeff, epsilon, rescale_grad, + param_norm, grad_norm, tid, grid_stride, numel, is_amp); } template inline void SeparatedLarsMomentumOpCUDAKernel( - const platform::CUDADeviceContext& cuda_ctx, const T* grad_data, - T* param_data, MT* velocity_data, MT* master_param_data, const MT* lr, - MT* p_buffer, MT* g_buffer, const MT mu, const MT lars_coeff, - const MT weight_decay, const MT epsilon, const MT rescale_grad, - const int64_t numel, const bool is_amp) { + const platform::CUDADeviceContext& cuda_ctx, const T* param_data, + T* param_out_data, const MT* velocity_data, MT* velocity_out_data, + const T* grad_data, const MT* lr, MT* p_buffer, MT* g_buffer, const MT mu, + const MT lars_coeff, const MT weight_decay, const MT epsilon, + const MT rescale_grad, const int64_t numel, const MT* master_param_data, + MT* master_out_data, const bool is_amp) { LarsThreadConfig lars_thread_config(numel); L2NormKernel<<>>(param_data, grad_data, p_buffer, @@ -322,11 +328,11 @@ inline void SeparatedLarsMomentumOpCUDAKernel( MomentumLarsKernel<<>>( - grad_data, param_data, velocity_data, master_param_data, lr, p_buffer, - g_buffer, mu, lars_coeff, weight_decay, epsilon, rescale_grad, + param_data, grad_data, velocity_data, param_out_data, velocity_out_data, + master_param_data, master_out_data, lr, p_buffer, g_buffer, mu, + lars_coeff, weight_decay, epsilon, rescale_grad, lars_thread_config.grid_for_norm, numel, is_amp); } - template class LarsMomentumOpCUDAKernel : public framework::OpKernel { using MT = MultiPrecisionType; @@ -334,6 +340,7 @@ class LarsMomentumOpCUDAKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { int num_blocks_per_sm = 0; + bool multi_precision = ctx.Attr("multi_precision"); auto& cuda_ctx = ctx.template device_context(); int sm_num = cuda_ctx.GetSMCount(); framework::Tensor tmp_buffer_t = @@ -342,13 +349,12 @@ class LarsMomentumOpCUDAKernel : public framework::OpKernel { auto* p_buffer = tmp_buffer_t.mutable_data(ctx.GetPlace()); auto* g_buffer = p_buffer + LARS_BLOCK_SIZE; - bool multi_precision = ctx.Attr("multi_precision"); MT mu = static_cast(ctx.Attr("mu")); MT lars_coeff = static_cast(ctx.Attr("lars_coeff")); MT epsilon = static_cast(ctx.Attr("epsilon")); MT rescale_grad = static_cast(ctx.Attr("rescale_grad")); - auto weight_decay_arr = ctx.Attr>("lars_weight_decay"); + auto weight_decay_arr = ctx.Attr>("lars_weight_decay"); auto grad = ctx.MultiInput("Grad"); auto param = ctx.MultiInput("Param"); auto velocity = ctx.MultiInput("Velocity"); @@ -360,25 +366,11 @@ class LarsMomentumOpCUDAKernel : public framework::OpKernel { ctx.MultiOutput("MasterParamOut"); int op_num = grad.size(); - for (size_t i = 0; i < op_num; ++i) { - PADDLE_ENFORCE_EQ( - param[i], param_out[i], - platform::errors::InvalidArgument( - "Input(Param) and Output(ParamOut) must be the same Tensors.")); - PADDLE_ENFORCE_EQ(velocity[i], velocity_out[i], - platform::errors::InvalidArgument( - "Input(Velocity) and Output(VelocityOut) must be " - "the same Tensors.")); - if (multi_precision) { - PADDLE_ENFORCE_EQ(master_param[i], master_param_out[i], - platform::errors::InvalidArgument( - "Input(MasterParam) and Output(MasterParamOut) " - "must be the same Tensors.")); - } - } #if CUDA_VERSION >= 11000 if (op_num > 1) { LarsParamWarpper lars_warpper; + VLOG(10) << "Num of ops merged in lars_warpper is " << lars_warpper.kNum; + /* Implementation of lars optimizer consists of following two steps: 1. Figure out the L2 norm statistic result of grad data and param data. 2. Update param and velocity with usage of L2 norm statistic result. @@ -387,11 +379,8 @@ class LarsMomentumOpCUDAKernel : public framework::OpKernel { - The thread quantity shall less than pyhsical SM limited threads - Launche as thread-block can synchronizlly execute. */ cudaOccupancyMaxActiveBlocksPerMultiprocessor( - &num_blocks_per_sm, - MergedMomentumLarsKernel, - LARS_BLOCK_SIZE, sizeof(MT) << 1); - - VLOG(10) << "Num of ops merged in lars_warpper is " << lars_warpper.kNum; + &num_blocks_per_sm, MergedMomentumLarsKernel, LARS_BLOCK_SIZE, + sizeof(MT) << 1); int merge_times = (op_num + lars_warpper.kNum - 1) / lars_warpper.kNum; for (int j = 0; j < merge_times; ++j) { @@ -404,21 +393,39 @@ class LarsMomentumOpCUDAKernel : public framework::OpKernel { total_numel += temp_numel; lars_warpper.numel_arr[i] = temp_numel; lars_warpper.g_arr[i] = grad[start_idx + i]->data(); - lars_warpper.p_arr[i] = param_out[start_idx + i]->data(); - lars_warpper.v_arr[i] = velocity_out[start_idx + i]->data(); - lars_warpper.lr_arr[i] = learning_rate[i]->data(); - lars_warpper.weight_decay_arr[i] = - static_cast(weight_decay_arr[start_idx + i]); + lars_warpper.p_arr[i] = + param_out[start_idx + i]->mutable_data(ctx.GetPlace()); + lars_warpper.v_arr[i] = + velocity_out[start_idx + i]->mutable_data(ctx.GetPlace()); + lars_warpper.lr_arr[i] = learning_rate[start_idx + i]->data(); + lars_warpper.weight_decay[i] = static_cast(weight_decay_arr[i]); + if (multi_precision) { + auto master_param_data = + master_param_out[start_idx + i]->mutable_data( + ctx.GetPlace()); + lars_warpper.SetMasterParam(i, master_param_data); + PADDLE_ENFORCE_EQ( + master_param[start_idx + i]->data(), master_param_data, + platform::errors::InvalidArgument( + "Input(MasterParam) and Output(MasterParamOut) of lars " + "optimizer must be the same Tensors.")); + } + PADDLE_ENFORCE_EQ( + param[start_idx + i]->data(), lars_warpper.p_arr[i], + platform::errors::InvalidArgument( + "Input(Param) and Output(ParamOut) of lars optimizer " + "must be the same Tensors.")); + PADDLE_ENFORCE_EQ(velocity[start_idx + i]->data(), + lars_warpper.v_arr[i], + platform::errors::InvalidArgument( + "Input(Velocity) and Output(VelocityOut) of " + "lars optimizer must be " + "the same Tensors.")); } + VLOG(10) << "Op number delt in loop " << j << " is : " << loop_num; int64_t avg_numel = total_numel / loop_num; LarsThreadConfig lars_thread_config(avg_numel, sm_num, num_blocks_per_sm); - for (int i = 0; i < loop_num; ++i) { - if (multi_precision) { - lars_warpper.SetMasterParam( - i, master_param_out[i]->mutable_data(ctx.GetPlace())); - } - } void* cuda_param[] = {reinterpret_cast(&lars_warpper), reinterpret_cast(&p_buffer), reinterpret_cast(&g_buffer), @@ -428,21 +435,26 @@ class LarsMomentumOpCUDAKernel : public framework::OpKernel { reinterpret_cast(&epsilon), reinterpret_cast(&rescale_grad), reinterpret_cast(&multi_precision)}; + // Lanuch all sm theads,thead of each block synchronizedly cooperate. cudaLaunchCooperativeKernel( - reinterpret_cast( - MergedMomentumLarsKernel), + reinterpret_cast(MergedMomentumLarsKernel), lars_thread_config.grid_for_lars, LARS_BLOCK_SIZE, cuda_param, 0, cuda_ctx.stream()); - - VLOG(10) << "Lanuched ops number is " << loop_num; } } else { + auto* param_data = param[0]->data(); auto* grad_data = grad[0]->data(); - auto* param_data = param_out[0]->data(); - auto* velocity_data = velocity_out[0]->data(); + auto* velocity_data = velocity[0]->data(); auto* lr = learning_rate[0]->data(); + auto* param_out_data = param_out[0]->mutable_data(ctx.GetPlace()); + auto* velocity_out_data = + velocity_out[0]->mutable_data(ctx.GetPlace()); const MT* master_param_data = - multi_precision ? master_param_out[0]->data() : nullptr; + multi_precision ? master_param[0]->data() : nullptr; + MT* master_param_out_data = + multi_precision + ? master_param_out[0]->mutable_data(ctx.GetPlace()) + : nullptr; int64_t numel = param[0]->numel(); MT lars_weight_decay = weight_decay_arr[0]; @@ -453,12 +465,14 @@ class LarsMomentumOpCUDAKernel : public framework::OpKernel { LarsThreadConfig lars_thread_config(numel, sm_num, num_blocks_per_sm); int thresh = 0; - void* cuda_param[] = { - reinterpret_cast(&grad_data), reinterpret_cast(¶m_data), + reinterpret_cast(&grad_data), reinterpret_cast(&velocity_data), + reinterpret_cast(¶m_out_data), + reinterpret_cast(&velocity_out_data), reinterpret_cast(&master_param_data), + reinterpret_cast(&master_param_out_data), reinterpret_cast(&lr), reinterpret_cast(&p_buffer), reinterpret_cast(&g_buffer), @@ -478,14 +492,20 @@ class LarsMomentumOpCUDAKernel : public framework::OpKernel { } #else for (int i = 0; i < op_num; ++i) { - MT* master_param_data = - multi_precision ? master_param_out[i]->data() : nullptr; + const MT* master_param_data = + multi_precision ? master_param[i]->data() : nullptr; + MT* master_param_out_data = + multi_precision + ? master_param_out[i]->mutable_data(ctx.GetPlace()) + : nullptr; SeparatedLarsMomentumOpCUDAKernel( - cuda_ctx, grad[i]->data(), param_out[i]->data(), - velocity_out[i]->data(), master_param_data, + cuda_ctx, param[i]->data(), + param_out[i]->mutable_data(ctx.GetPlace()), + velocity[i]->data(), + velocity_out[i]->mutable_data(ctx.GetPlace()), grad[i]->data(), learning_rate[i]->data(), p_buffer, g_buffer, mu, lars_coeff, weight_decay_arr[i], epsilon, rescale_grad, param[i]->numel(), - multi_precision); + master_param_data, master_param_out_data, multi_precision); } #endif } diff --git a/python/paddle/fluid/optimizer.py b/python/paddle/fluid/optimizer.py index ed0459974e5e2..ecbf0d17a8a99 100755 --- a/python/paddle/fluid/optimizer.py +++ b/python/paddle/fluid/optimizer.py @@ -2047,8 +2047,15 @@ def _create_accumulators(self, block, parameters): def _append_optimize_op(self, block, param_and_grad): assert isinstance(block, framework.Block) - if not isinstance(param_and_grad, list): - _lars_weight_decay = 0.0 + if not isinstance(param_and_grad, list) or framework.in_dygraph_mode(): + _lars_weight_decay = self._lars_weight_decay + param_name = param_and_grad[0].name + if len(self._exclude_from_weight_decay) > 0: + for name in self._exclude_from_weight_decay: + if name in param_name: + _lars_weight_decay = 0.0 + break + velocity_acc = self._get_accumulator(self._velocity_acc_str, param_and_grad[0]) lr = self._create_param_lr(param_and_grad) @@ -2083,7 +2090,7 @@ def _append_optimize_op(self, block, param_and_grad): # create the momentum optimize op momentum_op = block.append_op( - type='momentum', + type=self.type if _lars_weight_decay != 0.0 else 'momentum', inputs=inputs, outputs=outputs, attrs=attrs, diff --git a/python/paddle/fluid/tests/unittests/test_fleet_lars_meta_optimizer.py b/python/paddle/fluid/tests/unittests/test_fleet_lars_meta_optimizer.py index bee6acf732460..86ca75534ce27 100755 --- a/python/paddle/fluid/tests/unittests/test_fleet_lars_meta_optimizer.py +++ b/python/paddle/fluid/tests/unittests/test_fleet_lars_meta_optimizer.py @@ -64,7 +64,7 @@ def test_lars_optimizer(self): startup_prog = fluid.Program() train_prog = fluid.Program() avg_cost, strategy = self.net(train_prog, startup_prog) - optimizer = paddle.fluid.optimizer.Momentum( + optimizer = paddle.fluid.optimizer.LarsMomentum( learning_rate=0.01, momentum=0.9) optimizer = fleet.distributed_optimizer(optimizer, strategy=strategy) optimizer.minimize(avg_cost) @@ -139,7 +139,7 @@ def test_lars_apply_with_amp(self): "exclude_from_weight_decay": ["batch_norm", ".b"], } - optimizer = paddle.fluid.optimizer.Momentum( + optimizer = paddle.fluid.optimizer.LarsMomentum( learning_rate=0.01, momentum=0.9) optimizer = fleet.distributed_optimizer(optimizer, strategy=strategy) optimizer.minimize(avg_cost) From 57b952aff4d929411df27f8878afad3eaff4c5d2 Mon Sep 17 00:00:00 2001 From: JamesLim-sy Date: Sun, 17 Oct 2021 14:46:30 +0000 Subject: [PATCH 07/12] change the type form of lars_weight_decay --- .../operators/optimizers/lars_momentum_op.cc | 9 ---- .../operators/optimizers/lars_momentum_op.cu | 45 ++++++++++--------- .../operators/optimizers/lars_momentum_op.h | 3 +- python/paddle/fluid/optimizer.py | 12 +---- 4 files changed, 26 insertions(+), 43 deletions(-) diff --git a/paddle/fluid/operators/optimizers/lars_momentum_op.cc b/paddle/fluid/operators/optimizers/lars_momentum_op.cc index 65be35843bdf9..e37ed76685ec5 100644 --- a/paddle/fluid/operators/optimizers/lars_momentum_op.cc +++ b/paddle/fluid/operators/optimizers/lars_momentum_op.cc @@ -44,8 +44,6 @@ class LarsMomentumOp : public framework::OperatorWithKernel { auto grad_dim = ctx->GetInputsDim("Grad"); auto param_dim = ctx->GetInputsDim("Param"); auto velocity_dim = ctx->GetInputsDim("Velocity"); - auto lars_weight_decays = - ctx->Attrs().Get>("lars_weight_decay"); auto multi_precision = ctx->Attrs().Get("multi_precision"); PADDLE_ENFORCE_EQ( @@ -61,13 +59,6 @@ class LarsMomentumOp : public framework::OperatorWithKernel { "have same quantity. But number of Param is [%d] and Velocity " "is [%d].", param_dim.size(), velocity_dim.size())); - PADDLE_ENFORCE_EQ( - lars_weight_decays.size(), grad_dim.size(), - platform::errors::InvalidArgument( - "Attr(Lars_weight_decay) and " - "Input(Grad) of LarsMomentumOp should have same quantity. " - "But number of Lars_weight_decay is [%d] and Grad is [%d].", - lars_weight_decays.size(), grad_dim.size())); if (multi_precision) { OP_INOUT_CHECK(ctx->HasInputs("MasterParam"), "Input", "MasterParam", diff --git a/paddle/fluid/operators/optimizers/lars_momentum_op.cu b/paddle/fluid/operators/optimizers/lars_momentum_op.cu index ae224299f5be9..aae32f4567735 100644 --- a/paddle/fluid/operators/optimizers/lars_momentum_op.cu +++ b/paddle/fluid/operators/optimizers/lars_momentum_op.cu @@ -235,7 +235,7 @@ struct MergedLarsMasterParam { template ::value ? 70 : 90> + std::is_same::value ? 80 : 100> struct LarsParamWarpper : public MergedLarsMasterParam { static constexpr int kNum = kOpNum; @@ -244,7 +244,7 @@ struct LarsParamWarpper : public MergedLarsMasterParam { const T* __restrict__ g_arr[kOpNum]; T* p_arr[kOpNum]; MT* v_arr[kOpNum]; - MT weight_decay[kOpNum]; + MT weight_decay; }; template @@ -265,13 +265,13 @@ __global__ void MergedMomentumLarsKernel(LarsParamWarpper lars_warpper, L2NormKernel(&cg, lars_warpper.p_arr[i], lars_warpper.g_arr[i], p_buffer, g_buffer, numel, rescale_grad, ¶m_norm, &grad_norm); - MomentumUpdate( - lars_warpper.p_arr[i], lars_warpper.g_arr[i], lars_warpper.v_arr[i], - lars_warpper.p_arr[i], lars_warpper.v_arr[i], - lars_warpper.GetMasterParam(i), lars_warpper.GetMasterParam(i), - lars_warpper.lr_arr[i], mu, lars_warpper.weight_decay[i], lars_coeff, - epsilon, rescale_grad, param_norm, grad_norm, tid, grid_stride, numel, - is_amp); + MomentumUpdate(lars_warpper.p_arr[i], lars_warpper.g_arr[i], + lars_warpper.v_arr[i], lars_warpper.p_arr[i], + lars_warpper.v_arr[i], lars_warpper.GetMasterParam(i), + lars_warpper.GetMasterParam(i), + lars_warpper.lr_arr[i], mu, lars_warpper.weight_decay, + lars_coeff, epsilon, rescale_grad, param_norm, + grad_norm, tid, grid_stride, numel, is_amp); } } #endif @@ -294,6 +294,7 @@ __global__ void MomentumLarsKernel( L2NormKernel(&cg, param, grad, p_buffer, g_buffer, numel, rescale_grad, ¶m_norm, &grad_norm); #else + __shared__ MT s_buffer[2]; MT p_part_sum = threadIdx.x < thresh ? p_buffer[threadIdx.x] : 0; MT g_part_sum = threadIdx.x < thresh ? g_buffer[threadIdx.x] : 0; MT tmp0 = math::blockReduceSum(p_part_sum, FINAL_MASK); @@ -353,8 +354,9 @@ class LarsMomentumOpCUDAKernel : public framework::OpKernel { MT lars_coeff = static_cast(ctx.Attr("lars_coeff")); MT epsilon = static_cast(ctx.Attr("epsilon")); MT rescale_grad = static_cast(ctx.Attr("rescale_grad")); - auto weight_decay_arr = ctx.Attr>("lars_weight_decay"); + MT lars_weight_decay = weight_decay_arr[0]; + auto grad = ctx.MultiInput("Grad"); auto param = ctx.MultiInput("Param"); auto velocity = ctx.MultiInput("Velocity"); @@ -382,6 +384,7 @@ class LarsMomentumOpCUDAKernel : public framework::OpKernel { &num_blocks_per_sm, MergedMomentumLarsKernel, LARS_BLOCK_SIZE, sizeof(MT) << 1); + lars_warpper.weight_decay = lars_weight_decay; int merge_times = (op_num + lars_warpper.kNum - 1) / lars_warpper.kNum; for (int j = 0; j < merge_times; ++j) { size_t total_numel = 0; @@ -398,7 +401,6 @@ class LarsMomentumOpCUDAKernel : public framework::OpKernel { lars_warpper.v_arr[i] = velocity_out[start_idx + i]->mutable_data(ctx.GetPlace()); lars_warpper.lr_arr[i] = learning_rate[start_idx + i]->data(); - lars_warpper.weight_decay[i] = static_cast(weight_decay_arr[i]); if (multi_precision) { auto master_param_data = master_param_out[start_idx + i]->mutable_data( @@ -407,20 +409,20 @@ class LarsMomentumOpCUDAKernel : public framework::OpKernel { PADDLE_ENFORCE_EQ( master_param[start_idx + i]->data(), master_param_data, platform::errors::InvalidArgument( - "Input(MasterParam) and Output(MasterParamOut) of lars " - "optimizer must be the same Tensors.")); + "Since Input(MasterParam) and Output(MasterParamOut) of " + "lars optimizer must be the same Tensors.")); } PADDLE_ENFORCE_EQ( param[start_idx + i]->data(), lars_warpper.p_arr[i], platform::errors::InvalidArgument( - "Input(Param) and Output(ParamOut) of lars optimizer " + "Since Input(Param) and Output(ParamOut) of lars optimizer " "must be the same Tensors.")); - PADDLE_ENFORCE_EQ(velocity[start_idx + i]->data(), - lars_warpper.v_arr[i], - platform::errors::InvalidArgument( - "Input(Velocity) and Output(VelocityOut) of " - "lars optimizer must be " - "the same Tensors.")); + PADDLE_ENFORCE_EQ( + velocity[start_idx + i]->data(), lars_warpper.v_arr[i], + platform::errors::InvalidArgument( + "Since Input(Velocity) and Output(VelocityOut) of " + "lars optimizer must be " + "the same Tensors.")); } VLOG(10) << "Op number delt in loop " << j << " is : " << loop_num; int64_t avg_numel = total_numel / loop_num; @@ -456,7 +458,6 @@ class LarsMomentumOpCUDAKernel : public framework::OpKernel { ? master_param_out[0]->mutable_data(ctx.GetPlace()) : nullptr; int64_t numel = param[0]->numel(); - MT lars_weight_decay = weight_decay_arr[0]; // Figure out how many blocks can be active in each sm. cudaOccupancyMaxActiveBlocksPerMultiprocessor( @@ -504,7 +505,7 @@ class LarsMomentumOpCUDAKernel : public framework::OpKernel { velocity[i]->data(), velocity_out[i]->mutable_data(ctx.GetPlace()), grad[i]->data(), learning_rate[i]->data(), p_buffer, g_buffer, mu, lars_coeff, - weight_decay_arr[i], epsilon, rescale_grad, param[i]->numel(), + lars_weight_decay, epsilon, rescale_grad, param[i]->numel(), master_param_data, master_param_out_data, multi_precision); } #endif diff --git a/paddle/fluid/operators/optimizers/lars_momentum_op.h b/paddle/fluid/operators/optimizers/lars_momentum_op.h index df4d7b9a0438b..b7a6446c46b3d 100644 --- a/paddle/fluid/operators/optimizers/lars_momentum_op.h +++ b/paddle/fluid/operators/optimizers/lars_momentum_op.h @@ -33,11 +33,12 @@ class LarsMomentumOpKernel : public framework::OpKernel { T mu = static_cast(ctx.Attr("mu")); T lars_coeff = ctx.Attr("lars_coeff"); T epsilon = ctx.Attr("epsilon"); + T lars_weight_decay = weight_decay_arr[0]; int op_num = param.size(); for (int i = 0; i < op_num; ++i) { auto* lr = learning_rate[i]->data(); - T lars_weight_decay = weight_decay_arr[i]; + param_out[i]->mutable_data(ctx.GetPlace()); velocity_out[i]->mutable_data(ctx.GetPlace()); diff --git a/python/paddle/fluid/optimizer.py b/python/paddle/fluid/optimizer.py index ecbf0d17a8a99..94dfe8b7ac875 100755 --- a/python/paddle/fluid/optimizer.py +++ b/python/paddle/fluid/optimizer.py @@ -2122,16 +2122,6 @@ def _append_optimize_op(self, block, param_and_grad): master_weight_array.append(self._master_weights[ param_and_grad_element[0].name]) - if len(self._exclude_from_weight_decay) > 0: - _lars_weight_decay = self._lars_weight_decay - for name in self._exclude_from_weight_decay: - if name in param_and_grad_element[0].name: - _lars_weight_decay = 0.0 - break - lars_weight_decay_array.append(_lars_weight_decay) - else: - lars_weight_decay_array.append(self._lars_weight_decay) - inputs = { "Param": param_array, "Grad": grad_array, @@ -2144,7 +2134,7 @@ def _append_optimize_op(self, block, param_and_grad): "lars_coeff": self._lars_coeff, "rescale_grad": self._rescale_grad, "multi_precision": find_master, - "lars_weight_decay": lars_weight_decay_array + "lars_weight_decay": [self._lars_weight_decay] } if find_master: From 3f627523130425eb108f00a3dba817e85b40f252 Mon Sep 17 00:00:00 2001 From: JamesLim-sy Date: Sun, 17 Oct 2021 19:26:09 +0000 Subject: [PATCH 08/12] add test file of merged lars --- .../operators/optimizers/lars_momentum_op.cu | 1 - python/paddle/fluid/optimizer.py | 194 ++++------------ .../unittests/test_merged_lars_optimizer.py | 210 ++++++++++++++++++ 3 files changed, 249 insertions(+), 156 deletions(-) create mode 100644 python/paddle/fluid/tests/unittests/test_merged_lars_optimizer.py diff --git a/paddle/fluid/operators/optimizers/lars_momentum_op.cu b/paddle/fluid/operators/optimizers/lars_momentum_op.cu index aae32f4567735..e12a2c902884e 100644 --- a/paddle/fluid/operators/optimizers/lars_momentum_op.cu +++ b/paddle/fluid/operators/optimizers/lars_momentum_op.cu @@ -372,7 +372,6 @@ class LarsMomentumOpCUDAKernel : public framework::OpKernel { if (op_num > 1) { LarsParamWarpper lars_warpper; VLOG(10) << "Num of ops merged in lars_warpper is " << lars_warpper.kNum; - /* Implementation of lars optimizer consists of following two steps: 1. Figure out the L2 norm statistic result of grad data and param data. 2. Update param and velocity with usage of L2 norm statistic result. diff --git a/python/paddle/fluid/optimizer.py b/python/paddle/fluid/optimizer.py index 94dfe8b7ac875..94f66257917cc 100755 --- a/python/paddle/fluid/optimizer.py +++ b/python/paddle/fluid/optimizer.py @@ -2047,168 +2047,52 @@ def _create_accumulators(self, block, parameters): def _append_optimize_op(self, block, param_and_grad): assert isinstance(block, framework.Block) - if not isinstance(param_and_grad, list) or framework.in_dygraph_mode(): - _lars_weight_decay = self._lars_weight_decay - param_name = param_and_grad[0].name - if len(self._exclude_from_weight_decay) > 0: - for name in self._exclude_from_weight_decay: - if name in param_name: - _lars_weight_decay = 0.0 - break + _lars_weight_decay = self._lars_weight_decay + param_name = param_and_grad[0].name + if len(self._exclude_from_weight_decay) > 0: + for name in self._exclude_from_weight_decay: + if name in param_name: + _lars_weight_decay = 0.0 + break - velocity_acc = self._get_accumulator(self._velocity_acc_str, - param_and_grad[0]) - lr = self._create_param_lr(param_and_grad) - - find_master = self._multi_precision and param_and_grad[ - 0].dtype == core.VarDesc.VarType.FP16 - master_weight = (self._master_weights[param_and_grad[0].name] - if find_master else None) - - attrs = { - "mu": self._momentum, - "lars_coeff": self._lars_coeff, - "lars_weight_decay": [_lars_weight_decay], - "multi_precision": find_master, - "rescale_grad": self._rescale_grad - } - inputs = { - "Param": param_and_grad[0], - "Grad": param_and_grad[1], - "Velocity": velocity_acc, - "LearningRate": lr - } - - outputs = { - "ParamOut": param_and_grad[0], - "VelocityOut": velocity_acc - } - - if find_master: - inputs["MasterParam"] = master_weight - outputs["MasterParamOut"] = master_weight - - # create the momentum optimize op - momentum_op = block.append_op( - type=self.type if _lars_weight_decay != 0.0 else 'momentum', - inputs=inputs, - outputs=outputs, - attrs=attrs, - stop_gradient=True) - - return momentum_op - else: - assert isinstance( - param_and_grad, list - ), "Once merging all lars ops, argument `param_and_grad` must be list type." - - lr_array = [] - grad_array = [] - param_array = [] - velocity_array = [] - lars_weight_decay_array = [] - find_master = self._multi_precision and param_and_grad[0][ - 0].dtype == core.VarDesc.VarType.FP16 - master_weight_array = [] if find_master else None - - for param_and_grad_element in param_and_grad: - param_array.append(param_and_grad_element[0]) - grad_array.append(param_and_grad_element[1]) - velocity_array.append( - self._get_accumulator(self._velocity_acc_str, - param_and_grad_element[0])) - lr_array.append(self._create_param_lr(param_and_grad_element)) - if find_master: - master_weight_array.append(self._master_weights[ - param_and_grad_element[0].name]) - - inputs = { - "Param": param_array, - "Grad": grad_array, - "Velocity": velocity_array, - "LearningRate": lr_array - } - outputs = {"ParamOut": param_array, "VelocityOut": velocity_array} - attrs = { - "mu": self._momentum, - "lars_coeff": self._lars_coeff, - "rescale_grad": self._rescale_grad, - "multi_precision": find_master, - "lars_weight_decay": [self._lars_weight_decay] - } - - if find_master: - inputs["MasterParam"] = master_weight_array - outputs["MasterParamOut"] = master_weight_array - - # create the momentum optimize op - lars_momentum_op = block.append_op( - type=self.type, - inputs=inputs, - outputs=outputs, - attrs=attrs, - stop_gradient=True) - return lars_momentum_op + velocity_acc = self._get_accumulator(self._velocity_acc_str, + param_and_grad[0]) + lr = self._create_param_lr(param_and_grad) - def _create_optimization_pass(self, parameters_and_grads): - global_block = framework.default_main_program().global_block() - target_block = global_block - current_block = framework.default_main_program().current_block() + find_master = self._multi_precision and param_and_grad[ + 0].dtype == core.VarDesc.VarType.FP16 + master_weight = (self._master_weights[param_and_grad[0].name] + if find_master else None) - start = len(target_block.ops) - self._update_param_device_map(parameters_and_grads, target_block) - self._create_accumulators( - target_block, - [p[0] for p in parameters_and_grads if p[0].trainable]) - self._create_global_learning_rate() + attrs = { + "mu": self._momentum, + "lars_coeff": self._lars_coeff, + "lars_weight_decay": [_lars_weight_decay], + "multi_precision": find_master, + "rescale_grad": self._rescale_grad + } + inputs = { + "Param": param_and_grad[0], + "Grad": param_and_grad[1], + "Velocity": velocity_acc, + "LearningRate": lr + } - if framework.in_dygraph_mode(): - for param_and_grad in parameters_and_grads: - if param_and_grad[0].trainable is True: - self._append_optimize_op(target_block, param_and_grad) - else: - normal_parameters_and_grad = [] - multi_precision_parameters_and_grads = [] - has_amp_lars = False - has_lars = False - for param_and_grad in parameters_and_grads: - with param_and_grad[0].block.program._optimized_guard( - param_and_grad), name_scope("optimizer"): - if param_and_grad[0].trainable is True: - device = self._get_device_for_param(param_and_grad[0] - .name) + outputs = {"ParamOut": param_and_grad[0], "VelocityOut": velocity_acc} - if len(self._exclude_from_weight_decay) > 0: - for name in self._exclude_from_weight_decay: - if name in param_and_grad[0].name: - with device_guard(device): - # While weight_decay is zero, lars is momentum. - optimize_op = self._append_optimize_op( - target_block, param_and_grad) - else: - if self._multi_precision and param_and_grad[ - 0].dtype == core.VarDesc.VarType.FP16: - has_amp_lars = True - multi_precision_parameters_and_grads.append( - param_and_grad) - else: - has_lars = True - normal_parameters_and_grad.append( - param_and_grad) - with device_guard(device): - if has_amp_lars: - multi_precision_optimize_op = self._append_optimize_op( - target_block, multi_precision_parameters_and_grads) - if has_lars: - normal_optimize_op = self._append_optimize_op( - target_block, normal_parameters_and_grad) + if find_master: + inputs["MasterParam"] = master_weight + outputs["MasterParamOut"] = master_weight - # Get custom finish ops for subclasses - # FIXME: Need to fix this once we figure out how to handle dependencies - self._finish_update(target_block, parameters_and_grads) + # create the momentum optimize op + momentum_op = block.append_op( + type=self.type if _lars_weight_decay != 0.0 else 'momentum', + inputs=inputs, + outputs=outputs, + attrs=attrs, + stop_gradient=True) - end = len(target_block.ops) - return target_block._slice_ops(start, end) + return momentum_op class AdagradOptimizer(Optimizer): diff --git a/python/paddle/fluid/tests/unittests/test_merged_lars_optimizer.py b/python/paddle/fluid/tests/unittests/test_merged_lars_optimizer.py new file mode 100644 index 0000000000000..4e1317b2b635d --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_merged_lars_optimizer.py @@ -0,0 +1,210 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import paddle +import numpy as np +from paddle.fluid.layer_helper import LayerHelper +from collections import OrderedDict + + +def run_momentum_op(params, + grads, + velocitys, + master_params, + learning_rates, + place, + multi_precision, + weight_decay, + mu=0.9, + rescale_grad=0.01, + use_merged=False): + assert len(params) == len(grads) + assert len(params) == len(velocitys) + if multi_precision: + assert len(params) == len(master_params) + op_type = 'lars_momentum' + main = paddle.static.Program() + startup = paddle.static.Program() + + with paddle.static.program_guard(main, startup): + helper = LayerHelper(op_type, **locals()) + attrs = { + 'mu': mu, + 'multi_precision': multi_precision, + 'rescale_grad': rescale_grad, + 'lars_weight_decay': weight_decay.tolist() + } + + param_vars = [ + helper.create_variable( + persistable=True, shape=p.shape, dtype=p.dtype) for p in params + ] + grad_vars = [ + helper.create_variable( + shape=g.shape, dtype=g.dtype) for g in grads + ] + velocity_vars = [ + helper.create_variable( + persistable=True, shape=v.shape, dtype=v.dtype) + for v in velocitys + ] + lr_vars = [ + helper.create_variable( + persistable=True, shape=l.shape, dtype=l.dtype) + for l in learning_rates + ] + + feed_dict = OrderedDict() + + feed_dict.update( + OrderedDict([(p_var.name, p_val) + for p_var, p_val in zip(param_vars, params)])) + feed_dict.update( + OrderedDict([(v_var.name, v_val) + for v_var, v_val in zip(velocity_vars, velocitys)])) + fetch_list = list(feed_dict.keys()) + + feed_dict.update( + OrderedDict([(g_var.name, g_val) + for g_var, g_val in zip(grad_vars, grads)])) + feed_dict.update( + OrderedDict([(lr_var.name, lr_val) + for lr_var, lr_val in zip(lr_vars, learning_rates)])) + + if multi_precision: + master_param_vars = [ + helper.create_variable( + persistable=True, shape=p.shape, dtype=p.dtype) + for p in master_params + ] + feed_dict.update( + OrderedDict([(mp_var.name, mp_val) + for mp_var, mp_val in zip(master_param_vars, + master_params)])) + # CPUPlace does not use MasterParam + if isinstance(place, paddle.CUDAPlace): + fetch_list = fetch_list + [ + mp_var.name for mp_var in master_param_vars + ] + else: + master_param_vars = None + + if not use_merged: + for i, ( + p, g, v, lr + ) in enumerate(zip(param_vars, grad_vars, velocity_vars, lr_vars)): + inputs = { + 'Param': p, + 'Grad': g, + 'Velocity': v, + 'LearningRate': lr, + } + outputs = {'ParamOut': p, 'VelocityOut': v} + if multi_precision: + inputs['MasterParam'] = master_param_vars[i] + outputs['MasterParamOut'] = master_param_vars[i] + helper.append_op( + type=op_type, inputs=inputs, outputs=outputs, attrs=attrs) + else: + inputs = { + 'Param': param_vars, + 'Grad': grad_vars, + 'Velocity': velocity_vars, + 'LearningRate': lr_vars, + } + outputs = {'ParamOut': param_vars, 'VelocityOut': velocity_vars} + if multi_precision: + inputs['MasterParam'] = master_param_vars + outputs['MasterParamOut'] = master_param_vars + helper.append_op( + type=op_type, inputs=inputs, outputs=outputs, attrs=attrs) + + exe = paddle.static.Executor(place) + with paddle.static.scope_guard(paddle.static.Scope()): + exe.run(startup) + return exe.run(main, feed=feed_dict, fetch_list=fetch_list) + + +class TestMergedMomentum(unittest.TestCase): + def setUp(self): + paddle.enable_static() + self.shapes = [[3, 4], [2, 7], [5, 6], [7, 8]] + self.seed = 10 + + def gen_rand_data(self, shapes, dtype): + return [np.random.random(s).astype(dtype) for s in shapes] + + def gen_rand_lr(self, shapes, dtype): + lr = np.random.random(1).astype(dtype) + return [np.ones(1).astype(dtype) * lr for s in range(len(shapes))] + + def prepare_data(self, shapes, multi_precision, seed, place): + np.random.seed(seed) + mp_dtype = np.float32 + dtype = np.float16 if multi_precision and isinstance( + place, paddle.CUDAPlace) else np.float32 + params = self.gen_rand_data(shapes, dtype) + grads = self.gen_rand_data(shapes, dtype) + velocitys = self.gen_rand_data(shapes, mp_dtype) + weight_decay = self.gen_rand_data([[1]], mp_dtype)[0] + learning_rates = self.gen_rand_lr(shapes, mp_dtype) + if multi_precision: + master_params = [p.astype(mp_dtype) for p in params] + else: + master_params = None + return params, grads, velocitys, master_params, learning_rates, weight_decay + + def check_with_place(self, place, multi_precision): + params, grads, velocitys, master_params, learning_rates, weight_decay = self.prepare_data( + self.shapes, multi_precision, self.seed, place) + + def run_op(merge_option): + # CPU Momentum Op does not support rescale_grad + rescale_grad = 1.0 if isinstance(place, paddle.CPUPlace) else 0.01 + return run_momentum_op( + params, + grads, + velocitys, + master_params, + learning_rates, + place, + multi_precision, + weight_decay, + rescale_grad=rescale_grad, + use_merged=merge_option) + + outs1 = run_op(True) + outs2 = run_op(False) + self.assertEqual(len(outs1), len(outs2)) + for i, (out1, out2) in enumerate(zip(outs1, outs2)): + if isinstance(place, paddle.CUDAPlace): + self.assertTrue(np.array_equal(out1, out2)) + else: + self.assertTrue(np.allclose(out1, out2, atol=1e-7)) + + def get_places(self): + places = [paddle.CPUPlace()] + if paddle.is_compiled_with_cuda(): + places.append(paddle.CUDAPlace(0)) + return places + + def test_main(self): + for multi_precision in [False, True]: + for place in self.get_places(): + self.check_with_place(place, multi_precision) + + +if __name__ == "__main__": + unittest.main() From c6ef005fd4abba59d7b0ef92386bbe722220e514 Mon Sep 17 00:00:00 2001 From: JamesLim-sy Date: Mon, 18 Oct 2021 12:37:06 +0000 Subject: [PATCH 09/12] fix code according to comments. --- .../operators/optimizers/lars_momentum_op.cu | 131 ++++++++++-------- python/paddle/fluid/optimizer.py | 1 - 2 files changed, 70 insertions(+), 62 deletions(-) diff --git a/paddle/fluid/operators/optimizers/lars_momentum_op.cu b/paddle/fluid/operators/optimizers/lars_momentum_op.cu index e12a2c902884e..fa859c8e7a1cd 100644 --- a/paddle/fluid/operators/optimizers/lars_momentum_op.cu +++ b/paddle/fluid/operators/optimizers/lars_momentum_op.cu @@ -140,11 +140,13 @@ __forceinline__ __device__ void L2NormKernel( template __global__ void L2NormKernel( #endif - const T* p_data, const T* __restrict__ g_data, MT* __restrict__ p_buffer, - MT* __restrict__ g_buffer, const int64_t numel, const MT rescale_grad, - MT* __restrict__ p_n = nullptr, MT* __restrict__ g_n = nullptr) { + const T* p_data, const T* __restrict__ g_data, + MT* __restrict__ buffer_for_param_norm, + MT* __restrict__ buffer_for_grad_norm, const int64_t numel, + const MT rescale_grad, MT* __restrict__ p_n = nullptr, + MT* __restrict__ g_n = nullptr) { int tid = threadIdx.x + blockDim.x * blockIdx.x; - int grid_stride = LARS_BLOCK_SIZE * gridDim.x; + int grid_stride = blockDim.x * gridDim.x; MT p_tmp = static_cast(0); MT g_tmp = static_cast(0); @@ -159,14 +161,16 @@ __global__ void L2NormKernel( g_tmp = math::blockReduceSum(g_tmp, FINAL_MASK); if (threadIdx.x == 0) { - p_buffer[blockIdx.x] = p_tmp; - g_buffer[blockIdx.x] = g_tmp; + buffer_for_param_norm[blockIdx.x] = p_tmp; + buffer_for_grad_norm[blockIdx.x] = g_tmp; } #if CUDA_VERSION >= 11000 __shared__ MT s_buffer[2]; cg->sync(); // Grid sync for writring partial result to gloabl memory - MT p_part_sum = threadIdx.x < gridDim.x ? p_buffer[threadIdx.x] : 0; - MT g_part_sum = threadIdx.x < gridDim.x ? g_buffer[threadIdx.x] : 0; + MT p_part_sum = + threadIdx.x < gridDim.x ? buffer_for_param_norm[threadIdx.x] : 0; + MT g_part_sum = + threadIdx.x < gridDim.x ? buffer_for_grad_norm[threadIdx.x] : 0; MT tmp0 = math::blockReduceSum(p_part_sum, FINAL_MASK); MT tmp1 = math::blockReduceSum(g_part_sum, FINAL_MASK); if (threadIdx.x == 0) { @@ -217,15 +221,15 @@ __forceinline__ __device__ void MomentumUpdate( } #if CUDA_VERSION >= 11000 -template -struct MergedLarsMasterParam { +template +struct MasterParamHelper { DEVICE inline MT* GetMasterParam(size_t) const { return nullptr; } constexpr void SetMasterParam(size_t, MT*) {} }; -template -struct MergedLarsMasterParam { - MT* master_params[kOpNum]; +template +struct MasterParamHelper { + MT* master_params[OpNum]; DEVICE inline MT* GetMasterParam(size_t idx) const { return master_params[idx]; @@ -234,28 +238,28 @@ struct MergedLarsMasterParam { }; template ::value ? 80 : 100> -struct LarsParamWarpper : public MergedLarsMasterParam { - static constexpr int kNum = kOpNum; - - int numel_arr[kOpNum]; - const MT* __restrict__ lr_arr[kOpNum]; - const T* __restrict__ g_arr[kOpNum]; - T* p_arr[kOpNum]; - MT* v_arr[kOpNum]; +struct LarsParamWarpper : public MasterParamHelper { + static constexpr int kNum = OpNum; + + int numel_arr[OpNum]; + const MT* __restrict__ lr_arr[OpNum]; + const T* __restrict__ g_arr[OpNum]; + T* p_arr[OpNum]; + MT* v_arr[OpNum]; MT weight_decay; }; template __global__ void MergedMomentumLarsKernel(LarsParamWarpper lars_warpper, - MT* __restrict__ p_buffer, - MT* __restrict__ g_buffer, + MT* __restrict__ buffer_for_param_norm, + MT* __restrict__ buffer_for_grad_norm, const int op_num, const MT mu, const MT lars_coeff, const MT epsilon, const MT rescale_grad, const bool is_amp) { - int grid_stride = gridDim.x * LARS_BLOCK_SIZE; + int grid_stride = gridDim.x * blockDim.x; int tid = threadIdx.x + blockIdx.x * blockDim.x; const cooperative_groups::grid_group cg = cooperative_groups::this_grid(); for (int i = 0; i < op_num; ++i) { @@ -263,8 +267,8 @@ __global__ void MergedMomentumLarsKernel(LarsParamWarpper lars_warpper, MT param_norm = static_cast(0); MT grad_norm = static_cast(0); L2NormKernel(&cg, lars_warpper.p_arr[i], lars_warpper.g_arr[i], - p_buffer, g_buffer, numel, rescale_grad, ¶m_norm, - &grad_norm); + buffer_for_param_norm, buffer_for_grad_norm, numel, + rescale_grad, ¶m_norm, &grad_norm); MomentumUpdate(lars_warpper.p_arr[i], lars_warpper.g_arr[i], lars_warpper.v_arr[i], lars_warpper.p_arr[i], lars_warpper.v_arr[i], lars_warpper.GetMasterParam(i), @@ -281,22 +285,23 @@ __global__ void MomentumLarsKernel( const T* param, const T* __restrict__ grad, const MT* velocity, T* param_out, MT* velocity_out, const MT* master_param, MT* master_param_out, const MT* __restrict__ learning_rate, - MT* __restrict__ p_buffer, MT* __restrict__ g_buffer, const MT mu, - const MT lars_coeff, const MT lars_weight_decay, const MT epsilon, - const MT rescale_grad, const int thresh, const int64_t numel, - const bool is_amp) { + MT* __restrict__ buffer_for_param_norm, + MT* __restrict__ buffer_for_grad_norm, const MT mu, const MT lars_coeff, + const MT lars_weight_decay, const MT epsilon, const MT rescale_grad, + const int thresh, const int64_t numel, const bool is_amp) { int tid = threadIdx.x + blockIdx.x * blockDim.x; - int grid_stride = gridDim.x * LARS_BLOCK_SIZE; + int grid_stride = gridDim.x * blockDim.x; #if CUDA_VERSION >= 11000 const cooperative_groups::grid_group cg = cooperative_groups::this_grid(); MT param_norm = static_cast(0); MT grad_norm = static_cast(0); - L2NormKernel(&cg, param, grad, p_buffer, g_buffer, numel, rescale_grad, - ¶m_norm, &grad_norm); + L2NormKernel(&cg, param, grad, buffer_for_param_norm, + buffer_for_grad_norm, numel, rescale_grad, ¶m_norm, + &grad_norm); #else __shared__ MT s_buffer[2]; - MT p_part_sum = threadIdx.x < thresh ? p_buffer[threadIdx.x] : 0; - MT g_part_sum = threadIdx.x < thresh ? g_buffer[threadIdx.x] : 0; + MT p_part_sum = threadIdx.x < thresh ? buffer_for_param_norm[threadIdx.x] : 0; + MT g_part_sum = threadIdx.x < thresh ? buffer_for_grad_norm[threadIdx.x] : 0; MT tmp0 = math::blockReduceSum(p_part_sum, FINAL_MASK); MT tmp1 = math::blockReduceSum(g_part_sum, FINAL_MASK); if (threadIdx.x == 0) { @@ -318,20 +323,22 @@ template inline void SeparatedLarsMomentumOpCUDAKernel( const platform::CUDADeviceContext& cuda_ctx, const T* param_data, T* param_out_data, const MT* velocity_data, MT* velocity_out_data, - const T* grad_data, const MT* lr, MT* p_buffer, MT* g_buffer, const MT mu, - const MT lars_coeff, const MT weight_decay, const MT epsilon, - const MT rescale_grad, const int64_t numel, const MT* master_param_data, - MT* master_out_data, const bool is_amp) { + const T* grad_data, const MT* lr, MT* buffer_for_param_norm, + MT* buffer_for_grad_norm, const MT mu, const MT lars_coeff, + const MT weight_decay, const MT epsilon, const MT rescale_grad, + const int64_t numel, const MT* master_param_data, MT* master_out_data, + const bool is_amp) { LarsThreadConfig lars_thread_config(numel); L2NormKernel<<>>(param_data, grad_data, p_buffer, - g_buffer, numel, rescale_grad); + cuda_ctx.stream()>>>( + param_data, grad_data, buffer_for_param_norm, buffer_for_grad_norm, numel, + rescale_grad); MomentumLarsKernel<<>>( param_data, grad_data, velocity_data, param_out_data, velocity_out_data, - master_param_data, master_out_data, lr, p_buffer, g_buffer, mu, - lars_coeff, weight_decay, epsilon, rescale_grad, + master_param_data, master_out_data, lr, buffer_for_param_norm, + buffer_for_grad_norm, mu, lars_coeff, weight_decay, epsilon, rescale_grad, lars_thread_config.grid_for_norm, numel, is_amp); } template @@ -344,11 +351,12 @@ class LarsMomentumOpCUDAKernel : public framework::OpKernel { bool multi_precision = ctx.Attr("multi_precision"); auto& cuda_ctx = ctx.template device_context(); int sm_num = cuda_ctx.GetSMCount(); - framework::Tensor tmp_buffer_t = + framework::Tensor tmbuffer_for_param_norm_t = ctx.AllocateTmpTensor( {LARS_BLOCK_SIZE << 1}, cuda_ctx); - auto* p_buffer = tmp_buffer_t.mutable_data(ctx.GetPlace()); - auto* g_buffer = p_buffer + LARS_BLOCK_SIZE; + auto* buffer_for_param_norm = + tmbuffer_for_param_norm_t.mutable_data(ctx.GetPlace()); + auto* buffer_for_grad_norm = buffer_for_param_norm + LARS_BLOCK_SIZE; MT mu = static_cast(ctx.Attr("mu")); MT lars_coeff = static_cast(ctx.Attr("lars_coeff")); @@ -384,13 +392,13 @@ class LarsMomentumOpCUDAKernel : public framework::OpKernel { sizeof(MT) << 1); lars_warpper.weight_decay = lars_weight_decay; - int merge_times = (op_num + lars_warpper.kNum - 1) / lars_warpper.kNum; - for (int j = 0; j < merge_times; ++j) { + int loop = (op_num + lars_warpper.kNum - 1) / lars_warpper.kNum; + for (int j = 0; j < loop; ++j) { size_t total_numel = 0; int start_idx = j * lars_warpper.kNum; - int loop_num = std::min(lars_warpper.kNum, op_num - start_idx); + int warpper_num = std::min(lars_warpper.kNum, op_num - start_idx); - for (int i = 0; i < loop_num; ++i) { + for (int i = 0; i < warpper_num; ++i) { size_t temp_numel = param[start_idx + i]->numel(); total_numel += temp_numel; lars_warpper.numel_arr[i] = temp_numel; @@ -423,14 +431,14 @@ class LarsMomentumOpCUDAKernel : public framework::OpKernel { "lars optimizer must be " "the same Tensors.")); } - VLOG(10) << "Op number delt in loop " << j << " is : " << loop_num; - int64_t avg_numel = total_numel / loop_num; + VLOG(10) << "Ops warpped in this loop " << j << " is : " << warpper_num; + int64_t avg_numel = total_numel / warpper_num; LarsThreadConfig lars_thread_config(avg_numel, sm_num, num_blocks_per_sm); void* cuda_param[] = {reinterpret_cast(&lars_warpper), - reinterpret_cast(&p_buffer), - reinterpret_cast(&g_buffer), - reinterpret_cast(&loop_num), + reinterpret_cast(&buffer_for_param_norm), + reinterpret_cast(&buffer_for_grad_norm), + reinterpret_cast(&warpper_num), reinterpret_cast(&mu), reinterpret_cast(&lars_coeff), reinterpret_cast(&epsilon), @@ -474,8 +482,8 @@ class LarsMomentumOpCUDAKernel : public framework::OpKernel { reinterpret_cast(&master_param_data), reinterpret_cast(&master_param_out_data), reinterpret_cast(&lr), - reinterpret_cast(&p_buffer), - reinterpret_cast(&g_buffer), + reinterpret_cast(&buffer_for_param_norm), + reinterpret_cast(&buffer_for_grad_norm), reinterpret_cast(&mu), reinterpret_cast(&lars_coeff), reinterpret_cast(&lars_weight_decay), @@ -503,9 +511,10 @@ class LarsMomentumOpCUDAKernel : public framework::OpKernel { param_out[i]->mutable_data(ctx.GetPlace()), velocity[i]->data(), velocity_out[i]->mutable_data(ctx.GetPlace()), grad[i]->data(), - learning_rate[i]->data(), p_buffer, g_buffer, mu, lars_coeff, - lars_weight_decay, epsilon, rescale_grad, param[i]->numel(), - master_param_data, master_param_out_data, multi_precision); + learning_rate[i]->data(), buffer_for_param_norm, + buffer_for_grad_norm, mu, lars_coeff, lars_weight_decay, epsilon, + rescale_grad, param[i]->numel(), master_param_data, + master_param_out_data, multi_precision); } #endif } diff --git a/python/paddle/fluid/optimizer.py b/python/paddle/fluid/optimizer.py index 94f66257917cc..07566e9e9a678 100755 --- a/python/paddle/fluid/optimizer.py +++ b/python/paddle/fluid/optimizer.py @@ -1961,7 +1961,6 @@ def __init__(self, exclude_from_weight_decay=None, epsilon=0, multi_precision=False, - merge_option=False, rescale_grad=1.0): assert learning_rate is not None assert momentum is not None From 47ba53b18e4bca4f90c464bf071b29a7cc1342c9 Mon Sep 17 00:00:00 2001 From: JamesLim-sy Date: Tue, 19 Oct 2021 19:01:12 +0000 Subject: [PATCH 10/12] change the format of lars_weight_decay from scalar into vector --- .../operators/optimizers/lars_momentum_op.cu | 29 ++++++++++--------- 1 file changed, 15 insertions(+), 14 deletions(-) diff --git a/paddle/fluid/operators/optimizers/lars_momentum_op.cu b/paddle/fluid/operators/optimizers/lars_momentum_op.cu index fa859c8e7a1cd..8dc583876adfe 100644 --- a/paddle/fluid/operators/optimizers/lars_momentum_op.cu +++ b/paddle/fluid/operators/optimizers/lars_momentum_op.cu @@ -239,7 +239,7 @@ struct MasterParamHelper { template ::value ? 80 : 100> + std::is_same::value ? 80 : 90> struct LarsParamWarpper : public MasterParamHelper { static constexpr int kNum = OpNum; @@ -248,7 +248,7 @@ struct LarsParamWarpper : public MasterParamHelper { const T* __restrict__ g_arr[OpNum]; T* p_arr[OpNum]; MT* v_arr[OpNum]; - MT weight_decay; + MT weight_decay[OpNum]; }; template @@ -269,13 +269,13 @@ __global__ void MergedMomentumLarsKernel(LarsParamWarpper lars_warpper, L2NormKernel(&cg, lars_warpper.p_arr[i], lars_warpper.g_arr[i], buffer_for_param_norm, buffer_for_grad_norm, numel, rescale_grad, ¶m_norm, &grad_norm); - MomentumUpdate(lars_warpper.p_arr[i], lars_warpper.g_arr[i], - lars_warpper.v_arr[i], lars_warpper.p_arr[i], - lars_warpper.v_arr[i], lars_warpper.GetMasterParam(i), - lars_warpper.GetMasterParam(i), - lars_warpper.lr_arr[i], mu, lars_warpper.weight_decay, - lars_coeff, epsilon, rescale_grad, param_norm, - grad_norm, tid, grid_stride, numel, is_amp); + MomentumUpdate( + lars_warpper.p_arr[i], lars_warpper.g_arr[i], lars_warpper.v_arr[i], + lars_warpper.p_arr[i], lars_warpper.v_arr[i], + lars_warpper.GetMasterParam(i), lars_warpper.GetMasterParam(i), + lars_warpper.lr_arr[i], mu, lars_warpper.weight_decay[i], lars_coeff, + epsilon, rescale_grad, param_norm, grad_norm, tid, grid_stride, numel, + is_amp); } } #endif @@ -363,7 +363,6 @@ class LarsMomentumOpCUDAKernel : public framework::OpKernel { MT epsilon = static_cast(ctx.Attr("epsilon")); MT rescale_grad = static_cast(ctx.Attr("rescale_grad")); auto weight_decay_arr = ctx.Attr>("lars_weight_decay"); - MT lars_weight_decay = weight_decay_arr[0]; auto grad = ctx.MultiInput("Grad"); auto param = ctx.MultiInput("Param"); @@ -391,7 +390,6 @@ class LarsMomentumOpCUDAKernel : public framework::OpKernel { &num_blocks_per_sm, MergedMomentumLarsKernel, LARS_BLOCK_SIZE, sizeof(MT) << 1); - lars_warpper.weight_decay = lars_weight_decay; int loop = (op_num + lars_warpper.kNum - 1) / lars_warpper.kNum; for (int j = 0; j < loop; ++j) { size_t total_numel = 0; @@ -402,6 +400,7 @@ class LarsMomentumOpCUDAKernel : public framework::OpKernel { size_t temp_numel = param[start_idx + i]->numel(); total_numel += temp_numel; lars_warpper.numel_arr[i] = temp_numel; + lars_warpper.weight_decay[i] = static_cast(weight_decay_arr[i]); lars_warpper.g_arr[i] = grad[start_idx + i]->data(); lars_warpper.p_arr[i] = param_out[start_idx + i]->mutable_data(ctx.GetPlace()); @@ -465,6 +464,7 @@ class LarsMomentumOpCUDAKernel : public framework::OpKernel { ? master_param_out[0]->mutable_data(ctx.GetPlace()) : nullptr; int64_t numel = param[0]->numel(); + MT lars_weight_decay = static_cast(weight_decay_arr[0]); // Figure out how many blocks can be active in each sm. cudaOccupancyMaxActiveBlocksPerMultiprocessor( @@ -512,9 +512,10 @@ class LarsMomentumOpCUDAKernel : public framework::OpKernel { velocity[i]->data(), velocity_out[i]->mutable_data(ctx.GetPlace()), grad[i]->data(), learning_rate[i]->data(), buffer_for_param_norm, - buffer_for_grad_norm, mu, lars_coeff, lars_weight_decay, epsilon, - rescale_grad, param[i]->numel(), master_param_data, - master_param_out_data, multi_precision); + buffer_for_grad_norm, mu, lars_coeff, + static_cast(weight_decay_arr[i]), epsilon, rescale_grad, + param[i]->numel(), master_param_data, master_param_out_data, + multi_precision); } #endif } From 82dd12ab9f908a8c47c4f3a94bc5303b59e8733b Mon Sep 17 00:00:00 2001 From: JamesLim-sy Date: Tue, 19 Oct 2021 19:02:18 +0000 Subject: [PATCH 11/12] change the format of lars_weight_decay from scalar into vector --- paddle/fluid/operators/optimizers/lars_momentum_op.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/fluid/operators/optimizers/lars_momentum_op.cu b/paddle/fluid/operators/optimizers/lars_momentum_op.cu index 8dc583876adfe..5c31ebff52ddc 100644 --- a/paddle/fluid/operators/optimizers/lars_momentum_op.cu +++ b/paddle/fluid/operators/optimizers/lars_momentum_op.cu @@ -311,7 +311,6 @@ __global__ void MomentumLarsKernel( __syncthreads(); MT param_norm = Sqrt(s_buffer[0]); MT grad_norm = rescale_grad * Sqrt(s_buffer[1]); - #endif MomentumUpdate(param, grad, velocity, param_out, velocity_out, master_param, master_param_out, learning_rate, mu, @@ -341,6 +340,7 @@ inline void SeparatedLarsMomentumOpCUDAKernel( buffer_for_grad_norm, mu, lars_coeff, weight_decay, epsilon, rescale_grad, lars_thread_config.grid_for_norm, numel, is_amp); } + template class LarsMomentumOpCUDAKernel : public framework::OpKernel { using MT = MultiPrecisionType; From c5d06e0787bb0bba3871bd844f225724baf27619 Mon Sep 17 00:00:00 2001 From: JamesLim-sy Date: Wed, 20 Oct 2021 11:52:00 +0000 Subject: [PATCH 12/12] fix lars_weight_decay from a scalar into a vector. --- .../operators/optimizers/lars_momentum_op.cu | 4 ++ .../unittests/test_merged_lars_optimizer.py | 49 ++++++++++--------- 2 files changed, 31 insertions(+), 22 deletions(-) diff --git a/paddle/fluid/operators/optimizers/lars_momentum_op.cu b/paddle/fluid/operators/optimizers/lars_momentum_op.cu index 5c31ebff52ddc..efe9b622685f3 100644 --- a/paddle/fluid/operators/optimizers/lars_momentum_op.cu +++ b/paddle/fluid/operators/optimizers/lars_momentum_op.cu @@ -375,6 +375,10 @@ class LarsMomentumOpCUDAKernel : public framework::OpKernel { ctx.MultiOutput("MasterParamOut"); int op_num = grad.size(); + PADDLE_ENFORCE_EQ(weight_decay_arr.size(), op_num, + platform::errors::InvalidArgument( + "Since Input(lars_weight_decay) and Iutput(grad) of " + "lars optimizer must be the same size.")); #if CUDA_VERSION >= 11000 if (op_num > 1) { LarsParamWarpper lars_warpper; diff --git a/python/paddle/fluid/tests/unittests/test_merged_lars_optimizer.py b/python/paddle/fluid/tests/unittests/test_merged_lars_optimizer.py index 4e1317b2b635d..23c0d6206958e 100644 --- a/python/paddle/fluid/tests/unittests/test_merged_lars_optimizer.py +++ b/python/paddle/fluid/tests/unittests/test_merged_lars_optimizer.py @@ -19,17 +19,17 @@ from collections import OrderedDict -def run_momentum_op(params, - grads, - velocitys, - master_params, - learning_rates, - place, - multi_precision, - weight_decay, - mu=0.9, - rescale_grad=0.01, - use_merged=False): +def run_lars_momentum_op(params, + grads, + velocitys, + master_params, + learning_rates, + place, + multi_precision, + weight_decays, + mu=0.9, + rescale_grad=0.01, + use_merged=False): assert len(params) == len(grads) assert len(params) == len(velocitys) if multi_precision: @@ -44,9 +44,7 @@ def run_momentum_op(params, 'mu': mu, 'multi_precision': multi_precision, 'rescale_grad': rescale_grad, - 'lars_weight_decay': weight_decay.tolist() } - param_vars = [ helper.create_variable( persistable=True, shape=p.shape, dtype=p.dtype) for p in params @@ -111,6 +109,7 @@ def run_momentum_op(params, 'Velocity': v, 'LearningRate': lr, } + attrs['lars_weight_decay'] = [float(weight_decays[i])] outputs = {'ParamOut': p, 'VelocityOut': v} if multi_precision: inputs['MasterParam'] = master_param_vars[i] @@ -118,12 +117,17 @@ def run_momentum_op(params, helper.append_op( type=op_type, inputs=inputs, outputs=outputs, attrs=attrs) else: + lars_weight_decay = [] + for decay in weight_decays: + lars_weight_decay.append(float(decay)) + inputs = { 'Param': param_vars, 'Grad': grad_vars, 'Velocity': velocity_vars, 'LearningRate': lr_vars, } + attrs['lars_weight_decay'] = lars_weight_decay outputs = {'ParamOut': param_vars, 'VelocityOut': velocity_vars} if multi_precision: inputs['MasterParam'] = master_param_vars @@ -146,9 +150,11 @@ def setUp(self): def gen_rand_data(self, shapes, dtype): return [np.random.random(s).astype(dtype) for s in shapes] - def gen_rand_lr(self, shapes, dtype): - lr = np.random.random(1).astype(dtype) - return [np.ones(1).astype(dtype) * lr for s in range(len(shapes))] + def gen_lr_and_decay(self, shapes, dtype): + data = np.random.random(1).astype(dtype) + lr_rates = [np.ones(1).astype(dtype) * data for s in range(len(shapes))] + weight_decays = data * np.ones(len(shapes), dtype=np.float32) + return lr_rates, weight_decays def prepare_data(self, shapes, multi_precision, seed, place): np.random.seed(seed) @@ -158,22 +164,21 @@ def prepare_data(self, shapes, multi_precision, seed, place): params = self.gen_rand_data(shapes, dtype) grads = self.gen_rand_data(shapes, dtype) velocitys = self.gen_rand_data(shapes, mp_dtype) - weight_decay = self.gen_rand_data([[1]], mp_dtype)[0] - learning_rates = self.gen_rand_lr(shapes, mp_dtype) + learning_rates, weight_decays = self.gen_lr_and_decay(shapes, mp_dtype) if multi_precision: master_params = [p.astype(mp_dtype) for p in params] else: master_params = None - return params, grads, velocitys, master_params, learning_rates, weight_decay + return params, grads, velocitys, master_params, learning_rates, weight_decays def check_with_place(self, place, multi_precision): - params, grads, velocitys, master_params, learning_rates, weight_decay = self.prepare_data( + params, grads, velocitys, master_params, learning_rates, weight_decays = self.prepare_data( self.shapes, multi_precision, self.seed, place) def run_op(merge_option): # CPU Momentum Op does not support rescale_grad rescale_grad = 1.0 if isinstance(place, paddle.CPUPlace) else 0.01 - return run_momentum_op( + return run_lars_momentum_op( params, grads, velocitys, @@ -181,7 +186,7 @@ def run_op(merge_option): learning_rates, place, multi_precision, - weight_decay, + weight_decays, rescale_grad=rescale_grad, use_merged=merge_option)