From dde4eb4651ad8adf39daadb3e44c595ee710fa24 Mon Sep 17 00:00:00 2001 From: kliuae <17350011+kliuae@users.noreply.github.com> Date: Tue, 12 Mar 2024 04:14:07 +0800 Subject: [PATCH] [ROCm] Fix warp and lane calculation in blockReduceSum (#3321) --- csrc/reduction_utils.cuh | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/csrc/reduction_utils.cuh b/csrc/reduction_utils.cuh index 210bf0b023ab2..c25464e866e55 100644 --- a/csrc/reduction_utils.cuh +++ b/csrc/reduction_utils.cuh @@ -29,12 +29,22 @@ __inline__ __device__ T warpReduceSum(T val) { return val; } +__inline__ __device__ constexpr int _calculateLaneMask(int warp_size) { + return warp_size - 1; +} + +__inline__ __device__ constexpr int _calculateWidShift(int warp_size) { + return 5 + (warp_size >> 6); +} + /* Calculate the sum of all elements in a block */ template __inline__ __device__ T blockReduceSum(T val) { static __shared__ T shared[WARP_SIZE]; - int lane = threadIdx.x & 0x1f; - int wid = threadIdx.x >> 5; + constexpr auto LANE_MASK = _calculateLaneMask(WARP_SIZE); + constexpr auto WID_SHIFT = _calculateWidShift(WARP_SIZE); + int lane = threadIdx.x & LANE_MASK; + int wid = threadIdx.x >> WID_SHIFT; val = warpReduceSum(val);