Skip to content

Commit

Permalink
[FIX] Fix Alibi implementation in PagedAttention kernel (vllm-project…
Browse files Browse the repository at this point in the history
…#945)

* [FIX] Fix Alibi implementation in PagedAttention kernel

* Fix test_attention

* Fix

---------

Co-authored-by: Woosuk Kwon <[email protected]>
Co-authored-by: Oliver-ss <[email protected]>
  • Loading branch information
3 people authored Sep 7, 2023
1 parent 963e775 commit b2e0bc0
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 3 deletions.
2 changes: 1 addition & 1 deletion csrc/attention/attention_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ __global__ void single_query_cached_kv_attention_kernel(
// This includes a reduction across the threads in the same thread group.
float qk = scale * Qk_dot<scalar_t, THREAD_GROUP_SIZE>::dot(q_vecs[thread_group_offset], k_vecs);
// Add the ALiBi bias if slopes are given.
qk += (alibi_slope != 0) ? alibi_slope * (token_idx - context_len) : 0;
qk += (alibi_slope != 0) ? alibi_slope * (token_idx - context_len + 1) : 0;

if (thread_group_offset == 0) {
// Store the partial reductions to shared memory.
Expand Down
5 changes: 3 additions & 2 deletions tests/kernels/test_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
NUM_HEADS = [(40, 40), (64, 8)] # Arbitrary values for testing
HEAD_SIZES = [64, 80, 96, 112, 128, 256]
BLOCK_SIZES = [8, 16, 32]
USE_ALIBI = [False] # TODO(woosuk): Add USE_ALIBI=True
USE_ALIBI = [False, True]
SEEDS = [0]


Expand Down Expand Up @@ -83,7 +83,7 @@ def ref_single_query_cached_kv_attention(
if alibi_slopes is not None:
# Create the ALiBi bias used in the paged attention kernel.
position_ids = torch.arange(context_len, device="cuda").int()
alibi_bias = (context_len - position_ids).float()
alibi_bias = (position_ids - context_len + 1).float()
alibi_bias = alibi_slopes.view(-1, 1, 1) * alibi_bias.view(
1, 1, -1)

Expand Down Expand Up @@ -224,6 +224,7 @@ def ref_multi_query_kv_attention(
return ref_output


# TODO(woosuk): Add tests for USE_ALIBI=True.
@pytest.mark.parametrize("num_seqs", NUM_PREFILL_SEQS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
Expand Down

0 comments on commit b2e0bc0

Please sign in to comment.