Skip to content

Commit

Permalink
[KVCache] Fix attention prefill kernel for Metal and Android
Browse files Browse the repository at this point in the history
This PR fixes two bugs of the attention prefill ragged kernel.

* The first bug is the unroll of loop `ki`, which causes the TIR build
failure in the PointerValueTypeRewrite pass due to vector size.
* The second is the tile sizes of `tile_z` and `tile_y` may violate
the assertion check in `get_tile_size`.
  • Loading branch information
MasterJH5574 committed Nov 21, 2024
1 parent 4b4a668 commit e6b2a55
Showing 1 changed file with 6 additions and 1 deletion.
7 changes: 6 additions & 1 deletion python/tvm/relax/frontend/nn/llm/kv_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -1579,6 +1579,12 @@ def _attention_prefill_ragged(h_kv, h_q, d, dtype, rope_scaling: Dict[str, Any],
d,
64 // ((DataType(dtype).bits + 7) // 8) // max(d // 128, 1),
)
original_tile_y = tile_y
original_tile_z = tile_z
while (tile_x * tile_z) % (bdx * num_warps) != 0:
tile_z += original_tile_z
while (tile_x * tile_y) % (bdx * num_warps) != 0:
tile_y += original_tile_y

# Otherwise we would exceed maxComputeWorkgroupStorageSize
if (
Expand Down Expand Up @@ -1907,7 +1913,6 @@ def apply_to_gemm( # pylint: disable=unused-argument
sch.unroll(yio)
sch.vectorize(yiv)
sch.unroll(xi)
sch.unroll(ki)
sch.decompose_reduction(block, ty)

def apply_to_md(sch, block):
Expand Down

0 comments on commit e6b2a55

Please sign in to comment.