0x00 Materials
本文是对Flash Attention的学习笔记,其中有不少内容是摘自业内的前辈们的文章,在此一并感谢。所参考的资料、摘录的文章来源在下面列出:
-
From Online Softmax to FlashAttention(CSE599m, ML for ML System) ,本文的行文逻辑也是按照这篇文章来的。强烈安利CSE599m给入门ML System的新人。
-
FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness
-
FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning
-
[Attention优化][2w字]🔥原理&图解: 从Online-Softmax到FlashAttention V1/V2/V3
0x01 问题定义
$$ \text{Attention} = \text{Softmax}(\frac{QK^T}{\sqrt{d_k}})V $$
$$ Q,K,V \in \mathbb{R}^{N\times D} $$
其中$N$表示Sequence Length,$D$表示Dimension。我们先来考虑最简单的计算方式:
$$ S = QK^T $$
$$ P = \text{Softmax}(S) $$
$$ O = PV $$
在这个Naive的计算方式中,$P,S \in \mathbb R^{N\times N}$,这意味着为了计算P,我们需要多保存一个$N\times N$的矩阵,这个情况下内存的需求是$O(N^2)$的,很容易爆显存;且为了计算$S和P$势必需要从HBM中进行大量的读写操作,IO的访问次数是$O(N^2 + ND)$复杂度的。随着现在的Context Length需求越来越大,在$N$变大的时候,是很容易爆显存的。总结问题,主要有:
- Sequence Length($N$)越大,传统的Attention计算方法很容易爆显存。
- 传统的Attention计算方式对HBM的访问复杂度是平方级别的,越长的$N$,耗时越长。
考虑到HBM,Shared Memory的速度差异,我们希望能够减少HBM Access而将更多的IO Access操作放在Shared Memory中。
0x02 Online Softmax
我们再来看下Safe Softmax的逻辑:
$$ S_i = \frac{e^{x_i - M}}{\sum_{i=0}^N e^{x_i - M}}, M = \max{X},X \in \mathbb{R}^{N} $$ 。我们先用一个非常naive的思路来实现这个Softmax,这里使用From Online Softmax to FlashAttention文章中的伪代码来解释:
在这个简单的例子中,我们使用了三个循环来进行计算,这要求我们对$[1; N]$进行三次迭代。而Self-Attention中,因为SRAM放不下那么多的数据,所以我们需要三次访问$Q$和 $K$(并且重新计算),这在$I/O$效率上是不利的。
那么,有没有一种方法可以合并一些Pass,就像是我们经常在Kernel Fusion中做的那样呢?初看似乎困难,因为公式(8)依赖于公式(7)所得到的计算结果,但是,使用一些变换,可以允许我们以重计算一部分数据为代价来合并公式(7, 8)。
现在,我们来推导下公式,
$$ \begin{aligned} d_{i}^{\prime}& =\sum_{j=1}^ie^{x_j-m_i} \\ &= \left(\sum_{j=1}^{i-1} e^{x_j-m_i}\right)+e^{x_i-m_i} \\ &= \left(\sum_{j=1}^{i-1} e^{x_j-m_{i-1}}\right)e^{m_{i-1}-m_i}+e^{x_i-m_i} \\ &= d_{i-1}’ e^{m_{i-1}-m_i}+e^{x_i-m_i} \end{aligned} $$
我们可以得到一个递推的公式,其中$d_N^{\prime}$为最后我们需要的加和,即$\sum_{i=0}^{N}e^{x_i-m_N}$。在这个递推公式中,我们使用新的$m$来修正之前的$d_i^{\prime}$,之前错误的$m$可以通过幂相乘的计算规则消去。总的计算流程被缩减为2个Pass,如下图所示:
但是,这个计算方式还是有两个Pass,我们能不能将所有的计算Fuse到一个Pass中去呢?
在Online Softmax中很难做到这一点,因为$a_i$所需要的$m_N,d_N^{\prime}$依赖于全局更新。而$a_i$是一个无法全局更新的变量,除非在第一个Pass中再嵌套一个循环,这样违背了我们简化计算的初衷。但是,将问题放在Self-Attention的计算的时候,就变得不一样了。
我们在这里再理解下,为什么2Pass的Online Safe Softmax是重要的,在Self-Attention的计算中,我们有下面2个主要的问题:
- 需要提前计算好$QK^T$,保存在全局显存中,需要$O(N^2)$的显存,容易爆显存。
- 在算法中Online计算,每次循环中去加载一部分$Q,K$到片上内存,计算得到部分的$QK^T$。
总的来说,Online Softmax解决的是显存不足的问题,但是因为有两个Pass,还是存在HBM R/W次数较多,有Memory Bound,所以我们需要消除这个瓶颈。虽然现在我们需要对每一个$d_i^{\prime}$做Scale,但是考虑到目前显卡并不是Compute Bound,这多余的计算是可以暂时不去考虑的。
0x03 FA1
虽然在Online Softmax中,我们没有办法得到一个1 Pass的算法,但是在Self-Attention中,我们需要的是计算出$O=A\times V$,而不是$A$,这有什么不同呢?我们来推导下公式,不过首先,我们先来看一下原始的Self-Attention是怎么求解的:
这张未打码流程图仍然是从CSE 599m中借用的。可以看到,在第一个Pass中,就是0x02章节中提及的Online Softmax;在第二个Pass中,$o_i$的计算可能稍有点难以理解,可以画张图。实际上就是遍历$a_i$就是$\text{Attention}$矩阵的一行,拿每一行的每个值$a_i$去乘$V$矩阵的每一行,就是行乘列操作。这个操作可以同时把$O$矩阵的一行给算出来。
$$o_i^{\prime}:=\left(\sum_{j=1}^i\frac{e^{x_j-m_i}}{d_i^{\prime}}V[j,:]\right)$$
上面的公式就是把Pass2内部的计算整合在了一起,和0x02章节的推导一样,我们也去尝试做递推:
$$ \begin{aligned} o_i^{\prime}& =\sum_{j=1}^i\frac{e^{x_j-m_i}}{d_i'}V[j,:] \\ &= \left(\sum_{j=1}^{i-1}\frac{e^{x_j-m_i}}{d_i'}V[j,:] \right)+\frac{e^{x_i-m_i}}{d_i'}V[i,: ] \\ &= \left(\sum_{j=1}^{i-1}\frac{e^{x_j-m_{i-1}}}{d_{i-1}^{\prime}}\frac{e^{x_j-m_i}}{e^{x_j-m_{i-1}}}\frac{d_{i-1}^{\prime}}{d_i^{\prime}}V[j,.]\right)+\frac{e^{x_i-m_i}}{d_i^{\prime}}V[i,.] \\ &= \left(\sum_{j=1}^{i-1}\frac{e^{x_j-m_{i-1}}}{d_{i-1}^{\prime}}V[j,:]\right)\frac{d_{i-1}^{\prime}}{d_i^{\prime}}e^{m_{i-1}-m_i}+\frac{e^{x_i-m_i}}{d_i^{\prime}}V[i,:] \\ &=\begin{array}{c}\boldsymbol{o}_{i-1}^{\prime}\frac{d_{i-1}^{\prime}e^{m_{i-1}-m_i}}{d_i^{\prime}}+\frac{e^{x_i-m_i}}{d_i^{\prime}}V[i,:]\end{array} \end{aligned} $$
可以推导出和Online Softmax相似的形式,至此,我们推导出了FA算法。
可以看出,在FA算法中,$Q,K,V$都可以分块载入,我们可以进一步得到FA的Tiling方法:
在这种改进的Tiling技术中,K矩阵被划分为多个较小的区块,同样的方法也适用于Q矩阵。这些较小的区块可以被加载到SRAM中,以便于进行高效的计算。一旦这些区块被加载,就可以在kernel内部完成整个注意力机制的计算过程。从算法的角度来看,现在只需要一次性加载Q、K、V矩阵,就能在内核中完成所有的注意力计算。这种优化方法将原始3-pass Self Attention转变为1-pass FlashAttention,不仅节省了存储中间矩阵所需的显存,还减少了对Q和K矩阵的HBM R/W的次数。
最终,FA的算法可以被下面的伪代码来表示:
此时,我们再看FA的算法流程图,就不感觉陌生了。和上文中的推导思路一致:
在第6行,FA载入$K,V$分块,然后在第8行遍历完成所有的$Q$(这里有个显而易见的问题,$Q$的遍历放在最外面会好很多)。我们在这里再探讨下为什么分块$B_c=\lceil \frac{M}{4d} \rceil, B_r=\min (\lceil \frac{M}{4d} \rceil, d)$。
这样设置的目的是,为了确保SRAM能够放下所有$Q, K, V$的小块,其中$M$就是系统可用的SRAM上限。那么,对于每一个$Q$的分块$Q_i,O_i$以及$K, V$的分块$K_i, V_i$需要的共享内存为:
$$ \begin{gathered} SRAM(Q_{i})=B_{r}\times d=\min\left(\left\lceil\frac{M}{4d}\right\rceil,d\right)\times d<\lceil\frac{M}{4}\rceil \\ SRAM(O_i)=B_r\times d=\min\left(\left\lceil\frac{M}{4d}\right\rceil,d\right)\times d<\lceil\frac{M}{4}\rceil \\ SRAM(K_{j},V_{j})=2\times B_{c}\times d=2\times\left\lceil\frac{M}{4d}\right\rceil\times d<\lceil\frac{M}{2}\rceil \end{gathered} $$
在这个情况下,SRAM基本上可以被占满。FA1原始论文中说道,Block Size 越大,HBM Accesses 越低,在256附近基本就是效率最优的转折点。
文中的实验条件是A100GPU,GPT-2 medium (seq. length 1024, head dim. 64, 16 heads, batch size 64)
0x04 FA2
在0x03章节中我们提到:然后在第8行遍历完成所有的$Q$(这里有个显而易见的问题,$Q$的遍历放在最外面会好很多),这点就是FA2优化的很重要的一点。
FA2一共做了主要的几种优化:
-
优化了Scale的时机,使得除法的次数被大大减少
-
Forward优化了循环的顺序,使得HBM Access更加的高效。Backward没有
-
Forward/Backward均增加了Seq维度的并行
-
Warp的分配更加的合理,避免Split-K(不是很理解?)
优化了Scale的时机,使得除法的次数被大大减少
虽然一般来说,非matmul运算FLOPs要比matmul低,但是非matmul计算使用的是CUDA Cores,而矩阵计算可以利用Tensor Cores加速。基于Tensor Cores的matmul运算吞吐是不使用Tensor Cores的非matmul运算吞吐的16x。
与FA1相比,FA2的主要不同点是计算每一次的$\boldsymbol{O}^{(n)}$的逻辑,这里以$\boldsymbol{O}^{(1)},\boldsymbol{O}^{(2)}$为例来说明,在FA2中:
$$ \begin{gathered} \tilde{\mathbf{o}}^{(1)} =e^{s^{(1)}-m^{(1)}}\mathbf{V}^{(1)}\in\mathbb{R}^{B_{r}\times d} \\ \tilde{\mathrm{o}}^{(2)} =e^{s^{(1)}-m}\mathbf{V}^{(1)}+e^{s^{(2)}-m}\mathbf{V}^{(2)} \\ \mathrm{o}^{(2)} =\mathrm{diag}\left(\ell^{(2)}\right)^{-1}\tilde{\mathbf{O}}^{(2)}=\mathbf{O} \end{gathered} $$
其中,$\tilde{\mathrm{o}}^{(2)} =e^{s^{(1)}-m}\mathbf{V}^{(1)}+e^{s^{(2)}-m}\mathbf{V}^{(2)}$在计算的时候,$e^{s^{(1)}-m}\mathbf{V}^{(1)}$这一项是对$\tilde{\mathbf{o}}^{(1)}$做了缩放,缩放因子是$e^{m^{(1)} - m}$。也就是:
$$\tilde{\mathrm{o}}^{(2)} = e^{m^{(1)} - m} \tilde{\mathbf{o}}^{(1)} +e^{s^{(2)}-m}\mathbf{V}^{(2)}$$
相比于原来的FA1,我们首先计算Softmax的分子部分,在最后才算上分母。这样减少了每次迭代而必须的分母缩放。而原本的FA1的计算过程如下式所示:
$$ \mathbf{O}_{i}\leftarrow\mathrm{diag}\left(\ell_{i}^{\mathrm{new}}\right)^{-1}\left(\mathrm{diag}(\ell_{i})e^{{m_{i}-m_{i}^{\mathrm{new}}}}\mathbf{O}_{i}+e^{{\tilde{m}_{ij}-m_{i}^{\mathrm{new}}}}\mathbf{\tilde{P}}_{ij}\mathbf{V}_{j}\right) $$
FA2的计算中,先不在每个block的每次迭代计算中执行全部的rescale操作,而是最后执行一次rescale。每次计算可以减少一次除法运算。
可以看到在原文的伪代码中,在$T_c$循环结束后,才去做了分母上的计算。
第十行的$\text{diag}^{-1}$是错的,把$^{-1}$去掉。
优化了循环的顺序,增加了Seq维度的并行
FA1的两重循环中,是先外层循环load K, V,然后内层循环再load Q。这就会导致内层循环,每次计算的只是Qi的一部分,每次内循环的迭代都需要对Oi进行全局内存的读写。而且,一个显而易见的事实就是,在Attention的计算中,不同query的Attention计算是完全独立的。也就是说,如果外部循环是先load Q,那么就可以把不同的query块的Attention分配不同thread block进行计算,这些thread block之间是不需要通信的。没错,在FA2中,正是这样做的,对于forward pass,算法调换了循环的顺序,先load Q,再load K, V。
FA2增加seqlen并行,提高了occupancy,并且对于forward pass,Q*K^T在【行】方向的seqlen上天然可以并行,thread block之间不需要额外的通信。
Warp的分配更加的合理,避免Split-K
摘自 FlashAttention核心逻辑以及V1 V2差异总结
首先看fwd,相比V1,V2改进了Warp Partition:4个warp会从smem的K/V tile load同样的数据做mma计算,但是load 不同Q,把V1 sliced-K sliced-V 改成了v2 sliced-Q,V1的做法是需要warp之间产生同步通信的,因为在计算QK结果乘V的时候,如图所示需要跨warp reduction得到O的结果,而且fwd的目的是沿着行方向计算softmax,行方向信息最后要汇总的,这也需要跨warp不同。V2就不需要了,这样可以减少同步开销。
0x05 Causal Mask怎么用?
摘自 [Attention优化][2w字]🔥原理&图解: 从Online-Softmax到FlashAttention V1/V2/V3
非常简单的Early Exit逻辑:
情况0: 全Early Exit。全0的mask可以直接返回0,无需$Q\times K^T$,无需causal mask。
情况1: 部分Early Exit。全1的mask,只需$\text{Softmax}(Q\times K^T)$,无需causal mask。
情况3: 无法Early Exit。0-1混合的causal mask,需QxK^T,需要causal mask,然后$\text{Softmax}(\text{Mask}(Q \times K^T))$。
0x06 MHA/GQA/MQA
在FlashAttention中,也支持MQA和GQA。对于MQA和GQA的情形,FlashAttention采用Indexing的方式,而不是直接复制多份KV Head的内容到显存然后再进行计算。Indexing,即通过传入KV/KV Head索引到Kernel中,然后计算内存地址,直接从内存中读取KV。
0x07 IO复杂度分析
因为FA主要是优化IO Acces,所以我们分析下FA的IO复杂度。我们假设Sequence的长度是$N$,每个头的维度是$d$,SRAM的大小是$M,d \le M \le Nd$。
使用原始的Self Attention算法的IO复杂度是$\Theta(Nd + N^2)$,FA1的IO复杂度是$\Theta(N^2d^2M^{-1})$,考虑到$d$一般是64-128,而$M$一般是100KB,所以FA1的访存次数小于原始的做法。
Memory Accesses和d的平方成正比关系,当d越大,FA的Memory Accesses会增长剧烈。比如对于N=2K, M=192KB, 当d=256时,依然满足 FA IO Acesses < Naive Attention,但是当d=512时,这个结论就会反过来,变成是 FA IO Acesses > Naive Attention IO Acesses,并且由于FA本身的FLOPS就是比Naive Attention高的,于是,此时无论是IO还是FLOPS,FA都会比Naive Attention高,无论是访存还是计算量都没有优势,唯一剩下的优势,应该就只剩节省显存了(不需要保存中间的S和P矩阵,O(N^2)的内存复杂度)
0x08 Triton代码
先再来复习下Block是怎么切块的,这里的图摘自BBuf的 笔记图解大模型计算加速系列:Flash Attention V2,从原理到并行计算。
增加了Seq维度的并行以后:
与V1不同的是,我们在Q的seq_len维度上也做了切分,将其分成四份,即num_m_block = 4。所以现在我们共有1_2_4 = 8个block在跑。这些block之间的运算也是独立的, 因为:
- head的计算是独立的,所以红色block和蓝色block互不干扰
- 采用Q做外循环,KV做内循环时,行与行之间的block是独立的,因此不同行的block互相不干扰。
每个block从Q上加载对应位置的切块,同时从KV上加载head0的切块,计算出自己所维护的那部分O,然后写入O的对应位置。
我们使用OpenAI Triton的FA2 Tutorial代码来分析。
下面的代码是每一个子Block中的最内层的代码,其中q
是最外层循环的子块;K_block_ptr
、V_block_ptr
是$K$、$V$的子块,需要一次for循环完整的遍历。
@triton.jit
def _attn_fwd_inner(acc, l_i, m_i, q, #
K_block_ptr, V_block_ptr, #
start_m, qk_scale, #
BLOCK_M: tl.constexpr, HEAD_DIM: tl.constexpr, BLOCK_N: tl.constexpr, #
STAGE: tl.constexpr, offs_m: tl.constexpr, offs_n: tl.constexpr, #
N_CTX: tl.constexpr, fp8_v: tl.constexpr):
# range of values handled by this stage
# 根据STAGE的值,函数定义了处理的键(K)和值(V)的范围。
# 不同的STAGE对应不同的处理范围,支持因果(causal)和非因果(non-causal)的自注意力。
if STAGE == 1: # 使用 Mask
lo, hi = 0, start_m * BLOCK_M
elif STAGE == 2: # 使用 Mask
lo, hi = start_m * BLOCK_M, (start_m + 1) * BLOCK_M
lo = tl.multiple_of(lo, BLOCK_M)
# causal = False,不使用 Mask
else:
lo, hi = 0, N_CTX
# tl.advance 根据步长调整K_block_ptr的指向
K_block_ptr = tl.advance(K_block_ptr, (0, lo))
V_block_ptr = tl.advance(V_block_ptr, (lo, 0))
# 对K,V Block做完整的遍历
for start_n in range(lo, hi, BLOCK_N):
start_n = tl.multiple_of(start_n, BLOCK_N)
# -- compute qk ----
# 加载 K Block
k = tl.load(K_block_ptr)
# 伪代码 line8: q x k
qk = tl.dot(q, k)
if STAGE == 2:
# Mask
mask = offs_m[:, None] >= (start_n + offs_n[None, :])
# Mask 区域加上 -INF
qk = qk * qk_scale + tl.where(mask, 0, -1.0e6)
# 伪代码 line 9: Safe online softmax 的 max
m_ij = tl.maximum(m_i, tl.max(qk, 1))
# 伪代码 line 9: s - m
qk -= m_ij[:, None]
else:
# 伪代码 line 9: Safe online softmax 的 max,和伪代码的区别是这里有 qk_scale,稍后解释
m_ij = tl.maximum(m_i, tl.max(qk, 1) * qk_scale)
# 伪代码 line 9: s - m. 和伪代码的区别是这里有 qk_scale,稍后解释
qk = qk * qk_scale - m_ij[:, None]
# 伪代码 line 9: p = exp(s-m)
p = tl.math.exp2(qk)
# 伪代码 line 9: rowsum(p)
l_ij = tl.sum(p, 1)
# -- update m_i and l_i
# 伪代码 line 10
alpha = tl.math.exp2(m_i - m_ij)
l_i = l_i * alpha + l_ij
# -- update output accumulator --
# 伪代码 line 10: 这里的 acc 是伪代码中的 O_i
acc = acc * alpha[:, None]
# update acc
v = tl.load(V_block_ptr)
if fp8_v:
p = p.to(tl.float8e5)
else:
p = p.to(tl.float16)
# 伪代码 line 10.
acc = tl.dot(p, v, acc)
# update m_i and l_i
m_i = m_ij
# 更新下一轮的 K,V Block的指针
V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0))
K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N))
return acc, l_i, m_i
下面我们来看一下调用这个子块函数的函数。
@triton.autotune(list(filter(keep, configs)), key=["N_CTX", "HEAD_DIM"])
@triton.jit
def _attn_fwd(Q, K, V, sm_scale, M, Out, #
stride_qz, stride_qh, stride_qm, stride_qk, #
stride_kz, stride_kh, stride_kn, stride_kk, #
stride_vz, stride_vh, stride_vk, stride_vn, #
stride_oz, stride_oh, stride_om, stride_on, #
Z, H, N_CTX, #
HEAD_DIM: tl.constexpr, #
BLOCK_M: tl.constexpr, #
BLOCK_N: tl.constexpr, #
STAGE: tl.constexpr #
):
tl.static_assert(BLOCK_N <= HEAD_DIM)
# 输入参数里的Z和H分别表示batch size和注意力头数
# q.shape is [Batch, Head, Seq, Dim]
# 启动的时候 [grid] 是
# grid = lambda args: (triton.cdiv(q.shape[2], args["BLOCK_M"]), q.shape[0] * q.shape[1], 1)
# start_m表示当前kernel program 实例对应的seq维度的偏移,而off_hz表示的是batch*heads维度的偏移。
start_m = tl.program_id(0) # seq
off_hz = tl.program_id(1) # batch * heads
# 这两行计算了两个偏移量off_z和off_h,它们分别代表在batch(或heads)中的位置。
off_z = off_hz // H # 表示在哪个 Batch
off_h = off_hz % H # 表示在哪个 Head
# 计算用于定位Q、K和V张量中当前处理块的偏移量。这是基于先前计算的偏移量和提供的步长参数。
qvk_offset = off_z.to(tl.int64) * stride_qz + off_h.to(tl.int64) * stride_qh
# block pointers
# 使用tl.make_block_ptr创建一个指向Q张量当前处理块的指针。这个函数调用指定了基础地址、形状、步长、偏移量和块形状等,以及如何在内存中访问这个数据块。
# N_CTX 是q.shape[2],表示的是序列长度,BLOCK_DMODEL是Lk,表示的是每个注意力头的隐藏层维度大小
# 下面几个make_block_ptr创建的张量类似,分别是对K,V以及输出O创建指向当前处理块的指针
Q_block_ptr = tl.make_block_ptr(
base=Q + qvk_offset,
shape=(N_CTX, HEAD_DIM),
strides=(stride_qm, stride_qk),
offsets=(start_m * BLOCK_M, 0),
block_shape=(BLOCK_M, HEAD_DIM),
order=(1, 0),
)
v_order: tl.constexpr = (0, 1) if V.dtype.element_ty == tl.float8e5 else (1, 0)
V_block_ptr = tl.make_block_ptr(
base=V + qvk_offset,
shape=(N_CTX, HEAD_DIM),
strides=(stride_vk, stride_vn),
offsets=(0, 0),
block_shape=(BLOCK_N, HEAD_DIM),
order=v_order,
)
K_block_ptr = tl.make_block_ptr(
base=K + qvk_offset,
shape=(HEAD_DIM, N_CTX),
strides=(stride_kk, stride_kn),
offsets=(0, 0),
block_shape=(HEAD_DIM, BLOCK_N),
order=(0, 1),
)
O_block_ptr = tl.make_block_ptr(
base=Out + qvk_offset,
shape=(N_CTX, HEAD_DIM),
strides=(stride_om, stride_on),
offsets=(start_m * BLOCK_M, 0),
block_shape=(BLOCK_M, HEAD_DIM),
order=(1, 0),
)
# initialize offsets
# 计算M维度(seq维度)上每个线程应处理的元素的起始偏移量。
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
# 计算N维度(batch*heads维度)上每个线程应处理的元素的偏移量。
offs_n = tl.arange(0, BLOCK_N)
# initialize pointer to m and l
# 初始化m向量,m用于存储每个m维度上的最大logit,初始化为负无穷大。
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
# 初始化l向量,l用于累计softmax的分母,初始化为1。
l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + 1.0
# 初始化累加器,用于累积注意力加权和。注意这里的shape是(BLOCK_M, BLOCK_DMODEL)
acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32)
# load scales
qk_scale = sm_scale
qk_scale *= 1.44269504 # 1/log(2)
# load q: it will stay in SRAM throughout
# 将Q矩阵的当前块加载到SRAM中,此数据在整个计算过程中保持不变。
q = tl.load(Q_block_ptr)
# stage 1: off-band
# For causal = True, STAGE = 3 and _attn_fwd_inner gets 1 as its STAGE
# For causal = False, STAGE = 1, and _attn_fwd_inner gets 3 as its STAGE
if STAGE & 1:
acc, l_i, m_i = _attn_fwd_inner(acc, l_i, m_i, q, K_block_ptr, V_block_ptr, #
start_m, qk_scale, #
BLOCK_M, HEAD_DIM, BLOCK_N, #
4 - STAGE, offs_m, offs_n, N_CTX, V.dtype.element_ty == tl.float8e5 #
)
# stage 2: on-band
if STAGE & 2:
# barrier makes it easier for compielr to schedule the
# two loops independently
acc, l_i, m_i = _attn_fwd_inner(acc, l_i, m_i, q, K_block_ptr, V_block_ptr, #
start_m, qk_scale, #
BLOCK_M, HEAD_DIM, BLOCK_N, #
2, offs_m, offs_n, N_CTX, V.dtype.element_ty == tl.float8e5 #
)
# epilogue
m_i += tl.math.log2(l_i)
acc = acc / l_i[:, None]
m_ptrs = M + off_hz * N_CTX + offs_m
tl.store(m_ptrs, m_i)
tl.store(O_block_ptr, acc.to(Out.type.element_ty))
需要特别注意的是这段代码最后的epilogue部分就对应了FlashAttention V2伪代码中的12行以后的内容,根据softmax的分母部分较正输出。此外,Triton的实现里面考虑了一些paper里面没有的东西比如qk_scale
,causal mask
,对Q*K
的结果S
应用了减掉m,使得整个实现看起来要复杂不少,但整体的算法逻辑和并行设置和paper还是一致的。
最后在Attention中使用这个函数
class _attention(torch.autograd.Function):
@staticmethod
def forward(ctx, q, k, v, causal, sm_scale):
# shape constraints
HEAD_DIM_Q, HEAD_DIM_K = q.shape[-1], k.shape[-1]
# when v is in float8_e5m2 it is transposed.
HEAD_DIM_V = v.shape[-1]
assert HEAD_DIM_Q == HEAD_DIM_K and HEAD_DIM_K == HEAD_DIM_V
assert HEAD_DIM_K in {16, 32, 64, 128, 256}
o = torch.empty_like(q)
stage = 3 if causal else 1
extra_kern_args = {}
# Tuning for AMD target
if is_hip():
waves_per_eu = 3 if HEAD_DIM_K <= 64 else 2
extra_kern_args = {"waves_per_eu": waves_per_eu, "allow_flush_denorm": True}
# q.shape is [Batch, Head, Seq, Dim]
grid = lambda args: (triton.cdiv(q.shape[2], args["BLOCK_M"]), q.shape[0] * q.shape[1], 1)
M = torch.empty((q.shape[0], q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
# Launch Kernel.
_attn_fwd[grid](
q, k, v, sm_scale, M, o, #
q.stride(0), q.stride(1), q.stride(2), q.stride(3), #
k.stride(0), k.stride(1), k.stride(2), k.stride(3), #
v.stride(0), v.stride(1), v.stride(2), v.stride(3), #
o.stride(0), o.stride(1), o.stride(2), o.stride(3), #
q.shape[0], q.shape[1], #
N_CTX=q.shape[2], #
HEAD_DIM=HEAD_DIM_K, #
STAGE=stage, #
**extra_kern_args)
ctx.save_for_backward(q, k, v, o, M)
ctx.grid = grid
ctx.sm_scale = sm_scale
ctx.HEAD_DIM = HEAD_DIM_K
ctx.causal = causal
return o
0x09 CUDA代码
0x0A FA 3
0x0B 思考
- CPU上使用这个靠谱吗?CPU上并行度较低,用这个没有必要,但是可以考虑分块和Mask混合的MatMul来减少计算量,也就是Early Exit。