Skip to content

Commit

Permalink
Add vectorized rms_norm support for Navi31
Browse files Browse the repository at this point in the history
- supports vectorized rms_norm_kernel
  • Loading branch information
hyoon1 committed Nov 12, 2024
1 parent 8f3bf8b commit 265c248
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions csrc/layernorm_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,9 @@ using __nv_bfloat162 = __hip_bfloat162;
#endif

#if defined(__HIPCC__) && (defined(__gfx90a__) || defined(__gfx940__) || \
defined(__gfx941__) || defined(__gfx942__))
#define __HIP__MI300_MI250__
defined(__gfx941__) || defined(__gfx942__) || \
defined(__gfx1100__))
#define __HIP__MI300_MI250_Navi31__
#endif

namespace vllm {
Expand Down Expand Up @@ -72,7 +73,7 @@ struct __align__(16) vec8_t {
__device__ scalar_t sum() const { return x + y + z + w + u + v + s + t; }
};

#ifdef __HIP__MI300_MI250__
#ifdef __HIP__MI300_MI250_Navi31__

// TODO(woosuk): Further optimize this kernel.
template <typename scalar_t>
Expand Down

0 comments on commit 265c248

Please sign in to comment.