Skip to content

Commit

Permalink
fix lars
Browse files Browse the repository at this point in the history
  • Loading branch information
sneaxiy committed Oct 14, 2021
1 parent f4eda86 commit 7d3d24e
Showing 1 changed file with 31 additions and 10 deletions.
41 changes: 31 additions & 10 deletions paddle/fluid/operators/optimizers/lars_momentum_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -165,8 +165,10 @@ __global__ void L2NormKernel(
int tid = threadIdx.x + blockDim.x * blockIdx.x;
int grid_stride = LARS_BLOCK_SIZE * gridDim.x;
const MT rescale_pow = rescale_grad * rescale_grad;
s_buffer[0] = static_cast<MT>(0);
s_buffer[1] = static_cast<MT>(0);
if (threadIdx.x == 0) {
s_buffer[0] = static_cast<MT>(0);
s_buffer[1] = static_cast<MT>(0);
}
MT p_tmp = static_cast<MT>(0);
MT g_tmp = static_cast<MT>(0);

Expand All @@ -175,8 +177,12 @@ __global__ void L2NormKernel(
p_tmp = static_cast<MT>(p_data[tid]);
g_tmp = static_cast<MT>(g_data[tid]);
}
s_buffer[0] += math::blockReduceSum<MT>(p_tmp * p_tmp, FINAL_MASK);
s_buffer[1] += math::blockReduceSum<MT>(g_tmp * g_tmp, FINAL_MASK);
MT tmp0 = math::blockReduceSum<MT>(p_tmp * p_tmp, FINAL_MASK);
MT tmp1 = math::blockReduceSum<MT>(g_tmp * g_tmp, FINAL_MASK);
if (threadIdx.x == 0) {
s_buffer[0] += tmp0;
s_buffer[1] += tmp1;
}
} else {
/* Avoid occupy too much temp buffer. Slice the whole data into 2 parts,
the front of data whose quantity is excatly multiple of grid-thread
Expand All @@ -185,8 +191,12 @@ __global__ void L2NormKernel(
p_tmp = static_cast<MT>(p_data[tid]);
g_tmp = static_cast<MT>(g_data[tid]);
tid += grid_stride;
s_buffer[0] += math::blockReduceSum<MT>(p_tmp * p_tmp, FINAL_MASK);
s_buffer[1] += math::blockReduceSum<MT>(g_tmp * g_tmp, FINAL_MASK);
MT tmp0 = math::blockReduceSum<MT>(p_tmp * p_tmp, FINAL_MASK);
MT tmp1 = math::blockReduceSum<MT>(g_tmp * g_tmp, FINAL_MASK);
if (threadIdx.x == 0) {
s_buffer[0] += tmp0;
s_buffer[1] += tmp1;
}
__syncthreads();
}
MT p_val = 0;
Expand All @@ -195,8 +205,12 @@ __global__ void L2NormKernel(
p_val = static_cast<MT>(p_data[tid]);
g_val = static_cast<MT>(g_data[tid]);
}
s_buffer[0] += math::blockReduceSum<MT>(p_val * p_val, FINAL_MASK);
s_buffer[1] += math::blockReduceSum<MT>(g_val * g_val, FINAL_MASK);
MT tmp0 = math::blockReduceSum<MT>(p_val * p_val, FINAL_MASK);
MT tmp1 = math::blockReduceSum<MT>(g_val * g_val, FINAL_MASK);
if (threadIdx.x == 0) {
s_buffer[0] += tmp0;
s_buffer[1] += tmp1;
}
}
__syncthreads();

Expand All @@ -208,8 +222,15 @@ __global__ void L2NormKernel(
cg->sync(); // Grid sync for writring partial result to gloabl memory
MT p_part_sum = threadIdx.x < gridDim.x ? p_buffer[threadIdx.x] : 0;
MT g_part_sum = threadIdx.x < gridDim.x ? g_buffer[threadIdx.x] : 0;
*p_n = Sqrt(math::blockReduceSum<MT>(p_part_sum, FINAL_MASK));
*g_n = Sqrt(rescale_pow * math::blockReduceSum<MT>(g_part_sum, FINAL_MASK));
MT tmp0 = math::blockReduceSum<MT>(p_part_sum, FINAL_MASK);
MT tmp1 = math::blockReduceSum<MT>(g_part_sum, FINAL_MASK);
if (threadIdx.x == 0) {
s_buffer[0] = tmp0;
s_buffer[1] = tmp1;
}
__syncthreads();
*p_n = Sqrt(s_buffer[0]);
*g_n = Sqrt(rescale_pow * s_buffer[1]);
#endif
}

Expand Down

1 comment on commit 7d3d24e

@paddle-bot-old
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Congratulation! Your pull request passed all required CI. You could ask reviewer(s) to approve and merge. 🎉

Please sign in to comment.