From b6d103542c654fb63013a1e45a586d654ae36a2a Mon Sep 17 00:00:00 2001 From: mawong-amd <156021403+mawong-amd@users.noreply.github.com> Date: Sat, 30 Mar 2024 14:26:38 -0700 Subject: [PATCH] [Kernel] Layernorm performance optimization (#3662) --- cmake/utils.cmake | 5 + csrc/layernorm_kernels.cu | 270 +++++++++++++++++++++++++++++--- csrc/reduction_utils.cuh | 54 ++++--- tests/kernels/test_layernorm.py | 3 +- 4 files changed, 285 insertions(+), 47 deletions(-) diff --git a/cmake/utils.cmake b/cmake/utils.cmake index 6bf5d5130290b..c7d3d85389838 100644 --- a/cmake/utils.cmake +++ b/cmake/utils.cmake @@ -100,6 +100,11 @@ function (get_torch_gpu_compiler_flags OUT_GPU_FLAGS GPU_LANG) if (CUDA_VERSION VERSION_GREATER_EQUAL 11.8) list(APPEND GPU_FLAGS "-DENABLE_FP8_E5M2") + list(REMOVE_ITEM GPU_FLAGS + "-D__CUDA_NO_HALF_OPERATORS__" + "-D__CUDA_NO_HALF_CONVERSIONS__" + "-D__CUDA_NO_BFLOAT16_CONVERSIONS__" + "-D__CUDA_NO_HALF2_OPERATORS__") endif() elseif(${GPU_LANG} STREQUAL "HIP") diff --git a/csrc/layernorm_kernels.cu b/csrc/layernorm_kernels.cu index 6d34d014c858e..ea30fa2747838 100644 --- a/csrc/layernorm_kernels.cu +++ b/csrc/layernorm_kernels.cu @@ -4,6 +4,16 @@ #include "dispatch_utils.h" #include "reduction_utils.cuh" +#ifndef USE_ROCM + #include + #include +#else + #include + #include + + using __nv_bfloat16 = __hip_bfloat16; + using __nv_bfloat162 = __hip_bfloat162; +#endif namespace vllm { @@ -35,9 +45,199 @@ __global__ void rms_norm_kernel( } } -// TODO: Further optimize this kernel. -template -__global__ void fused_add_rms_norm_kernel( + +/* Converter structs for the conversion from torch types to HIP/CUDA types, + and the associated type conversions within HIP/CUDA. These helpers need + to be implemented for now because the relevant type conversion + operators/constructors are not consistently implemented by HIP/CUDA, so + a generic conversion via type casts cannot be implemented. + + Each struct should have the member static constexpr bool `exists`: + If false, the optimized kernel is not used for the corresponding torch type. + If true, the struct should be fully defined as shown in the examples below. + */ +template +struct _typeConvert { static constexpr bool exists = false; }; + +template<> +struct _typeConvert { + static constexpr bool exists = true; + using hip_type = __half; + using packed_hip_type = __half2; + + __device__ static inline float convert(hip_type x) { return __half2float(x); } + __device__ static inline float2 convert(packed_hip_type x) { return __half22float2(x); } + __device__ static inline hip_type convert(float x) { return __float2half_rn(x); } + __device__ static inline packed_hip_type convert(float2 x) { return __float22half2_rn(x); } +}; + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 +// CUDA_ARCH < 800 does not have BF16 support +// TODO: Add in ROCm support once public headers handle bf16 maturely +template<> +struct _typeConvert { + static constexpr bool exists = true; + using hip_type = __nv_bfloat16; + using packed_hip_type = __nv_bfloat162; + + __device__ static inline float convert(hip_type x) { return __bfloat162float(x); } + __device__ static inline float2 convert(packed_hip_type x) { return __bfloat1622float2(x); } + __device__ static inline hip_type convert(float x) { return __float2bfloat16(x); } + __device__ static inline packed_hip_type convert(float2 x) { return __float22bfloat162_rn(x); } +}; +#endif + + +/* Vector POD struct to generate vectorized and packed FP16/BF16 ops + for appropriate specializations of fused_add_rms_norm_kernel. + Only functions that are necessary in that kernel are implemented. + Alignment to 16 bytes is required to use 128-bit global memory ops. + */ +template +struct alignas(16) _f16Vec { + /* Not theoretically necessary that width is a power of 2 but should + almost always be the case for optimization purposes */ + static_assert(width > 0 && (width & (width - 1)) == 0, + "Width is not a positive power of 2!"); + using Converter = _typeConvert; + using T1 = typename Converter::hip_type; + using T2 = typename Converter::packed_hip_type; + T1 data[width]; + + __device__ _f16Vec& operator+=(const _f16Vec& other) { + if constexpr (width % 2 == 0) { + #pragma unroll + for (int i = 0; i < width; i += 2) { + T2 temp{data[i], data[i+1]}; + temp += T2{other.data[i], other.data[i+1]}; + data[i] = temp.x; + data[i+1] = temp.y; + } + } else { + #pragma unroll + for (int i = 0; i < width; ++i) + data[i] += other.data[i]; + } + return *this; + } + + __device__ _f16Vec& operator*=(const _f16Vec& other) { + if constexpr (width % 2 == 0) { + #pragma unroll + for (int i = 0; i < width; i += 2) { + T2 temp{data[i], data[i+1]}; + temp *= T2{other.data[i], other.data[i+1]}; + data[i] = temp.x; + data[i+1] = temp.y; + } + } else { + #pragma unroll + for (int i = 0; i < width; ++i) + data[i] *= other.data[i]; + } + return *this; + } + + __device__ _f16Vec& operator*=(const float scale) { + if constexpr (width % 2 == 0) { + #pragma unroll + for (int i = 0; i < width; i += 2) { + float2 temp_f = Converter::convert(T2{data[i], data[i+1]}); + temp_f.x *= scale; + temp_f.y *= scale; + T2 temp = Converter::convert(temp_f); + data[i] = temp.x; + data[i+1] = temp.y; + } + } else { + #pragma unroll + for (int i = 0; i < width; ++i) { + float temp = Converter::convert(data[i]) * scale; + data[i] = Converter::convert(temp); + } + } + return *this; + } + + __device__ float sum_squares() const { + float result = 0.0f; + if constexpr (width % 2 == 0) { + #pragma unroll + for (int i = 0; i < width; i += 2) { + float2 z = Converter::convert(T2{data[i], data[i+1]}); + result += z.x * z.x + z.y * z.y; + } + } else { + #pragma unroll + for (int i = 0; i < width; ++i) { + float x = Converter::convert(data[i]); + result += x * x; + } + } + return result; + } +}; + +/* Function specialization in the case of FP16/BF16 tensors. + Additional optimizations we can make in this case are + packed and vectorized operations, which help with the + memory latency bottleneck. */ +template +__global__ std::enable_if_t< + (width > 0) && _typeConvert::exists> fused_add_rms_norm_kernel( + scalar_t* __restrict__ input, // [..., hidden_size] + scalar_t* __restrict__ residual, // [..., hidden_size] + const scalar_t* __restrict__ weight, // [hidden_size] + const float epsilon, + const int num_tokens, + const int hidden_size) { + // Sanity checks on our vector struct and type-punned pointer arithmetic + static_assert(std::is_pod_v<_f16Vec>); + static_assert(sizeof(_f16Vec) == sizeof(scalar_t) * width); + + const int vec_hidden_size = hidden_size / width; + __shared__ float s_variance; + float variance = 0.0f; + /* These and the argument pointers are all declared `restrict` as they are + not aliased in practice. Argument pointers should not be dereferenced + in this kernel as that would be undefined behavior */ + auto* __restrict__ input_v = reinterpret_cast<_f16Vec*>(input); + auto* __restrict__ residual_v = reinterpret_cast<_f16Vec*>(residual); + auto* __restrict__ weight_v = reinterpret_cast*>(weight); + + for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) { + int id = blockIdx.x * vec_hidden_size + idx; + _f16Vec temp = input_v[id]; + temp += residual_v[id]; + variance += temp.sum_squares(); + residual_v[id] = temp; + } + /* Keep the following if-else block in sync with the + calculation of max_block_size in fused_add_rms_norm */ + if (num_tokens < 256) { + variance = blockReduceSum(variance); + } else variance = blockReduceSum(variance); + if (threadIdx.x == 0) { + s_variance = rsqrtf(variance / hidden_size + epsilon); + } + __syncthreads(); + + for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) { + int id = blockIdx.x * vec_hidden_size + idx; + _f16Vec temp = residual_v[id]; + temp *= s_variance; + temp *= weight_v[idx]; + input_v[id] = temp; + } +} + + +/* Generic fused_add_rms_norm_kernel + The width field is not used here but necessary for other specializations. + */ +template +__global__ std::enable_if_t< + (width == 0) || !_typeConvert::exists> fused_add_rms_norm_kernel( scalar_t* __restrict__ input, // [..., hidden_size] scalar_t* __restrict__ residual, // [..., hidden_size] const scalar_t* __restrict__ weight, // [hidden_size] @@ -48,12 +248,17 @@ __global__ void fused_add_rms_norm_kernel( float variance = 0.0f; for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { - float x = (float) input[blockIdx.x * hidden_size + idx]; - x += (float) residual[blockIdx.x * hidden_size + idx]; + scalar_t z = input[blockIdx.x * hidden_size + idx]; + z += residual[blockIdx.x * hidden_size + idx]; + float x = (float) z; variance += x * x; - residual[blockIdx.x * hidden_size + idx] = (scalar_t) x; + residual[blockIdx.x * hidden_size + idx] = z; } - variance = blockReduceSum(variance); + /* Keep the following if-else block in sync with the + calculation of max_block_size in fused_add_rms_norm */ + if (num_tokens < 256) { + variance = blockReduceSum(variance); + } else variance = blockReduceSum(variance); if (threadIdx.x == 0) { s_variance = rsqrtf(variance / hidden_size + epsilon); } @@ -93,6 +298,21 @@ void rms_norm( }); } +#define LAUNCH_FUSED_ADD_RMS_NORM(width) \ + VLLM_DISPATCH_FLOATING_TYPES( \ + input.scalar_type(), \ + "fused_add_rms_norm_kernel", \ + [&] { \ + vllm::fused_add_rms_norm_kernel \ + <<>>( \ + input.data_ptr(), \ + residual.data_ptr(), \ + weight.data_ptr(), \ + epsilon, \ + num_tokens, \ + hidden_size); \ + }); + void fused_add_rms_norm( torch::Tensor& input, // [..., hidden_size] torch::Tensor& residual, // [..., hidden_size] @@ -102,19 +322,29 @@ void fused_add_rms_norm( int num_tokens = input.numel() / hidden_size; dim3 grid(num_tokens); - dim3 block(std::min(hidden_size, 1024)); + /* This kernel is memory-latency bound in many scenarios. + When num_tokens is large, a smaller block size allows + for increased block occupancy on CUs and better latency + hiding on global mem ops. */ + const int max_block_size = (num_tokens < 256) ? 1024 : 256; + dim3 block(std::min(hidden_size, max_block_size)); const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - VLLM_DISPATCH_FLOATING_TYPES( - input.scalar_type(), - "fused_add_rms_norm_kernel", - [&] { - vllm::fused_add_rms_norm_kernel<<>>( - input.data_ptr(), - residual.data_ptr(), - weight.data_ptr(), - epsilon, - num_tokens, - hidden_size); - }); + /*If the tensor types are FP16/BF16, try to use the optimized kernel + with packed + vectorized ops. + Max optimization is achieved with a width-8 vector of FP16/BF16s + since we can load at most 128 bits at once in a global memory op. + However, this requires each tensor's data to be aligned to 16 + bytes. + */ + auto inp_ptr = reinterpret_cast(input.data_ptr()); + auto res_ptr = reinterpret_cast(residual.data_ptr()); + auto wt_ptr = reinterpret_cast(weight.data_ptr()); + bool ptrs_are_aligned = inp_ptr % 16 == 0 && res_ptr % 16 == 0 \ + && wt_ptr % 16 == 0; + if (ptrs_are_aligned && hidden_size % 8 == 0) { + LAUNCH_FUSED_ADD_RMS_NORM(8); + } else { + LAUNCH_FUSED_ADD_RMS_NORM(0); + } } diff --git a/csrc/reduction_utils.cuh b/csrc/reduction_utils.cuh index c25464e866e55..bb5171f854d55 100644 --- a/csrc/reduction_utils.cuh +++ b/csrc/reduction_utils.cuh @@ -20,43 +20,45 @@ #include "cuda_compat.h" namespace vllm { - -template +template __inline__ __device__ T warpReduceSum(T val) { -#pragma unroll - for (int mask = WARP_SIZE/2; mask > 0; mask >>= 1) + static_assert(numLanes > 0 && (numLanes & (numLanes - 1)) == 0, + "numLanes is not a positive power of 2!"); + static_assert(numLanes <= WARP_SIZE); + #pragma unroll + for (int mask = numLanes >> 1; mask > 0; mask >>= 1) val += VLLM_SHFL_XOR_SYNC(val, mask); return val; } -__inline__ __device__ constexpr int _calculateLaneMask(int warp_size) { - return warp_size - 1; -} - -__inline__ __device__ constexpr int _calculateWidShift(int warp_size) { - return 5 + (warp_size >> 6); +// Helper function to return the next largest power of 2 +static constexpr int _nextPow2(unsigned int num) { + if (num <= 1) return num; + return 1 << (CHAR_BIT * sizeof(num) - __builtin_clz(num - 1)); } /* Calculate the sum of all elements in a block */ -template +template __inline__ __device__ T blockReduceSum(T val) { - static __shared__ T shared[WARP_SIZE]; - constexpr auto LANE_MASK = _calculateLaneMask(WARP_SIZE); - constexpr auto WID_SHIFT = _calculateWidShift(WARP_SIZE); - int lane = threadIdx.x & LANE_MASK; - int wid = threadIdx.x >> WID_SHIFT; - - val = warpReduceSum(val); - - if (lane == 0) - shared[wid] = val; + static_assert(maxBlockSize <= 1024); + if constexpr (maxBlockSize > WARP_SIZE) { + val = warpReduceSum(val); + // Calculates max number of lanes that need to participate in the last warpReduce + constexpr int maxActiveLanes = (maxBlockSize + WARP_SIZE - 1) / WARP_SIZE; + static __shared__ T shared[maxActiveLanes]; + int lane = threadIdx.x % WARP_SIZE; + int wid = threadIdx.x / WARP_SIZE; + if (lane == 0) + shared[wid] = val; - __syncthreads(); + __syncthreads(); - // Modify from blockDim.x << 5 to blockDim.x / 32. to prevent - // blockDim.x is not divided by 32 - val = (threadIdx.x < (blockDim.x / (WARP_SIZE * 1.0f))) ? shared[lane] : (T)(0.0f); - val = warpReduceSum(val); + val = (threadIdx.x < blockDim.x / float(WARP_SIZE)) ? shared[lane] : (T)(0.0f); + val = warpReduceSum(val); + } else { + // A single warpReduce is equal to blockReduce + val = warpReduceSum(val); + } return val; } diff --git a/tests/kernels/test_layernorm.py b/tests/kernels/test_layernorm.py index b1e3c1a7f07f5..210d59e4f32fa 100644 --- a/tests/kernels/test_layernorm.py +++ b/tests/kernels/test_layernorm.py @@ -5,7 +5,8 @@ DTYPES = [torch.half, torch.bfloat16, torch.float] NUM_TOKENS = [7, 83, 4096] # Arbitrary values for testing -HIDDEN_SIZES = [768, 5120, 8192] # Arbitrary values for testing +HIDDEN_SIZES = [768, 769, 770, 771, 5120, 5124, 5125, 5126, 8192, + 8199] # Arbitrary values for testing ADD_RESIDUAL = [False, True] SEEDS = [0] CUDA_DEVICES = [