From d132897968ad7e53e2c350ca434516e5c79d305e Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Thu, 22 Aug 2024 15:17:36 -0400 Subject: [PATCH] [Runtime] Support KV cache with RoPE extension factor array This PR enhances the KV cache with the RoPE extensio factor support. With this PR, the KV cache can support models like Phi3.5 which comes with the extension factor. --- src/runtime/relax_vm/kv_state.cc | 19 ++++++++++++++++--- src/runtime/relax_vm/kv_state.h | 5 ++++- src/runtime/relax_vm/paged_kv_cache.cc | 14 +++++++++++--- 3 files changed, 31 insertions(+), 7 deletions(-) diff --git a/src/runtime/relax_vm/kv_state.cc b/src/runtime/relax_vm/kv_state.cc index b730a4eb07..7dd7d6f499 100644 --- a/src/runtime/relax_vm/kv_state.cc +++ b/src/runtime/relax_vm/kv_state.cc @@ -71,10 +71,23 @@ TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_get_query_positions") TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_debug_get_kv") .set_body_method(&AttentionKVCacheObj::DebugGetKV); TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_attention_with_fused_qkv") - .set_body_typed([](AttentionKVCache kv_cache, int64_t layer_id, - double attn_score_scaling_factor, NDArray qkv_data, NDArray o_data) { + .set_body([](TVMArgs args, TVMRetValue* rv) { + CHECK(args.size() == 5 || args.size() == 6) + << "KVState AttentionWithFusedQKV only accepts 5 or 6 arguments"; + AttentionKVCache kv_cache = args[0]; + int64_t layer_id = args[1]; + double attn_score_scaling_factor = args[2]; + NDArray qkv_data = args[3]; + NDArray o_data; + Optional ext_factors = NullOpt; + if (args.size() == 5) { + o_data = args[4]; + } else { + ext_factors = args[4].operator tvm::runtime::NDArray(); + o_data = args[5]; + } kv_cache->AttentionWithFusedQKV(layer_id, std::move(qkv_data), NullOpt, std::move(o_data), - attn_score_scaling_factor); + std::move(ext_factors), attn_score_scaling_factor); }); // RNN State methods diff --git a/src/runtime/relax_vm/kv_state.h b/src/runtime/relax_vm/kv_state.h index f4d6036b96..bac6d18810 100644 --- a/src/runtime/relax_vm/kv_state.h +++ b/src/runtime/relax_vm/kv_state.h @@ -167,10 +167,13 @@ class AttentionKVCacheObj : public KVStateObj { * `(total_length, num_qo_heads + 2 * num_kv_heads, head_dim)`. * \param mask The input mask data, in layout `(total_sqr_length)`. * \param o_data The output O data, in layout `(total_length, num_qo_heads, head_dim)`. + * \param rope_ext_factors The RoPE extension factor array in shape `(head_dim // 2,)`. + * \param attn_score_scaling_factor The additional attention scaling factor. * \sa AttentionKVCache::Attention */ virtual void AttentionWithFusedQKV(int64_t layer_id, NDArray qkv_data, Optional mask, - NDArray o_data, double attn_score_scaling_factor) = 0; + NDArray o_data, Optional rope_ext_factors, + double attn_score_scaling_factor) = 0; /************** Positions **************/ diff --git a/src/runtime/relax_vm/paged_kv_cache.cc b/src/runtime/relax_vm/paged_kv_cache.cc index 6bf3dc7ce6..10ac657259 100644 --- a/src/runtime/relax_vm/paged_kv_cache.cc +++ b/src/runtime/relax_vm/paged_kv_cache.cc @@ -1685,7 +1685,8 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { } void AttentionWithFusedQKV(int64_t layer_id, NDArray qkv_data, Optional mask, - NDArray o_data, double attn_score_scaling_factor) final { + NDArray o_data, Optional rope_ext_factors, + double attn_score_scaling_factor) final { // Part 1. Shape and dtype check. int64_t local_layer_id = layer_id - layer_id_begin_offset_; CHECK_GE(local_layer_id, 0); @@ -1726,8 +1727,15 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { NDArray v_data = temp_attn_v_device_.CreateView({total_seq_length, num_kv_heads_, head_dim_}, qkv_data->dtype); // Part 2. Split fused qkv and apply rotary embedding to q/k data. - f_split_rotary_(qkv_data, q_rope_position_map_view_, q_data, k_data, v_data, - static_cast(rope_mode_ == RoPEMode::kNormal)); + if (!rope_ext_factors.defined()) { + f_split_rotary_(qkv_data, q_rope_position_map_view_, q_data, k_data, v_data, + static_cast(rope_mode_ == RoPEMode::kNormal)); + } else { + CHECK(rope_mode_ == RoPEMode::kNormal) + << "The RoPE mode must be normal to support RoPE extension factors."; + f_split_rotary_(qkv_data, q_rope_position_map_view_, q_data, k_data, v_data, + rope_ext_factors.value()); + } // Part 3. Append k/v data to kv-cache if flag "append_before_attn" is set. if (append_before_attn_) {