Skip to content

Commit

Permalink
use flash-attn via xformers (vllm-project#877)
Browse files Browse the repository at this point in the history
  • Loading branch information
tmm1 authored Aug 30, 2023
1 parent f2588f4 commit 7cf85f3
Show file tree
Hide file tree
Showing 2 changed files with 0 additions and 5 deletions.
2 changes: 0 additions & 2 deletions tests/kernels/test_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,6 @@ def run_multi_query_kv_attention(
qkv.uniform_(-1e-3, 1e-3)
query, key, value = qkv.unbind(dim=1)

attn_op = xops.fmha.cutlass.FwOp()
attn_bias = BlockDiagonalCausalMask.from_seqlens(seq_lens)
output = xops.memory_efficient_attention_forward(
query.unsqueeze(0),
Expand All @@ -275,7 +274,6 @@ def run_multi_query_kv_attention(
attn_bias=attn_bias,
p=0.0,
scale=scale,
op=attn_op,
)
output = output.squeeze(0)

Expand Down
3 changes: 0 additions & 3 deletions vllm/model_executor/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@ def __init__(self,
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)
self.attn_op = xops.fmha.cutlass.FwOp()
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads

assert self.num_heads % self.num_kv_heads == 0
Expand Down Expand Up @@ -115,7 +114,6 @@ def multi_query_kv_attention(
attn_bias=input_metadata.attn_bias[0],
p=0.0,
scale=self.scale,
op=self.attn_op,
)
# TODO(woosuk): Unnecessary copy. Optimize.
output.copy_(out.squeeze(0))
Expand Down Expand Up @@ -404,7 +402,6 @@ def multi_query_kv_attention(
attn_bias=input_metadata.attn_bias[i],
p=0.0,
scale=self.scale,
op=self.attn_op,
)
# TODO(woosuk): Unnecessary copy. Optimize.
output[start:end].copy_(out.squeeze(0))
Expand Down

0 comments on commit 7cf85f3

Please sign in to comment.