Skip to content

Commit

Permalink
corrected types for strides in triton FA (#274)
Browse files Browse the repository at this point in the history
Co-authored-by: Aleksandr Malyshev <[email protected]>
  • Loading branch information
maleksan85 and Aleksandr Malyshev authored Nov 13, 2024
1 parent 8de3a62 commit 9a46e97
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 21 deletions.
3 changes: 2 additions & 1 deletion vllm/attention/backends/rocm_flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -625,7 +625,8 @@ def forward(
# QKV for prefill.
query = query[:num_prefill_tokens]

if key is not None and value is not None:
if key is not None and value is not None \
and attn_type != AttentionType.ENCODER_DECODER:
key = key[:num_prefill_tokens]
value = value[:num_prefill_tokens]

Expand Down
40 changes: 20 additions & 20 deletions vllm/attention/ops/triton_flash_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,26 +314,26 @@ def attn_fwd(
sm_scale,
L,
Out,
stride_qz,
stride_qh,
stride_qm,
stride_qk,
stride_kz,
stride_kh,
stride_kn,
stride_kk,
stride_vz,
stride_vh,
stride_vk,
stride_vn,
stride_oz,
stride_oh,
stride_om,
stride_on,
stride_bz,
stride_bh,
stride_bm,
stride_bn,
stride_qz: tl.int64,
stride_qh: tl.int64,
stride_qm: tl.int64,
stride_qk: tl.int64,
stride_kz: tl.int64,
stride_kh: tl.int64,
stride_kn: tl.int64,
stride_kk: tl.int64,
stride_vz: tl.int64,
stride_vh: tl.int64,
stride_vk: tl.int64,
stride_vn: tl.int64,
stride_oz: tl.int64,
stride_oh: tl.int64,
stride_om: tl.int64,
stride_on: tl.int64,
stride_bz: tl.int64,
stride_bh: tl.int64,
stride_bm: tl.int64,
stride_bn: tl.int64,
cu_seqlens_q,
cu_seqlens_k,
dropout_p,
Expand Down

0 comments on commit 9a46e97

Please sign in to comment.