Skip to content

Commit

Permalink
solve "Cannot find Symbol with name: _ZN4vllm15rms_norm_kernelIN3c108…
Browse files Browse the repository at this point in the history
…BFloat16ELi8EEENSt9enable_ifIXooooeqT0_Li0Entsr12_typeConvertIT_EE6existseqLi2ELi2EEvE4typeEPS4_PKS4_S9_fiii"
  • Loading branch information
wunhuang committed Nov 18, 2024
1 parent 4cbbb99 commit c33c704
Showing 1 changed file with 7 additions and 5 deletions.
12 changes: 7 additions & 5 deletions csrc/layernorm_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,14 @@

#if defined(__HIPCC__) && (defined(__gfx90a__) || defined(__gfx940__) || \
defined(__gfx941__) || defined(__gfx942__))
#define __HIP__MI300_MI250__ 1
#else
#define __HIP__MI300_MI250__ 0
#define __HIP__MI300_MI250__
#endif

namespace vllm {

// TODO(woosuk): Further optimize this kernel.
template <typename scalar_t, int width>
__global__ std::enable_if_t<(width > 0) && _typeConvert<scalar_t>::exists && __HIP__MI300_MI250__ == 1>
__global__ std::enable_if_t<(width > 0) && _typeConvert<scalar_t>::exists>
rms_norm_kernel(scalar_t* __restrict__ out, // [..., hidden_size]
const scalar_t* __restrict__ input, // [..., hidden_size]
const scalar_t* __restrict__ weight, // [hidden_size]
Expand Down Expand Up @@ -75,7 +73,7 @@ rms_norm_kernel(scalar_t* __restrict__ out, // [..., hidden_size]
}

template <typename scalar_t, int width>
__global__ std::enable_if_t<(width == 0) || !_typeConvert<scalar_t>::exists || __HIP__MI300_MI250__ == 0>
__global__ std::enable_if_t<(width == 0) || !_typeConvert<scalar_t>::exists>
rms_norm_kernel(scalar_t* __restrict__ out, // [..., hidden_size]
const scalar_t* __restrict__ input, // [..., hidden_size]
const scalar_t* __restrict__ weight, // [hidden_size]
Expand Down Expand Up @@ -242,11 +240,15 @@ void rms_norm(torch::Tensor& out, // [..., hidden_size]
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();

#ifdef __HIP__MI300_MI250__
if (vec_size % 8 == 0) {
LAUNCH_RMS_NORM(8);
} else {
LAUNCH_RMS_NORM(0);
}
#else
LAUNCH_RMS_NORM(0);
#endif
}

#define LAUNCH_FUSED_ADD_RMS_NORM(width) \
Expand Down

0 comments on commit c33c704

Please sign in to comment.