diff --git a/paddle/fluid/operators/roll_op.cu b/paddle/fluid/operators/roll_op.cu index ce93c5f984e37..34d4d67e39d53 100644 --- a/paddle/fluid/operators/roll_op.cu +++ b/paddle/fluid/operators/roll_op.cu @@ -36,13 +36,16 @@ __global__ void RollCudaKernel(const T* input, T* output, int64_t N, } int64_t output_idx = idx; - int64_t dim_idx, dim_idx_shift; + int64_t new_dim_idx = 0; -#pragma unroll Rank +#pragma unroll for (size_t i = 0; i < Rank; i++) { - dim_idx = (idx / strides[i]) % sizes[i]; - dim_idx_shift = (dim_idx + shifts[i]) % sizes[i]; - output_idx = output_idx + (dim_idx_shift - dim_idx) * strides[i]; + new_dim_idx = (idx / strides[i]) % sizes[i] + shifts[i]; + if (new_dim_idx >= sizes[i]) { + output_idx += (shifts[i] - sizes[i]) * strides[i]; + } else { + output_idx += shifts[i] * strides[i]; + } } output[output_idx] = input[idx]; }