1,Naive softmax
给定输入向量 Softmax
。):
Naive Softmax 算法主要包括两个步骤,其算法实现步骤和 FLOPs
分析如下:
-
计算归一化项
$dn$ :先对矩阵每个元素都需要进行指数运算,涉及FLOPs
为$N^2$ (逐元素操作),假设是对每一行进行Softmax
,每一行有$N$ 个元素,需要进行$N - 1$ 次加法,矩阵总共有$N$ 行,因此需要$s\times(N - 1)$ 次加法,最后计算归一化项$d_N$ 的FLOPs
为$2N^2N$ -
计算 softmax 输出:分为两步进行每个元素都需要除以所在行的总和,总共
$N^2$ 个元素,FLOPs
为$N^2$ 。
综上,Native Softmax
的总 FLOPs 为:
算法的 python
代码实现和其对 global memory 的访存量 MAC
数值如下所示:
"""
在 attenion 算子中, softmax 函数的输入 QK^T, 输入矩阵大小就是 [s,s]
"""
# [N, N] -> [N, N], 每个元素进行 3 次内存访问:2次读取和一次写入.
# mac = 3N^2, flops = 3N^2 - N
def native_softmax(x):
s, s = x.shape # 第一个维度是序列长度,第二个维度是隐藏层大小
output = np.array(x) # np.array() 将 python 中的数据结构(如列表、元组等)转换为 NumPy 的数组
for r in range(s):
sum = 0
for i in range(s):
sum += np.exp(x[r][i])
for i in range(s):
output[r][i] = np.exp(x[r][i]) / sum
return output
2,Safe Softmax
和 Native Softmax
相比,Safe Softmax
为了防止数值溢出还需要将 max
最大值:
Safe Softmax
涉及三个步骤,其算法实现步骤和 FLOPs
分析如下:
-
对每行求最大值:遍历每行元素,做
$N-1$ 次比较,得到每行元素的最大值,总共$N$ 行,因此该操作涉及FLOPs
为$N(N-1)$ -
计算指数并求和得到归一化项
$d_N$ :将每个元素减去最大值后,再计算指数,这个过程是逐元素操作,FLOPs
为$N^2 + N^2$ 。对每行进行求和,每行进行$N - 1$ 次加法,整个矩阵共$N\times(N - 1)$ 次加法。 -
计算 softmax 输出:将每个元素减去最大值后,再计算指数,最后除以行总和,需要
$2N^2$ 次除法。
值的注意的是,这里计算 max 需要一次独立的全局 reduce,计算分母的 sum 再需要一次独立的全局 reduce,最后分别计算每一个元素的 softmax 值。三个步骤之间存在数据依赖。
结合前面 Native Softmax
的 FLOPs 计算,再加上对每行求最大值的操作,可知 Safe Softmax
总 FLOPs
:
Safe Softmax
算法的 python
代码实现和其对 global memory 的访存量 MAC
数值如下所示:
# [N, N] -> [N, N], 每个元素进行 4 次内存访问:3次读取和一次写入.
# mac = 4N^2, flops = 4N^2 - 2N
def safe_softmax(x):
s, s = x.shape # 第一个维度是序列长度,第二个维度是隐藏层大小
output = np.array(x) # np.array() 将 python 中的数据结构(如列表、元组等)转换为 NumPy 的数组
for r in range(s):
max_r = 0
for i in range(s):
max_r = max(max_r, x[r][i]) # flops 为 1
sum = 0
for i in range(s):
sum += np.exp(x[r][i] - max_r) # flops 为 2 + 1
for i in range(s):
output[r][i] = np.exp(x[r][i] - max_r) / sum # flops 为 2
return output
IO 复杂度分析:Safe Softmax
需要 4
次内存访问,即 Safe Softmax 算法的内存访问(MAC
)偏大,即 softmax 函数的 HBM 访问次数为
从 Safe Softmax
公式很明显看出,MAC
大原因是因为存在数据依赖:(2) 需要依赖 MAC
),又因为 Softmax 典型情况都是内存受限,所以这肯定能提高 Softmax 算子的运行速度。
Online normalizer calculation for softmax 论文将 3 步 safe softmax 合并成 2 步完成的方法,并证明了
$$\begin{aligned} d_i' &= \sum^i_{j=1}e^{x_j - m_i} \ &= \sum^{i-1}{j=1}e^{x_j - m_i} + e^{x_i-m_i} \ &= \left ({\sum^{i-1}{j=1}e^{x_j - m_{i-1}}} \right ) * e^{m_{i-1} - m_i} + e^{x_i-m_i} \ &= d_{i-1}'* e^{m_{i-1} - m_i} + e^{x_i-m_i} \ \end{aligned}$$
即 Online Softmax
计算公式如下:
这里
如果想继续优化,则使用分块技术计算归一化常数,假设
分块计算完
算法分析和公式证明过程,本文不再描述,感兴趣的可以看我上一篇文章-《online-softmax 论文解读》。
这篇论文在算法上其实有两个创新:
- 提出并证明了通过一次遍历输入数据来计算 Softmax 函数归一化项的方法,该方法将 Softmax 函数的内存访问次数减少了
$1.33 (4/3 = 1.33)$ 倍 - 证明了可以分块计算归一化常数,这个方法可以发挥 GPU 多线程的特性。
这里针对上面两个创新,我分别给出 online softmax 算法的 python
代码实现以及 global memory 的访存量 MAC
。
import numpy as np
import torch.nn.functional as F
import torch
def online_softmax_update(m0, d0, m1, d1):
# x 1
m = max(m0, m1) # flops: 1
d = d0 * np.exp(m0 - m) + d1 * np.exp(m1-m) # flops: 5
return m, d
# [N, N] -> [N, N], 每个元素进行 3 次内存访问:2 次读取和一次写入.
# mac = 3N^2, flops = 8N^2
def online_softmax(x):
s, s = x.shape
output = np.array(x)
for r in range(s):
m = x[r][0]
d = 1
for j in range(1, s):
m, d = online_softmax_update(m, d, x[r][j], 1) # flops 为 6
for i in range(s):
output[r][i] = np.exp(x[r][i] - m) / d # flops 为 2
return output
# [N, N] -> [N, N], 每个元素进行 3 次内存访问:2 次读取和一次写入.
# mac = 3N^2, flops = 8N^2,分块计算,可发挥并行计算优势
def block_online_softmax(x, block_size=256):
assert x.shape[1] % block_size == 0
s, s = x.shape
output = np.array(x)
for r in range(s):
m = x[r][0]
d = 0
# 可使用多线程并行计算,实际 mac 为 N^2
for b in range(0, s // block_size):
# Calculate m,d of single block
m_block = x[r][b*block_size]
d_block = 0
for j in range(0, block_size):
m_block, d_block = online_softmax_update(m_block, d_block, x[r][b*block_size + j], 1)
# Merge all block's result to total
m, d = online_softmax_update(m, d, m_block, d_block)
for i in range(s):
output[r][i] = np.exp(x[r][i] - m) / d
return output
if __name__ == "__main__":
x = np.random.randn(1024, 1024)
# 对每一行执行 softmax 操作
pytorch_softmax_out = F.softmax(torch.tensor(x), dim=1) # dim=0表示按列计算;dim=1表示按行计算。
native_softmax_out = native_softmax(x)
safe_softmax_out = safe_softmax(x)
online_softmax_out = online_softmax(x)
block_online_softmax_out = block_online_softmax(x, 256)
if torch.allclose(pytorch_softmax_out, torch.tensor(native_softmax_out), atol=1e-4):
print("naive softmax 与 PyTorch softmax 结果一致!")
else:
print("naive softmax safe_softmax 与 PyTorch softmax 结果不一致!")
if torch.allclose(pytorch_softmax_out, torch.tensor(safe_softmax_out), atol=1e-4):
print("safe softmax 与 PyTorch softmax 结果一致!")
else:
print("safe softmax 与 PyTorch softmax 结果不一致!")
if torch.allclose(pytorch_softmax_out, torch.tensor(online_softmax_out), atol=1e-4):
print("online softmax 与 PyTorch softmax 结果一致!")
else:
print("online softmax 与 PyTorch softmax 结果不一致!")
if torch.allclose(pytorch_softmax_out, torch.tensor(block_online_softmax_out), atol=1e-4):
print("block online softmax 与 PyTorch softmax 结果一致!")
else:
print("block online softmax 与 PyTorch softmax 结果不一致!")
程序运行后输出结果如下所示:
naive softmax 与 PyTorch softmax 结果一致! safe softmax 与 PyTorch softmax 结果一致! online softmax 与 PyTorch softmax 结果一致! block online softmax 与 PyTorch softmax 结果一致!
给定输入二维矩阵
标准的 Attention
运算大致可以描述为以下三个步骤:
- 将
$Q, K$ 矩阵以块的形式从HBM
中加载到SRAM
中,计算$S=QK^T$ ,将$S$ 写入到HBM
中。 - 将
$S$ 矩阵从HBM
中加载到SRAM
中,计算$P = Softmax(S)$ ,将$P$ 写入到 HBM 中。 - 将
$P, V$ 矩阵以块的形式从 HBM 中加载到 SRAM 中,计算$O=PV$ , 将$O$ 写入到 HBM 中。
self-attention 算子涉及到的和 HBM 数据传输过程如上图所示,很明显需要从HBM 中读取 5次,写入 HBM 3 次,HBM
访存量
Roofline
性能分析模型是一种用于衡量和分析计算性能的工具,通过将应用程序的实际计算性能与硬件的理论峰值性能进行对比,以揭示应用是受到计算性能的限制还是受到内存带宽的限制,这里的内存带宽是指芯片外内存带宽。
Roofline 模型的有两个关键指标:操作强度和性能上限。操作强度定义:每字节内存数据传输所对应的操作次数,即每字节 flops,单位一般为 GFlops/sec。Roofline
模型将浮点运算性能、操作强度和内存性能整合在一个二维图中。浮点运算的峰值性能可以通过硬件规格或微基准测试得出。
图表采用对数-对数刻度,Y 轴为可实现的浮点性能,X 轴为操作强度,范围从每 1/4
Flops/DRAM 字节
到 16Flops/DRAM 字节
。Operational Intensity, OI, 也称算术强度 Arithmetic Intensity。
计算机内存系统在不同操作强度下支持的最大浮点性能的计算公式:
Roofline
模型有两个作用:
- 上限分析:为浮点程序性能设定了一个上限(水平线)。
- 瓶颈分析:比较浮点程序的操作强度硬件的操作强度,判断程序是处于内存还是计算受限。
FlashAttention 论文就是基于 Roofline
模型分析了 self-attention 层处于内存受限状态,从而得到了减少 HBM 访问次数的思路。
FlashAttention 论文中说的 SRAM
是指哪种 GPU 内存类型?
1,可以从 cuda 编程和算法角度理解 SRAM 是 L1 Cache (数据缓冲)。
FlashAttention 核心是分块计算注意力,可以简单理解为就是将输入张量划分成很多块,每个数据块放到 sm 里面去计算(cuda/triton 编程的核心就是在于如何将数据分块),sm 里面 L1 cache/共享内存的大小基本就决定了 这个数据块的上限空间大小,所以论文里面说的 SRAM 大小其实值的是 L1 Cache 大小,L2 Cache 是所有 SM 能共同访问的,明显不是论文里指的 SRAM。
2,可以从 GPU 内存层次角度直接看出 SRAM 是 L1 Cache (数据缓冲)。
论文 2.1 节明确都说了 A100 的 SRAM 大小是 192 KB,而英伟达官网给出的 A100 白皮书也明确说了 A100 的 L1 cache 大小是 192KB( 组合共享内存和 L1 数据缓存),所以论文的 SRAM 肯定指的是 L1 cache 了。
“As an example, the A100 GPU has 40-80GB of high bandwidth memory (HBM) with bandwidth 1.5-2.0TB/s and 192KB of on-chip SRAM per each of 108 streaming multiprocessors with bandwidth estimated around 19TB/s [44, 45].”
computations block by block。
Online Softmax 实现在一个 for 循环中计算
注意力输出
$$O_i \leftarrow \text{diag}(\ell_i^{\text{new}})^{-1} (\text{diag}(\ell_i) e^{m_{i} - m_i^{\text{new}}}O_i + e^{\tilde{m}{ij} - m_i^{\text{new}}} \tilde{P}{ij} V_j)$$
这里原论文给出的推导不是很容易看懂,我参考文章给出了推导证明,其证明了
回到一开始的标准 attention,将 online-softmax 算法套进去,这里的 Softmax 是对
$$\begin{aligned} S_{r, i} &= \sum^{d}{j=0}Q{r,j}K_{j,i}\ M_{r, i} &= \max(M_{r, i-1}, S_{r, i}), \quad D_{r, i}' = D_{r, i-1}' * e^{M_{r, i-1} - M_{r, i}} + e^{S_{r, i}-M_{r, i}}, \ P_{r, i} &= \frac{e^{S_{r, i} - M_{r, N}}}{D_{r, N}'}, \ O_{r, c} &= \sum^N_{i=0}(P_{r, i} * V_{i, c}) \end{aligned} $$
对应代码如下:
def online_softmax_update(m0, d0, m1, d1):
# x 1
m = max(m0, m1) # flops: 1
d = d0 * np.exp(m0 - m) + d1 * np.exp(m1-m) # flops: 5
return m, d
def flashattn_0(Q, K, V):
N, Dim = Q.shape
# 1, Load Q K and write S. and Compute S[r][i] by matrix multiply
S = np.zeros([N, N], "float32")
for r in range(0, N):
for i in range(0, N):
for j in range(0, Dim):
S[r][i] += Q[r][j] * K[i][j] # K^T 的列就是 K 的行
# 2, Load S and write O. Compute softmax[i] and O[r][c]
O = np.zeros([N, Dim], "float32")
for r in range(0, N):
m = S[r][0]
d = 1
for i in range(1, N):
m, d = online_softmax_update(m, d, S[r][i], 1) # flops 为 6
softmax = np.zeros([N], "float32")
for i in range(0, N):
softmax[i] = np.exp(S[r][i] - m) / d
for c in range(0, Dim):
for i in range(0, N):
O[r][c] += softmax[i] * V[i][c] # V[i][c] 的加载不连续
return O
将 online softmax 应用到标准 attention 后,Softmax 的 HBM 访存减少了,但还能不能继续优化,在一个 for 循环内完成注意力 O[r][c] 的计算呢?就像 online-softmax 那样,实际是可以的。
$$\begin{aligned} O_{r,c, i} &= O_{r,c, i-1} + Softmax_{r, i} * V[i, c] \ &= O_{r,c, i-1} + \frac{e^{S_{r, i} - M_{r,N}}}{D'{r,N}} * V[i, c] \ &= \sum{j=1}^i \frac{e^{S_{r, j} - M_{r,N}}}{D'_{r,N}} * V[j, c] \end{aligned}$$
可以发现
$$\begin{aligned} O'{r,c, i} &= \sum{j=1}^i \frac{e^{S_{r, j} - M_{r, i}}}{D'{r, i}} * V[j, c] \ &= \sum{j=1}^{i-1} \frac{e^{S_{r, j} - M_{r, i}}}{D'{r,i}} * V[j, c] + \frac{e^{S{r, i} - M_{r, i}}}{D'{r,i}} * V[i, c] \ &= \sum{j=1}^{i-1} \frac{e^{S_{r, j} - M_{r, i-1}}}{D'{r, i-1}} * V[j, c] * \frac{D'{r, i-1} * e^{M_{r,i-1} - M_{r,i}}}{D'{r, i}} + \frac{e^{S{r, i} - M_{r, i}}}{D'{r,i}} * V[i, c] \ &= O'{r,c, i-1} * \frac{e^{M_{r,i-1} - M_{r,i}} * D'{r, i-1}}{D'{r, i}} + \frac{e^{S_{r, i} - M_{r, i}}}{D'_{r,i}} * V[i, c] \end{aligned}$$
可以看到 $O'{r,c, i}$ 仅仅和 $O'{r,c, i-1}$ 以及
$$\begin{aligned} S_{r, i} &= \sum^{Dim}{j=1}Q[r, j]K[j, i]\ M{r, i} &= \max(M_{r, i-1}, S_{r, i}), \quad D_{r, i}' = D_{r, i-1}' * e^{M_{r, i-1} - M_{r, i}} + e^{S_{r, i}-M_{r, i}}\ O'{r,c, i} &=O{r,c,i-1}'*\frac{e^{M_{r, i-1} - M_{r, i}}D_{r,i-1}'}{D_{r,i}'} + \frac{e^{S_{r, i} - M_{r, i}}}{D_{r, i}'}V[i, c]\ \end{aligned} $$
最终,我们想要的注意力输出结果为:
【定理 1】 算法 1 注意力输出矩阵
上述就是 FlashAttention 算法的等效计算公式,对应的伪代码可以写为:
for (r = 1 to N)
for (i = 1 to N)
// [N, Dim] * [Dim, N] -> [N, N]
for (j = 1 to Dim)
S[r, i] += Q[r, j] * K[j, i]
// [N, N]
M[r, i] = max(M[r, i-1], S[r, i])
// [N, N]
D'[r, i] = D'[r, i-1] * exp(M[r, i-1] - M[r, i]) + exp(S[r, i] - M[r, i])
// [N, Dim]
for (c = 0 to Dim)
for i in range(0, N):
o += o * e(...) * D'[r, i-1] / D'[r, i] + e(...) / D'[r, i] * V[i, c]
O[r][c] = o
再用 python 实现如下所示:
def online_softmax_update(m0, d0, m1, d1):
# x 1
m = max(m0, m1) # flops: 1
d = d0 * np.exp(m0 - m) + d1 * np.exp(m1-m) # flops: 5
return m, d
def flashattn_update(m, d, m0, d0, o0, m1, d1, o1):
# | | | | | |
# | | | x v 1
# Init value: MIN_M 0 0
o = o0 * np.exp(m0 - m) * d0 / d + o1 * np.exp(m1 - m) * d1 / d
return o
def flashattn_1(Q, K, V):
N, Dim = Q.shape
# 1, Load Q K and write S. and Compute S[r][i] by matrix multiply
S = np.zeros([N, N], "float32")
O = np.zeros([N, Dim], "float32")
m = np.zeros([N], "float32")
d = np.zeros([N], "float32")
for r in range(0, N):
# 计算 QK^T 的第 i 行结果 S[r][i]
for i in range(0, N):
# QK^T
for j in range(0, Dim):
S[r][i] += Q[r][j] * K[i][j] # K^T 的列就是 K 的行
# softmax: [N,N] -> [N,N]
if i == 0:
mm = S[r][0]
dd = 0
mm, dd = online_softmax_update(mm, dd, S[r][i], 1) # flops 为 6
m[i] = mm
d[i] = dd
# PV: [N, N] * [N, Dim] -> [N, dim]
for c in range(0, Dim):
o = 0
for i in range(0, N):
# 迭代更新注意力计算输出
o = flashattn_update(
m[i],
d[i],
m[i-1] if i > 0 else MIN_M,
d[i-1] if i > 0 else 0,
o,
S[r][i],
V[i][c],
1
)
O[r][c] = o
return O
继续优化,上面的公式和代码只是实现了在一个 for 循环中计算
因此,FlashAttention-1 的分块计算 python 代码如下。
def block_flashattn(Q, K, V, block_size=32):
N, Dim = Q.shape
# 1, Load Q K and write S. and Compute S[r][i] by matrix multiply
S = np.zeros([N, N], "float32")
O = np.zeros([N, Dim], "float32")
for r in range(0, N):
for i in range(0, N):
# QK^T
for j in range(0, Dim):
S[r][i] += Q[r][j] * K[i][j]
for r in range(0, N):
# Softmax
mm = np.zeros([N], "float32")
dd = np.zeros([N], "float32")
m = np.zeros([N // block_size], "float32")
d = np.zeros([N // block_size], "float32")
for b in range(0, N // block_size):
# Calculate m,d of single block
for i in range(0, block_size):
mm[b*block_size + i], dd[b*block_size + i] = online_softmax_update(
mm[b*block_size + i-1] if i > 0 else MIN_M,
dd[b*block_size + i-1] if j > 0 else 0,
S[r, b*block_size + i],
1,
)
# Merge all block's result to total
m[b], d[b] = online_softmax_update(
m[b-1] if b > 0 else MIN_M,
d[b-1] if b > 0 else 0,
mm[(b + 1) * block_size - 1], # 当前块的 mm 和 dd
dd[(b + 1) * block_size - 1])
# PV: [N, N] * [N, Dim] -> [N, dim]
for c in range(0, Dim):
o = 0
for b in range(0, N //block_size):
# Calculate single block
oo = 0
for i in range(0, block_size):
oo = flashattn_update(
mm[b * block_size + i], # 当前迭代位置的 m
dd[b * block_size + i], # 当前迭代位置的 d
mm[b * block_size + i-1] if i > 0 else MIN_M,
dd[b * block_size + i-1] if i > 0 else 0,
oo,
S[r, b * block_size + i], # 当前迭代位置的 s[r,i]
V[b * block_size + i, c],
1
)
# Merge all blocks to total
o = flashattn_update(
m[b],
d[b],
m[b - 1] if b > 0 else MIN_M,
d[b - 1] if b > 0 else 0,
o,
mm[(b + 1) * block_size - 1],
dd[(b + 1) * block_size - 1],
oo,
)
O[r][c] = o
return O
FlashAttention-v1 其实并没有提出新的算法和网络结构上的优化,但是其在算法上综合了过往的两个创新点:分块和重计算,并将其应用于 Attention 结构,给出了详尽的数学计算、证明和 IO 复杂度分析(论文长达 34 页大头都是公式),可以说是过往 transformer 模型在 gpu 上优化的集大成者,而且最重要的是提供了非常易用的前向传播和反向传播的代码库,这使得其广为引用和应用于工业界。
可见,优秀的代码功底、扎实的理论基础、底层硬件和框架的熟悉对于科研工作非常重要,即使你没有提出新的算法,但是你的工作依然可以广为传播和应用。
总的来说,FlashAttention 在算法层面通过重排注意力计算,并利用经典技术(分块和重计算)显著加速了注意力计算,将内存占用从二次方降低到线性。使得在 sequence length 偏长和 attention 计算处于内存密集型的情况下有着明显的加速效果。并直接带来了相对于优化基准 2-4 倍的实际运行时间加速,以及高达 10-20 倍的内存节省,并且计算结果是精确而非近似的。
本文主要分析其在模型推理阶段的优化,因此重计算方法的分析就略过了。
论文总结的一些定理:
【定理 1】 算法 1 注意力输出矩阵
【定理 2】假设 SRAM
大小,且 HBM
访问次数是
【命题 3】设 SRAM
的大小,且 HBM
访问来计算精确的注意力。
【定理 4】假设 SRAM
大小,且
【定理 5】设 SRAM
的大小,且
FlashAttention
算法实现步骤如下所示。
$\text{算法 1 FlashAttention} \ 要求:矩阵; Q, K, V \in \mathbb{R}^{N \times d} ;存储在;\text{HBM}(高带宽内存)中,片上;\text{SRAM};大小为;M. \$
$1: 设置块大小;B_c = \left\lceil \frac{M}{4d} \right\rceil , B_r = \min \left(\left\lceil \frac{M}{4d} \right\rceil , d\right). \ 2: 初始化;O = (0){N \times d} \in \mathbb{R}^{N \times d} , \ell = (0)N \in \mathbb{R}^N , m = (-\infty)N \in \mathbb{R}^N;存储在; \text{HBM} 中. \ 3: 将 ;Q;分成; T_r = \left\lceil \frac{N}{B_r} \right\rceil ;块 Q_1, \dots, Q{T_r},每块大小为;B_r\times d;将;K, V;分为; T_c = \left\lceil \frac{N}{B_c} \right\rceil ;块; K_1, \dots, K{T_c} ;和; V_1, \dots, V{T_c},每块大小为; B_c \times d. \ 4: 将 ;O;分为;T_r; 块;O_1, \dots, O_{T_r},每块大小为 ;B_r\times d,将 ;\ell;分为;T_r;块 \ell_1, \dots, \ell_{T_r},将; m ;分为;T_r;块 m_1, \dots, m_{T_r},每块大小为;B_r. \ 5: for ;1 \leq j \leq T_c;\text{do} \ 6: \quad 从;\text{HBM} 加载;K_j, V_j;到片上 ;\text{SRAM}. \ 7: \quad for ; 1 \leq i \leq T_r; \text{do} \ 8: \quad \quad 从 ; \text{HBM}; 加载 ; Q_i, O_i, \ell_i, m_i ;到片上; \text{SRAM}. \ 9: \quad \quad 在片上计算; S_{ij} = Q_i K_j^T \in \mathbb{R}^{B_r \times B_c}. \ 10: \quad \quad 在片上计算; \tilde{m}{ij} = \text{rowmax}(S{ij}) \in \mathbb{R}^{B_r} , \tilde{P}{ij} = \exp(S{ij} - \tilde{m}{ij}) \in \mathbb{R}^{B_r \times B_c} (逐元素操作),计算; \tilde{\ell}{ij} = \text{rowsum}(\tilde{P}{ij}) \in \mathbb{R}^{B_r}. \ 11: \quad \quad 在片上计算; m_i^{\text{new}} = \max(m_i, \tilde{m}{ij}) \in \mathbb{R}^{B_r} , \ell_i^{\text{new}} = e^{m_i - m_i^{\text{new}}} \ell_i + e^{\tilde{m}{ij} - m_i^{\text{new}}} \tilde{\ell}{ij} \in \mathbb{R}^{B_r}. \ 12: \quad \quad 将; O_i \leftarrow \text{diag}(\ell_i^{\text{new}})^{-1} (\text{diag}(\ell_i) e^{m{i} - m_i^{\text{new}}}O_i + e^{\tilde{m}{ij} - m_i^{\text{new}}} \tilde{P}{ij} V_j) ; 写回到; \text{HBM}. \ 13: \quad \quad 将; \ell_i \leftarrow \ell_i^{\text{new}}, m_i \leftarrow m_i^{\text{new}} ;写回到; \text{HBM}. \ 14: \quad \text{end for} \ 15: \text{end for} \ 16: 返回; O$
上面的是纯 python 代码,下面我们继续优化,利用 triton 框架写出极度优化的 FlashAttention-1 内核代码。
FlashAttention-2 主要是在 cuda 工程上进行极致优化,算法上的改动很小,$O_{r,c,i}$ 的迭代计算过程中需要反复除以
$$O'{r,c, i} =O{r,c,i-1}'*\frac{e^{M_{r, i-1} - M_{r, i}}D_{r,i-1}'}{D_{r,i}'} + \frac{e^{S_{r, i} - M_{r, i}}}{D_{r, i}'}V[i, c]$$
先将原递推公式简写成:
将等式两边同时乘以
$$O[i]D[i] = O[i-1] * E * D[i-1] + EV$$
再记
这样公式就大大简化了,减少了 FlashAttention 的计算量。对应的 python 代码:
# m, d, m0, d0, o0, m1, d1, o1):
def flashattn_2_update(m, m0, od0, m1, d1, od1):
# | | | | |
# | | x v 1
# Init value: MIN_M 0
od = od0 * np.exp(m0 - m) + od1 * np.exp(m1 - m) * d1
return od
def block_flashattn2(Q, K, V, block_size=32):
N, Dim = Q.shape
# 1, Load Q K and write S. and Compute S[r][i] by matrix multiply
S = np.zeros([N, N], "float32")
O = np.zeros([N, Dim], "float32")
for r in range(0, N):
for i in range(0, N):
# QK^T
for j in range(0, Dim):
S[r][i] += Q[r][j] * K[i][j]
for r in range(0, N):
# Softmax
mm = np.zeros([N], "float32")
dd = np.zeros([N], "float32")
m = np.zeros([N // block_size], "float32")
d = np.zeros([N // block_size], "float32")
for b in range(0, N // block_size):
# Calculate m,d of single block
for i in range(0, block_size):
mm[b*block_size + i], dd[b*block_size + i] = online_softmax_update(
mm[b*block_size + i-1] if i > 0 else MIN_M,
dd[b*block_size + i-1] if j > 0 else 0,
S[r, b*block_size + i],
1,
)
# Merge all block's result to total
m[b], d[b] = online_softmax_update(
m[b-1] if b > 0 else MIN_M,
d[b-1] if b > 0 else 0,
mm[(b + 1) * block_size - 1], # 当前块的 mm 和 dd
dd[(b + 1) * block_size - 1])
# PV: [N, N] * [N, Dim] -> [N, dim]
for c in range(0, Dim):
o = 0
for b in range(0, N //block_size):
# Calculate single block
od = 0
for i in range(0, block_size):
od = flashattn_2_update(
mm[b * block_size + i], # 当前迭代位置的 m
mm[b * block_size + i-1] if i > 0 else MIN_M,
od,
S[r, b * block_size + i], # 当前迭代位置的 s[r,i]
V[b * block_size + i, c],
1
)
# Merge all blocks to total
o = flashattn_2_update(
m[b], # 当前迭代 block 的 m
m[b - 1] if b > 0 else MIN_M,
o, # 上一个 block 的结果
mm[(b + 1) * block_size - 1], # m1
dd[(b + 1) * block_size - 1], # d1
od / dd[(b + 1) * block_size - 1], # od1
)
O[r][c] = o / d[b - 1] # 上一轮 block 的 d
return O
FlashAttention-2 的缺陷是对训练提速较多,对推理加速不大:主要是因为推理阶段查询的长度通常是 1,这意味着如果批量大小小于 GPU 上流处理器(SM)的数量(A100 GPU 上有 108 个 SM),那么 atttention 操作只能使用一小部分 GPU!尤其是在使用较长的上下文时,由于需要更小的批量大小以适应 GPU 内存,批量大小为 1 的情况下,FlashAttention 的 GPU 利用率不到 1%。
为了提高 attention 在推理阶段的计算速度,提出了 FlashAttention-3
。
Flash-Decoding 在前作对 batch size
和 query length
并行的基础上增加了一个新的并行化维度:keys/values
的序列长度,代价是最后一个小的归约步骤。
Flash-Decoding 的工作流程分为三个步骤:
- 首先,将键/值拆分成更小的块。
- 然后,使用 FlashAttention 并行计算查询与每个拆分块的注意力值,同时为每行和每个块记录一个额外的标量:注意力值的 log-sum-exp。
- 最后,通过对所有拆分块进行归约,结合 log-sum-exp 调整各个块的贡献,计算出最终的结果。
上述步骤之所以可行,是因为注意力/softmax 可以迭代计算(前作的贡献)。在 Flash-Decoding 中,它在两个层次上使用:在拆分内(类似于 FlashAttention)和跨拆分来执行最终归约。
实际上,步骤 (1) 不涉及任何 GPU 操作,因为键/值块是完整键/值张量的视图。接下来,我们有两个独立的内核分别执行步骤 (2) 和 (3)。
总结: Flash-Decoding 主要是针对 llm 推理的加速,在 batch_size 较小和序列长度较大时有着明显的加速效果,且性能对序列长度的增加并不敏感。
- Online normalizer calculation for softmax
- FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness
- FlashAttention-2:Faster Attention with Better Parallelism and Work Partitioning
- FlashAttention-3: Fast and Accurate Attention with Asynchrony and Low-precision
- FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness
- Chenfan Blog-FlashAttentions
- FlashAttentions 原理及代码实现
- Self Attention 固定激活值显存分析与优化及PyTorch实现
- 榨干 GPU 效能的 Flash Attention 3
- 图解大模型计算加速系列:FlashAttention V1,从硬件到计算逻辑
- FlashAttention 实现算子融合原理的可视化
- FlashAttention: Fast and Memory-Efficient Exact Attention With IO-Awareness