From bd3b95892dba071bec274d5d259d12507055c896 Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Sat, 13 Jan 2024 11:02:18 -0500 Subject: [PATCH] [Unity] PagedKVCache supporting on-the-fly RoPE calculation This PR enhances PagedKVCache with the inline RoPE compute, which unblocks the movement towards sliding window and attention sink. Both FlashInfer and TIR kernels are updated in this PR with the RoPE calculation. Note that FlashInfer is bumped in order to include the RoPE update. The previous standalone kernel used for RoPE application are thereby removed. --- Co-authored-by: Bohan Hou Co-authored-by: Hongyi Jin --- 3rdparty/flashinfer | 2 +- src/runtime/relax_vm/paged_kv_cache.cc | 131 ++-- ...in_paged_attention_kv_cache_flashinfer.py} | 22 +- ...me_builtin_paged_attention_kv_cache_tir.py | 688 ++++++++++++++---- 4 files changed, 631 insertions(+), 212 deletions(-) rename tests/python/relax/{test_runtime_builtin_paged_attention_kv_cache.py => test_runtime_builtin_paged_attention_kv_cache_flashinfer.py} (95%) diff --git a/3rdparty/flashinfer b/3rdparty/flashinfer index 7d3a47310af1a..9cd1f42e968a8 160000 --- a/3rdparty/flashinfer +++ b/3rdparty/flashinfer @@ -1 +1 @@ -Subproject commit 7d3a47310af1ac0795e0d8e8435e42c882c96a13 +Subproject commit 9cd1f42e968a8de7d3af2c7567072e0ad6c8ffed diff --git a/src/runtime/relax_vm/paged_kv_cache.cc b/src/runtime/relax_vm/paged_kv_cache.cc index 88234678730bd..20e68a9d33001 100644 --- a/src/runtime/relax_vm/paged_kv_cache.cc +++ b/src/runtime/relax_vm/paged_kv_cache.cc @@ -70,6 +70,8 @@ struct Block { std::vector page_ids; /*! \brief The total sequence length in the block. */ int32_t seq_length = 0; + /*! \brief The start position in sequence of this block. */ + int32_t start_pos = 0; /*! \brief The global index of the block. */ const int32_t index; @@ -236,14 +238,18 @@ class PagedAttentionKVCacheObj : public AttentionKVCache { std::vector page_indices_on_depths_device_; /*! \brief The number of KV slots used in the last page of sequences. */ std::vector last_page_len_on_depths_device_; + /*! \brief The k position offset of applying RoPE for each sequence. */ + std::vector k_rope_pos_offset_device_; /*! * \brief The append length indptr array on device. * \note Since the Q/K/V data may have raggedness in terms of lengths, * we represent the the append lengths in CSR format. */ NDArray cur_append_length_indptr_device_; - /*! \brief The position offset of applying RoPE for each sequence. */ - NDArray cur_rope_offset_device_; + /*! \brief The k position offset of applying RoPE for each sequence. */ + NDArray k_ragged_rope_pos_offset_device_; + /*! \brief The q position mapping of applying RoPE for each sequence. */ + NDArray q_rope_position_map_device_; /*! * \brief The corresponding position in global KV cache (pages) * for each position along the length dimension of K/V data when @@ -264,7 +270,8 @@ class PagedAttentionKVCacheObj : public AttentionKVCache { // attention/append. //------------------------------------------- NDArray cur_append_length_indptr_view_; - NDArray cur_rope_offset_view_; + NDArray k_ragged_rope_pos_offset_view_; + NDArray q_rope_position_map_view_; NDArray append_position_map_view_; NDArray temp_attn_output_view_; NDArray temp_attn_scores_view_; @@ -273,6 +280,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCache { std::vector page_indptr_on_depths_view_; std::vector page_indices_on_depths_view_; std::vector last_page_len_on_depths_view_; + std::vector k_rope_pos_offset_view_; PackedFunc f_transpose_append_; PackedFunc f_attention_prefill_; @@ -284,7 +292,6 @@ class PagedAttentionKVCacheObj : public AttentionKVCache { Optional f_attention_prefill_end_forward_; Optional f_attention_decode_begin_forward_; Optional f_attention_decode_end_forward_; - PackedFunc f_rotary_; Optional f_merge_inplace_; Optional f_debug_get_kv_; @@ -312,7 +319,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCache { Optional f_attention_prefill_end_forward, Optional f_attention_decode_begin_forward, Optional f_attention_decode_end_forward, - PackedFunc f_rotary, Optional f_merge_inplace, + Optional f_merge_inplace, Optional f_debug_get_kv) : page_size_(page_size), num_layers_(num_layers), @@ -333,7 +340,6 @@ class PagedAttentionKVCacheObj : public AttentionKVCache { f_attention_prefill_end_forward_(std::move(f_attention_prefill_end_forward)), f_attention_decode_begin_forward_(std::move(f_attention_decode_begin_forward)), f_attention_decode_end_forward_(std::move(f_attention_decode_end_forward)), - f_rotary_(std::move(f_rotary)), f_merge_inplace_(std::move(f_merge_inplace)), f_debug_get_kv_(std::move(f_debug_get_kv)) { pages_.reserve(num_layers); @@ -350,13 +356,16 @@ class PagedAttentionKVCacheObj : public AttentionKVCache { NDArray::Empty({num_total_pages}, dtype_aux_, device)); last_page_len_on_depths_device_.push_back( NDArray::Empty({reserved_num_seqs}, dtype_aux_, device)); + k_rope_pos_offset_device_.push_back(NDArray::Empty({reserved_num_seqs}, dtype_aux_, device)); qo_indptr_on_depths_view_.push_back(NDArray()); page_indptr_on_depths_view_.push_back(NDArray()); page_indices_on_depths_view_.push_back(NDArray()); last_page_len_on_depths_view_.push_back(NDArray()); + k_rope_pos_offset_view_.push_back(NDArray()); } cur_append_length_indptr_device_ = NDArray::Empty({reserved_num_seqs + 1}, dtype_aux_, device); - cur_rope_offset_device_ = NDArray::Empty({reserved_num_seqs}, dtype_aux_, device); + k_ragged_rope_pos_offset_device_ = NDArray::Empty({reserved_num_seqs}, dtype_aux_, device); + q_rope_position_map_device_ = NDArray::Empty({num_total_pages * page_size}, dtype_aux_, device); append_position_map_device_ = NDArray::Empty({num_total_pages * page_size}, dtype_aux_, device); temp_attn_output_device_ = @@ -428,6 +437,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCache { int32_t parent_block_idx = parent_it->second.last_block_idx; // Create a child block with the parent block pointer. int32_t child_block_idx = GetFreeBlock(); + global_block_pool_[child_block_idx].start_pos = parent_it->second.seq_length; global_block_pool_[child_block_idx].parent_idx = parent_block_idx; // Create the child sequence with the child block. seq_map_.insert({child_seq_id, Sequence(global_block_pool_, child_block_idx)}); @@ -471,16 +481,16 @@ class PagedAttentionKVCacheObj : public AttentionKVCache { // - Collect sequence/block/page information for attention. std::vector sequences; - std::vector rope_offset; + std::vector k_ragged_rope_pos_offset; is_decode_request_ = true; sequences.reserve(cur_batch_size_); - rope_offset.reserve(cur_batch_size_); + k_ragged_rope_pos_offset.reserve(cur_batch_size_); for (int i = 0; i < cur_batch_size_; ++i) { auto it = seq_map_.find(seq_ids[i]); CHECK(it != seq_map_.end()) << "The sequence \"" << seq_ids[i] << "\" cannot be found in KV cache."; sequences.push_back(&it->second); - rope_offset.push_back(it->second.seq_length); + k_ragged_rope_pos_offset.push_back(it->second.seq_length); it->second.seq_length += append_lengths[i]; if (append_lengths[i] != 1) { is_decode_request_ = false; @@ -504,6 +514,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCache { std::vector> page_indptr_on_depths; std::vector> page_indices_on_depths; std::vector> last_page_len_on_depths; + std::vector> k_rope_pos_offset_on_depths; use_decode_kernel_.clear(); for (int d = 0; d < num_depths_; ++d) { auto [chunked_block_ids, use_decode_kernel] = GetChunkedBlockIds(block_ids_on_depths[d]); @@ -513,23 +524,27 @@ class PagedAttentionKVCacheObj : public AttentionKVCache { std::vector page_indptr_h{0}; std::vector page_indices_h; std::vector last_page_len_h; + std::vector k_rope_pos_offset_h; for (const auto& [block_id, chunk_append_length] : chunked_block_ids) { qo_indptr_h.push_back(qo_indptr_h.back() + chunk_append_length); if (block_id == -1) { page_indptr_h.push_back(page_indptr_h.back()); last_page_len_h.push_back(0); + k_rope_pos_offset_h.push_back(0); } else { const Block& block = global_block_pool_[block_id]; page_indptr_h.push_back(page_indptr_h.back() + block.page_ids.size()); page_indices_h.insert(page_indices_h.end(), block.page_ids.begin(), block.page_ids.end()); last_page_len_h.push_back( block.seq_length == 0 ? 0 : (block.seq_length - 1) % page_size_ + 1); + k_rope_pos_offset_h.push_back(block.start_pos); } } qo_indptr_on_depths.push_back(qo_indptr_h); page_indptr_on_depths.push_back(page_indptr_h); page_indices_on_depths.push_back(page_indices_h); last_page_len_on_depths.push_back(last_page_len_h); + k_rope_pos_offset_on_depths.push_back(k_rope_pos_offset_h); } if (num_depths_ > 1) { @@ -543,21 +558,25 @@ class PagedAttentionKVCacheObj : public AttentionKVCache { // Map each the token position in the input batch to the position // in the global KV cache. The mapping is used in when appending k/v values. + std::vector q_rope_position_map; std::vector append_position_map; for (int i = 0; i < cur_batch_size_; ++i) { int64_t append_length = append_lengths[i]; const Block& block = global_block_pool_[sequences[i]->last_block_idx]; for (int64_t pos = 0; pos < append_length; ++pos) { - int64_t pos_in_seq = block.seq_length - append_length + pos; - append_position_map.push_back(block.page_ids[pos_in_seq / page_size_] * page_size_ + - pos_in_seq % page_size_); + int64_t pos_in_block = block.seq_length - append_length + pos; + q_rope_position_map.push_back(sequences[i]->seq_length - append_length + pos); + append_position_map.push_back(block.page_ids[pos_in_block / page_size_] * page_size_ + + pos_in_block % page_size_); } } // - Sync NDArrays to GPU. SyncAuxArrayToDevice(std::move(qo_indptr_on_depths), std::move(page_indptr_on_depths), std::move(page_indices_on_depths), std::move(last_page_len_on_depths), - std::move(rope_offset), std::move(append_position_map)); + std::move(k_rope_pos_offset_on_depths), + std::move(k_ragged_rope_pos_offset), std::move(q_rope_position_map), + std::move(append_position_map)); // NOTE(Zihao): This logic is problematic ATM because we need a unique split per depth KernelBeginForward(); @@ -643,14 +662,9 @@ class PagedAttentionKVCacheObj : public AttentionKVCache { << "The auxiliary arrays are not synchronized to device. Please call " "`BeginForward` to synchronize before calling `Attention`."; - // Part 2: apply rotary embedding to q/k data. - f_rotary_(q_data, k_data, cur_append_length_indptr_view_, cur_rope_offset_view_, - cur_batch_size_, num_qo_heads_, num_kv_heads_, head_dim_, /*qkv_layout=*/0, - rotary_scale_, rotary_theta_); - - // Part 3: append k/v data to kv-cache + // Part 2: append k/v data to kv-cache f_transpose_append_(pages_[layer_id], k_data, v_data, append_position_map_view_); - // Part 4: perform attention + // Part 3: perform attention AttentionInternal(layer_id, q_data, k_data, v_data, o_data); } @@ -865,7 +879,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCache { if (use_decode_kernel_[0]) { f_attention_decode_begin_forward_.value()( /*depth=*/0, page_indptr_on_depths_view_[0], last_page_len_on_depths_view_[0], - num_qo_heads_, num_kv_heads_, head_dim_, page_size_, /*rotary_mode=*/true); + num_qo_heads_, num_kv_heads_, head_dim_, page_size_, /*rotary_mode=*/1); } else { f_attention_prefill_begin_forward_.value()(/*depth=*/0, qo_indptr_on_depths_view_[0], cur_batch_size_, num_qo_heads_, num_kv_heads_); @@ -880,7 +894,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCache { if (use_decode_kernel_[d]) { f_attention_decode_begin_forward_.value()( d, page_indptr_on_depths_view_[d], last_page_len_on_depths_view_[d], num_qo_heads_, - num_kv_heads_, head_dim_, page_size_, /*rotary_mode=*/false); + num_kv_heads_, head_dim_, page_size_, /*rotary_mode=*/1); } else { f_attention_prefill_begin_forward_.value()(/*depth=*/d, qo_indptr_on_depths_view_[d], last_page_len_on_depths_view_[d]->shape[0], @@ -901,22 +915,25 @@ class PagedAttentionKVCacheObj : public AttentionKVCache { if (use_decode_kernel_[0]) { f_attention_decode_(/*depth=*/0, q_data, pages_[layer_id], page_indptr_on_depths_view_[0], page_indices_on_depths_view_[0], last_page_len_on_depths_view_[0], - output, merged_attn_scores_view_, - /*rotary_mode=*/0, rotary_scale_, rotary_theta_); + k_rope_pos_offset_view_[0], q_rope_position_map_view_, output, + merged_attn_scores_view_, + /*rotary_mode=*/1, rotary_scale_, rotary_theta_); } else { f_attention_prefill_(/*depth=*/0, q_data, qo_indptr_on_depths_view_[0], pages_[layer_id], page_indptr_on_depths_view_[0], page_indices_on_depths_view_[0], - last_page_len_on_depths_view_[0], output, merged_attn_scores_view_, + last_page_len_on_depths_view_[0], k_rope_pos_offset_view_[0], + q_rope_position_map_view_, output, merged_attn_scores_view_, /*causal=*/1, - /*rotary_mode=*/0, rotary_scale_, rotary_theta_); + /*rotary_mode=*/1, rotary_scale_, rotary_theta_); } } else { // Compute appended text self-attention f_attention_prefill_ragged_.value()(q_data, cur_append_length_indptr_view_, k_data, v_data, - cur_append_length_indptr_view_, output, + cur_append_length_indptr_view_, q_rope_position_map_view_, + k_ragged_rope_pos_offset_view_, output, merged_attn_scores_view_, /*causal=*/1, - /*rotary_mode=*/0, rotary_scale_, rotary_theta_); + /*rotary_mode=*/1, rotary_scale_, rotary_theta_); for (int d = 0; d < num_depths_; ++d) { if (page_indices_on_depths_view_[d]->shape[0] == 0) { @@ -926,16 +943,18 @@ class PagedAttentionKVCacheObj : public AttentionKVCache { // Use decode kernel for depth d f_attention_decode_(/*depth=*/d, q_data, pages_[layer_id], page_indptr_on_depths_view_[d], page_indices_on_depths_view_[d], last_page_len_on_depths_view_[d], + k_rope_pos_offset_view_[d], q_rope_position_map_view_, temp_attn_output_view_, temp_attn_scores_view_, - /*rotary_mode=*/0, rotary_scale_, rotary_theta_); + /*rotary_mode=*/1, rotary_scale_, rotary_theta_); } else { // Use prefill kernel for depth d f_attention_prefill_(/*depth=*/d, q_data, qo_indptr_on_depths_view_[d], pages_[layer_id], page_indptr_on_depths_view_[d], page_indices_on_depths_view_[d], - last_page_len_on_depths_view_[d], temp_attn_output_view_, + last_page_len_on_depths_view_[d], k_rope_pos_offset_view_[d], + q_rope_position_map_view_, temp_attn_output_view_, temp_attn_scores_view_, /*causal=*/0, - /*rotary_mode=*/0, rotary_scale_, rotary_theta_); + /*rotary_mode=*/1, rotary_scale_, rotary_theta_); } f_merge_inplace_.value()(output, merged_attn_scores_view_, temp_attn_output_view_, temp_attn_scores_view_); @@ -952,7 +971,9 @@ class PagedAttentionKVCacheObj : public AttentionKVCache { std::vector> page_indptr_on_depths, std::vector> page_indices_on_depths, std::vector> last_page_len_on_depths, - std::vector rope_offset, + std::vector> k_rope_pos_offset_on_depths, + std::vector k_ragged_rope_pos_offset, + std::vector q_rope_position_map, std::vector append_position_map) { ICHECK(dtype_aux_.bits == 32 && dtype_aux_.code == kDLInt); ICHECK_EQ(qo_indptr_on_depths.size(), num_depths_); @@ -1015,22 +1036,37 @@ class PagedAttentionKVCacheObj : public AttentionKVCache { fcopy_from_vec(last_page_len_on_depths_view_[d], last_page_len_on_depths[d].data()); } - // 5. cur_append_lengths_indptr + // 5. k_rope_pos_offset + for (int d = 0; d < num_depths_; ++d) { + ICHECK_EQ(k_rope_pos_offset_on_depths[d].size() + 1, qo_indptr_on_depths[d].size()); + k_rope_pos_offset_view_[d] = k_rope_pos_offset_device_[d].CreateView( + {static_cast(k_rope_pos_offset_on_depths[d].size())}, dtype_aux_); + fcopy_from_vec(k_rope_pos_offset_view_[d], k_rope_pos_offset_on_depths[d].data()); + } + + // 6. cur_append_lengths_indptr cur_append_length_indptr_view_ = cur_append_length_indptr_device_.CreateView({num_sequences + 1}, dtype_aux_); fcopy_from_vec(cur_append_length_indptr_view_, cur_append_lengths_indptr.data()); - // 6. cur_rope_offset - ICHECK_EQ(rope_offset.size(), num_sequences); - cur_rope_offset_view_ = cur_rope_offset_device_.CreateView({num_sequences}, dtype_aux_); - fcopy_from_vec(cur_rope_offset_view_, rope_offset.data()); + // 7. k_ragged_rope_pos_offset + ICHECK_EQ(k_ragged_rope_pos_offset.size(), num_sequences); + k_ragged_rope_pos_offset_view_ = + k_ragged_rope_pos_offset_device_.CreateView({num_sequences}, dtype_aux_); + fcopy_from_vec(k_ragged_rope_pos_offset_view_, k_ragged_rope_pos_offset.data()); + + // 8. q_rope_position_map + ICHECK_EQ(q_rope_position_map.size(), total_append_length); + q_rope_position_map_view_ = + q_rope_position_map_device_.CreateView({total_append_length}, dtype_aux_); + fcopy_from_vec(q_rope_position_map_view_, q_rope_position_map.data()); - // 7. append_position_map + // 9. append_position_map append_position_map_view_ = append_position_map_device_.CreateView({total_append_length}, dtype_aux_); fcopy_from_vec(append_position_map_view_, append_position_map.data()); - // 8. Create view for temporary arrays for attention computation. + // 10. Create view for temporary arrays for attention computation. temp_attn_output_view_ = temp_attn_output_device_.CreateView( {total_append_length, num_qo_heads_, head_dim_}, temp_attn_output_device_->dtype); temp_attn_scores_view_ = temp_attn_scores_device_.CreateView( @@ -1065,8 +1101,8 @@ TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create") PackedFunc f_attention_prefill_begin_forward, PackedFunc f_attention_prefill_end_forward, PackedFunc f_attention_decode_begin_forward, - PackedFunc f_attention_decode_end_forward, PackedFunc f_rotary, - PackedFunc f_merge_inplace, Optional f_debug_get_kv) { + PackedFunc f_attention_decode_end_forward, PackedFunc f_merge_inplace, + Optional f_debug_get_kv) { CHECK_EQ(cache_config.size(), 3); int64_t reserved_num_seqs = cache_config[0]; int64_t total_token_capacity = cache_config[1]; @@ -1081,7 +1117,7 @@ TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create") std::move(f_attention_prefill_ragged_end_forward), std::move(f_attention_prefill_begin_forward), std::move(f_attention_prefill_end_forward), std::move(f_attention_decode_begin_forward), std::move(f_attention_decode_end_forward), - std::move(f_rotary), std::move(f_merge_inplace), std::move(f_debug_get_kv)); + std::move(f_merge_inplace), std::move(f_debug_get_kv)); return PagedAttentionKVCache(std::move(n)); }); @@ -1090,7 +1126,8 @@ TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create_reduced") int64_t num_kv_heads, int64_t head_dim, double rotary_scale, double rotary_theta, NDArray init, PackedFunc f_transpose_append, PackedFunc f_attention_prefill, PackedFunc f_attention_decode, - PackedFunc f_rotary, Optional f_debug_get_kv) { + PackedFunc f_attention_prefill_ragged, PackedFunc f_merge_inplace, + Optional f_debug_get_kv) { CHECK_EQ(cache_config.size(), 3); int64_t reserved_num_seqs = cache_config[0]; int64_t total_token_capacity = cache_config[1]; @@ -1100,9 +1137,9 @@ TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create_reduced") page_size, num_layers, num_qo_heads, num_kv_heads, head_dim, reserved_num_seqs, num_total_pages, rotary_scale, rotary_theta, init->dtype, init->device, std::move(f_transpose_append), std::move(f_attention_prefill), - std::move(f_attention_decode), // - NullOpt, NullOpt, NullOpt, NullOpt, NullOpt, NullOpt, NullOpt, // - std::move(f_rotary), NullOpt, std::move(f_debug_get_kv)); + std::move(f_attention_decode), std::move(f_attention_prefill_ragged), // + NullOpt, NullOpt, NullOpt, NullOpt, NullOpt, NullOpt, // + std::move(f_merge_inplace), std::move(f_debug_get_kv)); return PagedAttentionKVCache(std::move(n)); }); diff --git a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache.py b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_flashinfer.py similarity index 95% rename from tests/python/relax/test_runtime_builtin_paged_attention_kv_cache.py rename to tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_flashinfer.py index 94507d3931e4d..69b7a15793b53 100644 --- a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache.py +++ b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_flashinfer.py @@ -59,7 +59,6 @@ fattention_prefill_ragged_begin_forward = None fattention_prefill_ragged_end_forward = None fattention_merge_state = None -fattention_rotary = None @T.prim_func @@ -145,7 +144,7 @@ def set_global_func(): global fattention_prefill_ragged global fattention_prefill_ragged_begin_forward global fattention_prefill_ragged_end_forward - global fattention_merge_state, fattention_rotary + global fattention_merge_state fclear = tvm.get_global_func("vm.builtin.paged_attention_kv_cache_clear") fcreate = tvm.get_global_func("vm.builtin.paged_attention_kv_cache_create") @@ -182,7 +181,6 @@ def set_global_func(): "flashinfer.attention_kernel_prefill_with_ragged_kv_cache_end_forward" ) fattention_merge_state = tvm.get_global_func("flashinfer.merge_state_in_place") - fattention_rotary = tvm.get_global_func("flashinfer.batch_qk_apply_rotary_in_place") def create_kv_cache(): @@ -216,7 +214,6 @@ def create_kv_cache(): fattention_prefill_end_forward, fattention_decode_begin_forward, fattention_decode_end_forward, - fattention_rotary, fattention_merge_state, fcopy_cache, ) @@ -303,13 +300,7 @@ def apply_attention( cached_k[seq_id] = np.concatenate( [ cached_k[seq_id], - np.stack( - [ - f_apply_rotary(new_k[l], cached_k[seq_id].shape[1], rope_scale, rope_theta) - for l in range(num_layers) - ], - axis=0, - ), + np.stack([new_k[l] for l in range(num_layers)], axis=0), ], axis=1, ) @@ -347,12 +338,9 @@ def apply_attention( rope_scale, rope_theta, ).transpose(1, 0, 2) - # Todo(Zihao, Ruihang): fold RoPE into flashinfer attn kernel in multi-level cases. - # so that k/v values in cache does not have RoPE applied. - # k_seq = f_apply_rotary(cached_k[seq_id][layer_id], 0, rope_scale, rope_theta).transpose( - # 1, 2, 0 - # ) - k_seq = cached_k[seq_id][layer_id].transpose(1, 2, 0) + k_seq = f_apply_rotary(cached_k[seq_id][layer_id], 0, rope_scale, rope_theta).transpose( + 1, 2, 0 + ) v_seq = cached_v[seq_id][layer_id].transpose(1, 0, 2) k_seq = np.repeat(k_seq, num_qo_heads // num_kv_heads, axis=0) diff --git a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py index 280ad7e0ea319..721ea42dc179d 100644 --- a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py +++ b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py @@ -58,11 +58,10 @@ def kv_cache_transpose_append( var_v_data: T.handle, var_position_map: T.handle, ): - ntoken = T.SizeVar("ntoken", "int64") - page_size = T.SizeVar("page_size", "int64") - num_pages = T.int64() + ntoken = T.SizeVar("ntoken", "int32") + num_pages = T.int32() - pages = T.match_buffer(var_pages, (num_pages, 2, num_kv_heads, page_size, head_dim), dtype) + pages = T.match_buffer(var_pages, (num_pages, 2, num_kv_heads, 16, head_dim), dtype) k_data = T.match_buffer(var_k_data, (ntoken, num_kv_heads, head_dim), dtype) v_data = T.match_buffer(var_v_data, (ntoken, num_kv_heads, head_dim), dtype) position_map = T.match_buffer(var_position_map, (ntoken,), "int32") @@ -72,21 +71,21 @@ def kv_cache_transpose_append( vgpos, vh, vf = T.axis.remap("SSS", [global_pos, h, f]) T.reads(position_map[vgpos], k_data[vgpos, vh, vf]) T.writes( - pages[position_map[vgpos] // page_size, 0, vh, position_map[vgpos] % page_size, vf] + pages[position_map[vgpos] // 16, 0, vh, position_map[vgpos] % 16, vf] ) position: T.int64 = T.Cast("int64", position_map[vgpos]) pages[ - T.floordiv(position, page_size), 0, vh, T.floormod(position, page_size), vf + T.floordiv(position, 16), 0, vh, T.floormod(position, 16), vf ] = k_data[vgpos, vh, vf] with T.block("v_transpose_append"): vgpos, vh, vf = T.axis.remap("SSS", [global_pos, h, f]) T.reads(position_map[vgpos], k_data[vgpos, vh, vf]) T.writes( - pages[position_map[vgpos] // page_size, 1, vh, position_map[vgpos] % page_size, vf] + pages[position_map[vgpos] // 16, 1, vh, position_map[vgpos] % 16, vf] ) position: T.int64 = T.Cast("int64", position_map[vgpos]) pages[ - T.floordiv(position, page_size), 1, vh, T.floormod(position, page_size), vf + T.floordiv(position, 16), 1, vh, T.floormod(position, 16), vf ] = v_data[vgpos, vh, vf] @@ -150,7 +149,8 @@ def create_kv_cache(): copy_cache, _attention_prefill(num_kv_heads, num_qo_heads, head_dim, dtype), _attention_decode(num_kv_heads, num_qo_heads, head_dim, dtype), - _inplace_rope(rope_theta, rope_scale, head_dim, num_qo_heads, num_kv_heads, dtype), + _attention_prefill_ragged(num_kv_heads, num_qo_heads, head_dim, dtype), + _merge_state_inplace(num_qo_heads, head_dim, dtype), ]: mod = tvm.IRModule({"main": tir_func}) with target: @@ -158,7 +158,14 @@ def create_kv_cache(): f = tvm.build(mod["main"], target=target) builts.append(f.entry_func) - ftranspose_append, fcopy_cache, fattn_prefill, fattn_decode, fbatch_rotary = builts + ( + ftranspose_append, + fcopy_cache, + fattn_prefill, + fattn_decode, + fattn_prefill_ragged, + fmerge_state, + ) = builts fcreate = tvm.get_global_func("vm.builtin.paged_attention_kv_cache_create_reduced") cache = fcreate( tvm.runtime.ShapeTuple([reserved_nseq, maximum_total_seq_length, page_size]), @@ -172,7 +179,8 @@ def create_kv_cache(): ftranspose_append, fattn_prefill, fattn_decode, - fbatch_rotary, + fattn_prefill_ragged, + fmerge_state, fcopy_cache, ) return cache @@ -258,13 +266,7 @@ def apply_attention( cached_k[seq_id] = np.concatenate( [ cached_k[seq_id], - np.stack( - [ - f_apply_rotary(new_k[l], cached_k[seq_id].shape[1], rope_scale, rope_theta) - for l in range(num_layers) - ], - axis=0, - ), + np.stack([new_k[l] for l in range(num_layers)], axis=0), ], axis=1, ) @@ -302,12 +304,9 @@ def apply_attention( rope_scale, rope_theta, ).transpose(1, 0, 2) - # Todo(Zihao, Ruihang): fold RoPE into flashinfer attn kernel in multi-level cases. - # so that k/v values in cache does not have RoPE applied. - # k_seq = f_apply_rotary(cached_k[seq_id][layer_id], 0, rope_scale, rope_theta).transpose( - # 1, 2, 0 - # ) - k_seq = cached_k[seq_id][layer_id].transpose(1, 2, 0) + k_seq = f_apply_rotary(cached_k[seq_id][layer_id], 0, rope_scale, rope_theta).transpose( + 1, 2, 0 + ) v_seq = cached_v[seq_id][layer_id].transpose(1, 0, 2) k_seq = np.repeat(k_seq, num_qo_heads // num_kv_heads, axis=0) @@ -385,6 +384,33 @@ def test_paged_attention_kv_cache_remove_sequence(kv_cache): ) +@tvm.testing.requires_gpu +@tvm.testing.requires_cuda +def test_paged_attention_kv_cache_fork_sequence(kv_cache): + fclear(kv_cache) + + cached_k = {} + cached_v = {} + batch = [(0, 60), (1, 88), (2, 17), (3, 4)] + apply_attention(kv_cache, batch, cached_k, cached_v) + # Fork existing sequences. + apply_attention(kv_cache, [((4, 3), 35)], cached_k, cached_v) + apply_attention(kv_cache, [((5, 0), 20)], cached_k, cached_v) + apply_attention(kv_cache, [((6, 5), 102)], cached_k, cached_v) + apply_attention(kv_cache, [((7, 0), 3)], cached_k, cached_v) + apply_attention(kv_cache, [((8, 5), 71)], cached_k, cached_v) + apply_attention(kv_cache, [((9, 5), 20)], cached_k, cached_v) + # Mixture of decode and prefill. + operation_seq = [ + [(2, 1), (4, 1), (7, 1), (6, 1), (8, 1), (9, 1)], + [(7, 1), (6, 1), (8, 1), (9, 1)], + [(7, 1), (1, 1), (6, 1), (2, 1), (8, 1), (4, 1), (9, 1)], + [(7, 10), (6, 2), (8, 3), (9, 4)], + ] + for batch in operation_seq: + apply_attention(kv_cache, batch, cached_k, cached_v) + + @tvm.testing.requires_gpu @tvm.testing.requires_cuda def test_paged_attention_kv_cache_popn(kv_cache): @@ -404,76 +430,6 @@ def test_paged_attention_kv_cache_popn(kv_cache): verify_cached_kv(kv_cache, seq_ids=list(range(4)), expected_k=cached_k, expected_v=cached_v) -def _inplace_rope( - theta: float, - scale: float, - head_dim: int, - num_q_heads: int, - num_kv_heads: int, - dtype: str, -): - assert head_dim <= 128, "Rotary embedding currently only supports head_dim <= 128" - rotary_dim = head_dim - - def _rope( - x: T.Buffer, - s: tir.Var, - h: tir.Var, - d: tir.Var, - rope_offset: tir.Var, - instance_offset: tir.Var, - ): - cos_freq, sin_freq = rope_freq((s + rope_offset) * scale, d, rotary_dim, theta, dtype) - cos = cos_freq * x[s + instance_offset, h, d] - sin = sin_freq * tir.if_then_else( - d < rotary_dim // 2, - -x[s + instance_offset, h, d + rotary_dim // 2], - x[s + instance_offset, h, d - rotary_dim // 2], - ) - return cos + sin - - # fmt: off - @T.prim_func - def tir_rotary( - var_q: T.handle, - var_k: T.handle, - var_append_len_indptr: T.handle, - var_rope_offsets: T.handle, - _0: T.int32, - _1: T.int32, - _2: T.int32, - _3: T.int32, - _4: T.int32, - _5: T.float32, - _6: T.float32, - ): - T.func_attr({"tir.is_scheduled": 1}) - total_len = T.int32() - batch_size = T.int32() - q = T.match_buffer(var_q, (total_len, num_q_heads, head_dim), dtype) - k = T.match_buffer(var_k, (total_len, num_kv_heads, head_dim), dtype) - rope_offsets = T.match_buffer(var_rope_offsets, (batch_size,), "int32") - append_len_indptr = T.match_buffer(var_append_len_indptr, (batch_size + 1,), "int32") - for b_h in T.thread_binding(batch_size * (num_q_heads + num_kv_heads), thread="blockIdx.x"): - b: T.int32 = b_h // (num_q_heads + num_kv_heads) - h: T.int32 = b_h % (num_q_heads + num_kv_heads) - instance_offset: T.int32 = append_len_indptr[b] - rope_offset: T.int32 = rope_offsets[b] - append_len: T.int32 = append_len_indptr[b + 1] - append_len_indptr[b] - for s0 in range(T.ceildiv(append_len, 32)): - for s1 in T.thread_binding(32, thread="threadIdx.y"): - for d0 in T.thread_binding(T.ceildiv(head_dim, 4), thread="threadIdx.x"): - for d1 in T.vectorized(4): - s: T.int32 = s0 * 32 + s1 - d: T.int32 = d0 * 4 + d1 - if s < append_len and d < head_dim: - if h < num_q_heads: - q[s + instance_offset, h, d] = _rope(q, s, h, d, rope_offset, instance_offset) - else: - k[s + instance_offset, h - num_q_heads, d] = _rope(k, s, h - num_q_heads, d, rope_offset, instance_offset) - return tir_rotary - - def rope_freq(s: tir.Var, d: tir.Var, d_range: int, theta: float, dtype: str): """Compute the inverse frequency of RoPE and then return the cosine and sine of it. @@ -562,25 +518,28 @@ def batch_prefill_paged_kv( var_page_indptr: T.handle, # [batch_size + 1] var_page_values: T.handle, # [nnz_pages] var_last_page_len: T.handle, # [b] + var_k_rope_pos_offset: T.handle, # [b] + var_q_rope_position: T.handle, # [total_q_len] var_output: T.handle, # [total_len, h_q, d] var_lse: T.handle, # [total_len, h_q] causal: T.int32, - _1: T.int32, - _2: T.float32, - _3: T.float32, + rotary_mode: T.int32, + rope_scale: T.float32, + rope_theta: T.float32, ): batch_size = T.int32(is_size_var=True) total_len = T.int32(is_size_var=True) nnz_pages = T.int32(is_size_var=True) max_num_pages = T.int32(is_size_var=True) - page_size = T.int32(is_size_var=True) q = T.match_buffer(var_q, (total_len, h_q, d), dtype) q_indptr = T.match_buffer(var_q_indptr, (batch_size + 1,), "int32") - pages = T.match_buffer(var_pages, (max_num_pages, 2, h_kv, page_size, d), dtype) + pages = T.match_buffer(var_pages, (max_num_pages, 2, h_kv, 16, d), dtype) page_indptr = T.match_buffer(var_page_indptr, (batch_size + 1,), "int32") page_values = T.match_buffer(var_page_values, (nnz_pages,), "int32") last_page_len = T.match_buffer(var_last_page_len, (batch_size,), "int32") + k_rope_pos_offset = T.match_buffer(var_k_rope_pos_offset, (batch_size,), "int32") + q_rope_position = T.match_buffer(var_q_rope_position, (total_len,), "int32") output = T.match_buffer(var_output, (total_len, h_q, d), dtype) lse = T.match_buffer(var_lse, (total_len, h_q), "float32") # pylint: disable=unused-variable @@ -599,9 +558,6 @@ def batch_prefill_paged_kv( batch_rows = _var("int32") iterator = _var("int32") kv_chunk_len = _var("int32") - m_new = _var("float32") - m_prev = _var("float32") - d_new = _var("float32") Q_smem = T.alloc_buffer((tile_x, d), dtype, scope="shared") K_smem = T.alloc_buffer((tile_z, d), dtype, scope="shared") @@ -615,6 +571,10 @@ def batch_prefill_paged_kv( m_prev_smem = T.alloc_buffer((tile_x, ), "float32", scope="shared") d_smem = T.alloc_buffer((tile_x, ), "float32", scope="shared") + m_new = T.alloc_buffer((math.ceil(tile_x / (32 * num_warps)),), "float32", scope="local") + m_prev = T.alloc_buffer((math.ceil(tile_x / (32 * num_warps)),), "float32", scope="local") + d_new = T.alloc_buffer((math.ceil(tile_x / (32 * num_warps)),), "float32", scope="local") + ## get tile_no, batch_idx, batch_tiles, batch_rows tile_id[0] = bx batch_idx[0] = 0 @@ -640,7 +600,7 @@ def batch_prefill_paged_kv( cur_last_page_len: T.int32 = last_page_len[b_idx] kv_chunk_len[0] = T.if_then_else( cur_page_indptr_begin != cur_page_indptr_end, - (cur_page_indptr_end - cur_page_indptr_begin - 1) * page_size + cur_last_page_len, + (cur_page_indptr_end - cur_page_indptr_begin - 1) * 16 + cur_last_page_len, 0 ) T.tvm_storage_sync("shared") @@ -667,7 +627,11 @@ def batch_prefill_paged_kv( cur_L = L_start + i // group_size cur_H_qo = H_qo_start + i % group_size if cur_L < q_indptr[b_idx + 1]: - Q_smem[i, j] = q[cur_L, cur_H_qo, j] + Q_smem[i, j] = T.if_then_else( + rotary_mode == 1, + _rope(q, q_rope_position[cur_L], d, rope_theta, rope_scale, (cur_L, cur_H_qo, j)), + q[cur_L, cur_H_qo, j] + ) else: Q_smem[i, j] = 0.0 T.tvm_storage_sync("shared") @@ -681,9 +645,13 @@ def batch_prefill_paged_kv( T.writes() cur_L = L_kv_start + i if cur_L < kv_chunk_len[0]: - page_no: T.int32(is_size_var=True) = page_values[cur_page_indptr_begin + T.floordiv(cur_L, page_size)] - page_offset: T.int32(is_size_var=True) = T.floormod(cur_L, page_size) - K_smem[i, j] = pages[page_no, 0, by, page_offset, j] + page_no: T.int32(is_size_var=True) = page_values[cur_page_indptr_begin + T.floordiv(cur_L, 16)] + page_offset: T.int32(is_size_var=True) = T.floormod(cur_L, 16) + K_smem[i, j] = T.if_then_else( + rotary_mode == 1, + _rope(pages, k_rope_pos_offset[b_idx] + cur_L, d, rope_theta, rope_scale, (page_no, 0, by, page_offset, j)), + pages[page_no, 0, by, page_offset, j] + ) else: K_smem[i, j] = 0.0 T.tvm_storage_sync("shared") @@ -694,8 +662,8 @@ def batch_prefill_paged_kv( T.writes() cur_L = L_kv_start + i if cur_L < kv_chunk_len[0]: - page_no: T.int32(is_size_var=True) = page_values[cur_page_indptr_begin + T.floordiv(cur_L, page_size)] - page_offset: T.int32(is_size_var=True) = T.floormod(cur_L, page_size) + page_no: T.int32(is_size_var=True) = page_values[cur_page_indptr_begin + T.floordiv(cur_L, 16)] + page_offset: T.int32(is_size_var=True) = T.floormod(cur_L, 16) V_smem[i, j] = pages[page_no, 1, by, page_offset, j] else: V_smem[i, j] = 0.0 @@ -721,8 +689,8 @@ def batch_prefill_paged_kv( row: T.int32 = i * 32 * num_warps + ty * 32 + tx if row < tile_x: with T.block("update1"): - m_prev[0] = m_smem[row] - m_new[0] = m_smem[row] + m_prev[i] = m_smem[row] + m_new[i] = m_smem[row] # mask out of kv_chunk_len S for j in T.serial(tile_z): if mask(causal, @@ -730,8 +698,8 @@ def batch_prefill_paged_kv( col=L_kv_start + j, kv_len=kv_chunk_len[0], qo_len=q_indptr[b_idx + 1] - q_indptr[b_idx]): - m_new[0] = T.max(m_new[0], S_smem[row, j]) - d_new[0] = d_smem[row] * T.exp2(m_prev[0] - m_new[0]) + m_new[i] = T.max(m_new[i], S_smem[row, j]) + d_new[i] = d_smem[row] * T.exp2(m_prev[i] - m_new[i]) for i in T.serial(T.ceildiv(tile_x, 32 * num_warps)): row: T.int32 = i * 32 * num_warps + ty * 32 + tx @@ -744,19 +712,19 @@ def batch_prefill_paged_kv( col=L_kv_start + j, kv_len=kv_chunk_len[0], qo_len=q_indptr[b_idx + 1] - q_indptr[b_idx]): - S_smem[row, j] = T.exp2(S_smem[row, j] - m_new[0]) + S_smem[row, j] = T.exp2(S_smem[row, j] - m_new[i]) else: - S_smem[row, j] = T.exp2(-5e4 - m_new[0]) + S_smem[row, j] = T.exp2(-5e4 - m_new[i]) for i in T.serial(T.ceildiv(tile_x, 32 * num_warps)): row: T.int32 = i * 32 * num_warps + ty * 32 + tx if row < tile_x: with T.block("update"): for j in T.serial(tile_z): - d_new[0] += S_smem[row, j] - m_smem[row] = m_new[0] - d_smem[row] = d_new[0] - m_prev_smem[row] = m_prev[0] + d_new[i] += S_smem[row, j] + m_smem[row] = m_new[i] + d_smem[row] = d_new[i] + m_prev_smem[row] = m_prev[i] T.tvm_storage_sync("shared") # Update O @@ -775,6 +743,13 @@ def batch_prefill_paged_kv( if L_start + i // group_size < q_indptr[b_idx + 1]: output[L_start + i // group_size, H_qo_start + i % group_size, j] = O_local[i, j] / d_smem[i] + # Store LSE to gmem + for li in T.grid(tile_x): + with T.block("lse_store"): + i = T.axis.remap("S", [li]) + if L_start + i // group_size < q_indptr[b_idx + 1]: + lse[L_start + i // group_size, H_qo_start + i % group_size] = m_smem[i] + T.log2(d_smem[i]) + # move to next tile tile_id[0] += NUM_BLKS # fmt: on @@ -835,6 +810,12 @@ def apply_to_gemm( # pylint: disable=too-many-arguments,unused-argument sch.reorder(ko, ki, xi, yi) sch.decompose_reduction(block, ty) + def apply_to_md(sch, block): + loop = sch.get_loops(block)[-1] + _, ty, tx = sch.split(loop, factors=[None, num_warps, 32]) + sch.bind(ty, "threadIdx.y") + sch.bind(tx, "threadIdx.x") + tile_s = get_tile_size(tile_x, tile_z, 32 * num_warps) tile_o = get_tile_size(tile_x, tile_y, 32 * num_warps) apply_to_gemm(sch, sch.get_block("S_gemm"), tile_s, 0, 1, k_major=True) @@ -845,6 +826,7 @@ def apply_to_gemm( # pylint: disable=too-many-arguments,unused-argument apply_to_qkv_load(sch, sch.get_block("Q_load")) apply_to_qkv_load(sch, sch.get_block("K_load")) apply_to_qkv_load(sch, sch.get_block("V_load")) + apply_to_md(sch, sch.get_block("lse_store")) return sch.mod["main"].with_attr("tir.is_scheduled", 1) @@ -861,6 +843,7 @@ def _attention_decode(num_kv_heads, num_qo_heads, head_dim, qkv_dtype): GROUP_SIZE = H_qo // H_kv VEC_SIZE = max(8 // qkv_dtype_bytes, D // 32) bdx = D // VEC_SIZE + assert bdx == 32 bdy = GROUP_SIZE threads_per_CTA = max(128, bdx * bdy) bdz = threads_per_CTA // (bdx * bdy) @@ -871,12 +854,14 @@ def _attention_decode(num_kv_heads, num_qo_heads, head_dim, qkv_dtype): # fmt: off @T.prim_func def batch_decode_paged_kv( - handler_id: T.int32, # pylint: disable=unused-argument + _0: T.int32, # pylint: disable=unused-argument Q_handle: T.handle, pages_handle: T.handle, page_table_indptr_handle: T.handle, page_table_values_handle: T.handle, last_page_len_handle: T.handle, + k_rope_pos_offset_handle: T.handle, + q_rope_position_handle: T.handle, output_handle: T.handle, lse_handle: T.handle, rotary_mode: T.int32, @@ -885,16 +870,17 @@ def batch_decode_paged_kv( ): T.func_attr({"tir.is_scheduled": 1}) B = T.int32(is_size_var=True) - page_size = T.int32(is_size_var=True) nnz_pages = T.int32(is_size_var=True) max_num_pages = T.int32(is_size_var=True) Q = T.match_buffer(Q_handle, (B, H_qo, D), qkv_dtype) pages = T.match_buffer( - pages_handle, (max_num_pages, 2, H_kv, page_size, D), qkv_dtype + pages_handle, (max_num_pages, 2, H_kv, 16, D), qkv_dtype ) page_table_indptr = T.match_buffer(page_table_indptr_handle, (B + 1,), "int32") page_table_values = T.match_buffer(page_table_values_handle, (nnz_pages,), "int32") + k_rope_pos_offset = T.match_buffer(k_rope_pos_offset_handle, (B,), "int32") + q_rope_position = T.match_buffer(q_rope_position_handle, (B,), "int32") last_page_len = T.match_buffer(last_page_len_handle, (B,), "int32") output = T.match_buffer(output_handle, (B, H_qo, D), qkv_dtype) lse = T.match_buffer(lse_handle, (B, H_qo), "float32") # pylint: disable=unused-variable @@ -911,14 +897,15 @@ def batch_decode_paged_kv( kv_chunk_len = T.alloc_buffer((1,), "int32", scope="local") K_smem = T.alloc_buffer((bdz * bdy * tile_size_per_bdx, D), qkv_dtype, scope="shared") V_smem = T.alloc_buffer((bdz * bdy * tile_size_per_bdx, D), qkv_dtype, scope="shared") - S_allreduce = T.alloc_buffer((bdz, bdy, bdx), "float32", scope="shared") O_allreduce = T.alloc_buffer((bdz, bdy, D), "float32", scope="shared") md_allreduce = T.alloc_buffer((bdz, bdy, 2), "float32", scope="shared") + S_reduce_local = T.alloc_buffer((1,), "float32", scope="local") + mask = T.alloc_buffer((1,), "uint32", scope="local") + t0 = T.alloc_buffer((1,), "float32", scope="local") S_local = T.alloc_buffer((bdy * tile_size_per_bdx), "float32", scope="local") K_local = T.alloc_buffer((VEC_SIZE,), qkv_dtype, scope="local") V_local = T.alloc_buffer((VEC_SIZE,), qkv_dtype, scope="local") - offset = T.alloc_buffer((1,), "int32", scope="local") m_prev = T.alloc_buffer((1,), "float32", scope="local") d_prev = T.alloc_buffer((1,), "float32", scope="local") other_m = T.alloc_buffer((1,), "float32", scope="local") @@ -934,7 +921,7 @@ def batch_decode_paged_kv( cur_last_page_len: T.int32 = last_page_len[batch_idx] kv_chunk_len[0] = T.if_then_else( cur_page_indptr_begin != cur_page_indptr_end, - (cur_page_indptr_end - cur_page_indptr_begin - 1) * page_size + cur_last_page_len, + (cur_page_indptr_end - cur_page_indptr_begin - 1) * 16 + cur_last_page_len, 0 ) @@ -946,9 +933,11 @@ def batch_decode_paged_kv( # load q for vec in T.vectorized(VEC_SIZE): - Q_local[vec] = T.if_then_else(rotary_mode == 1, - _rope(Q, kv_chunk_len[0]-1, head_dim, rope_theta, rope_scale, (bx, by * GROUP_SIZE + ty, tx * VEC_SIZE + vec)), - Q[bx, by * GROUP_SIZE + ty, tx * VEC_SIZE + vec]) + Q_local[vec] = T.if_then_else( + rotary_mode == 1, + _rope(Q, q_rope_position[batch_idx], head_dim, rope_theta, rope_scale, (bx, by * GROUP_SIZE + ty, tx * VEC_SIZE + vec)), + Q[bx, by * GROUP_SIZE + ty, tx * VEC_SIZE + vec] + ) for iterator in T.serial(T.ceildiv(kv_chunk_len[0], tile_size_per_bdx * bdy * bdz)): tile_start_s: T.int32(is_size_var=True) = (tz * bdy + ty) * tile_size_per_bdx @@ -957,12 +946,14 @@ def batch_decode_paged_kv( for j in T.serial(tile_size_per_bdx): row_g: T.int32(is_size_var=True) = tile_start_g + j if row_g < kv_chunk_len[0]: - page_no: T.int32(is_size_var=True) = page_table_values[cur_page_indptr_begin + T.floordiv(row_g, page_size)] - page_offset: T.int32(is_size_var=True) = T.floormod(row_g, page_size) + page_no: T.int32(is_size_var=True) = page_table_values[cur_page_indptr_begin + T.floordiv(row_g, 16)] + page_offset: T.int32(is_size_var=True) = T.floormod(row_g, 16) for vec in T.vectorized(VEC_SIZE): - K_smem[tile_start_s + j, tx * VEC_SIZE + vec] = T.if_then_else(rotary_mode == 1, - _rope(pages, row_g, head_dim, rope_theta, rope_scale, (page_no, 0, by, page_offset, tx * VEC_SIZE + vec)), - pages[page_no, 0, by, page_offset, tx * VEC_SIZE + vec]) + K_smem[tile_start_s + j, tx * VEC_SIZE + vec] = T.if_then_else( + rotary_mode == 1, + _rope(pages, k_rope_pos_offset[batch_idx] + row_g, head_dim, rope_theta, rope_scale, (page_no, 0, by, page_offset, tx * VEC_SIZE + vec)), + pages[page_no, 0, by, page_offset, tx * VEC_SIZE + vec] + ) else: for vec in T.vectorized(VEC_SIZE): K_smem[tile_start_s + j, tx * VEC_SIZE + vec] = 0.0 @@ -971,8 +962,8 @@ def batch_decode_paged_kv( for j in T.serial(tile_size_per_bdx): row_g: T.int32(is_size_var=True) = tile_start_g + j if row_g < kv_chunk_len[0]: - page_no: T.int32(is_size_var=True) = page_table_values[cur_page_indptr_begin + T.floordiv(row_g, page_size)] - page_offset: T.int32(is_size_var=True) = T.floormod(row_g, page_size) + page_no: T.int32(is_size_var=True) = page_table_values[cur_page_indptr_begin + T.floordiv(row_g, 16)] + page_offset: T.int32(is_size_var=True) = T.floormod(row_g, 16) for vec in T.vectorized(VEC_SIZE): V_smem[tile_start_s + j, tx * VEC_SIZE + vec] = pages[page_no, 1, by, page_offset, tx * VEC_SIZE + vec] else: @@ -989,19 +980,23 @@ def batch_decode_paged_kv( for vec in T.vectorized(VEC_SIZE): K_local[vec] = K_smem[tz * bdy * tile_size_per_bdx + j, tx * VEC_SIZE + vec] # compute S = Q * K * sm_scale - S_local[j] = 0 + S_reduce_local[0] = 0 for vec in T.serial(VEC_SIZE): - S_local[j] += Q_local[vec] * K_local[vec] * sm_scale - # allreduce over bdx - S_allreduce[tz, ty, tx] = S_local[j] - T.tvm_storage_sync("shared") - offset[0] = bdx // 2 - while offset[0] > 0: - if tx < offset[0]: - S_allreduce[tz, ty, tx] += S_allreduce[tz, ty, tx + offset[0]] - T.tvm_storage_sync("shared") - offset[0] = offset[0] >> 1 - S_local[j] = S_allreduce[tz, ty, 0] + S_reduce_local[0] += Q_local[vec] * K_local[vec] * sm_scale + + t0[0] = T.tvm_warp_shuffle_down(mask[0], S_reduce_local[0], 16, 32, 32) + S_reduce_local[0] = S_reduce_local[0] + t0[0] + t0[0] = T.tvm_warp_shuffle_down(mask[0], S_reduce_local[0], 8, 32, 32) + S_reduce_local[0] = S_reduce_local[0] + t0[0] + t0[0] = T.tvm_warp_shuffle_down(mask[0], S_reduce_local[0], 4, 32, 32) + S_reduce_local[0] = S_reduce_local[0] + t0[0] + t0[0] = T.tvm_warp_shuffle_down(mask[0], S_reduce_local[0], 2, 32, 32) + S_reduce_local[0] = S_reduce_local[0] + t0[0] + t0[0] = T.tvm_warp_shuffle_down(mask[0], S_reduce_local[0], 1, 32, 32) + S_reduce_local[0] = S_reduce_local[0] + t0[0] + S_reduce_local[0] = T.tvm_warp_shuffle(mask[0], S_reduce_local[0], 0, 32, 32) + + S_local[j] = S_reduce_local[0] # update st_m st_m[0] = T.max(st_m[0], S_local[j]) @@ -1054,13 +1049,412 @@ def batch_decode_paged_kv( # store O to global memory for vec in T.vectorized(VEC_SIZE): output[batch_idx, by * GROUP_SIZE + ty, tx * VEC_SIZE + vec] = O_local[vec] + + # store lse to global memory + lse[batch_idx, by * GROUP_SIZE + ty] = st_m[0] + T.log2(st_d[0]) # fmt: on # pylint: enable=line-too-long,invalid-name,too-many-arguments,too-many-branches return batch_decode_paged_kv +def _attention_prefill_ragged(h_kv, h_q, d, dtype): + assert dtype == "float16", f"TIR attention kernel does not support dtype {dtype} right now" + # pylint: disable=invalid-name + NUM_BLKS = 16 + LOAD_VEC = 8 // ((tvm.DataType(dtype).bits + 7) // 8) # 8 bytes + group_size = h_q // h_kv + sm_scale = 1.0 / math.sqrt(float(d)) * math.log2(math.exp(1)) + + num_warps = 4 + tile_x, tile_y, tile_z = 32, d, 16 + L_per_cta = tile_x // group_size + + def mask(causal, row, col, kv_len, qo_len): + return T.if_then_else( + causal > 0, + col < kv_len - qo_len + row + 1, + col < kv_len, + ) + + # fmt: off + @T.prim_func + def batch_prefill_ragged_kv( + var_q: T.handle, # [total_len, h_q, d] + var_q_indptr: T.handle, # [batch_size + 1] + var_k: T.handle, # [total_len, h_kv, d] + var_v: T.handle, # [total_len, h_kv, d] + var_kv_indptr: T.handle, # [batch_size + 1] + var_q_rope_position: T.handle, # [total_q_len] + var_k_rope_pos_offset: T.handle, # [b] + var_output: T.handle, # [total_len, h_q, d] + var_lse: T.handle, # [total_len, h_q] + causal: T.int32, + rotary_mode: T.int32, + rope_scale: T.float32, + rope_theta: T.float32, + ): + batch_size = T.int32(is_size_var=True) + qo_len = T.int32(is_size_var=True) + kv_len = T.int32(is_size_var=True) + + q = T.match_buffer(var_q, (qo_len, h_q, d), dtype) + q_indptr = T.match_buffer(var_q_indptr, (batch_size + 1,), "int32") + k = T.match_buffer(var_k, (kv_len, h_kv, d), dtype) + v = T.match_buffer(var_v, (kv_len, h_kv, d), dtype) + kv_indptr = T.match_buffer(var_kv_indptr, (batch_size + 1,), "int32") + q_rope_position = T.match_buffer(var_q_rope_position, (qo_len,), "int32") + k_rope_pos_offset = T.match_buffer(var_k_rope_pos_offset, (batch_size,), "int32") + output = T.match_buffer(var_output, (qo_len, h_q, d), dtype) + lse = T.match_buffer(var_lse, (qo_len, h_q), "float32") # pylint: disable=unused-variable + + # kernel code + for lbx in T.thread_binding(NUM_BLKS, thread="blockIdx.x"): + for lby in T.thread_binding(h_kv, thread="blockIdx.y"): + for lty in T.thread_binding(num_warps, thread="threadIdx.y"): + for ltx in T.thread_binding(32, thread="threadIdx.x"): + with T.block("attn"): + bx, by, ty, tx = T.axis.remap("SSSS", [lbx, lby, lty, ltx]) + T.reads() + T.writes() + tile_id = _var("int32") + batch_idx = _var("int32") + batch_tiles = _var("int32") + batch_rows = _var("int32") + iterator = _var("int32") + kv_chunk_len = _var("int32") + + Q_smem = T.alloc_buffer((tile_x, d), dtype, scope="shared") + K_smem = T.alloc_buffer((tile_z, d), dtype, scope="shared") + V_smem = T.alloc_buffer((tile_z, d), dtype, scope="shared") + S_smem = T.alloc_buffer((tile_x, tile_z), "float32", scope="shared") + + S_local = T.alloc_buffer((tile_x, tile_z), "float32", scope="local") + O_local = T.alloc_buffer((tile_x, d), "float32", scope="local") + + m_smem = T.alloc_buffer((tile_x, ), "float32", scope="shared") + m_prev_smem = T.alloc_buffer((tile_x, ), "float32", scope="shared") + d_smem = T.alloc_buffer((tile_x, ), "float32", scope="shared") + + m_new = T.alloc_buffer((math.ceil(tile_x / (32 * num_warps)),), "float32", scope="local") + m_prev = T.alloc_buffer((math.ceil(tile_x / (32 * num_warps)),), "float32", scope="local") + d_new = T.alloc_buffer((math.ceil(tile_x / (32 * num_warps)),), "float32", scope="local") + + ## get tile_no, batch_idx, batch_tiles, batch_rows + tile_id[0] = bx + batch_idx[0] = 0 + batch_rows[0] = (q_indptr[1] - q_indptr[0]) * group_size + batch_tiles[0] = T.ceildiv(batch_rows[0], tile_x) + while T.tvm_thread_invariant(batch_idx[0] < batch_size): + # advance to next tile + while tile_id[0] >= batch_tiles[0] and batch_idx[0] < batch_size: + tile_id[0] -= batch_tiles[0] + batch_idx[0] += 1 + if batch_idx[0] < batch_size: + b_idx: T.int32 = batch_idx[0] + batch_rows[0] = (q_indptr[b_idx + 1] - q_indptr[b_idx]) * group_size + batch_tiles[0] = T.ceildiv(batch_rows[0], tile_x) + + if T.tvm_thread_invariant(batch_idx[0] < batch_size): + b_idx: T.int32 = batch_idx[0] + L_start: T.int32 = q_indptr[b_idx] + tile_id[0] * L_per_cta + H_qo_start: T.int32 = by * group_size + + kv_chunk_len[0] = kv_indptr[b_idx + 1] - kv_indptr[b_idx] + T.tvm_storage_sync("shared") + + # init states + for i in T.serial(T.ceildiv(tile_x, 32 * num_warps)): + row: T.int32 = i * 32 * num_warps + ty * 32 + tx + if row < tile_x: + m_smem[row] = -5e4 + d_smem[row] = 1.0 + + for li, lj in T.grid(tile_x, tile_y): + with T.block("O_init"): + i, j = T.axis.remap("SS", [li, lj]) + O_local[i, j] = 0.0 + T.tvm_storage_sync("shared") + + # Load Q from gmem to smem + for li, lj in T.grid(tile_x, tile_y): + with T.block("Q_load"): + i, j = T.axis.remap("SS", [li, lj]) + T.reads() + T.writes() + cur_L = L_start + i // group_size + cur_H_qo = H_qo_start + i % group_size + if cur_L < q_indptr[b_idx + 1]: + Q_smem[i, j] = T.if_then_else( + rotary_mode == 1, + _rope(q, q_rope_position[cur_L], d, rope_theta, rope_scale, (cur_L, cur_H_qo, j)), + q[cur_L, cur_H_qo, j] + ) + else: + Q_smem[i, j] = 0.0 + T.tvm_storage_sync("shared") + + for iterator in T.serial(T.ceildiv(kv_chunk_len[0], tile_z)): + L_kv_start: T.int32 = iterator * tile_z + L_kv_base: T.int32 = kv_indptr[b_idx] + for lz, ly in T.grid(tile_z, tile_y): + with T.block("K_load"): + i, j = T.axis.remap("SS", [lz, ly]) + T.reads() + T.writes() + cur_L = L_kv_start + i + if cur_L < kv_chunk_len[0]: + K_smem[i, j] = T.if_then_else( + rotary_mode == 1, + _rope(k, k_rope_pos_offset[b_idx] + cur_L, d, rope_theta, rope_scale, (L_kv_base + cur_L, by, j)), + k[L_kv_base + cur_L, by, j] + ) + else: + K_smem[i, j] = 0.0 + T.tvm_storage_sync("shared") + for lz, ly in T.grid(tile_z, tile_y): + with T.block("V_load"): + i, j = T.axis.remap("SS", [lz, ly]) + T.reads() + T.writes() + cur_L = L_kv_start + i + if cur_L < kv_chunk_len[0]: + V_smem[i, j] = v[L_kv_base + cur_L, by, j] + else: + V_smem[i, j] = 0.0 + T.tvm_storage_sync("shared") + + # Compute S + with T.block(): + for li, lj, lk in T.grid(tile_x, tile_z, tile_y): + with T.block("S_gemm"): + i, j, k = T.axis.remap("SSR", [li, lj, lk]) + with T.init(): + S_local[i, j] = 0.0 + S_local[i, j] += Q_smem[i, k] * K_smem[j, k] * sm_scale + T.tvm_storage_sync("shared") + for li, lj in T.grid(tile_x, tile_z): + with T.block("S_store"): + i, j = T.axis.remap("SS", [li, lj]) + S_smem[i, j] = S_local[i, j] + T.tvm_storage_sync("shared") + + # Update S, m, d + for i in T.serial(T.ceildiv(tile_x, 32 * num_warps)): + row: T.int32 = i * 32 * num_warps + ty * 32 + tx + if row < tile_x: + with T.block("update1"): + m_prev[i] = m_smem[row] + m_new[i] = m_smem[row] + # mask out of kv_chunk_len S + for j in T.serial(tile_z): + if mask(causal, + row=tile_id[0] * L_per_cta + row // group_size, + col=L_kv_start + j, + kv_len=kv_chunk_len[0], + qo_len=q_indptr[b_idx + 1] - q_indptr[b_idx]): + m_new[i] = T.max(m_new[i], S_smem[row, j]) + d_new[i] = d_smem[row] * T.exp2(m_prev[i] - m_new[i]) + + for i in T.serial(T.ceildiv(tile_x, 32 * num_warps)): + row: T.int32 = i * 32 * num_warps + ty * 32 + tx + with T.block("update"): + for j in T.serial(tile_z): + # this is to avoid sync inside condition branch + if row < tile_x: + if mask(causal, + row=tile_id[0] * L_per_cta + row // group_size, + col=L_kv_start + j, + kv_len=kv_chunk_len[0], + qo_len=q_indptr[b_idx + 1] - q_indptr[b_idx]): + S_smem[row, j] = T.exp2(S_smem[row, j] - m_new[i]) + else: + S_smem[row, j] = T.exp2(-5e4 - m_new[i]) + + for i in T.serial(T.ceildiv(tile_x, 32 * num_warps)): + row: T.int32 = i * 32 * num_warps + ty * 32 + tx + if row < tile_x: + with T.block("update"): + for j in T.serial(tile_z): + d_new[i] += S_smem[row, j] + m_smem[row] = m_new[i] + d_smem[row] = d_new[i] + m_prev_smem[row] = m_prev[i] + T.tvm_storage_sync("shared") + + # Update O + with T.block(): + for li, lj, lk in T.grid(tile_x, tile_y, tile_z): + with T.block("O_gemm"): + i, j, k = T.axis.remap("SSR", [li, lj, lk]) + with T.init(): + O_local[i, j] *= T.exp2(m_prev_smem[i] - m_smem[i]) + O_local[i, j] += S_smem[i, k] * V_smem[k, j] + + # Store O from smem to gmem + for li, lj in T.grid(tile_x, tile_y): + with T.block("O_store"): + i, j = T.axis.remap("SS", [li, lj]) + if L_start + i // group_size < q_indptr[b_idx + 1]: + output[L_start + i // group_size, H_qo_start + i % group_size, j] = O_local[i, j] / d_smem[i] + + # Store LSE to gmem + for li in T.grid(tile_x): + with T.block("lse_store"): + i = T.axis.remap("S", [li]) + if L_start + i // group_size < q_indptr[b_idx + 1]: + lse[L_start + i // group_size, H_qo_start + i % group_size] = m_smem[i] + T.log2(d_smem[i]) + + # move to next tile + tile_id[0] += NUM_BLKS + # fmt: on + # pylint: enable=line-too-long,invalid-name,too-many-arguments,too-many-branches + sch = tir.Schedule(batch_prefill_ragged_kv) + + def get_tile_size(x, y, t): + cnt = (x * y) // t + assert (x * y) % t == 0 + tile_y = (int)(math.ceil(math.sqrt(cnt))) + while cnt % tile_y != 0 and y % tile_y != 0 and tile_y <= cnt: + tile_y += 1 + assert tile_y <= cnt + tile_x = cnt // tile_y + return tile_x, tile_y + + def apply_to_qkv_load(sch: tir.Schedule, block): + loop_x, loop_y = sch.get_loops(block)[-2:] + loop = sch.fuse(loop_x, loop_y) + _, ty, tx, vec = sch.split( + loop, factors=[None, num_warps, 32, LOAD_VEC], preserve_unit_iters=True + ) + sch.bind(ty, "threadIdx.y") + sch.bind(tx, "threadIdx.x") + sch.vectorize(vec) + + def apply_to_so_ewise(sch: tir.Schedule, block, tile, vec_len=4): + loop_x, loop_y = sch.get_loops(block)[-2:] + xo, xi = sch.split(loop_x, factors=[None, tile[0]]) + yo, yi = sch.split(loop_y, factors=[None, tile[1]]) + sch.reorder(xo, yo, xi, yi) + t = sch.fuse(xo, yo) + ty, tx = sch.split(t, factors=[num_warps, 32]) + sch.bind(ty, "threadIdx.y") + sch.bind(tx, "threadIdx.x") + if tile[1] % vec_len == 0: + yi, vec = sch.split(yi, factors=[None, vec_len]) + sch.vectorize(vec) + elif tile[1] in [2, 4]: + sch.vectorize(yi) + + def apply_to_gemm( # pylint: disable=too-many-arguments,unused-argument + sch: tir.Schedule, block, tile, read_0, read_1, r_len=8, k_major=False + ): + loop_x, loop_y, loop_z = sch.get_loops(block)[-3:] + xo, xi = sch.split(loop_x, factors=[None, tile[0]]) + yo, yi = sch.split(loop_y, factors=[None, tile[1]]) + sch.reorder(xo, yo, xi, yi) + t = sch.fuse(xo, yo) + ty, tx = sch.split(t, factors=[num_warps, 32]) + sch.bind(ty, "threadIdx.y") + sch.bind(tx, "threadIdx.x") + + ko, ki = sch.split(loop_z, factors=[None, r_len]) + if k_major: + sch.reorder(ko, xi, yi, ki) + else: + sch.reorder(ko, ki, xi, yi) + sch.decompose_reduction(block, ty) + + def apply_to_md(sch, block): + loop = sch.get_loops(block)[-1] + _, ty, tx = sch.split(loop, factors=[None, num_warps, 32]) + sch.bind(ty, "threadIdx.y") + sch.bind(tx, "threadIdx.x") + + tile_s = get_tile_size(tile_x, tile_z, 32 * num_warps) + tile_o = get_tile_size(tile_x, tile_y, 32 * num_warps) + apply_to_gemm(sch, sch.get_block("S_gemm"), tile_s, 0, 1, k_major=True) + apply_to_gemm(sch, sch.get_block("O_gemm"), tile_o, 2, 3, k_major=False) + apply_to_so_ewise(sch, sch.get_block("S_store"), tile_s) + apply_to_so_ewise(sch, sch.get_block("O_init"), tile_o) + apply_to_so_ewise(sch, sch.get_block("O_store"), tile_o) + apply_to_qkv_load(sch, sch.get_block("Q_load")) + apply_to_qkv_load(sch, sch.get_block("K_load")) + apply_to_qkv_load(sch, sch.get_block("V_load")) + + apply_to_md(sch, sch.get_block("lse_store")) + return sch.mod["main"].with_attr("tir.is_scheduled", 1) + + +def _merge_state_inplace(num_heads, head_dim, v_dtype): + # pylint: disable=invalid-name + v_dtype_bytes = 2 + VEC_SIZE = max(8 // v_dtype_bytes, head_dim // 32) + bdx = head_dim // VEC_SIZE + bdy = num_heads + + @T.prim_func + def merge_state_inplace( + v: T.handle, + s: T.handle, + v_other: T.handle, + s_other: T.handle, + ): + T.func_attr({"tir.is_scheduled": 1}) + N = T.int32(is_size_var=True) + H = T.int32(is_size_var=True) + D = T.int32(is_size_var=True) + + V = T.match_buffer(v, (N, H, D), v_dtype) + S = T.match_buffer(s, (N, H), "float32") + V_other = T.match_buffer(v_other, (N, H, D), v_dtype) + S_other = T.match_buffer(s_other, (N, H), "float32") + + for bx in T.thread_binding(N, thread="blockIdx.x"): + for ty in T.thread_binding(bdy, thread="threadIdx.y"): + for tx in T.thread_binding(bdx, thread="threadIdx.x"): + with T.block("merge"): + s_val = _var("float32") + s_other_val = _var("float32") + s_max = _var("float32") + scale = _var("float32") + other_scale = _var("float32") + + v_vec = T.alloc_buffer((VEC_SIZE,), v_dtype, scope="local") + v_other_vec = T.alloc_buffer((VEC_SIZE,), v_dtype, scope="local") + + s_val[0] = S[bx, ty] + s_other_val[0] = S_other[bx, ty] + s_max[0] = T.max(s_val[0], s_other_val[0]) + s_val[0] = T.exp2(s_val[0] - s_max[0]) + s_other_val[0] = T.exp2(s_other_val[0] - s_max[0]) + scale[0] = s_val[0] / (s_val[0] + s_other_val[0]) + other_scale[0] = s_other_val[0] / (s_val[0] + s_other_val[0]) + + # load v + for vec in T.vectorized(VEC_SIZE): + v_vec[vec] = V[bx, ty, tx * VEC_SIZE + vec] + # load v_other + for vec in T.vectorized(VEC_SIZE): + v_other_vec[vec] = V_other[bx, ty, tx * VEC_SIZE + vec] + + # merge + for vec in T.serial(VEC_SIZE): + v_vec[vec] = v_vec[vec] * scale[0] + v_other_vec[vec] * other_scale[0] + + # store v + for vec in T.vectorized(VEC_SIZE): + V[bx, ty, tx * VEC_SIZE + vec] = v_vec[vec] + + # store s + S[bx, ty] = T.log2(s_val[0] + s_other_val[0]) + s_max[0] + + # pylint: enable=invalid-name + return merge_state_inplace + + if __name__ == "__main__": cache = create_kv_cache() test_paged_attention_kv_cache_prefill_and_decode(cache) test_paged_attention_kv_cache_remove_sequence(cache) + test_paged_attention_kv_cache_fork_sequence(cache) test_paged_attention_kv_cache_popn(cache)