Skip to content

Commit

Permalink
[Cherry-Pick][Fix][Relax] Add the missing tree-attn func arg for KV c…
Browse files Browse the repository at this point in the history
…ache creation

This PR fixes the TIRPagedKVCache construction issue, which is caused
by missing the tree-attention with paged KV cache kernel.
  • Loading branch information
MasterJH5574 committed Sep 7, 2024
1 parent 95a3def commit 2685d6a
Showing 1 changed file with 1 addition and 0 deletions.
1 change: 1 addition & 0 deletions python/tvm/relax/frontend/nn/llm/kv_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,6 +375,7 @@ def __init__( # pylint: disable=too-many-locals
bb.add_func(_kv_cache_debug_get_kv(num_hidden_layers, num_key_value_heads, head_dim, dtype), "kv_cache_debug_get_kv"),
bb.add_func(_compact_kv_copy(num_key_value_heads, head_dim, dtype, target), "kv_cache_compact_kv_copy"),
bb.add_func(tree_attn(num_key_value_heads, num_attention_heads, head_dim, dtype, rope_scaling, target), "tir_attention_prefill_with_tree_mask"),
bb.add_func(tree_attn_with_paged_kv_cache(num_key_value_heads, num_attention_heads, head_dim, dtype, rope_scaling, target), "tir_attention_prefill_with_tree_mask_with_paged_kv_cache"),
rope_ext_factors,
# fmt: on
# pylint: enable=line-too-long
Expand Down

0 comments on commit 2685d6a

Please sign in to comment.