Skip to content

Commit

Permalink
[ROCM] Fix blockReduceSum to use correct warp counts for ROCm and CUDA (
Browse files Browse the repository at this point in the history
  • Loading branch information
dllehr-amd authored and dbogunowicz committed Mar 26, 2024
1 parent a20b244 commit 84e31ca
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 11 deletions.
10 changes: 2 additions & 8 deletions csrc/attention/attention_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,6 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifdef USE_ROCM
#include <hip/hip_runtime.h>
#endif

#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
Expand All @@ -33,11 +30,8 @@
#include "../quantization/fp8_e5m2_kvcache/quant_utils.cuh"
#endif

#ifndef USE_ROCM
#define WARP_SIZE 32
#else
#define WARP_SIZE warpSize
#endif
#include <algorithm>

#define MAX(a, b) ((a) > (b) ? (a) : (b))
#define MIN(a, b) ((a) < (b) ? (a) : (b))
#define DIVIDE_ROUND_UP(a, b) (((a) + (b)-1) / (b))
Expand Down
10 changes: 10 additions & 0 deletions csrc/cuda_compat.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,15 @@
#pragma once

#ifdef USE_ROCM
#include <hip/hip_runtime.h>
#endif

#ifndef USE_ROCM
#define WARP_SIZE 32
#else
#define WARP_SIZE warpSize
#endif

#ifndef USE_ROCM
#define VLLM_LDG(arg) __ldg(arg)
#else
Expand Down
6 changes: 3 additions & 3 deletions csrc/reduction_utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,15 @@ namespace vllm {
template<typename T>
__inline__ __device__ T warpReduceSum(T val) {
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1)
for (int mask = WARP_SIZE/2; mask > 0; mask >>= 1)
val += VLLM_SHFL_XOR_SYNC(val, mask);
return val;
}

/* Calculate the sum of all elements in a block */
template<typename T>
__inline__ __device__ T blockReduceSum(T val) {
static __shared__ T shared[32];
static __shared__ T shared[WARP_SIZE];
int lane = threadIdx.x & 0x1f;
int wid = threadIdx.x >> 5;

Expand All @@ -45,7 +45,7 @@ __inline__ __device__ T blockReduceSum(T val) {

// Modify from blockDim.x << 5 to blockDim.x / 32. to prevent
// blockDim.x is not divided by 32
val = (threadIdx.x < (blockDim.x / 32.f)) ? shared[lane] : (T)(0.0f);
val = (threadIdx.x < (blockDim.x / (WARP_SIZE * 1.0f))) ? shared[lane] : (T)(0.0f);
val = warpReduceSum<T>(val);
return val;
}
Expand Down

0 comments on commit 84e31ca

Please sign in to comment.