From 4e036fa1a0c21b5b089809f575d37b2a0e6538da Mon Sep 17 00:00:00 2001 From: Zeng Jinle <32832641+sneaxiy@users.noreply.github.com> Date: Sun, 17 Oct 2021 23:01:23 +0800 Subject: [PATCH] refine rescale_grad (#36490) --- paddle/fluid/operators/optimizers/lars_momentum_op.cu | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/paddle/fluid/operators/optimizers/lars_momentum_op.cu b/paddle/fluid/operators/optimizers/lars_momentum_op.cu index 89326679d5d50..2c27a2135c14b 100644 --- a/paddle/fluid/operators/optimizers/lars_momentum_op.cu +++ b/paddle/fluid/operators/optimizers/lars_momentum_op.cu @@ -160,7 +160,6 @@ __global__ void L2NormKernel( __shared__ MT s_buffer[2]; int tid = threadIdx.x + blockDim.x * blockIdx.x; int grid_stride = LARS_BLOCK_SIZE * gridDim.x; - const MT rescale_pow = rescale_grad * rescale_grad; MT p_tmp = static_cast(0); MT g_tmp = static_cast(0); @@ -190,7 +189,7 @@ __global__ void L2NormKernel( } __syncthreads(); *p_n = Sqrt(s_buffer[0]); - *g_n = Sqrt(rescale_pow * s_buffer[1]); + *g_n = rescale_grad * Sqrt(s_buffer[1]); #endif }