Skip to content

Commit

Permalink
fix lars (#36431)
Browse files Browse the repository at this point in the history
  • Loading branch information
sneaxiy authored Oct 14, 2021
1 parent 3cf5764 commit 8256f6f
Showing 1 changed file with 31 additions and 10 deletions.
41 changes: 31 additions & 10 deletions paddle/fluid/operators/optimizers/lars_momentum_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -165,8 +165,10 @@ __global__ void L2NormKernel(
int tid = threadIdx.x + blockDim.x * blockIdx.x;
int grid_stride = LARS_BLOCK_SIZE * gridDim.x;
const MT rescale_pow = rescale_grad * rescale_grad;
s_buffer[0] = static_cast<MT>(0);
s_buffer[1] = static_cast<MT>(0);
if (threadIdx.x == 0) {
s_buffer[0] = static_cast<MT>(0);
s_buffer[1] = static_cast<MT>(0);
}
MT p_tmp = static_cast<MT>(0);
MT g_tmp = static_cast<MT>(0);

Expand All @@ -175,8 +177,12 @@ __global__ void L2NormKernel(
p_tmp = static_cast<MT>(p_data[tid]);
g_tmp = static_cast<MT>(g_data[tid]);
}
s_buffer[0] += math::blockReduceSum<MT>(p_tmp * p_tmp, FINAL_MASK);
s_buffer[1] += math::blockReduceSum<MT>(g_tmp * g_tmp, FINAL_MASK);
MT tmp0 = math::blockReduceSum<MT>(p_tmp * p_tmp, FINAL_MASK);
MT tmp1 = math::blockReduceSum<MT>(g_tmp * g_tmp, FINAL_MASK);
if (threadIdx.x == 0) {
s_buffer[0] += tmp0;
s_buffer[1] += tmp1;
}
} else {
/* Avoid occupy too much temp buffer. Slice the whole data into 2 parts,
the front of data whose quantity is excatly multiple of grid-thread
Expand All @@ -185,8 +191,12 @@ __global__ void L2NormKernel(
p_tmp = static_cast<MT>(p_data[tid]);
g_tmp = static_cast<MT>(g_data[tid]);
tid += grid_stride;
s_buffer[0] += math::blockReduceSum<MT>(p_tmp * p_tmp, FINAL_MASK);
s_buffer[1] += math::blockReduceSum<MT>(g_tmp * g_tmp, FINAL_MASK);
MT tmp0 = math::blockReduceSum<MT>(p_tmp * p_tmp, FINAL_MASK);
MT tmp1 = math::blockReduceSum<MT>(g_tmp * g_tmp, FINAL_MASK);
if (threadIdx.x == 0) {
s_buffer[0] += tmp0;
s_buffer[1] += tmp1;
}
__syncthreads();
}
MT p_val = 0;
Expand All @@ -195,8 +205,12 @@ __global__ void L2NormKernel(
p_val = static_cast<MT>(p_data[tid]);
g_val = static_cast<MT>(g_data[tid]);
}
s_buffer[0] += math::blockReduceSum<MT>(p_val * p_val, FINAL_MASK);
s_buffer[1] += math::blockReduceSum<MT>(g_val * g_val, FINAL_MASK);
MT tmp0 = math::blockReduceSum<MT>(p_val * p_val, FINAL_MASK);
MT tmp1 = math::blockReduceSum<MT>(g_val * g_val, FINAL_MASK);
if (threadIdx.x == 0) {
s_buffer[0] += tmp0;
s_buffer[1] += tmp1;
}
}
__syncthreads();

Expand All @@ -208,8 +222,15 @@ __global__ void L2NormKernel(
cg->sync(); // Grid sync for writring partial result to gloabl memory
MT p_part_sum = threadIdx.x < gridDim.x ? p_buffer[threadIdx.x] : 0;
MT g_part_sum = threadIdx.x < gridDim.x ? g_buffer[threadIdx.x] : 0;
*p_n = Sqrt(math::blockReduceSum<MT>(p_part_sum, FINAL_MASK));
*g_n = Sqrt(rescale_pow * math::blockReduceSum<MT>(g_part_sum, FINAL_MASK));
MT tmp0 = math::blockReduceSum<MT>(p_part_sum, FINAL_MASK);
MT tmp1 = math::blockReduceSum<MT>(g_part_sum, FINAL_MASK);
if (threadIdx.x == 0) {
s_buffer[0] = tmp0;
s_buffer[1] = tmp1;
}
__syncthreads();
*p_n = Sqrt(s_buffer[0]);
*g_n = Sqrt(rescale_pow * s_buffer[1]);
#endif
}

Expand Down

0 comments on commit 8256f6f

Please sign in to comment.