diff --git a/src/runtime/relax_vm/kv_state.cc b/src/runtime/relax_vm/kv_state.cc index b730a4eb07..7dd7d6f499 100644 --- a/src/runtime/relax_vm/kv_state.cc +++ b/src/runtime/relax_vm/kv_state.cc @@ -71,10 +71,23 @@ TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_get_query_positions") TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_debug_get_kv") .set_body_method(&AttentionKVCacheObj::DebugGetKV); TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_attention_with_fused_qkv") - .set_body_typed([](AttentionKVCache kv_cache, int64_t layer_id, - double attn_score_scaling_factor, NDArray qkv_data, NDArray o_data) { + .set_body([](TVMArgs args, TVMRetValue* rv) { + CHECK(args.size() == 5 || args.size() == 6) + << "KVState AttentionWithFusedQKV only accepts 5 or 6 arguments"; + AttentionKVCache kv_cache = args[0]; + int64_t layer_id = args[1]; + double attn_score_scaling_factor = args[2]; + NDArray qkv_data = args[3]; + NDArray o_data; + Optional ext_factors = NullOpt; + if (args.size() == 5) { + o_data = args[4]; + } else { + ext_factors = args[4].operator tvm::runtime::NDArray(); + o_data = args[5]; + } kv_cache->AttentionWithFusedQKV(layer_id, std::move(qkv_data), NullOpt, std::move(o_data), - attn_score_scaling_factor); + std::move(ext_factors), attn_score_scaling_factor); }); // RNN State methods diff --git a/src/runtime/relax_vm/kv_state.h b/src/runtime/relax_vm/kv_state.h index f4d6036b96..bac6d18810 100644 --- a/src/runtime/relax_vm/kv_state.h +++ b/src/runtime/relax_vm/kv_state.h @@ -167,10 +167,13 @@ class AttentionKVCacheObj : public KVStateObj { * `(total_length, num_qo_heads + 2 * num_kv_heads, head_dim)`. * \param mask The input mask data, in layout `(total_sqr_length)`. * \param o_data The output O data, in layout `(total_length, num_qo_heads, head_dim)`. + * \param rope_ext_factors The RoPE extension factor array in shape `(head_dim // 2,)`. + * \param attn_score_scaling_factor The additional attention scaling factor. * \sa AttentionKVCache::Attention */ virtual void AttentionWithFusedQKV(int64_t layer_id, NDArray qkv_data, Optional mask, - NDArray o_data, double attn_score_scaling_factor) = 0; + NDArray o_data, Optional rope_ext_factors, + double attn_score_scaling_factor) = 0; /************** Positions **************/ diff --git a/src/runtime/relax_vm/paged_kv_cache.cc b/src/runtime/relax_vm/paged_kv_cache.cc index 6bf3dc7ce6..10ac657259 100644 --- a/src/runtime/relax_vm/paged_kv_cache.cc +++ b/src/runtime/relax_vm/paged_kv_cache.cc @@ -1685,7 +1685,8 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { } void AttentionWithFusedQKV(int64_t layer_id, NDArray qkv_data, Optional mask, - NDArray o_data, double attn_score_scaling_factor) final { + NDArray o_data, Optional rope_ext_factors, + double attn_score_scaling_factor) final { // Part 1. Shape and dtype check. int64_t local_layer_id = layer_id - layer_id_begin_offset_; CHECK_GE(local_layer_id, 0); @@ -1726,8 +1727,15 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { NDArray v_data = temp_attn_v_device_.CreateView({total_seq_length, num_kv_heads_, head_dim_}, qkv_data->dtype); // Part 2. Split fused qkv and apply rotary embedding to q/k data. - f_split_rotary_(qkv_data, q_rope_position_map_view_, q_data, k_data, v_data, - static_cast(rope_mode_ == RoPEMode::kNormal)); + if (!rope_ext_factors.defined()) { + f_split_rotary_(qkv_data, q_rope_position_map_view_, q_data, k_data, v_data, + static_cast(rope_mode_ == RoPEMode::kNormal)); + } else { + CHECK(rope_mode_ == RoPEMode::kNormal) + << "The RoPE mode must be normal to support RoPE extension factors."; + f_split_rotary_(qkv_data, q_rope_position_map_view_, q_data, k_data, v_data, + rope_ext_factors.value()); + } // Part 3. Append k/v data to kv-cache if flag "append_before_attn" is set. if (append_before_attn_) {