Skip to content

Commit

Permalink
[ROCm] Fix warp and lane calculation in blockReduceSum (vllm-project#…
Browse files Browse the repository at this point in the history
  • Loading branch information
kliuae authored Mar 11, 2024
1 parent 0395610 commit dde4eb4
Showing 1 changed file with 12 additions and 2 deletions.
14 changes: 12 additions & 2 deletions csrc/reduction_utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,22 @@ __inline__ __device__ T warpReduceSum(T val) {
return val;
}

__inline__ __device__ constexpr int _calculateLaneMask(int warp_size) {
return warp_size - 1;
}

__inline__ __device__ constexpr int _calculateWidShift(int warp_size) {
return 5 + (warp_size >> 6);
}

/* Calculate the sum of all elements in a block */
template<typename T>
__inline__ __device__ T blockReduceSum(T val) {
static __shared__ T shared[WARP_SIZE];
int lane = threadIdx.x & 0x1f;
int wid = threadIdx.x >> 5;
constexpr auto LANE_MASK = _calculateLaneMask(WARP_SIZE);
constexpr auto WID_SHIFT = _calculateWidShift(WARP_SIZE);
int lane = threadIdx.x & LANE_MASK;
int wid = threadIdx.x >> WID_SHIFT;

val = warpReduceSum<T>(val);

Expand Down

0 comments on commit dde4eb4

Please sign in to comment.