Skip to content

Commit

Permalink
Revert "[Core] Remove unnecessary copies in flash attn backend" (vllm…
Browse files Browse the repository at this point in the history
  • Loading branch information
Yard1 authored Jun 13, 2024
1 parent 719133f commit 226163c
Showing 1 changed file with 6 additions and 7 deletions.
13 changes: 6 additions & 7 deletions vllm/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,7 @@ def forward(
# normal attention
# When block_tables are not filled, it means q and k are the
# prompt, and they have the same length.
flash_attn_varlen_func(
out = flash_attn_varlen_func(
q=query,
k=key,
v=value,
Expand All @@ -329,13 +329,14 @@ def forward(
causal=True,
window_size=self.sliding_window,
alibi_slopes=self.alibi_slopes,
out=output[:num_prefill_tokens],
)
assert output[:num_prefill_tokens].shape == out.shape
output[:num_prefill_tokens] = out
else:
# prefix-enabled attention
assert prefill_meta.seq_lens is not None
max_seq_len = max(prefill_meta.seq_lens)
flash_attn_varlen_func(
output[:num_prefill_tokens] = flash_attn_varlen_func(
q=query,
k=key_cache,
v=value_cache,
Expand All @@ -347,12 +348,11 @@ def forward(
causal=True,
alibi_slopes=self.alibi_slopes,
block_table=prefill_meta.block_tables,
out=output[:num_prefill_tokens],
)

if decode_meta := attn_metadata.decode_metadata:
# Decoding run.
flash_attn_with_kvcache(
output[num_prefill_tokens:] = flash_attn_with_kvcache(
decode_query.unsqueeze(1),
key_cache,
value_cache,
Expand All @@ -361,8 +361,7 @@ def forward(
softmax_scale=self.scale,
causal=True,
alibi_slopes=self.alibi_slopes,
out=output[num_prefill_tokens:].unsqueeze(1),
)
).squeeze(1)

# Reshape the output tensor.
return output.view(num_tokens, hidden_size)

0 comments on commit 226163c

Please sign in to comment.