From 8b317c6dd09ce566f4b4abeb446585ac75262cce Mon Sep 17 00:00:00 2001 From: James Whedbee Date: Wed, 10 Apr 2024 10:12:00 -0500 Subject: [PATCH] [Model][AMD] ROCm support for 256 head dims for Gemma (#3972) --- vllm/attention/ops/triton_flash_attention.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/vllm/attention/ops/triton_flash_attention.py b/vllm/attention/ops/triton_flash_attention.py index b86e845020b07..87cf30cbef79a 100644 --- a/vllm/attention/ops/triton_flash_attention.py +++ b/vllm/attention/ops/triton_flash_attention.py @@ -677,8 +677,7 @@ def check_args( assert q.shape[-1] == k.shape[-1] and q.shape[-1] == v.shape[-1] # TODO: Change assert if we support qkl f8 and v f16 assert q.dtype == k.dtype and q.dtype == v.dtype - # TODO: Fix assert to check head size <=256 once supported - assert head_size <= 128 + assert head_size <= 256 assert o.shape == q.shape assert (nheads_q % nheads_k) == 0 @@ -729,7 +728,7 @@ def forward( o_strides = (o.stride(0), o.stride(2), o.stride(1), o.stride(3)) # Get closest power of 2 over or equal to 32. - unpadded_head_dims = {32, 64, 128} + unpadded_head_dims = {32, 64, 128, 256} if head_size not in unpadded_head_dims: padded_d_model = None for i in unpadded_head_dims: