From 9df6dabfed3c6553a0778f262bfc843a063a706d Mon Sep 17 00:00:00 2001 From: tianyan01 Date: Mon, 25 Sep 2023 11:57:29 +0800 Subject: [PATCH] fix grid dim.y should less than 65535 bug --- .../fluid/operators/fused/fused_dropout_act_bias.h | 12 +++++++----- paddle/fluid/operators/fused/fused_dropout_common.h | 8 ++++++-- .../operators/fused/fused_residual_dropout_bias.h | 12 +++++++----- paddle/fluid/operators/fused/fused_softmax_mask.cu.h | 8 +++++--- 4 files changed, 25 insertions(+), 15 deletions(-) diff --git a/paddle/fluid/operators/fused/fused_dropout_act_bias.h b/paddle/fluid/operators/fused/fused_dropout_act_bias.h index e3e19d9ea6ebc..aa6c56f937524 100644 --- a/paddle/fluid/operators/fused/fused_dropout_act_bias.h +++ b/paddle/fluid/operators/fused/fused_dropout_act_bias.h @@ -86,8 +86,8 @@ __global__ void FusedDropoutActBias( const int quant_round_type = 1, const float quant_max_bound = 127.0, const float quant_min_bound = -127.0) { - int col_id = blockDim.x * blockIdx.x + threadIdx.x; - int row_id = blockIdx.y; + int col_id = threadIdx.x; + int row_id = gridDim.y * blockIdx.x + blockIdx.y; int idx = row_id * cols + col_id; curandStatePhilox4_32_10_t state; @@ -95,9 +95,11 @@ __global__ void FusedDropoutActBias( const T factor = GetFactor(dropout_prob, is_upscale_in_train, is_test); - for (int r = row_id; r < rows; r += blockDim.y * gridDim.y) { - for (int i = col_id * VecSize; i < cols; - i += blockDim.x * gridDim.x * VecSize) { + int i = col_id * VecSize; + int r = row_id; + int stride = blockDim.x * VecSize; + for (; r < rows; r += blockDim.y * gridDim.y * gridDim.x) { + for (; i < cols; i += stride) { FusedResidualDropoutBiasOneThread(std::min( ctx.GetMaxThreadsPerBlock(), 512)))); - const auto blocks_x = + auto blocks_x = std::max(static_cast(1), (tmp_cols + threads - 1) / threads); - const auto blocks_y = std::max(static_cast(1), rows); + auto blocks_y = std::max(static_cast(1), rows); platform::GpuLaunchConfig config; + while (blocks_y > 65535) { + blocks_x *= 2; + blocks_y /= 2; + } config.block_per_grid.x = blocks_x; config.block_per_grid.y = blocks_y; config.thread_per_block.x = threads; diff --git a/paddle/fluid/operators/fused/fused_residual_dropout_bias.h b/paddle/fluid/operators/fused/fused_residual_dropout_bias.h index f162d200abfe1..cabcbea2d2e60 100644 --- a/paddle/fluid/operators/fused/fused_residual_dropout_bias.h +++ b/paddle/fluid/operators/fused/fused_residual_dropout_bias.h @@ -174,16 +174,18 @@ __global__ void FusedResidualDropoutBias( const float *dequant_out_scale_data = nullptr, const int quant_out_scale_offset = 0, const float quant_next_in_scale = 1.0) { - int col_id = blockDim.x * blockIdx.x + threadIdx.x; - int row_id = blockIdx.y; + int col_id = threadIdx.x; + int row_id = gridDim.y * blockIdx.x + blockIdx.y; int idx = row_id * cols + col_id; curandStatePhilox4_32_10_t state; curand_init(seed, idx, increment, &state); const T factor = GetFactor(dropout_prob, is_upscale_in_train, is_test); phi::funcs::ReluFunctor relu; - for (int r = row_id; r < rows; r += blockDim.y * gridDim.y) { - for (int i = col_id * VecSize; i < cols; - i += blockDim.x * gridDim.x * VecSize) { + int i = col_id * VecSize; + int r = row_id; + int stride = blockDim.x * VecSize; + for (; r < rows; r += blockDim.y * gridDim.y * gridDim.x) { + for (; i < cols; i += stride) { FusedResidualDropoutBiasOneThread= seq_len) return; // ((bid*head_num + hid)*seq_len + seq_id) * seq_len - int offset = + int64_t offset = ((blockIdx.y * gridDim.z + blockIdx.z) * seq_len + seq_id) * seq_len; // (bid * seq_len + seq_id) * seq_len - int mask_offset = (blockIdx.y * seq_len + seq_id) * seq_len; + int64_t mask_offset = (blockIdx.y * seq_len + seq_id) * seq_len; src += offset; dst += offset; mask += mask_offset;