Skip to content
This repository has been archived by the owner on Oct 11, 2024. It is now read-only.

Commit

Permalink
[Model][AMD] ROCm support for 256 head dims for Gemma (vllm-project#3972
Browse files Browse the repository at this point in the history
)
  • Loading branch information
jamestwhedbee authored and SageMoore committed Apr 11, 2024
1 parent b454575 commit 88424b5
Showing 1 changed file with 2 additions and 3 deletions.
5 changes: 2 additions & 3 deletions vllm/attention/ops/triton_flash_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -677,8 +677,7 @@ def check_args(
assert q.shape[-1] == k.shape[-1] and q.shape[-1] == v.shape[-1]
# TODO: Change assert if we support qkl f8 and v f16
assert q.dtype == k.dtype and q.dtype == v.dtype
# TODO: Fix assert to check head size <=256 once supported
assert head_size <= 128
assert head_size <= 256
assert o.shape == q.shape
assert (nheads_q % nheads_k) == 0

Expand Down Expand Up @@ -729,7 +728,7 @@ def forward(
o_strides = (o.stride(0), o.stride(2), o.stride(1), o.stride(3))

# Get closest power of 2 over or equal to 32.
unpadded_head_dims = {32, 64, 128}
unpadded_head_dims = {32, 64, 128, 256}
if head_size not in unpadded_head_dims:
padded_d_model = None
for i in unpadded_head_dims:
Expand Down

0 comments on commit 88424b5

Please sign in to comment.