From 8f793162c5760c59e40249c01416c34a63fb6014 Mon Sep 17 00:00:00 2001 From: charlifu Date: Fri, 23 Aug 2024 15:56:52 +0000 Subject: [PATCH 01/12] add kernel folder for hip and update CMake file --- .gitignore | 1 - CMakeLists.txt | 20 + csrc/hip/attention.hip | 1121 ++++++++++++++++++++++++++++++++++++ csrc/hip/torch_binding.cpp | 0 4 files changed, 1141 insertions(+), 1 deletion(-) create mode 100644 csrc/hip/attention.hip create mode 100644 csrc/hip/torch_binding.cpp diff --git a/.gitignore b/.gitignore index 761b00ac3bc48..a5b6b48f65823 100644 --- a/.gitignore +++ b/.gitignore @@ -187,7 +187,6 @@ _build/ *.swp # hip files generated by PyTorch -*.hip *_hip* hip_compat.h diff --git a/CMakeLists.txt b/CMakeLists.txt index ab91b86426cd4..417d4a17afb45 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -306,6 +306,26 @@ define_gpu_extension_target( USE_SABI 3 WITH_SOABI) +# +# _hip_C extension +# + +if(VLLM_GPU_LANG STREQUAL "HIP") + set(VLLM_HIP_EXT_SRC + "csrc/hip/torch_bindings.cpp" + "csrc/hip/attention.hip") + + define_gpu_extension_target( + _hip_C + DESTINATION vllm + LANGUAGE ${VLLM_GPU_LANG} + SOURCES ${VLLM_HIP_EXT_SRC} + COMPILE_FLAGS ${VLLM_GPU_FLAGS} + ARCHITECTURES ${VLLM_GPU_ARCHES} + USE_SABI 3 + WITH_SOABI) + +endif() if(VLLM_GPU_LANG STREQUAL "CUDA" OR VLLM_GPU_LANG STREQUAL "HIP") diff --git a/csrc/hip/attention.hip b/csrc/hip/attention.hip new file mode 100644 index 0000000000000..0f1050891c477 --- /dev/null +++ b/csrc/hip/attention.hip @@ -0,0 +1,1121 @@ +#include +#include +#include +#include + +#include + +#if defined(__HIPCC__) && (defined(__gfx90a__) || defined(__gfx940__) || \ + defined(__gfx941__) || defined(__gfx942__)) + #define __HIP__MI300_MI250__ +#endif + +#if defined(NDEBUG) + #undef NDEBUG + #include + #define UNREACHABLE_CODE assert(false); + #define NDEBUG +#else + #define UNREACHABLE_CODE assert(false); +#endif + +#define MAX(a, b) ((a) > (b) ? (a) : (b)) +#define MIN(a, b) ((a) < (b) ? (a) : (b)) +#define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b)) +#define WARP_SIZE 64 + +#if defined(__HIP__MI300_MI250__) // TODO: Add NAVI support + + #define GCN_MFMA_INSTR1 __builtin_amdgcn_mfma_f32_16x16x4f32 + #define GCN_MFMA_INSTR __builtin_amdgcn_mfma_f32_4x4x4f16 + +using floatx4 = __attribute__((__vector_size__(4 * sizeof(float)))) float; +using float16x4 = + __attribute__((__vector_size__(4 * sizeof(_Float16)))) _Float16; +typedef float16x4 _Half4; +typedef struct _Half8 { + _Half4 xy[2]; +} _Half8; + +using bit16_t = uint16_t; +using bit16x4 = __attribute__((__vector_size__(4 * sizeof(uint16_t)))) uint16_t; +typedef bit16x4 _B16x4; +typedef struct _B16x8 { + _B16x4 xy[2]; +} _B16x8; + +////// Non temporal load stores /////// + + #if 1 + +template +__device__ __forceinline__ T load(T* addr) { + return addr[0]; +} + +template +__device__ __forceinline__ void store(T value, T* addr) { + addr[0] = value; +} + + #else + +template +__device__ __forceinline__ T load(const T* addr) { + return __builtin_nontemporal_load(addr); +} + +template <> +__device__ __forceinline__ float2 load(const float2* addr) { + auto addr_alias{reinterpret_cast(addr)}; + auto result = __builtin_nontemporal_load(addr_alias); + auto ret = reinterpret_cast(&result); + return ret[0]; +} + +template <> +__device__ __forceinline__ float4 load(const float4* addr) { + auto addr_alias{reinterpret_cast(addr)}; + auto result1 = __builtin_nontemporal_load(addr_alias); + auto result2 = __builtin_nontemporal_load(addr_alias + 1); + float4 ret{}; + auto ret_alias = reinterpret_cast(&result1); + ret.x = ret_alias->x; + ret.y = ret_alias->y; + ret_alias = reinterpret_cast(&result2); + ret.z = ret_alias->x; + ret.w = ret_alias->y; + return ret; +} + +template <> +__device__ __forceinline__ __half load(const __half* addr) { + auto addr_alias{reinterpret_cast(addr)}; + auto result = __builtin_nontemporal_load(addr_alias); + auto ret = reinterpret_cast<__half*>(&result); + return ret[0]; +} + +template <> +__device__ __forceinline__ __half2 load(const __half2* addr) { + auto addr_alias{reinterpret_cast(addr)}; + auto result = __builtin_nontemporal_load(addr_alias); + auto ret = reinterpret_cast<__half2*>(&result); + return ret[0]; +} + +template <> +__device__ __forceinline__ vllm::Half4_ load(const vllm::Half4_* addr) { + auto addr_alias{reinterpret_cast(addr)}; + auto result = __builtin_nontemporal_load(addr_alias); + auto ret = reinterpret_cast(&result); + return ret[0]; +} + +template <> +__device__ __forceinline__ vllm::Half8_ load(const vllm::Half8_* addr) { + auto addr_alias{reinterpret_cast(addr)}; + auto result1 = __builtin_nontemporal_load(addr_alias); + auto result2 = __builtin_nontemporal_load(addr_alias + 1); + vllm::Half8_ ret{}; + auto ret_alias = reinterpret_cast(&result1); + ret.x = ret_alias->x; + ret.y = ret_alias->y; + ret_alias = reinterpret_cast(&result2); + ret.z = ret_alias->x; + ret.w = ret_alias->y; + return ret; +} + +//// Not using nontemporal stores for now +template +__device__ __forceinline__ void store(T value, T* addr) { + return __builtin_nontemporal_store(value, addr); +} + + #endif + +template +__device__ __forceinline__ floatx4 gcn_mfma_instr(const _B16x4& inpA, + const _B16x4& inpB, + const floatx4& inpC) { + if constexpr (std::is_same::value) { + return __builtin_amdgcn_mfma_f32_4x4x4f16(inpA, inpB, inpC, absz, cbid, + blgp); + } else if constexpr (std::is_same::value) { + return __builtin_amdgcn_mfma_f32_4x4x4bf16_1k(inpA, inpB, inpC, absz, cbid, + blgp); + } else { + static_assert(false, "unsupported 16b dtype"); + } +} + +template +__device__ __forceinline__ float to_float(const T& inp) { + if constexpr (std::is_same::value) { + return (float)inp; + } else if constexpr (std::is_same::value) { + return __bfloat162float(inp); + } else { + static_assert(false, "unsupported 16b dtype"); + } +} + +template +__device__ __forceinline__ T from_float(const float& inp) { + if constexpr (std::is_same::value) { + return (_Float16)inp; + } else if constexpr (std::is_same::value) { + return __float2bfloat16(inp); + } else { + static_assert(false, "unsupported 16b dtype"); + } +} + +template +__device__ __forceinline__ _B16x4 from_floatx4(const floatx4& inp) { + union tmpcvt { + uint16_t u; + _Float16 f; + __hip_bfloat16 b; + } t16; + _B16x4 ret; + if constexpr (std::is_same::value) { + #pragma unroll + for (int i = 0; i < 4; i++) { + t16.f = (_Float16)inp[i]; + ret[i] = t16.u; + } + return ret; + } else if constexpr (std::is_same::value) { + #pragma unroll + for (int i = 0; i < 4; i++) { + t16.b = __float2bfloat16(inp[i]); + ret[i] = t16.u; + } + return ret; + } else { + static_assert(false, "unsupported 16b dtype"); + } +} + +template +__device__ __forceinline__ _B16x4 addx4(const _B16x4& inp1, + const _B16x4& inp2) { + union tmpcvt { + uint16_t u; + _Float16 f; + __hip_bfloat16 b; + } t1, t2, res; + _B16x4 ret; + if constexpr (std::is_same::value) { + #pragma unroll + for (int i = 0; i < 4; i++) { + t1.u = inp1[i]; + t2.u = inp2[i]; + res.f = t1.f + t2.f; + ret[i] = res.u; + } + return ret; + } else if constexpr (std::is_same::value) { + #pragma unroll + for (int i = 0; i < 4; i++) { + t1.u = inp1[i]; + t2.u = inp2[i]; + res.b = t1.b + t2.b; + ret[i] = res.u; + } + return ret; + } else { + static_assert(false, "unsupported 16b dtype"); + } +} + +/////////////////////////////////////// + +// grid (num_seqs, num_partitions,num_heads/gqa_ratio) +// block (partition size) +template +__global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( + const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] + const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, + // head_size/x, block_size, x] + const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, + // head_size, block_size] + const int num_kv_heads, const float scale, + const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] + const int* __restrict__ context_lens, // [num_seqs] + const int max_num_blocks_per_seq, + const float* __restrict__ alibi_slopes, // [num_heads] + const int q_stride, const int kv_block_stride, const int kv_head_stride, + float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] + float* __restrict__ max_logits, // [num_seqs, num_heads, + // max_num_partitions] + scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, + // head_size] + scalar_t* __restrict__ final_out, // [num_seqs, num_heads, head_size] + #if 0 + scalar_t* __restrict__ qk_out, // [num_heads, num_seqs, max_ctx_blocks,block_size] + #endif + int max_ctx_blocks) { + constexpr int NWARPS = NUM_THREADS / WARP_SIZE; + const int warpid = threadIdx.x / WARP_SIZE; + const int laneid = threadIdx.x % WARP_SIZE; + const int lane4id = laneid % 4; + + const int seq_idx = blockIdx.x; + const int partition_idx = blockIdx.y; + const int partition_size = blockDim.x; + const int max_num_partitions = gridDim.y; + + const int context_len = context_lens[seq_idx]; + const int partition_start_token_idx = partition_idx * partition_size; + // exit if partition is out of context for seq + if (partition_start_token_idx >= context_len) { + return; + } + constexpr int QHLOOP = + DIVIDE_ROUND_UP(GQA_RATIO, 4); // each 4 lanes fetch 4 different qheads, + // total qheads =8, so qhloop is 2 + constexpr int GQA_RATIO4 = 4 * QHLOOP; + __shared__ float shared_qk_max[NWARPS][GQA_RATIO4 + 1]; + __shared__ float shared_exp_sum[NWARPS][GQA_RATIO4 + 1]; + _B16x8 Qlocal[QHLOOP]; + constexpr int x = 16 / sizeof(scalar_t); + constexpr int KHELOOP = HEAD_SIZE / x; + _B16x8 Klocal[KHELOOP]; + constexpr int VHELOOP = + HEAD_SIZE / + WARP_SIZE; // v head_size dimension is distributed across lanes + constexpr int VTLOOP = 8; // 16 separate 4xtokens across warp -> 16/2 + // 8xtokens + _B16x8 Vlocal[VHELOOP][VTLOOP]; + floatx4 dout[QHLOOP]; + float qk_max[QHLOOP]; + #pragma unroll + for (int h = 0; h < QHLOOP; h++) { + dout[h] = {0}; + qk_max[h] = -FLT_MAX; + } + + const int wg_start_head_idx = blockIdx.z * GQA_RATIO; + const int wg_start_kv_head_idx = blockIdx.z; + + const int warp_start_token_idx = + partition_start_token_idx + warpid * WARP_SIZE; + + if (warp_start_token_idx >= context_len) { // warp out of context + #pragma unroll + for (int h = 0; h < GQA_RATIO4; h++) { + shared_qk_max[warpid][h] = -FLT_MAX; + shared_exp_sum[warpid][h] = 0.0f; + } + } else { // warp within context + + const int num_context_blocks = DIVIDE_ROUND_UP(context_len, BLOCK_SIZE); + const int last_ctx_block = num_context_blocks - 1; + + const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq; + + const int local_token_idx = threadIdx.x; + const int global_token_idx = partition_start_token_idx + local_token_idx; + + const int block_idx = (global_token_idx < context_len) + ? global_token_idx / BLOCK_SIZE + : last_ctx_block; + // fetch block number for q and k + // int32 physical_block_number leads to overflow when multiplied with + // kv_block_stride + const int64_t physical_block_number = + static_cast(block_table[block_idx]); + + // fetch vphysical block numbers up front + constexpr int VBLOCKS = 8 * VTLOOP / BLOCK_SIZE; + int vphysical_blocks[VBLOCKS]; + + const int warp_start_block_idx = warp_start_token_idx / BLOCK_SIZE; + #pragma unroll + for (int b = 0; b < VBLOCKS; b++) { + const int vblock_idx = warp_start_block_idx + b; + const int vblock_idx_ctx = + (vblock_idx <= last_ctx_block) ? vblock_idx : last_ctx_block; + vphysical_blocks[b] = block_table[vblock_idx_ctx]; + } + // each 4 lanes fetch 8 helems, so warp fetches 8*16 = 128 helems + const scalar_t* q_ptr = + q + seq_idx * q_stride + wg_start_head_idx * HEAD_SIZE; + const _B16x8* q_ptrh8 = reinterpret_cast(q_ptr); + const int qhead_elemh8 = laneid / 4; + #pragma unroll + for (int h = 0; h < QHLOOP - 1; h++) { + const int qhead_idx = h * 4 + lane4id; + Qlocal[h] = q_ptrh8[qhead_idx * HEAD_SIZE / 8 + qhead_elemh8]; + } + const int final_qhead_idx = 4 * (QHLOOP - 1) + lane4id; + if (final_qhead_idx < GQA_RATIO) { + Qlocal[QHLOOP - 1] = + q_ptrh8[final_qhead_idx * HEAD_SIZE / 8 + qhead_elemh8]; + } else { + Qlocal[QHLOOP - 1].xy[0] = {0}; + Qlocal[QHLOOP - 1].xy[1] = {0}; + } + + const scalar_t* k_ptr = k_cache + physical_block_number * kv_block_stride + + wg_start_kv_head_idx * kv_head_stride; + + const int physical_block_offset = + local_token_idx % BLOCK_SIZE; // since x=half8, physical_block_offset + // is already cast as _H8 + + const _B16x8* k_ptrh8 = reinterpret_cast(k_ptr); + #pragma unroll + for (int d = 0; d < KHELOOP; d++) { + Klocal[d] = k_ptrh8[d * BLOCK_SIZE + physical_block_offset]; + } + + float alibi_slope[QHLOOP]; + if (alibi_slopes != nullptr) { + #pragma unroll + for (int h = 0; h < QHLOOP; h++) { + const int qhead_idx = h * 4 + lane4id; + alibi_slope[h] = (qhead_idx < GQA_RATIO) + ? alibi_slopes[wg_start_head_idx + qhead_idx] + : 0.f; + } + } + + const scalar_t* v_ptr = v_cache + wg_start_kv_head_idx * kv_head_stride; + const _B16x8* v_ptrh8 = reinterpret_cast(v_ptr); + // iterate over each v block + #pragma unroll + for (int b = 0; b < VBLOCKS; b++) { + // int32 physical_block_number leads to overflow when multiplied with + // kv_block_stride + const int64_t vphysical_block_number = + static_cast(vphysical_blocks[b]); + const _B16x8* v_ptrh8b = + v_ptrh8 + (vphysical_block_number * kv_block_stride) / 8; + // iterate over each head elem (within head_size) + #pragma unroll + for (int h = 0; h < VHELOOP; h++) { + const int head_size_elem = h * WARP_SIZE + laneid; + const _B16x8* v_ptrh8be = v_ptrh8b + head_size_elem * BLOCK_SIZE / 8; + // iterate over all velems within block + #pragma unroll + for (int d = 0; d < BLOCK_SIZE / 8; d++) { + Vlocal[h][b * BLOCK_SIZE / 8 + d] = v_ptrh8be[d]; + } + } + } + + #pragma unroll + for (int h = 0; h < QHLOOP; h++) { + dout[h] = gcn_mfma_instr(Qlocal[h].xy[0], + Klocal[0].xy[0], dout[h]); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[1], + Klocal[0].xy[1], dout[h]); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[0], + Klocal[1].xy[0], dout[h]); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[1], + Klocal[1].xy[1], dout[h]); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[0], + Klocal[2].xy[0], dout[h]); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[1], + Klocal[2].xy[1], dout[h]); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[0], + Klocal[3].xy[0], dout[h]); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[1], + Klocal[3].xy[1], dout[h]); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[0], + Klocal[4].xy[0], dout[h]); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[1], + Klocal[4].xy[1], dout[h]); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[0], + Klocal[5].xy[0], dout[h]); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[1], + Klocal[5].xy[1], dout[h]); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[0], + Klocal[6].xy[0], dout[h]); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[1], + Klocal[6].xy[1], dout[h]); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[0], + Klocal[7].xy[0], dout[h]); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[1], + Klocal[7].xy[1], dout[h]); + if constexpr (KHELOOP > 8) { + dout[h] = gcn_mfma_instr(Qlocal[h].xy[0], + Klocal[8].xy[0], dout[h]); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[1], + Klocal[8].xy[1], dout[h]); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[0], + Klocal[9].xy[0], dout[h]); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[1], + Klocal[9].xy[1], dout[h]); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[0], + Klocal[10].xy[0], dout[h]); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[1], + Klocal[10].xy[1], dout[h]); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[0], + Klocal[11].xy[0], dout[h]); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[1], + Klocal[11].xy[1], dout[h]); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[0], + Klocal[12].xy[0], dout[h]); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[1], + Klocal[12].xy[1], dout[h]); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[0], + Klocal[13].xy[0], dout[h]); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[1], + Klocal[13].xy[1], dout[h]); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[0], + Klocal[14].xy[0], dout[h]); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[1], + Klocal[14].xy[1], dout[h]); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[0], + Klocal[15].xy[0], dout[h]); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[1], + Klocal[15].xy[1], dout[h]); + } // KHELOOP>8 + dout[h] *= scale; + } + // transpose dout so that 4 token ids are in each lane, and 4 heads are across + // 4 lanes + #pragma unroll + for (int h = 0; h < QHLOOP; h++) { + floatx4 tmp = {0}; + #pragma unroll + for (int i = 0; i < 4; i++) { + const float B = (lane4id == i) ? 1.0f : 0.0f; + // const float A = (global_token_idx < context_len) ? dout[h][i] : 0.0f; + tmp = __builtin_amdgcn_mfma_f32_4x4x1f32(dout[h][i], B, tmp, 0, 0, 0); + // tmp = __builtin_amdgcn_mfma_f32_4x4x1f32(A, B, tmp, 0, 0, 0); + } + dout[h] = tmp; + } + + const int lane4_token_idx = 4 * (global_token_idx >> 2); + const int alibi_offset = lane4_token_idx - context_len + 1; + if (alibi_slopes != nullptr) { + #pragma unroll + for (int h = 0; h < QHLOOP; h++) { + #pragma unroll + for (int i = 0; i < 4; i++) { + dout[h][i] += alibi_slope[h] * (alibi_offset + i); + } + } + } + + #pragma unroll + for (int h = 0; h < QHLOOP; h++) { + qk_max[h] = -FLT_MAX; + #pragma unroll + for (int i = 0; i < 4; i++) { + qk_max[h] = (lane4_token_idx + i < context_len) + ? fmaxf(qk_max[h], dout[h][i]) + : qk_max[h]; + } + #pragma unroll + for (int mask = WARP_SIZE / 2; mask >= 4; mask /= 2) { + qk_max[h] = fmaxf(qk_max[h], __shfl_xor(qk_max[h], mask)); + } + } + + float exp_sum[QHLOOP]; + #pragma unroll + for (int h = 0; h < QHLOOP; h++) { + exp_sum[h] = 0.0f; + #pragma unroll + for (int i = 0; i < 4; i++) { + dout[h][i] = (lane4_token_idx + i < context_len) + ? __expf(dout[h][i] - qk_max[h]) + : 0.0f; + exp_sum[h] += dout[h][i]; + } + #pragma unroll + for (int mask = WARP_SIZE / 2; mask >= 4; mask /= 2) { + exp_sum[h] += __shfl_xor(exp_sum[h], mask); + } + } + + #pragma unroll + for (int h = 0; h < QHLOOP; h++) { + const int head_idx = 4 * h + lane4id; + shared_qk_max[warpid][head_idx] = qk_max[h]; + shared_exp_sum[warpid][head_idx] = exp_sum[h]; + } + } // warp within context + + __syncthreads(); + + const int num_heads = gridDim.z * GQA_RATIO; + float* max_logits_ptr = + max_logits + seq_idx * num_heads * max_num_partitions + partition_idx; + float* exp_sums_ptr = + exp_sums + seq_idx * num_heads * max_num_partitions + partition_idx; + #pragma unroll + for (int h = 0; h < QHLOOP; h++) { + float global_qk_max = -FLT_MAX; + float warp_qk_max[NWARPS]; + const int head_idx = 4 * h + lane4id; + #pragma unroll + for (int w = 0; w < NWARPS; w++) { + warp_qk_max[w] = shared_qk_max[w][head_idx]; + global_qk_max = fmaxf(global_qk_max, warp_qk_max[w]); + } + float global_exp_sum = 0.0f; + #pragma unroll + for (int w = 0; w < NWARPS; w++) { + global_exp_sum += + shared_exp_sum[w][head_idx] * __expf(warp_qk_max[w] - global_qk_max); + } + if (head_idx < GQA_RATIO) { + max_logits_ptr[(wg_start_head_idx + head_idx) * max_num_partitions] = + global_qk_max; + exp_sums_ptr[(wg_start_head_idx + head_idx) * max_num_partitions] = + global_exp_sum; + } + const float global_inv_sum_scale = __fdividef(1.f, global_exp_sum + 1e-6f) * + __expf(qk_max[h] - global_qk_max); + dout[h] *= global_inv_sum_scale; + } + // logits[h] -> every 4 lanes hold 4 heads, each lane holds 4 tokens, there + // are 4x16 tokens across warp + _B16x4 logits[QHLOOP]; + #pragma unroll + for (int h = 0; h < QHLOOP; h++) { + logits[h] = from_floatx4(dout[h]); + } + + __shared__ _B16x4 vout_shared[QHLOOP][VHELOOP][WARP_SIZE][NWARPS + 1]; + + if (warp_start_token_idx >= context_len) { // warp out of context + #pragma unroll + for (int qh = 0; qh < QHLOOP; qh++) { + #pragma unroll + for (int vh = 0; vh < VHELOOP; vh++) { + vout_shared[qh][vh][laneid][warpid] = {0}; + } + } + } else { // warp in context + // iterate across heads + #pragma unroll + for (int qh = 0; qh < QHLOOP; qh++) { + // iterate over each v head elem (within head_size) + #pragma unroll + for (int vh = 0; vh < VHELOOP; vh++) { + floatx4 acc = {0}; + // iterate over tokens + acc = gcn_mfma_instr(logits[qh], Vlocal[vh][0].xy[0], + acc); + acc = gcn_mfma_instr(logits[qh], Vlocal[vh][0].xy[1], + acc); + acc = gcn_mfma_instr(logits[qh], Vlocal[vh][1].xy[0], + acc); + acc = gcn_mfma_instr(logits[qh], Vlocal[vh][1].xy[1], + acc); + acc = gcn_mfma_instr(logits[qh], Vlocal[vh][2].xy[0], + acc); + acc = gcn_mfma_instr(logits[qh], Vlocal[vh][2].xy[1], + acc); + acc = gcn_mfma_instr(logits[qh], Vlocal[vh][3].xy[0], + acc); + acc = gcn_mfma_instr(logits[qh], Vlocal[vh][3].xy[1], + acc); + acc = gcn_mfma_instr(logits[qh], Vlocal[vh][4].xy[0], + acc); + acc = gcn_mfma_instr(logits[qh], Vlocal[vh][4].xy[1], + acc); + acc = gcn_mfma_instr(logits[qh], + Vlocal[vh][5].xy[0], acc); + acc = gcn_mfma_instr(logits[qh], + Vlocal[vh][5].xy[1], acc); + acc = gcn_mfma_instr(logits[qh], + Vlocal[vh][6].xy[0], acc); + acc = gcn_mfma_instr(logits[qh], + Vlocal[vh][6].xy[1], acc); + acc = gcn_mfma_instr(logits[qh], + Vlocal[vh][7].xy[0], acc); + acc = gcn_mfma_instr(logits[qh], + Vlocal[vh][7].xy[1], acc); + vout_shared[qh][vh][laneid][warpid] = from_floatx4(acc); + } + } + } // warp in context + + __syncthreads(); + + if (warpid == 0) { + _B16x4 vout[QHLOOP][VHELOOP]; + // iterate across heads + scalar_t* out_ptr; + int out_num_partitions; + if (context_len > partition_size) { + out_num_partitions = max_num_partitions; + out_ptr = out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE + + partition_idx * HEAD_SIZE; + } else { + out_num_partitions = 1; + out_ptr = final_out + seq_idx * num_heads * HEAD_SIZE; + } + #pragma unroll + for (int qh = 0; qh < QHLOOP; qh++) { + // iterate over each v head elem (within head_size) + #pragma unroll + for (int vh = 0; vh < VHELOOP; vh++) { + vout[qh][vh] = {0}; + #pragma unroll + for (int w = 0; w < NWARPS; w++) { + vout[qh][vh] = + addx4(vout[qh][vh], vout_shared[qh][vh][laneid][w]); + } + const int head_size_elem = vh * WARP_SIZE + laneid; + bit16_t* out_ptr_b16 = reinterpret_cast(out_ptr); + #pragma unroll + for (int i = 0; i < 4; i++) { + const int head_idx = 4 * qh + i; + if (head_idx < GQA_RATIO) { + out_ptr_b16[(wg_start_head_idx + head_idx) * out_num_partitions * + HEAD_SIZE + + head_size_elem] = vout[qh][vh][i]; + } + } + } + } + } + + #if 0 + const int num_seqs = gridDim.x; + const int global_token4id = global_token_idx/4; + #pragma unroll + for (int t=0;t<4;t++) { + #pragma unroll + for (int h=0;h +__global__ +__launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( + scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] + const float* __restrict__ exp_sums, // [num_seqs, num_heads, + // max_num_partitions] + const float* __restrict__ max_logits, // [num_seqs, num_heads, + // max_num_partitions] + const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, + // max_num_partitions, head_size] + const int* __restrict__ context_lens, // [num_seqs] + const int max_num_partitions) { + const int num_heads = gridDim.x; + const int head_idx = blockIdx.x; + const int seq_idx = blockIdx.y; + const int context_len = context_lens[seq_idx]; + const int num_partitions = DIVIDE_ROUND_UP(context_len, PARTITION_SIZE); + if (num_partitions == 1) { + // if num_partitions==1, main kernel will write to out directly, no work in + // reduction kernel + return; + } + + constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; + const int warpid = threadIdx.x / WARP_SIZE; + const int laneid = threadIdx.x % WARP_SIZE; + + __shared__ float shared_global_exp_sum; + __shared__ float shared_exp_sums[2 * WARP_SIZE]; + + if (warpid == 0) { + const float* max_logits_ptr = max_logits + + seq_idx * num_heads * max_num_partitions + + head_idx * max_num_partitions; + + // valid partition is the last valid partition in case threadid > num + // partitions + const int valid_partition = + (threadIdx.x < num_partitions) ? threadIdx.x : num_partitions - 1; + const int valid_partition2 = (WARP_SIZE + threadIdx.x < num_partitions) + ? WARP_SIZE + threadIdx.x + : num_partitions - 1; + float reg_max_logit = max_logits_ptr[valid_partition]; + float reg_max_logit2 = max_logits_ptr[valid_partition2]; + float max_logit = fmaxf(reg_max_logit, reg_max_logit2); + + #pragma unroll + for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) { + max_logit = fmaxf(max_logit, __shfl_xor(max_logit, mask)); + } + + const float* exp_sums_ptr = exp_sums + + seq_idx * num_heads * max_num_partitions + + head_idx * max_num_partitions; + + float global_exp_sum = 0.0f; + float rescaled_exp_sum = exp_sums_ptr[valid_partition]; + float rescaled_exp_sum2 = exp_sums_ptr[valid_partition2]; + rescaled_exp_sum *= + (threadIdx.x < num_partitions) ? expf(reg_max_logit - max_logit) : 0.0f; + rescaled_exp_sum2 *= (threadIdx.x + WARP_SIZE < num_partitions) + ? expf(reg_max_logit2 - max_logit) + : 0.0f; + global_exp_sum += rescaled_exp_sum + rescaled_exp_sum2; + shared_exp_sums[threadIdx.x] = rescaled_exp_sum; + shared_exp_sums[threadIdx.x + WARP_SIZE] = rescaled_exp_sum2; + + #pragma unroll + for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) { + global_exp_sum += __shfl_xor(global_exp_sum, mask); + } + if (threadIdx.x == 0) { + shared_global_exp_sum = global_exp_sum; + } + } // warpid == 0 + const scalar_t* tmp_out_ptr = + tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE + + head_idx * max_num_partitions * HEAD_SIZE + threadIdx.x; + constexpr int MAX_NPAR = 64; + scalar_t tmps[MAX_NPAR]; + const float dzero = 0.0f; + #pragma unroll + for (int j = 0; j < MAX_NPAR; j++) { + tmps[j] = from_float(dzero); + } + const int last_partition_offset = (num_partitions - 1) * HEAD_SIZE; + const int num_partition_offset = (num_partitions)*HEAD_SIZE; + int idx = 0; + + constexpr int JCHUNK = 16; + + #pragma unroll + for (int j = 0; j < JCHUNK * HEAD_SIZE; j += HEAD_SIZE) { + // lastj is last valid partition + const int lastj_offset = + (j < num_partition_offset) ? j : last_partition_offset; + tmps[idx] = tmp_out_ptr[lastj_offset]; + idx++; + } + __syncthreads(); + + if (num_partitions > JCHUNK) { + #pragma unroll + for (int j = JCHUNK * HEAD_SIZE; j < 2 * JCHUNK * HEAD_SIZE; + j += HEAD_SIZE) { + const int lastj_offset = + (j < num_partition_offset) ? j : last_partition_offset; + tmps[idx] = tmp_out_ptr[lastj_offset]; + idx++; + } + + if (num_partitions > 2 * JCHUNK) { + #pragma unroll + for (int j = 2 * JCHUNK * HEAD_SIZE; j < MAX_NPAR * HEAD_SIZE; + j += HEAD_SIZE) { + const int lastj_offset = + (j < num_partition_offset) ? j : last_partition_offset; + tmps[idx] = tmp_out_ptr[lastj_offset]; + idx++; + } + } + } // num_partitions > JCHUNK + + // Aggregate tmp_out to out. + float acc = 0.0f; + #pragma unroll + for (int j = 0; j < JCHUNK; j++) { + acc += to_float(tmps[j]) * shared_exp_sums[j]; + } + if (num_partitions > JCHUNK) { + #pragma unroll + for (int j = JCHUNK; j < 2 * JCHUNK; j++) { + acc += to_float(tmps[j]) * shared_exp_sums[j]; + } + if (num_partitions > 2 * JCHUNK) { + #pragma unroll + for (int j = 2 * JCHUNK; j < MAX_NPAR; j++) { + acc += to_float(tmps[j]) * shared_exp_sums[j]; + } + } + } + + if (num_partitions > MAX_NPAR) { + idx = 0; + #pragma unroll + for (int j = MAX_NPAR * HEAD_SIZE; j < 2 * MAX_NPAR * HEAD_SIZE; + j += HEAD_SIZE) { + // lastj is last valid partition + const int lastj_offset = + (j < num_partition_offset) ? j : last_partition_offset; + tmps[idx] = tmp_out_ptr[lastj_offset]; + idx++; + } + + #pragma unroll + for (int j = 0; j < MAX_NPAR; j++) { + acc += to_float(tmps[j]) * shared_exp_sums[j + MAX_NPAR]; + } + } + + const float inv_global_exp_sum = + __fdividef(1.0f, shared_global_exp_sum + 1e-6f); + acc *= inv_global_exp_sum; + scalar_t* out_ptr = + out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE; + out_ptr[threadIdx.x] = from_float(acc); +} + +#else // !defined(__HIP__MI300_MI250__) TODO: Add NAVI support + +template +__global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( + const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] + const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, + // head_size/x, block_size, x] + const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, + // head_size, block_size] + const int num_kv_heads, const float scale, + const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] + const int* __restrict__ context_lens, // [num_seqs] + const int max_num_blocks_per_seq, + const float* __restrict__ alibi_slopes, // [num_heads] + const int q_stride, const int kv_block_stride, const int kv_head_stride, + float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] + float* __restrict__ max_logits, // [num_seqs, num_heads, + // max_num_partitions] + scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, + // head_size] + scalar_t* __restrict__ final_out, // [num_seqs, num_heads, head_size] + #if 0 + scalar_t* __restrict__ qk_out, // [num_heads, num_seqs, max_ctx_blocks,block_size] + #endif + int max_ctx_blocks) { + UNREACHABLE_CODE +} + +// Grid: (num_heads, num_seqs). +template +__global__ +__launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( + scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] + const float* __restrict__ exp_sums, // [num_seqs, num_heads, + // max_num_partitions] + const float* __restrict__ max_logits, // [num_seqs, num_heads, + // max_num_partitions] + const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, + // max_num_partitions, head_size] + const int* __restrict__ context_lens, // [num_seqs] + const int max_num_partitions){UNREACHABLE_CODE} + +#endif // defined(__HIP__MI300_MI250__) TODO: Add NAVI support + +#define LAUNCH_CUSTOM_ATTENTION(GQA_RATIO) \ + paged_attention_ll4mi_QKV_kernel \ + <<>>( \ + query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, scale, \ + block_tables_ptr, context_lens_ptr, max_num_blocks_per_seq, \ + alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, \ + exp_sums_ptr, max_logits_ptr, tmp_out_ptr, out_ptr, max_ctx_blocks); + +template +void paged_attention_custom_launcher( + torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits, + torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache, + torch::Tensor& value_cache, const int num_kv_heads, float scale, + torch::Tensor& block_tables, torch::Tensor& context_lens, + int max_context_len, +#if 0 + torch::Tensor& qk_out, + torch::Tensor& softmax_out, +#endif + const c10::optional& alibi_slopes) { + + int num_seqs = query.size(0); + int num_heads = query.size(1); + int head_size = query.size(2); + int max_num_blocks_per_seq = block_tables.size(1); + int q_stride = query.stride(0); + int kv_block_stride = key_cache.stride(0); + int kv_head_stride = key_cache.stride(1); + + // NOTE: alibi_slopes is optional. + const float* alibi_slopes_ptr = + alibi_slopes + ? reinterpret_cast(alibi_slopes.value().data_ptr()) + : nullptr; + + T* out_ptr = reinterpret_cast(out.data_ptr()); + float* exp_sums_ptr = reinterpret_cast(exp_sums.data_ptr()); + float* max_logits_ptr = reinterpret_cast(max_logits.data_ptr()); + T* tmp_out_ptr = reinterpret_cast(tmp_out.data_ptr()); + T* query_ptr = reinterpret_cast(query.data_ptr()); + T* key_cache_ptr = reinterpret_cast(key_cache.data_ptr()); + T* value_cache_ptr = reinterpret_cast(value_cache.data_ptr()); + int* block_tables_ptr = block_tables.data_ptr(); + int* context_lens_ptr = context_lens.data_ptr(); +#if 0 + T* qk_out_ptr = reinterpret_cast(qk_out.data_ptr()); + T* softmax_out_ptr = reinterpret_cast(softmax_out.data_ptr()); +#endif + + const int max_ctx_blocks = DIVIDE_ROUND_UP(max_context_len, BLOCK_SIZE); + const int max_num_partitions = + DIVIDE_ROUND_UP(max_context_len, PARTITION_SIZE); + const int gqa_ratio = num_heads / num_kv_heads; + assert(num_heads % num_kv_heads == 0); + assert(head_size == HEAD_SIZE); + assert(max_num_partitions <= 128); + + constexpr int NTHR = PARTITION_SIZE; + dim3 grid(num_seqs, max_num_partitions, num_kv_heads); + dim3 block(NTHR); + const at::cuda::OptionalCUDAGuard device_guard(device_of(query)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + switch (gqa_ratio) { + case 1: + LAUNCH_CUSTOM_ATTENTION(1); + break; + case 2: + LAUNCH_CUSTOM_ATTENTION(2); + break; + case 3: + LAUNCH_CUSTOM_ATTENTION(3); + break; + case 4: + LAUNCH_CUSTOM_ATTENTION(4); + break; + case 5: + LAUNCH_CUSTOM_ATTENTION(5); + break; + case 6: + LAUNCH_CUSTOM_ATTENTION(6); + break; + case 7: + LAUNCH_CUSTOM_ATTENTION(7); + break; + case 8: + LAUNCH_CUSTOM_ATTENTION(8); + break; + case 9: + LAUNCH_CUSTOM_ATTENTION(9); + break; + case 10: + LAUNCH_CUSTOM_ATTENTION(10); + break; + case 11: + LAUNCH_CUSTOM_ATTENTION(11); + break; + case 12: + LAUNCH_CUSTOM_ATTENTION(12); + break; + case 13: + LAUNCH_CUSTOM_ATTENTION(13); + break; + case 14: + LAUNCH_CUSTOM_ATTENTION(14); + break; + case 15: + LAUNCH_CUSTOM_ATTENTION(15); + break; + case 16: + LAUNCH_CUSTOM_ATTENTION(16); + break; + default: + TORCH_CHECK(false, "Unsupported gqa ratio: ", gqa_ratio); + break; + } + // dim3 grid2(num_heads,num_seqs,head_size/HEAD_ELEMS_PER_WG); + // dim3 block2(1024); + // LAUNCH_CUSTOM_ATTENTION2; + + // reduction kernel is only required if max_context_len > partition size, + // otherwise main kernel writes directly to final output + // note there are cases with graphing where max_context_len is the max + // supported by graphing, not the actual max among all the sequences: in that + // case reduction kernel will still run but return immediately + if (max_context_len > PARTITION_SIZE) { + dim3 reduce_grid(num_heads, num_seqs); + dim3 reduce_block(head_size); + paged_attention_ll4mi_reduce_kernel + <<>>( + out_ptr, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, + context_lens_ptr, max_num_partitions); + } +} + +#define CALL_CUSTOM_LAUNCHER(T, BLK_SIZE, HEAD_SIZE) \ + paged_attention_custom_launcher( \ + out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \ + num_kv_heads, scale, block_tables, context_lens, max_context_len, \ + alibi_slopes); + +#define CALL_CUSTOM_LAUNCHER_BLK(T, HEAD_SIZE) \ + switch (block_size) { \ + case 16: \ + CALL_CUSTOM_LAUNCHER(T, 16, HEAD_SIZE); \ + break; \ + case 32: \ + CALL_CUSTOM_LAUNCHER(T, 32, HEAD_SIZE); \ + break; \ + default: \ + TORCH_CHECK(false, "Unsupported block size: ", block_size); \ + break; \ + } + +#define CALL_CUSTOM_LAUNCHER_BLK_HEAD(T) \ + switch (head_size) { \ + case 64: \ + CALL_CUSTOM_LAUNCHER_BLK(T, 64); \ + break; \ + case 128: \ + CALL_CUSTOM_LAUNCHER_BLK(T, 128); \ + break; \ + default: \ + TORCH_CHECK(false, "Unsupported head size: ", head_size); \ + break; \ + } + +void paged_attention_custom( + torch::Tensor& out, // [num_seqs, num_heads, head_size] + torch::Tensor& exp_sums, // [num_seqs, num_heads, max_num_partitions] + torch::Tensor& max_logits, // [num_seqs, num_heads, max_num_partitions] + torch::Tensor& + tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size] + torch::Tensor& query, // [num_seqs, num_heads, head_size] + torch::Tensor& + key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] + torch::Tensor& + value_cache, // [num_blocks, num_heads, head_size, block_size] + int num_kv_heads, float scale, + torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq] + torch::Tensor& context_lens, // [num_seqs] + int block_size, int max_context_len, +#if 0 + torch::Tensor& qk_out, + torch::Tensor& softmax_out, +#endif + const c10::optional& alibi_slopes, + const std::string& kv_cache_dtype) { + assert(kv_cache_dtype == "auto"); + const int head_size = query.size(2); + if (query.dtype() == at::ScalarType::Half) { + CALL_CUSTOM_LAUNCHER_BLK_HEAD(_Float16); + } else if (query.dtype() == at::ScalarType::BFloat16) { + CALL_CUSTOM_LAUNCHER_BLK_HEAD(__hip_bfloat16); + } else { + TORCH_CHECK(false, "Unsupported data type: ", query.dtype()); + } +} + +#undef WARP_SIZE +#undef MAX +#undef MIN +#undef DIVIDE_ROUND_UP \ No newline at end of file diff --git a/csrc/hip/torch_binding.cpp b/csrc/hip/torch_binding.cpp new file mode 100644 index 0000000000000..e69de29bb2d1d From 5f605bc7e28d29442ce59d34e05920e28a6975ca Mon Sep 17 00:00:00 2001 From: charlifu Date: Tue, 27 Aug 2024 18:33:58 +0000 Subject: [PATCH 02/12] register custom op --- csrc/hip/ops.h | 12 ++++++++++++ csrc/hip/torch_binding.cpp | 27 +++++++++++++++++++++++++++ 2 files changed, 39 insertions(+) create mode 100644 csrc/hip/ops.h diff --git a/csrc/hip/ops.h b/csrc/hip/ops.h new file mode 100644 index 0000000000000..82cc579996eca --- /dev/null +++ b/csrc/hip/ops.h @@ -0,0 +1,12 @@ +#pragma once + +#include +#include + +void paged_attention_custom( + torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits, + torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache, + torch::Tensor& value_cache, int64_t num_kv_heads, double scale, + torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size, + int64_t max_seq_len, const c10::optional& alibi_slopes, + const std::string& kv_cache_dtype) \ No newline at end of file diff --git a/csrc/hip/torch_binding.cpp b/csrc/hip/torch_binding.cpp index e69de29bb2d1d..b2802868d0dc3 100644 --- a/csrc/hip/torch_binding.cpp +++ b/csrc/hip/torch_binding.cpp @@ -0,0 +1,27 @@ +#include "hip/ops.h" +#include "core/registration.h" + +#include + +// Note on op signatures: +// The X_meta signatures are for the meta functions corresponding to op X. +// They must be kept in sync with the signature for X. Generally, only +// functions that return Tensors require a meta function. +// +// See the following links for detailed docs on op registration and function +// schemas. +// https://docs.google.com/document/d/1_W62p8WJOQQUzPsJYa7s701JXt0qf2OfLub2sbkHOaU/edit#heading=h.ptttacy8y1u9 +// https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/README.md#annotations + +TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { + // vLLM custom ops for rocm + + // Custom attention op + // Compute the attention between an input query and the cached + // keys/values using PagedAttention. + ops.def( + "paged_attention_custom(" + "" + ); + ops.impl("paged_attention_custom", torch::kCUDA, &paged_attention_custom) +} \ No newline at end of file From 554804b9c9635795d0f9e8d96e0580c121e1d051 Mon Sep 17 00:00:00 2001 From: charlifu Date: Mon, 9 Sep 2024 15:12:25 +0000 Subject: [PATCH 03/12] add paged attention for rocm --- CMakeLists.txt | 23 +++++++++++-------- csrc/hip/ops.h | 12 ---------- csrc/{hip/attention.hip => rocm/attention.cu} | 2 +- csrc/rocm/ops.h | 16 +++++++++++++ .../torch_bindings.cpp} | 20 +++++++++++----- setup.py | 3 +++ 6 files changed, 47 insertions(+), 29 deletions(-) delete mode 100644 csrc/hip/ops.h rename csrc/{hip/attention.hip => rocm/attention.cu} (99%) create mode 100644 csrc/rocm/ops.h rename csrc/{hip/torch_binding.cpp => rocm/torch_bindings.cpp} (55%) diff --git a/CMakeLists.txt b/CMakeLists.txt index 7746bb7f11e06..861a2e3dc386b 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -313,25 +313,24 @@ define_gpu_extension_target( USE_SABI 3 WITH_SOABI) -# -# _hip_C extension -# if(VLLM_GPU_LANG STREQUAL "HIP") - set(VLLM_HIP_EXT_SRC - "csrc/hip/torch_bindings.cpp" - "csrc/hip/attention.hip") + # + # _rocm_C extension + # + set(VLLM_ROCM_EXT_SRC + "csrc/rocm/torch_bindings.cpp" + "csrc/rocm/attention.cu") define_gpu_extension_target( - _hip_C + _rocm_C DESTINATION vllm LANGUAGE ${VLLM_GPU_LANG} - SOURCES ${VLLM_HIP_EXT_SRC} + SOURCES ${VLLM_ROCM_EXT_SRC} COMPILE_FLAGS ${VLLM_GPU_FLAGS} ARCHITECTURES ${VLLM_GPU_ARCHES} USE_SABI 3 WITH_SOABI) - endif() @@ -341,5 +340,9 @@ if(VLLM_GPU_LANG STREQUAL "CUDA" OR VLLM_GPU_LANG STREQUAL "HIP") message(STATUS "Enabling moe extension.") add_dependencies(default _moe_C) - endif() + +if(VLLM_GPU_LANG STREQUAL "HIP") + message(STATUS "Enabling rocm extension.") + add_dependencies(default _rocm_C) +endif() \ No newline at end of file diff --git a/csrc/hip/ops.h b/csrc/hip/ops.h deleted file mode 100644 index 82cc579996eca..0000000000000 --- a/csrc/hip/ops.h +++ /dev/null @@ -1,12 +0,0 @@ -#pragma once - -#include -#include - -void paged_attention_custom( - torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits, - torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache, - torch::Tensor& value_cache, int64_t num_kv_heads, double scale, - torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size, - int64_t max_seq_len, const c10::optional& alibi_slopes, - const std::string& kv_cache_dtype) \ No newline at end of file diff --git a/csrc/hip/attention.hip b/csrc/rocm/attention.cu similarity index 99% rename from csrc/hip/attention.hip rename to csrc/rocm/attention.cu index 0f1050891c477..145b6cc239483 100644 --- a/csrc/hip/attention.hip +++ b/csrc/rocm/attention.cu @@ -1083,7 +1083,7 @@ void paged_attention_custom_launcher( break; \ } -void paged_attention_custom( +void paged_attention( torch::Tensor& out, // [num_seqs, num_heads, head_size] torch::Tensor& exp_sums, // [num_seqs, num_heads, max_num_partitions] torch::Tensor& max_logits, // [num_seqs, num_heads, max_num_partitions] diff --git a/csrc/rocm/ops.h b/csrc/rocm/ops.h new file mode 100644 index 0000000000000..01bbd789ed14e --- /dev/null +++ b/csrc/rocm/ops.h @@ -0,0 +1,16 @@ +#pragma once + +#include +#include + +#include "core/scalar_type.hpp" + +void paged_attention(torch::Tensor& out, torch::Tensor& exp_sums, + torch::Tensor& max_logits, torch::Tensor& tmp_out, + torch::Tensor& query, torch::Tensor& key_cache, + torch::Tensor& value_cache, int64_t num_kv_heads, + double scale, torch::Tensor& block_tables, + torch::Tensor& context_lens, int64_t block_size, + int64_t max_context_len, + const c10::optional& alibi_slopes, + const std::string& kv_cache_dtype); \ No newline at end of file diff --git a/csrc/hip/torch_binding.cpp b/csrc/rocm/torch_bindings.cpp similarity index 55% rename from csrc/hip/torch_binding.cpp rename to csrc/rocm/torch_bindings.cpp index b2802868d0dc3..cf0c391bcfdc2 100644 --- a/csrc/hip/torch_binding.cpp +++ b/csrc/rocm/torch_bindings.cpp @@ -1,4 +1,5 @@ -#include "hip/ops.h" +#include "cuda_utils.h" +#include "rocm/ops.h" #include "core/registration.h" #include @@ -19,9 +20,16 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { // Custom attention op // Compute the attention between an input query and the cached // keys/values using PagedAttention. - ops.def( - "paged_attention_custom(" - "" - ); - ops.impl("paged_attention_custom", torch::kCUDA, &paged_attention_custom) +custom_ops.def( + "paged_attention(Tensor! out, Tensor exp_sums," + " Tensor max_logits, Tensor tmp_out," + " Tensor query, Tensor key_cache," + " Tensor value_cache, int num_kv_heads," + " float scale, Tensor block_tables," + " Tensor context_lens, int block_size," + " int max_context_len," + " Tensor? alibi_slopes," + " str kv_cache_dtype) -> ()"); + custom_ops.impl("paged_attention", torch::kCUDA, + &paged_attention); } \ No newline at end of file diff --git a/setup.py b/setup.py index 1e08a5bd70cd3..ae8b5e8c70d48 100644 --- a/setup.py +++ b/setup.py @@ -459,6 +459,9 @@ def _read_requirements(filename: str) -> List[str]: if _is_cuda() or _is_hip(): ext_modules.append(CMakeExtension(name="vllm._moe_C")) +if _is_hip(): + ext_modules.append(CMakeExtension(name="vllm._rocm_C")) + if _build_custom_ops(): ext_modules.append(CMakeExtension(name="vllm._C")) From 276377823acf3a683a466cb20165c03f56acd9ea Mon Sep 17 00:00:00 2001 From: charlifu Date: Mon, 9 Sep 2024 22:40:21 +0000 Subject: [PATCH 04/12] enable custom page attn and unit test --- csrc/rocm/attention.cu | 26 +++-- csrc/rocm/ops.h | 5 +- csrc/rocm/torch_bindings.cpp | 24 ++--- tests/kernels/test_attention.py | 166 ++++++++++++++++++++++++++++++- vllm/_custom_ops.py | 28 ++++++ vllm/attention/ops/paged_attn.py | 83 +++++++++++----- vllm/attention/selector.py | 15 +++ vllm/envs.py | 5 + 8 files changed, 300 insertions(+), 52 deletions(-) diff --git a/csrc/rocm/attention.cu b/csrc/rocm/attention.cu index 145b6cc239483..df54e22147413 100644 --- a/csrc/rocm/attention.cu +++ b/csrc/rocm/attention.cu @@ -1,4 +1,20 @@ -#include +/* + * Copyright (c) 2024, The vLLM team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include #include #include #include @@ -1094,14 +1110,10 @@ void paged_attention( key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size] - int num_kv_heads, float scale, + int64_t num_kv_heads, double scale, torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq] torch::Tensor& context_lens, // [num_seqs] - int block_size, int max_context_len, -#if 0 - torch::Tensor& qk_out, - torch::Tensor& softmax_out, -#endif + int64_t block_size, int64_t max_context_len, const c10::optional& alibi_slopes, const std::string& kv_cache_dtype) { assert(kv_cache_dtype == "auto"); diff --git a/csrc/rocm/ops.h b/csrc/rocm/ops.h index 01bbd789ed14e..ba0c4591aac13 100644 --- a/csrc/rocm/ops.h +++ b/csrc/rocm/ops.h @@ -1,9 +1,6 @@ #pragma once -#include -#include - -#include "core/scalar_type.hpp" +#include void paged_attention(torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits, torch::Tensor& tmp_out, diff --git a/csrc/rocm/torch_bindings.cpp b/csrc/rocm/torch_bindings.cpp index cf0c391bcfdc2..2352dd9e6b565 100644 --- a/csrc/rocm/torch_bindings.cpp +++ b/csrc/rocm/torch_bindings.cpp @@ -1,8 +1,5 @@ -#include "cuda_utils.h" -#include "rocm/ops.h" #include "core/registration.h" - -#include +#include "rocm/ops.h" // Note on op signatures: // The X_meta signatures are for the meta functions corresponding to op X. @@ -14,13 +11,13 @@ // https://docs.google.com/document/d/1_W62p8WJOQQUzPsJYa7s701JXt0qf2OfLub2sbkHOaU/edit#heading=h.ptttacy8y1u9 // https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/README.md#annotations -TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { - // vLLM custom ops for rocm +TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, rocm_ops) { + // vLLM custom ops for rocm - // Custom attention op - // Compute the attention between an input query and the cached - // keys/values using PagedAttention. -custom_ops.def( + // Custom attention op + // Compute the attention between an input query and the cached + // keys/values using PagedAttention. + rocm_ops.def( "paged_attention(Tensor! out, Tensor exp_sums," " Tensor max_logits, Tensor tmp_out," " Tensor query, Tensor key_cache," @@ -30,6 +27,7 @@ custom_ops.def( " int max_context_len," " Tensor? alibi_slopes," " str kv_cache_dtype) -> ()"); - custom_ops.impl("paged_attention", torch::kCUDA, - &paged_attention); -} \ No newline at end of file + rocm_ops.impl("paged_attention", torch::kCUDA, &paged_attention); +} + +REGISTER_EXTENSION(TORCH_EXTENSION_NAME) \ No newline at end of file diff --git a/tests/kernels/test_attention.py b/tests/kernels/test_attention.py index 8aa2d4a53aaa0..a80c63734ed1b 100644 --- a/tests/kernels/test_attention.py +++ b/tests/kernels/test_attention.py @@ -3,14 +3,16 @@ import pytest import torch -from xformers import ops as xops -from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask from vllm import _custom_ops as ops from vllm.utils import get_max_shared_memory_bytes, is_hip from .allclose_default import get_default_atol, get_default_rtol +if not is_hip(): + from xformers import ops as xops + from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask + FLOAT32_BYTES = torch.finfo(torch.float).bits // 8 # This will change depending on the compute capability. # - 512 as a buffer @@ -121,6 +123,7 @@ def ref_single_query_cached_kv_attention( @pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE) @pytest.mark.parametrize("seed", SEEDS) @pytest.mark.parametrize("device", CUDA_DEVICES) +@pytest.mark.skip() def test_paged_attention( kv_cache_factory, version: str, @@ -312,6 +315,164 @@ def ref_multi_query_kv_attention( return torch.cat(ref_outputs, dim=0) +@pytest.mark.parametrize("version", ["rocm"]) +@pytest.mark.parametrize("num_seqs", NUM_GEN_SEQS) +@pytest.mark.parametrize("num_heads", NUM_HEADS) +@pytest.mark.parametrize("head_size", [64, 128]) # only test 64 128 +@pytest.mark.parametrize("use_alibi", USE_ALIBI) +@pytest.mark.parametrize("block_size", BLOCK_SIZES) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("kv_cache_dtype", ["auto"]) +@pytest.mark.parametrize("seed", SEEDS) +@pytest.mark.parametrize("device", CUDA_DEVICES) +@pytest.mark.skipif(not is_hip(), reason="only for rocm") +def test_paged_attention_rocm( + kv_cache_factory, + version: str, + num_seqs: int, + num_heads: Tuple[int, int], + head_size: int, + use_alibi: bool, + block_size: int, + dtype: torch.dtype, + kv_cache_dtype: str, + seed: int, + device: str, +) -> None: + random.seed(seed) + torch.random.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + torch.set_default_device(device) + scale = float(1.0 / (head_size**0.5)) + num_query_heads, num_kv_heads = num_heads + query = torch.empty(num_seqs, num_query_heads, head_size, dtype=dtype) + query.uniform_(-scale, scale) + + assert num_query_heads % num_kv_heads == 0 + num_queries_per_kv = num_query_heads // num_kv_heads + alibi_slopes = None + if use_alibi: + alibi_slopes = torch.randn(num_query_heads, dtype=torch.float) + + context_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_seqs)] + context_lens[-1] = MAX_SEQ_LEN + #context_lens = [8192 for _ in range(num_seqs)] + max_context_len = max(context_lens) + context_lens = torch.tensor(context_lens, dtype=torch.int) + #print('>>> ctx lens', context_lens) + + # Create the block tables. + max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size + block_tables = [] + for _ in range(num_seqs): + block_table = [ + random.randint(0, NUM_BLOCKS - 1) + for _ in range(max_num_blocks_per_seq) + ] + block_tables.append(block_table) + block_tables = torch.tensor(block_tables, dtype=torch.int) + + # Create the KV caches. + key_caches, value_caches = kv_cache_factory(NUM_BLOCKS, block_size, 1, + num_kv_heads, head_size, + kv_cache_dtype, dtype, seed, + device) + key_cache, value_cache = key_caches[0], value_caches[0] + + # Using default kv_scale + kv_scale = 1.0 + + # Call the paged attention kernel. + output = torch.empty_like(query) + PARTITION_SIZE_ROCM = 256 + num_partitions = ((max_context_len + PARTITION_SIZE_ROCM - 1) // + PARTITION_SIZE_ROCM) + assert PARTITION_SIZE % block_size == 0 + num_seqs, num_heads, head_size = output.shape + tmp_output = torch.empty( + size=(num_seqs, num_heads, num_partitions, head_size), + dtype=output.dtype, + ) + exp_sums = torch.empty( + size=(num_seqs, num_heads, num_partitions), + dtype=torch.float32, + ) + max_logits = torch.empty_like(exp_sums) + if version == "rocm": + ops.paged_attention_rocm( + output, + exp_sums, + max_logits, + tmp_output, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + context_lens, + block_size, + max_context_len, + alibi_slopes, + kv_cache_dtype, + ) + else: + raise AssertionError(f"Unknown version: {version}") + + # Run the reference implementation. + if kv_cache_dtype == "fp8": + # Convert cache data back to dtype. + x = 16 // torch.tensor([], dtype=dtype).element_size() + key_cache_shape = (NUM_BLOCKS, num_kv_heads, head_size // x, + block_size, x) + dequantized_key_cache = torch.empty(size=key_cache_shape, + dtype=dtype, + device=device) + ops.convert_fp8(key_cache, dequantized_key_cache) + key_cache = dequantized_key_cache + + value_cache_shape = value_cache.shape + dequantized_value_cache = torch.empty(size=value_cache_shape, + dtype=dtype, + device=device) + ops.convert_fp8(value_cache, dequantized_value_cache) + value_cache = dequantized_value_cache + + ref_output = torch.empty_like(query) + ref_single_query_cached_kv_attention( + ref_output, + query, + num_queries_per_kv, + key_cache, + value_cache, + block_tables, + context_lens, + scale, + alibi_slopes, + ) + + # NOTE(woosuk): Due to the kernel-level differences in the two + # implementations, there is a small numerical difference in the two + # outputs. Thus, we use a relaxed tolerance for the test. + atol = get_default_atol(output) if is_hip() else 1e-3 + rtol = get_default_rtol(output) if is_hip() else 1e-5 + + # NOTE(zhaoyang): FP8 KV Cache will introduce quantization error, + # so we use a relaxed tolerance for the test. + atol, rtol = 1e-4, 1e-5 + if dtype == torch.bfloat16: + atol, rtol = 2e-4, 1e-5 + if use_alibi: + if dtype == torch.half: + atol, rtol = 5e-4, 1e-5 + if dtype == torch.bfloat16: + atol, rtol = 1e-3, 1e-5 + if kv_cache_dtype == "fp8": + atol, rtol = 1e-2, 1e-5 + assert torch.allclose(output, ref_output, atol=atol, rtol=rtol) + + # TODO(woosuk): Add tests for USE_ALIBI=True. @pytest.mark.parametrize("num_seqs", NUM_PREFILL_SEQS) @pytest.mark.parametrize("num_heads", NUM_HEADS) @@ -319,6 +480,7 @@ def ref_multi_query_kv_attention( @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("seed", SEEDS) @pytest.mark.parametrize("device", CUDA_DEVICES) +@pytest.mark.skipif(is_hip(), reason="skip for rocm") @torch.inference_mode() def test_multi_query_kv_attention( num_seqs: int, diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 151cdbee8eb04..72de8a1150f43 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -17,6 +17,10 @@ except ImportError as e: logger.warning("Failed to import from vllm._C with %r", e) +if current_platform.is_rocm(): + with contextlib.suppress(ImportError): + import vllm._rocm_C # noqa: F401 + with contextlib.suppress(ImportError): import vllm._moe_C # noqa: F401 @@ -127,6 +131,30 @@ def paged_attention_v2( blocksparse_block_size, blocksparse_head_sliding_step) +def paged_attention_rocm( + out: torch.Tensor, + exp_sum: torch.Tensor, + max_logits: torch.Tensor, + tmp_out: torch.Tensor, + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + num_kv_heads: int, + scale: float, + block_tables: torch.Tensor, + seq_lens: torch.Tensor, + block_size: int, + max_seq_len: int, + alibi_slopes: Optional[torch.Tensor], + kv_cache_dtype: str, +) -> None: + torch.ops._rocm_C.paged_attention(out, exp_sum, max_logits, tmp_out, query, + key_cache, value_cache, num_kv_heads, + scale, block_tables, seq_lens, + block_size, max_seq_len, alibi_slopes, + kv_cache_dtype) + + # pos encoding ops def rotary_embedding( positions: torch.Tensor, diff --git a/vllm/attention/ops/paged_attn.py b/vllm/attention/ops/paged_attn.py index 92023d5b75f5a..93a45d28018d8 100644 --- a/vllm/attention/ops/paged_attn.py +++ b/vllm/attention/ops/paged_attn.py @@ -4,13 +4,15 @@ import torch from vllm import _custom_ops as ops +from vllm.attention.selector import use_rocm_paged_attention from vllm.triton_utils import HAS_TRITON if HAS_TRITON: from vllm.attention.ops.prefix_prefill import context_attention_fwd # Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`. -_PARTITION_SIZE = 512 +_PARTITION_SIZE_V1V2 = 512 +_PARTITION_SIZE_ROCM = 256 @dataclass @@ -114,8 +116,17 @@ def forward_decode( output = torch.empty_like(query) block_size = value_cache.shape[3] num_seqs, num_heads, head_size = query.shape + gqa_ratio = num_heads // num_kv_heads + use_rocm = use_rocm_paged_attention(query.dtype, head_size, block_size, + kv_cache_dtype, gqa_ratio, + max_seq_len) + + # select the partition size + _PARTITION_SIZE = _PARTITION_SIZE_ROCM if use_rocm else \ + _PARTITION_SIZE_V1V2 max_num_partitions = ((max_seq_len + _PARTITION_SIZE - 1) // _PARTITION_SIZE) + # NOTE(woosuk): We use a simple heuristic to decide whether to use # PagedAttention V1 or V2. If the number of partitions is 1, we use # V1 to avoid the overhead of reduction. Also, if the number of @@ -123,7 +134,7 @@ def forward_decode( # to parallelize. # TODO(woosuk): Tune this heuristic. # For context len > 8192, use V2 kernel to avoid shared memory shortage. - use_v1 = (max_seq_len <= 8192 + use_v1 = ((not use_rocm) and max_seq_len <= 8192 and (max_num_partitions == 1 or num_seqs * num_heads > 512)) if use_v1: @@ -163,30 +174,50 @@ def forward_decode( device=output.device, ) max_logits = torch.empty_like(exp_sums) - ops.paged_attention_v2( - output, - exp_sums, - max_logits, - tmp_output, - query, - key_cache, - value_cache, - num_kv_heads, - scale, - block_tables, - seq_lens, - block_size, - max_seq_len, - alibi_slopes, - kv_cache_dtype, - k_scale, - v_scale, - tp_rank, - blocksparse_local_blocks, - blocksparse_vert_stride, - blocksparse_block_size, - blocksparse_head_sliding_step, - ) + if not use_rocm: + ops.paged_attention_v2( + output, + exp_sums, + max_logits, + tmp_output, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + seq_lens, + block_size, + max_seq_len, + alibi_slopes, + kv_cache_dtype, + k_scale, + v_scale, + tp_rank, + blocksparse_local_blocks, + blocksparse_vert_stride, + blocksparse_block_size, + blocksparse_head_sliding_step, + ) + else: + # run rocm custom paged attention + ops.paged_attention_rocm( + output, + exp_sums, + max_logits, + tmp_output, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + seq_lens, + block_size, + max_seq_len, + alibi_slopes, + kv_cache_dtype, + ) return output @staticmethod diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index 855586d4e5961..9668a85b2e99d 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -38,6 +38,21 @@ def backend_name_to_enum(backend_name: str) -> _Backend: return _Backend[backend_name] +def use_rocm_paged_attention(qtype: torch.dtype, head_size: int, + block_size: int, kv_cache_dtype: str, + gqa_ratio: int, max_seq_len: int) -> bool: + # To use rocm custom paged attention kernel or not + rocm_paged_attention_available = ( + is_hip() and envs.VLLM_USE_ROCM_CUSTOM_PAGED_ATTN + and "gfx1" not in torch.cuda.get_device_properties("cuda").gcnArchName) + return (rocm_paged_attention_available + and (qtype == torch.half or qtype == torch.bfloat16) + and (head_size == 64 or head_size == 128) + and (block_size == 16 or block_size == 32) + and kv_cache_dtype == "auto" + and (gqa_ratio >= 1 and gqa_ratio <= 16) and max_seq_len <= 32768) + + def get_env_variable_attn_backend() -> Optional[_Backend]: ''' Get the backend override specified by the vLLM attention diff --git a/vllm/envs.py b/vllm/envs.py index ed45047e9f8fc..3a29463a5baca 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -203,6 +203,11 @@ def get_default_config_root(): (os.environ.get("VLLM_DYNAMO_USE_CUSTOM_DISPATCHER", "True").lower() in ("true", "1")), + # Rocm custom paged attention implemented for MI3* GPUs + "VLLM_USE_ROCM_CUSTOM_PAGED_ATTN": + lambda: (os.getenv("VLLM_USE_ROCM_CUSTOM_PAGED_ATTN", "True").lower() in + ("true", "1") != "0"), + # local rank of the process in the distributed setting, used to determine # the GPU device id "LOCAL_RANK": From 5951a52b35dbeec44a8e7a6a9dcca33683e67316 Mon Sep 17 00:00:00 2001 From: charlifu Date: Mon, 9 Sep 2024 22:41:49 +0000 Subject: [PATCH 05/12] fix v1/v2 --- tests/kernels/test_attention.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/kernels/test_attention.py b/tests/kernels/test_attention.py index a80c63734ed1b..980e57c85868c 100644 --- a/tests/kernels/test_attention.py +++ b/tests/kernels/test_attention.py @@ -123,7 +123,6 @@ def ref_single_query_cached_kv_attention( @pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE) @pytest.mark.parametrize("seed", SEEDS) @pytest.mark.parametrize("device", CUDA_DEVICES) -@pytest.mark.skip() def test_paged_attention( kv_cache_factory, version: str, From 0f66eb99cb1f7fece85ca1e76d33008380f80763 Mon Sep 17 00:00:00 2001 From: charlifu Date: Tue, 10 Sep 2024 00:45:31 +0000 Subject: [PATCH 06/12] add hip back to gitignore --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index a5b6b48f65823..761b00ac3bc48 100644 --- a/.gitignore +++ b/.gitignore @@ -187,6 +187,7 @@ _build/ *.swp # hip files generated by PyTorch +*.hip *_hip* hip_compat.h From 79449fe4836d4b12059f02bd6b562d0912bc8636 Mon Sep 17 00:00:00 2001 From: charlifu Date: Tue, 10 Sep 2024 00:48:58 +0000 Subject: [PATCH 07/12] linting --- tests/kernels/test_attention.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tests/kernels/test_attention.py b/tests/kernels/test_attention.py index 980e57c85868c..ee1a3d1eaa371 100644 --- a/tests/kernels/test_attention.py +++ b/tests/kernels/test_attention.py @@ -317,7 +317,7 @@ def ref_multi_query_kv_attention( @pytest.mark.parametrize("version", ["rocm"]) @pytest.mark.parametrize("num_seqs", NUM_GEN_SEQS) @pytest.mark.parametrize("num_heads", NUM_HEADS) -@pytest.mark.parametrize("head_size", [64, 128]) # only test 64 128 +@pytest.mark.parametrize("head_size", [64, 128]) # only test 64 128 @pytest.mark.parametrize("use_alibi", USE_ALIBI) @pytest.mark.parametrize("block_size", BLOCK_SIZES) @pytest.mark.parametrize("dtype", DTYPES) @@ -379,8 +379,9 @@ def test_paged_attention_rocm( device) key_cache, value_cache = key_caches[0], value_caches[0] + # TODO(charlifu) enable fp8 kv cache # Using default kv_scale - kv_scale = 1.0 + # kv_scale = 1.0 # Call the paged attention kernel. output = torch.empty_like(query) @@ -470,7 +471,7 @@ def test_paged_attention_rocm( if kv_cache_dtype == "fp8": atol, rtol = 1e-2, 1e-5 assert torch.allclose(output, ref_output, atol=atol, rtol=rtol) - + # TODO(woosuk): Add tests for USE_ALIBI=True. @pytest.mark.parametrize("num_seqs", NUM_PREFILL_SEQS) From 6f4007953a71d072b5ae82398fb0ee803c7063aa Mon Sep 17 00:00:00 2001 From: charlifu Date: Tue, 10 Sep 2024 15:51:38 +0000 Subject: [PATCH 08/12] remove unneeded code --- csrc/rocm/attention.cu | 95 ------------------------------------------ 1 file changed, 95 deletions(-) diff --git a/csrc/rocm/attention.cu b/csrc/rocm/attention.cu index df54e22147413..816285f01db24 100644 --- a/csrc/rocm/attention.cu +++ b/csrc/rocm/attention.cu @@ -62,8 +62,6 @@ typedef struct _B16x8 { ////// Non temporal load stores /////// - #if 1 - template __device__ __forceinline__ T load(T* addr) { return addr[0]; @@ -74,83 +72,6 @@ __device__ __forceinline__ void store(T value, T* addr) { addr[0] = value; } - #else - -template -__device__ __forceinline__ T load(const T* addr) { - return __builtin_nontemporal_load(addr); -} - -template <> -__device__ __forceinline__ float2 load(const float2* addr) { - auto addr_alias{reinterpret_cast(addr)}; - auto result = __builtin_nontemporal_load(addr_alias); - auto ret = reinterpret_cast(&result); - return ret[0]; -} - -template <> -__device__ __forceinline__ float4 load(const float4* addr) { - auto addr_alias{reinterpret_cast(addr)}; - auto result1 = __builtin_nontemporal_load(addr_alias); - auto result2 = __builtin_nontemporal_load(addr_alias + 1); - float4 ret{}; - auto ret_alias = reinterpret_cast(&result1); - ret.x = ret_alias->x; - ret.y = ret_alias->y; - ret_alias = reinterpret_cast(&result2); - ret.z = ret_alias->x; - ret.w = ret_alias->y; - return ret; -} - -template <> -__device__ __forceinline__ __half load(const __half* addr) { - auto addr_alias{reinterpret_cast(addr)}; - auto result = __builtin_nontemporal_load(addr_alias); - auto ret = reinterpret_cast<__half*>(&result); - return ret[0]; -} - -template <> -__device__ __forceinline__ __half2 load(const __half2* addr) { - auto addr_alias{reinterpret_cast(addr)}; - auto result = __builtin_nontemporal_load(addr_alias); - auto ret = reinterpret_cast<__half2*>(&result); - return ret[0]; -} - -template <> -__device__ __forceinline__ vllm::Half4_ load(const vllm::Half4_* addr) { - auto addr_alias{reinterpret_cast(addr)}; - auto result = __builtin_nontemporal_load(addr_alias); - auto ret = reinterpret_cast(&result); - return ret[0]; -} - -template <> -__device__ __forceinline__ vllm::Half8_ load(const vllm::Half8_* addr) { - auto addr_alias{reinterpret_cast(addr)}; - auto result1 = __builtin_nontemporal_load(addr_alias); - auto result2 = __builtin_nontemporal_load(addr_alias + 1); - vllm::Half8_ ret{}; - auto ret_alias = reinterpret_cast(&result1); - ret.x = ret_alias->x; - ret.y = ret_alias->y; - ret_alias = reinterpret_cast(&result2); - ret.z = ret_alias->x; - ret.w = ret_alias->y; - return ret; -} - -//// Not using nontemporal stores for now -template -__device__ __forceinline__ void store(T value, T* addr) { - return __builtin_nontemporal_store(value, addr); -} - - #endif - template __device__ __forceinline__ floatx4 gcn_mfma_instr(const _B16x4& inpA, const _B16x4& inpB, @@ -699,22 +620,6 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( } } } - - #if 0 - const int num_seqs = gridDim.x; - const int global_token4id = global_token_idx/4; - #pragma unroll - for (int t=0;t<4;t++) { - #pragma unroll - for (int h=0;h Date: Thu, 12 Sep 2024 17:14:49 -0500 Subject: [PATCH 09/12] Update CMakeLists.txt Co-authored-by: Woosuk Kwon --- CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index ee9fa750b0d1c..7d0f26203e790 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -344,4 +344,4 @@ endif() if(VLLM_GPU_LANG STREQUAL "HIP") message(STATUS "Enabling rocm extension.") add_dependencies(default _rocm_C) -endif() \ No newline at end of file +endif() From 2fc628ca34ad5449fad7a24133228e1da4b0bc97 Mon Sep 17 00:00:00 2001 From: charlifu Date: Thu, 12 Sep 2024 22:20:50 +0000 Subject: [PATCH 10/12] add empty line and remove env --- csrc/rocm/attention.cu | 2 +- csrc/rocm/ops.h | 2 +- csrc/rocm/torch_bindings.cpp | 2 +- vllm/attention/selector.py | 3 +-- vllm/envs.py | 5 ----- 5 files changed, 4 insertions(+), 10 deletions(-) diff --git a/csrc/rocm/attention.cu b/csrc/rocm/attention.cu index 816285f01db24..8fa7c862fbfa8 100644 --- a/csrc/rocm/attention.cu +++ b/csrc/rocm/attention.cu @@ -1035,4 +1035,4 @@ void paged_attention( #undef WARP_SIZE #undef MAX #undef MIN -#undef DIVIDE_ROUND_UP \ No newline at end of file +#undef DIVIDE_ROUND_UP diff --git a/csrc/rocm/ops.h b/csrc/rocm/ops.h index ba0c4591aac13..4a07a3f1775bd 100644 --- a/csrc/rocm/ops.h +++ b/csrc/rocm/ops.h @@ -10,4 +10,4 @@ void paged_attention(torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& context_lens, int64_t block_size, int64_t max_context_len, const c10::optional& alibi_slopes, - const std::string& kv_cache_dtype); \ No newline at end of file + const std::string& kv_cache_dtype); diff --git a/csrc/rocm/torch_bindings.cpp b/csrc/rocm/torch_bindings.cpp index 2352dd9e6b565..082e314587908 100644 --- a/csrc/rocm/torch_bindings.cpp +++ b/csrc/rocm/torch_bindings.cpp @@ -30,4 +30,4 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, rocm_ops) { rocm_ops.impl("paged_attention", torch::kCUDA, &paged_attention); } -REGISTER_EXTENSION(TORCH_EXTENSION_NAME) \ No newline at end of file +REGISTER_EXTENSION(TORCH_EXTENSION_NAME) diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index 9668a85b2e99d..ad80070d3b453 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -42,8 +42,7 @@ def use_rocm_paged_attention(qtype: torch.dtype, head_size: int, block_size: int, kv_cache_dtype: str, gqa_ratio: int, max_seq_len: int) -> bool: # To use rocm custom paged attention kernel or not - rocm_paged_attention_available = ( - is_hip() and envs.VLLM_USE_ROCM_CUSTOM_PAGED_ATTN + rocm_paged_attention_available = (is_hip() and "gfx1" not in torch.cuda.get_device_properties("cuda").gcnArchName) return (rocm_paged_attention_available and (qtype == torch.half or qtype == torch.bfloat16) diff --git a/vllm/envs.py b/vllm/envs.py index 3a29463a5baca..ed45047e9f8fc 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -203,11 +203,6 @@ def get_default_config_root(): (os.environ.get("VLLM_DYNAMO_USE_CUSTOM_DISPATCHER", "True").lower() in ("true", "1")), - # Rocm custom paged attention implemented for MI3* GPUs - "VLLM_USE_ROCM_CUSTOM_PAGED_ATTN": - lambda: (os.getenv("VLLM_USE_ROCM_CUSTOM_PAGED_ATTN", "True").lower() in - ("true", "1") != "0"), - # local rank of the process in the distributed setting, used to determine # the GPU device id "LOCAL_RANK": From f5733766a7b4c87c079fc04e7b849f60ca3aff1e Mon Sep 17 00:00:00 2001 From: charlifu Date: Fri, 13 Sep 2024 14:42:55 +0000 Subject: [PATCH 11/12] move kernel selection for rocm to rocm_flash_attn.py --- tests/kernels/test_attention.py | 2 +- vllm/_custom_ops.py | 3 +- vllm/attention/backends/rocm_flash_attn.py | 84 ++++++++++++++++++---- vllm/attention/ops/paged_attn.py | 81 +++++++-------------- vllm/attention/selector.py | 3 +- 5 files changed, 100 insertions(+), 73 deletions(-) diff --git a/tests/kernels/test_attention.py b/tests/kernels/test_attention.py index ee1a3d1eaa371..384e90e80cbef 100644 --- a/tests/kernels/test_attention.py +++ b/tests/kernels/test_attention.py @@ -388,7 +388,7 @@ def test_paged_attention_rocm( PARTITION_SIZE_ROCM = 256 num_partitions = ((max_context_len + PARTITION_SIZE_ROCM - 1) // PARTITION_SIZE_ROCM) - assert PARTITION_SIZE % block_size == 0 + assert PARTITION_SIZE_ROCM % block_size == 0 num_seqs, num_heads, head_size = output.shape tmp_output = torch.empty( size=(num_seqs, num_heads, num_partitions, head_size), diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 72de8a1150f43..59b22c22e01fb 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -18,8 +18,7 @@ logger.warning("Failed to import from vllm._C with %r", e) if current_platform.is_rocm(): - with contextlib.suppress(ImportError): - import vllm._rocm_C # noqa: F401 + import vllm._rocm_C # noqa: F401 with contextlib.suppress(ImportError): import vllm._moe_C # noqa: F401 diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index b0f4d0530b7f0..f1404b8b6bfe7 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -5,6 +5,7 @@ import torch import vllm.envs as envs +from vllm import _custom_ops as ops from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionMetadata, AttentionType) from vllm.attention.backends.utils import (CommonAttentionState, @@ -15,6 +16,9 @@ logger = init_logger(__name__) +_PARTITION_SIZE = 256 +ON_NAVI = "gfx1" in torch.cuda.get_device_properties("cuda").gcnArchName + class ROCmFlashAttentionBackend(AttentionBackend): @@ -480,20 +484,61 @@ def forward( if decode_meta := attn_metadata.decode_metadata: # Decoding run. - output[num_prefill_tokens:] = PagedAttention.forward_decode( - decode_query, - key_cache, - value_cache, - decode_meta.block_tables, - decode_meta.seq_lens_tensor, - decode_meta.max_decode_seq_len, - self.kv_cache_dtype, - self.num_kv_heads, - self.scale, - self.alibi_slopes, - k_scale, - v_scale, - ) + # Whether to use rocm custom paged attention or not + num_seqs, num_heads, head_size = decode_query.shape + block_size = value_cache.shape[3] + gqa_ratio = num_heads // self.num_kv_heads + use_custom = use_rocm_custom_paged_attention( + decode_query.dtype, head_size, block_size, self.kv_cache_dtype, + gqa_ratio, decode_meta.max_decode_seq_len) + if use_custom: + max_seq_len = decode_meta.max_decode_seq_len + max_num_partitions = ((max_seq_len + _PARTITION_SIZE - 1) // + _PARTITION_SIZE) + assert _PARTITION_SIZE % block_size == 0 + tmp_output = torch.empty( + size=(num_seqs, num_heads, max_num_partitions, head_size), + dtype=output.dtype, + device=output.device, + ) + exp_sums = torch.empty( + size=(num_seqs, num_heads, max_num_partitions), + dtype=torch.float32, + device=output.device, + ) + max_logits = torch.empty_like(exp_sums) + ops.paged_attention_rocm( + output[num_prefill_tokens:], + exp_sums, + max_logits, + tmp_output, + decode_query, + key_cache, + value_cache, + self.num_kv_heads, + self.scale, + decode_meta.block_tables, + decode_meta.seq_lens_tensor, + block_size, + max_seq_len, + self.alibi_slopes, + self.kv_cache_dtype, + ) + else: + output[num_prefill_tokens:] = PagedAttention.forward_decode( + decode_query, + key_cache, + value_cache, + decode_meta.block_tables, + decode_meta.seq_lens_tensor, + decode_meta.max_decode_seq_len, + self.kv_cache_dtype, + self.num_kv_heads, + self.scale, + self.alibi_slopes, + k_scale, + v_scale, + ) # Reshape the output tensor. return output.view(num_tokens, hidden_size) @@ -532,3 +577,14 @@ def _sdpa_attention( start = end return output + + +def use_rocm_custom_paged_attention(qtype: torch.dtype, head_size: int, + block_size: int, kv_cache_dtype: str, + gqa_ratio: int, max_seq_len: int) -> bool: + # rocm custom page attention not support on navi (gfx1*) + return (not ON_NAVI and (qtype == torch.half or qtype == torch.bfloat16) + and (head_size == 64 or head_size == 128) + and (block_size == 16 or block_size == 32) + and kv_cache_dtype == "auto" + and (gqa_ratio >= 1 and gqa_ratio <= 16) and max_seq_len <= 32768) diff --git a/vllm/attention/ops/paged_attn.py b/vllm/attention/ops/paged_attn.py index 93a45d28018d8..85184559e9cce 100644 --- a/vllm/attention/ops/paged_attn.py +++ b/vllm/attention/ops/paged_attn.py @@ -4,15 +4,13 @@ import torch from vllm import _custom_ops as ops -from vllm.attention.selector import use_rocm_paged_attention from vllm.triton_utils import HAS_TRITON if HAS_TRITON: from vllm.attention.ops.prefix_prefill import context_attention_fwd # Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`. -_PARTITION_SIZE_V1V2 = 512 -_PARTITION_SIZE_ROCM = 256 +_PARTITION_SIZE = 512 @dataclass @@ -116,14 +114,7 @@ def forward_decode( output = torch.empty_like(query) block_size = value_cache.shape[3] num_seqs, num_heads, head_size = query.shape - gqa_ratio = num_heads // num_kv_heads - use_rocm = use_rocm_paged_attention(query.dtype, head_size, block_size, - kv_cache_dtype, gqa_ratio, - max_seq_len) - # select the partition size - _PARTITION_SIZE = _PARTITION_SIZE_ROCM if use_rocm else \ - _PARTITION_SIZE_V1V2 max_num_partitions = ((max_seq_len + _PARTITION_SIZE - 1) // _PARTITION_SIZE) @@ -134,7 +125,7 @@ def forward_decode( # to parallelize. # TODO(woosuk): Tune this heuristic. # For context len > 8192, use V2 kernel to avoid shared memory shortage. - use_v1 = ((not use_rocm) and max_seq_len <= 8192 + use_v1 = (max_seq_len <= 8192 and (max_num_partitions == 1 or num_seqs * num_heads > 512)) if use_v1: @@ -174,50 +165,30 @@ def forward_decode( device=output.device, ) max_logits = torch.empty_like(exp_sums) - if not use_rocm: - ops.paged_attention_v2( - output, - exp_sums, - max_logits, - tmp_output, - query, - key_cache, - value_cache, - num_kv_heads, - scale, - block_tables, - seq_lens, - block_size, - max_seq_len, - alibi_slopes, - kv_cache_dtype, - k_scale, - v_scale, - tp_rank, - blocksparse_local_blocks, - blocksparse_vert_stride, - blocksparse_block_size, - blocksparse_head_sliding_step, - ) - else: - # run rocm custom paged attention - ops.paged_attention_rocm( - output, - exp_sums, - max_logits, - tmp_output, - query, - key_cache, - value_cache, - num_kv_heads, - scale, - block_tables, - seq_lens, - block_size, - max_seq_len, - alibi_slopes, - kv_cache_dtype, - ) + ops.paged_attention_v2( + output, + exp_sums, + max_logits, + tmp_output, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + seq_lens, + block_size, + max_seq_len, + alibi_slopes, + kv_cache_dtype, + k_scale, + v_scale, + tp_rank, + blocksparse_local_blocks, + blocksparse_vert_stride, + blocksparse_block_size, + blocksparse_head_sliding_step, + ) return output @staticmethod diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index ad80070d3b453..c1b92ceabc2e5 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -42,7 +42,8 @@ def use_rocm_paged_attention(qtype: torch.dtype, head_size: int, block_size: int, kv_cache_dtype: str, gqa_ratio: int, max_seq_len: int) -> bool: # To use rocm custom paged attention kernel or not - rocm_paged_attention_available = (is_hip() + rocm_paged_attention_available = ( + is_hip() and "gfx1" not in torch.cuda.get_device_properties("cuda").gcnArchName) return (rocm_paged_attention_available and (qtype == torch.half or qtype == torch.bfloat16) From 208c9b33155988390c3e16794b1088ce1760ef78 Mon Sep 17 00:00:00 2001 From: charlifu Date: Fri, 13 Sep 2024 14:45:57 +0000 Subject: [PATCH 12/12] remove redundant codes --- vllm/attention/ops/paged_attn.py | 2 -- vllm/attention/selector.py | 15 --------------- 2 files changed, 17 deletions(-) diff --git a/vllm/attention/ops/paged_attn.py b/vllm/attention/ops/paged_attn.py index 85184559e9cce..92023d5b75f5a 100644 --- a/vllm/attention/ops/paged_attn.py +++ b/vllm/attention/ops/paged_attn.py @@ -114,10 +114,8 @@ def forward_decode( output = torch.empty_like(query) block_size = value_cache.shape[3] num_seqs, num_heads, head_size = query.shape - max_num_partitions = ((max_seq_len + _PARTITION_SIZE - 1) // _PARTITION_SIZE) - # NOTE(woosuk): We use a simple heuristic to decide whether to use # PagedAttention V1 or V2. If the number of partitions is 1, we use # V1 to avoid the overhead of reduction. Also, if the number of diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index c1b92ceabc2e5..855586d4e5961 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -38,21 +38,6 @@ def backend_name_to_enum(backend_name: str) -> _Backend: return _Backend[backend_name] -def use_rocm_paged_attention(qtype: torch.dtype, head_size: int, - block_size: int, kv_cache_dtype: str, - gqa_ratio: int, max_seq_len: int) -> bool: - # To use rocm custom paged attention kernel or not - rocm_paged_attention_available = ( - is_hip() - and "gfx1" not in torch.cuda.get_device_properties("cuda").gcnArchName) - return (rocm_paged_attention_available - and (qtype == torch.half or qtype == torch.bfloat16) - and (head_size == 64 or head_size == 128) - and (block_size == 16 or block_size == 32) - and kv_cache_dtype == "auto" - and (gqa_ratio >= 1 and gqa_ratio <= 16) and max_seq_len <= 32768) - - def get_env_variable_attn_backend() -> Optional[_Backend]: ''' Get the backend override specified by the vLLM attention