-
Notifications
You must be signed in to change notification settings - Fork 1.3k
/
flash_attention.py
115 lines (101 loc) · 5.41 KB
/
flash_attention.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import math
import torch
import torch.nn as nn
from einops import rearrange
from flash_attn.rotary import RotaryEmbedding, RotaryEmbedding2D
from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func
from flash_attn.bert_padding import unpad_input, pad_input, index_first_axis
class FlashAttention(nn.Module):
"""Implement the scaled dot product attention with softmax.
Arguments
---------
softmax_scale: The temperature to use for the softmax attention.
(default: 1/sqrt(d_keys) where d_keys is computed at
runtime)
attention_dropout: The dropout rate to apply to the attention
(default: 0.1)
"""
def __init__(self, softmax_scale=None, attention_dropout=0.0, device=None, dtype=None):
super().__init__()
self.softmax_scale = softmax_scale
self.dropout_p = attention_dropout
def forward(self, qkv, key_padding_mask=None, causal=False, cu_seqlens=None,
max_s=None, need_weights=False):
"""Implements the multihead softmax attention.
Arguments
---------
qkv: The tensor containing the query, key, and value. (B, S, 3, H, D) if key_padding_mask is None
if unpadded: (nnz, 3, h, d)
key_padding_mask: a bool tensor of shape (B, S)
"""
assert not need_weights
assert qkv.dtype in [torch.float16, torch.bfloat16]
assert qkv.is_cuda
if cu_seqlens is None:
batch_size = qkv.shape[0]
seqlen = qkv.shape[1]
if key_padding_mask is None:
qkv = rearrange(qkv, 'b s ... -> (b s) ...')
max_s = seqlen
cu_seqlens = torch.arange(0, (batch_size + 1) * seqlen, step=seqlen, dtype=torch.int32,
device=qkv.device)
output = flash_attn_unpadded_qkvpacked_func(
qkv, cu_seqlens, max_s, self.dropout_p if self.training else 0.0,
softmax_scale=self.softmax_scale, causal=causal
)
output = rearrange(output, '(b s) ... -> b s ...', b=batch_size)
else:
nheads = qkv.shape[-2]
x = rearrange(qkv, 'b s three h d -> b s (three h d)')
x_unpad, indices, cu_seqlens, max_s = unpad_input(x, key_padding_mask)
x_unpad = rearrange(x_unpad, 'nnz (three h d) -> nnz three h d', three=3, h=nheads)
output_unpad = flash_attn_unpadded_qkvpacked_func(
x_unpad, cu_seqlens, max_s, self.dropout_p if self.training else 0.0,
softmax_scale=self.softmax_scale, causal=causal
)
output = rearrange(pad_input(rearrange(output_unpad, 'nnz h d -> nnz (h d)'),
indices, batch_size, seqlen),
'b s (h d) -> b s h d', h=nheads)
else:
assert max_s is not None
output = flash_attn_unpadded_qkvpacked_func(
qkv, cu_seqlens, max_s, self.dropout_p if self.training else 0.0,
softmax_scale=self.softmax_scale, causal=causal
)
return output, None
class FlashMHA(nn.Module):
def __init__(self, embed_dim, num_heads, bias=True, batch_first=True, attention_dropout=0.0,
causal=False, use_rotary_emb=None, device=None, dtype=None, **kwargs) -> None:
assert batch_first
factory_kwargs = {'device': device, 'dtype': dtype}
super().__init__()
self.embed_dim = embed_dim
self.causal = causal
self.num_heads = num_heads
assert self.embed_dim % num_heads == 0, "self.kdim must be divisible by num_heads"
self.head_dim = self.embed_dim // num_heads
assert self.head_dim in [16, 32, 64, 128], "Only support head_dim == 16, 32, 64, or 128"
assert use_rotary_emb in [None, '1d', '2d']
self.use_rotary_emb = use_rotary_emb
if self.use_rotary_emb == '1d':
self.rotary_emb = RotaryEmbedding(self.head_dim)
elif self.use_rotary_emb == '2d':
self.rotary_emb = RotaryEmbedding2D(self.head_dim)
self.Wqkv = nn.Linear(embed_dim, 3 * embed_dim, bias=bias, **factory_kwargs)
self.inner_attn = FlashAttention(attention_dropout=attention_dropout, **factory_kwargs)
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias, **factory_kwargs)
def forward(self, x, key_padding_mask=None, need_weights=False):
"""x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim)
key_padding_mask: bool tensor of shape (batch, seqlen)
"""
qkv = self.Wqkv(x)
if self.use_rotary_emb:
query, key, value = rearrange(qkv, 'b s (three h d) -> b s three h d', three=3,
h=self.num_heads).unbind(dim=2)
query, key = self.rotary_emb(query, key, seq_dimension=-3)
qkv = torch.stack([query.type(x.dtype), key.type(x.dtype), value], dim=2)
else:
qkv = rearrange(qkv, 'b s (three h d) -> b s three h d', three=3, h=self.num_heads)
context, attn_weights = self.inner_attn(qkv, key_padding_mask=key_padding_mask,
need_weights=need_weights, causal=self.causal)
return self.out_proj(rearrange(context, 'b s h d -> b s (h d)')), attn_weights