From e6b2a55d1e1668d889ce69efa3921bc73dcb8b8a Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Wed, 20 Nov 2024 23:22:39 -0500 Subject: [PATCH] [KVCache] Fix attention prefill kernel for Metal and Android This PR fixes two bugs of the attention prefill ragged kernel. * The first bug is the unroll of loop `ki`, which causes the TIR build failure in the PointerValueTypeRewrite pass due to vector size. * The second is the tile sizes of `tile_z` and `tile_y` may violate the assertion check in `get_tile_size`. --- python/tvm/relax/frontend/nn/llm/kv_cache.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/python/tvm/relax/frontend/nn/llm/kv_cache.py b/python/tvm/relax/frontend/nn/llm/kv_cache.py index 618345d0a5..18f3e19909 100644 --- a/python/tvm/relax/frontend/nn/llm/kv_cache.py +++ b/python/tvm/relax/frontend/nn/llm/kv_cache.py @@ -1579,6 +1579,12 @@ def _attention_prefill_ragged(h_kv, h_q, d, dtype, rope_scaling: Dict[str, Any], d, 64 // ((DataType(dtype).bits + 7) // 8) // max(d // 128, 1), ) + original_tile_y = tile_y + original_tile_z = tile_z + while (tile_x * tile_z) % (bdx * num_warps) != 0: + tile_z += original_tile_z + while (tile_x * tile_y) % (bdx * num_warps) != 0: + tile_y += original_tile_y # Otherwise we would exceed maxComputeWorkgroupStorageSize if ( @@ -1907,7 +1913,6 @@ def apply_to_gemm( # pylint: disable=unused-argument sch.unroll(yio) sch.vectorize(yiv) sch.unroll(xi) - sch.unroll(ki) sch.decompose_reduction(block, ty) def apply_to_md(sch, block):