Skip to content

Commit

Permalink
[cherry-pick] fix bug when the cuda kernel config exceeds dims max (#…
Browse files Browse the repository at this point in the history
…33748) (#33893)

fix bug when the cuda kernel config exceeds dims max
  • Loading branch information
zhiqiu authored Jul 1, 2021
1 parent 702610e commit bedcf0d
Showing 1 changed file with 3 additions and 4 deletions.
7 changes: 3 additions & 4 deletions paddle/fluid/operators/layer_norm_op.cu
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -399,9 +399,9 @@ __global__ void LayerNormBackwardComputeGradInput(
const U *__restrict__ mean, const U *__restrict__ var, const float epsilon,
const U *gamma, T *grad_input) {
#ifdef __HIPCC__
for (auto i1 = hipBlockIdx_y; i1 < n1; i1 += hipGridDim_y) {
for (auto i1 = hipBlockIdx_x; i1 < n1; i1 += hipGridDim_x) {
#else
for (auto i1 = blockIdx.y; i1 < n1; i1 += gridDim.y) {
for (auto i1 = blockIdx.x; i1 < n1; i1 += gridDim.x) {
#endif
U sum_loss1 = U(0);
U sum_loss2 = U(0);
Expand Down Expand Up @@ -867,9 +867,8 @@ static void LayerNormBackward(const T *x, const T *d_y, const U *scale,
constexpr int BDIMX1 = 32;
constexpr int BDIMY1 = 4;
dim3 threads1(BDIMX1, BDIMY1, 1);
const dim3 blocks1(1, batch_size, 1);
LayerNormBackwardComputeGradInput<
T, U, BDIMX1, BDIMY1><<<blocks1, threads1, 0, stream>>>(
T, U, BDIMX1, BDIMY1><<<batch_size, threads1, 0, stream>>>(
d_y, x, batch_size, feature_size, mean, var, epsilon, scale, d_x);
break;
}
Expand Down

0 comments on commit bedcf0d

Please sign in to comment.