0x00 Materials

本文是对Flash Attention的学习笔记,其中有不少内容是摘自业内的前辈们的文章,在此一并感谢。所参考的资料、摘录的文章来源在下面列出:

  1. From Online Softmax to FlashAttention(CSE599m, ML for ML System) ,本文的行文逻辑也是按照这篇文章来的。强烈安利CSE599m给入门ML System的新人。

  2. FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness

  3. FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning

  4. 【BBuf的CUDA笔记】十四,OpenAI Triton入门笔记三 FusedAttention

  5. [Attention优化][2w字]🔥原理&图解: 从Online-Softmax到FlashAttention V1/V2/V3

0x01 问题定义

$$ \text{Attention} = \text{Softmax}(\frac{QK^T}{\sqrt{d_k}})V $$

$$ Q,K,V \in \mathbb{R}^{N\times D} $$

其中$N$表示Sequence Length,$D$表示Dimension。我们先来考虑最简单的计算方式:

$$ S = QK^T $$

$$ P = \text{Softmax}(S) $$

$$ O = PV $$

在这个Naive的计算方式中,$P,S \in \mathbb R^{N\times N}$,这意味着为了计算P,我们需要多保存一个$N\times N$的矩阵,这个情况下内存的需求是$O(N^2)$的,很容易爆显存;且为了计算$S和P$势必需要从HBM中进行大量的读写操作,IO的访问次数是$O(N^2 + ND)$复杂度的。随着现在的Context Length需求越来越大,在$N$变大的时候,是很容易爆显存的。总结问题,主要有:

  1. Sequence Length($N$)越大,传统的Attention计算方法很容易爆显存。
  2. 传统的Attention计算方式对HBM的访问复杂度是平方级别的,越长的$N$,耗时越长。
From FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness
HBM / Shared Mem IO Bandwidth

From FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness

考虑到HBM,Shared Memory的速度差异,我们希望能够减少HBM Access而将更多的IO Access操作放在Shared Memory中。

0x02 Online Softmax

我们再来看下Safe Softmax的逻辑:

$$ S_i = \frac{e^{x_i - M}}{\sum_{i=0}^N e^{x_i - M}}, M = \max{X},X \in \mathbb{R}^{N} $$ 。我们先用一个非常naive的思路来实现这个Softmax,这里使用From Online Softmax to FlashAttention文章中的伪代码来解释:

From From Online Softmax to FlashAttention(CSE599m, ML for ML System)
Online Softmax 伪代码

From From Online Softmax to FlashAttention(CSE599m, ML for ML System)

在这个简单的例子中,我们使用了三个循环来进行计算,这要求我们对$[1; N]$进行三次迭代。而Self-Attention中,因为SRAM放不下那么多的数据,所以我们需要三次访问$Q$和 $K$(并且重新计算),这在$I/O$效率上是不利的。

那么,有没有一种方法可以合并一些Pass,就像是我们经常在Kernel Fusion中做的那样呢?初看似乎困难,因为公式(8)依赖于公式(7)所得到的计算结果,但是,使用一些变换,可以允许我们以重计算一部分数据为代价来合并公式(7, 8)。

现在,我们来推导下公式,

$$ \begin{aligned} d_{i}^{\prime}& =\sum_{j=1}^ie^{x_j-m_i} \\ &= \left(\sum_{j=1}^{i-1} e^{x_j-m_i}\right)+e^{x_i-m_i} \\ &= \left(\sum_{j=1}^{i-1} e^{x_j-m_{i-1}}\right)e^{m_{i-1}-m_i}+e^{x_i-m_i} \\ &= d_{i-1}’ e^{m_{i-1}-m_i}+e^{x_i-m_i} \end{aligned} $$

我们可以得到一个递推的公式,其中$d_N^{\prime}$为最后我们需要的加和,即$\sum_{i=0}^{N}e^{x_i-m_N}$。在这个递推公式中,我们使用新的$m$来修正之前的$d_i^{\prime}$,之前错误的$m$可以通过幂相乘的计算规则消去。总的计算流程被缩减为2个Pass,如下图所示:

From From Online Softmax to FlashAttention(CSE599m, ML for ML System)
Online Softmax 2 passes 伪代码

From From Online Softmax to FlashAttention(CSE599m, ML for ML System)

但是,这个计算方式还是有两个Pass,我们能不能将所有的计算Fuse到一个Pass中去呢?

在Online Softmax中很难做到这一点,因为$a_i$所需要的$m_N,d_N^{\prime}$依赖于全局更新。而$a_i$是一个无法全局更新的变量,除非在第一个Pass中再嵌套一个循环,这样违背了我们简化计算的初衷。但是,将问题放在Self-Attention的计算的时候,就变得不一样了。

我们在这里再理解下,为什么2Pass的Online Safe Softmax是重要的,在Self-Attention的计算中,我们有下面2个主要的问题:

  1. 需要提前计算好$QK^T$,保存在全局显存中,需要$O(N^2)$的显存,容易爆显存。
  2. 在算法中Online计算,每次循环中去加载一部分$Q,K$到片上内存,计算得到部分的$QK^T$。

总的来说,Online Softmax解决的是显存不足的问题,但是因为有两个Pass,还是存在HBM R/W次数较多,有Memory Bound,所以我们需要消除这个瓶颈。虽然现在我们需要对每一个$d_i^{\prime}$做Scale,但是考虑到目前显卡并不是Compute Bound,这多余的计算是可以暂时不去考虑的。

0x03 FA1

虽然在Online Softmax中,我们没有办法得到一个1 Pass的算法,但是在Self-Attention中,我们需要的是计算出$O=A\times V$,而不是$A$,这有什么不同呢?我们来推导下公式,不过首先,我们先来看一下原始的Self-Attention是怎么求解的:

From From Online Softmax to FlashAttention(CSE599m, ML for ML System)
原始的Self-Attention 伪代码

From From Online Softmax to FlashAttention(CSE599m, ML for ML System)

这张未打码流程图仍然是从CSE 599m中借用的。可以看到,在第一个Pass中,就是0x02章节中提及的Online Softmax;在第二个Pass中,$o_i$的计算可能稍有点难以理解,可以画张图。实际上就是遍历$a_i$就是$\text{Attention}$矩阵的一行,拿每一行的每个值$a_i$去乘$V$矩阵的每一行,就是行乘列操作。这个操作可以同时把$O$矩阵的一行给算出来。

$$o_i^{\prime}:=\left(\sum_{j=1}^i\frac{e^{x_j-m_i}}{d_i^{\prime}}V[j,:]\right)$$

上面的公式就是把Pass2内部的计算整合在了一起,和0x02章节的推导一样,我们也去尝试做递推:

$$ \begin{aligned} o_i^{\prime}& =\sum_{j=1}^i\frac{e^{x_j-m_i}}{d_i'}V[j,:] \\ &= \left(\sum_{j=1}^{i-1}\frac{e^{x_j-m_i}}{d_i'}V[j,:] \right)+\frac{e^{x_i-m_i}}{d_i'}V[i,: ] \\ &= \left(\sum_{j=1}^{i-1}\frac{e^{x_j-m_{i-1}}}{d_{i-1}^{\prime}}\frac{e^{x_j-m_i}}{e^{x_j-m_{i-1}}}\frac{d_{i-1}^{\prime}}{d_i^{\prime}}V[j,.]\right)+\frac{e^{x_i-m_i}}{d_i^{\prime}}V[i,.] \\ &= \left(\sum_{j=1}^{i-1}\frac{e^{x_j-m_{i-1}}}{d_{i-1}^{\prime}}V[j,:]\right)\frac{d_{i-1}^{\prime}}{d_i^{\prime}}e^{m_{i-1}-m_i}+\frac{e^{x_i-m_i}}{d_i^{\prime}}V[i,:] \\ &=\begin{array}{c}\boldsymbol{o}_{i-1}^{\prime}\frac{d_{i-1}^{\prime}e^{m_{i-1}-m_i}}{d_i^{\prime}}+\frac{e^{x_i-m_i}}{d_i^{\prime}}V[i,:]\end{array} \end{aligned} $$

可以推导出和Online Softmax相似的形式,至此,我们推导出了FA算法。

From From Online Softmax to FlashAttention(CSE599m, ML for ML System)
FA1 伪代码

From From Online Softmax to FlashAttention(CSE599m, ML for ML System)

可以看出,在FA算法中,$Q,K,V$都可以分块载入,我们可以进一步得到FA的Tiling方法:

From From Online Softmax to FlashAttention(CSE599m, ML for ML System)
FA1 Tiled

From From Online Softmax to FlashAttention(CSE599m, ML for ML System)

在这种改进的Tiling技术中,K矩阵被划分为多个较小的区块,同样的方法也适用于Q矩阵。这些较小的区块可以被加载到SRAM中,以便于进行高效的计算。一旦这些区块被加载,就可以在kernel内部完成整个注意力机制的计算过程。从算法的角度来看,现在只需要一次性加载Q、K、V矩阵,就能在内核中完成所有的注意力计算。这种优化方法将原始3-pass Self Attention转变为1-pass FlashAttention,不仅节省了存储中间矩阵所需的显存,还减少了对Q和K矩阵的HBM R/W的次数。

最终,FA的算法可以被下面的伪代码来表示:

From From Online Softmax to FlashAttention(CSE599m, ML for ML System)
FA1 tiled 伪代码

From From Online Softmax to FlashAttention(CSE599m, ML for ML System)

此时,我们再看FA的算法流程图,就不感觉陌生了。和上文中的推导思路一致:

From FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness
FA1 原文 伪代码

From FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness

在第6行,FA载入$K,V$分块,然后在第8行遍历完成所有的$Q$(这里有个显而易见的问题,$Q$的遍历放在最外面会好很多)。我们在这里再探讨下为什么分块$B_c=\lceil \frac{M}{4d} \rceil, B_r=\min (\lceil \frac{M}{4d} \rceil, d)$。

这样设置的目的是,为了确保SRAM能够放下所有$Q, K, V$的小块,其中$M$就是系统可用的SRAM上限。那么,对于每一个$Q$的分块$Q_i,O_i$以及$K, V$的分块$K_i, V_i$需要的共享内存为:

$$ \begin{gathered} SRAM(Q_{i})=B_{r}\times d=\min\left(\left\lceil\frac{M}{4d}\right\rceil,d\right)\times d<\lceil\frac{M}{4}\rceil \\ SRAM(O_i)=B_r\times d=\min\left(\left\lceil\frac{M}{4d}\right\rceil,d\right)\times d<\lceil\frac{M}{4}\rceil \\ SRAM(K_{j},V_{j})=2\times B_{c}\times d=2\times\left\lceil\frac{M}{4d}\right\rceil\times d<\lceil\frac{M}{2}\rceil \end{gathered} $$

在这个情况下,SRAM基本上可以被占满。FA1原始论文中说道,Block Size 越大,HBM Accesses 越低,在256附近基本就是效率最优的转折点。

From FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness
FA1 Block Size 实验

From FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness

文中的实验条件是A100GPU,GPT-2 medium (seq. length 1024, head dim. 64, 16 heads, batch size 64)

0x04 FA2

在0x03章节中我们提到:然后在第8行遍历完成所有的$Q$(这里有个显而易见的问题,$Q$的遍历放在最外面会好很多),这点就是FA2优化的很重要的一点。

FA2一共做了主要的几种优化:

  1. 优化了Scale的时机,使得除法的次数被大大减少

  2. Forward优化了循环的顺序,使得HBM Access更加的高效。Backward没有

  3. Forward/Backward均增加了Seq维度的并行

  4. Warp的分配更加的合理,避免Split-K(不是很理解?)

优化了Scale的时机,使得除法的次数被大大减少

虽然一般来说,非matmul运算FLOPs要比matmul低,但是非matmul计算使用的是CUDA Cores,而矩阵计算可以利用Tensor Cores加速。基于Tensor Cores的matmul运算吞吐是不使用Tensor Cores的非matmul运算吞吐的16x。

与FA1相比,FA2的主要不同点是计算每一次的$\boldsymbol{O}^{(n)}$的逻辑,这里以$\boldsymbol{O}^{(1)},\boldsymbol{O}^{(2)}$为例来说明,在FA2中:

$$ \begin{gathered} \tilde{\mathbf{o}}^{(1)} =e^{s^{(1)}-m^{(1)}}\mathbf{V}^{(1)}\in\mathbb{R}^{B_{r}\times d} \\ \tilde{\mathrm{o}}^{(2)} =e^{s^{(1)}-m}\mathbf{V}^{(1)}+e^{s^{(2)}-m}\mathbf{V}^{(2)} \\ \mathrm{o}^{(2)} =\mathrm{diag}\left(\ell^{(2)}\right)^{-1}\tilde{\mathbf{O}}^{(2)}=\mathbf{O} \end{gathered} $$

其中,$\tilde{\mathrm{o}}^{(2)} =e^{s^{(1)}-m}\mathbf{V}^{(1)}+e^{s^{(2)}-m}\mathbf{V}^{(2)}$在计算的时候,$e^{s^{(1)}-m}\mathbf{V}^{(1)}$这一项是对$\tilde{\mathbf{o}}^{(1)}$做了缩放,缩放因子是$e^{m^{(1)} - m}$。也就是:

$$\tilde{\mathrm{o}}^{(2)} = e^{m^{(1)} - m} \tilde{\mathbf{o}}^{(1)} +e^{s^{(2)}-m}\mathbf{V}^{(2)}$$

相比于原来的FA1,我们首先计算Softmax的分子部分,在最后才算上分母。这样减少了每次迭代而必须的分母缩放。而原本的FA1的计算过程如下式所示:

$$ \mathbf{O}_{i}\leftarrow\mathrm{diag}\left(\ell_{i}^{\mathrm{new}}\right)^{-1}\left(\mathrm{diag}(\ell_{i})e^{{m_{i}-m_{i}^{\mathrm{new}}}}\mathbf{O}_{i}+e^{{\tilde{m}_{ij}-m_{i}^{\mathrm{new}}}}\mathbf{\tilde{P}}_{ij}\mathbf{V}_{j}\right) $$

FA2的计算中,先不在每个block的每次迭代计算中执行全部的rescale操作,而是最后执行一次rescale。每次计算可以减少一次除法运算。

From From Online Softmax to FlashAttention(CSE599m, ML for ML System)
FA2 伪代码

From From Online Softmax to FlashAttention(CSE599m, ML for ML System)

可以看到在原文的伪代码中,在$T_c$循环结束后,才去做了分母上的计算。

第十行的$\text{diag}^{-1}$是错的,把$^{-1}$去掉。

优化了循环的顺序,增加了Seq维度的并行

FA1的两重循环中,是先外层循环load K, V,然后内层循环再load Q。这就会导致内层循环,每次计算的只是Qi的一部分,每次内循环的迭代都需要对Oi进行全局内存的读写。而且,一个显而易见的事实就是,在Attention的计算中,不同query的Attention计算是完全独立的。也就是说,如果外部循环是先load Q,那么就可以把不同的query块的Attention分配不同thread block进行计算,这些thread block之间是不需要通信的。没错,在FA2中,正是这样做的,对于forward pass,算法调换了循环的顺序,先load Q,再load K, V。

FA2增加seqlen并行,提高了occupancy,并且对于forward pass,Q*K^T在【行】方向的seqlen上天然可以并行,thread block之间不需要额外的通信。

Warp的分配更加的合理,避免Split-K

摘自 FlashAttention核心逻辑以及V1 V2差异总结

From From Online Softmax to FlashAttention(CSE599m, ML for ML System)
Warp Split-K

From From Online Softmax to FlashAttention(CSE599m, ML for ML System)

首先看fwd,相比V1,V2改进了Warp Partition:4个warp会从smem的K/V tile load同样的数据做mma计算,但是load 不同Q,把V1 sliced-K sliced-V 改成了v2 sliced-Q,V1的做法是需要warp之间产生同步通信的,因为在计算QK结果乘V的时候,如图所示需要跨warp reduction得到O的结果,而且fwd的目的是沿着行方向计算softmax,行方向信息最后要汇总的,这也需要跨warp不同。V2就不需要了,这样可以减少同步开销。

0x05 Causal Mask怎么用?

摘自 [Attention优化][2w字]🔥原理&图解: 从Online-Softmax到FlashAttention V1/V2/V3

非常简单的Early Exit逻辑:

情况0: 全Early Exit。全0的mask可以直接返回0,无需$Q\times K^T$,无需causal mask。

情况1: 部分Early Exit。全1的mask,只需$\text{Softmax}(Q\times K^T)$,无需causal mask。

情况3: 无法Early Exit。0-1混合的causal mask,需QxK^T,需要causal mask,然后$\text{Softmax}(\text{Mask}(Q \times K^T))$。

[Attention优化][2w字]🔥原理&amp;图解: 从Online-Softmax到FlashAttention V1/V2/V3
Masked 示意图

[Attention优化][2w字]🔥原理&图解: 从Online-Softmax到FlashAttention V1/V2/V3

0x06 MHA/GQA/MQA

在FlashAttention中,也支持MQA和GQA。对于MQA和GQA的情形,FlashAttention采用Indexing的方式,而不是直接复制多份KV Head的内容到显存然后再进行计算。Indexing,即通过传入KV/KV Head索引到Kernel中,然后计算内存地址,直接从内存中读取KV。

0x07 IO复杂度分析

因为FA主要是优化IO Acces,所以我们分析下FA的IO复杂度。我们假设Sequence的长度是$N$,每个头的维度是$d$,SRAM的大小是$M,d \le M \le Nd$。

使用原始的Self Attention算法的IO复杂度是$\Theta(Nd + N^2)$,FA1的IO复杂度是$\Theta(N^2d^2M^{-1})$,考虑到$d$一般是64-128,而$M$一般是100KB,所以FA1的访存次数小于原始的做法。

Memory Accesses和d的平方成正比关系,当d越大,FA的Memory Accesses会增长剧烈。比如对于N=2K, M=192KB, 当d=256时,依然满足 FA IO Acesses < Naive Attention,但是当d=512时,这个结论就会反过来,变成是 FA IO Acesses > Naive Attention IO Acesses,并且由于FA本身的FLOPS就是比Naive Attention高的,于是,此时无论是IO还是FLOPS,FA都会比Naive Attention高,无论是访存还是计算量都没有优势,唯一剩下的优势,应该就只剩节省显存了(不需要保存中间的S和P矩阵,O(N^2)的内存复杂度)

0x08 Triton代码

先再来复习下Block是怎么切块的,这里的图摘自BBuf的 笔记图解大模型计算加速系列:Flash Attention V2,从原理到并行计算

图解大模型计算加速系列:Flash Attention V2,从原理到并行计算
Block切块方向

图解大模型计算加速系列:Flash Attention V2,从原理到并行计算

增加了Seq维度的并行以后:

图解大模型计算加速系列:Flash Attention V2,从原理到并行计算
Seq维度切块方向

图解大模型计算加速系列:Flash Attention V2,从原理到并行计算

与V1不同的是,我们在Q的seq_len维度上也做了切分,将其分成四份,即num_m_block = 4。所以现在我们共有1_2_4 = 8个block在跑。这些block之间的运算也是独立的, 因为:

  • head的计算是独立的,所以红色block和蓝色block互不干扰
  • 采用Q做外循环,KV做内循环时,行与行之间的block是独立的,因此不同行的block互相不干扰。

每个block从Q上加载对应位置的切块,同时从KV上加载head0的切块,计算出自己所维护的那部分O,然后写入O的对应位置。

我们使用OpenAI Triton的FA2 Tutorial代码来分析。

下面的代码是每一个子Block中的最内层的代码,其中q是最外层循环的子块;K_block_ptrV_block_ptr是$K$、$V$的子块,需要一次for循环完整的遍历。

@triton.jit
def _attn_fwd_inner(acc, l_i, m_i, q,  #
                    K_block_ptr, V_block_ptr,  #
                    start_m, qk_scale,  #
                    BLOCK_M: tl.constexpr, HEAD_DIM: tl.constexpr, BLOCK_N: tl.constexpr,  #
                    STAGE: tl.constexpr, offs_m: tl.constexpr, offs_n: tl.constexpr,  #
                    N_CTX: tl.constexpr, fp8_v: tl.constexpr):
    # range of values handled by this stage
    # 根据STAGE的值,函数定义了处理的键(K)和值(V)的范围。
    # 不同的STAGE对应不同的处理范围,支持因果(causal)和非因果(non-causal)的自注意力。
    if STAGE == 1: # 使用 Mask
        lo, hi = 0, start_m * BLOCK_M
    elif STAGE == 2: # 使用 Mask
        lo, hi = start_m * BLOCK_M, (start_m + 1) * BLOCK_M
        lo = tl.multiple_of(lo, BLOCK_M)
    # causal = False,不使用 Mask
    else:
        lo, hi = 0, N_CTX
        
    # tl.advance 根据步长调整K_block_ptr的指向    
    K_block_ptr = tl.advance(K_block_ptr, (0, lo))
    V_block_ptr = tl.advance(V_block_ptr, (lo, 0))

    # 对K,V Block做完整的遍历
    for start_n in range(lo, hi, BLOCK_N):
        start_n = tl.multiple_of(start_n, BLOCK_N)
        # -- compute qk ----
        # 加载 K Block
        k = tl.load(K_block_ptr)

        # 伪代码 line8: q x k
        qk = tl.dot(q, k)
        if STAGE == 2:
            # Mask
            mask = offs_m[:, None] >= (start_n + offs_n[None, :])

            # Mask 区域加上 -INF
            qk = qk * qk_scale + tl.where(mask, 0, -1.0e6)

            # 伪代码 line 9: Safe online softmax 的 max
            m_ij = tl.maximum(m_i, tl.max(qk, 1))

            # 伪代码 line 9: s - m
            qk -= m_ij[:, None]
        else:
            # 伪代码 line 9: Safe online softmax 的 max,和伪代码的区别是这里有 qk_scale,稍后解释
            m_ij = tl.maximum(m_i, tl.max(qk, 1) * qk_scale)

            # 伪代码 line 9: s - m. 和伪代码的区别是这里有 qk_scale,稍后解释
            qk = qk * qk_scale - m_ij[:, None]
        # 伪代码 line 9: p = exp(s-m)
        p = tl.math.exp2(qk)

        # 伪代码 line 9: rowsum(p)
        l_ij = tl.sum(p, 1)
        
        # -- update m_i and l_i
        # 伪代码 line 10
        alpha = tl.math.exp2(m_i - m_ij)
        l_i = l_i * alpha + l_ij
        # -- update output accumulator --
        # 伪代码 line 10: 这里的 acc 是伪代码中的 O_i
        acc = acc * alpha[:, None]
        # update acc
        v = tl.load(V_block_ptr)
        if fp8_v:
            p = p.to(tl.float8e5)
        else:
            p = p.to(tl.float16)

        # 伪代码 line 10.
        acc = tl.dot(p, v, acc)
        # update m_i and l_i
        m_i = m_ij

        # 更新下一轮的 K,V Block的指针
        V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0))
        K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N))
    return acc, l_i, m_i

下面我们来看一下调用这个子块函数的函数。

@triton.autotune(list(filter(keep, configs)), key=["N_CTX", "HEAD_DIM"])
@triton.jit
def _attn_fwd(Q, K, V, sm_scale, M, Out,  #
              stride_qz, stride_qh, stride_qm, stride_qk,  #
              stride_kz, stride_kh, stride_kn, stride_kk,  #
              stride_vz, stride_vh, stride_vk, stride_vn,  #
              stride_oz, stride_oh, stride_om, stride_on,  #
              Z, H, N_CTX,  #
              HEAD_DIM: tl.constexpr,  #
              BLOCK_M: tl.constexpr,  #
              BLOCK_N: tl.constexpr,  #
              STAGE: tl.constexpr  #
              ):
    tl.static_assert(BLOCK_N <= HEAD_DIM)

    # 输入参数里的Z和H分别表示batch size和注意力头数
    # q.shape is [Batch, Head, Seq, Dim]
    # 启动的时候 [grid] 是
    # grid = lambda args: (triton.cdiv(q.shape[2], args["BLOCK_M"]), q.shape[0] * q.shape[1], 1)
    # start_m表示当前kernel program 实例对应的seq维度的偏移,而off_hz表示的是batch*heads维度的偏移。
    start_m = tl.program_id(0) # seq
    off_hz = tl.program_id(1) # batch * heads

    # 这两行计算了两个偏移量off_z和off_h,它们分别代表在batch(或heads)中的位置。
    off_z = off_hz // H # 表示在哪个 Batch
    off_h = off_hz % H # 表示在哪个 Head
    # 计算用于定位Q、K和V张量中当前处理块的偏移量。这是基于先前计算的偏移量和提供的步长参数。
    qvk_offset = off_z.to(tl.int64) * stride_qz + off_h.to(tl.int64) * stride_qh

    # block pointers
    # 使用tl.make_block_ptr创建一个指向Q张量当前处理块的指针。这个函数调用指定了基础地址、形状、步长、偏移量和块形状等,以及如何在内存中访问这个数据块。
    # N_CTX 是q.shape[2],表示的是序列长度,BLOCK_DMODEL是Lk,表示的是每个注意力头的隐藏层维度大小
    # 下面几个make_block_ptr创建的张量类似,分别是对K,V以及输出O创建指向当前处理块的指针
    Q_block_ptr = tl.make_block_ptr(
        base=Q + qvk_offset,
        shape=(N_CTX, HEAD_DIM),
        strides=(stride_qm, stride_qk),
        offsets=(start_m * BLOCK_M, 0),
        block_shape=(BLOCK_M, HEAD_DIM),
        order=(1, 0),
    )
    v_order: tl.constexpr = (0, 1) if V.dtype.element_ty == tl.float8e5 else (1, 0)
    V_block_ptr = tl.make_block_ptr(
        base=V + qvk_offset,
        shape=(N_CTX, HEAD_DIM),
        strides=(stride_vk, stride_vn),
        offsets=(0, 0),
        block_shape=(BLOCK_N, HEAD_DIM),
        order=v_order,
    )
    K_block_ptr = tl.make_block_ptr(
        base=K + qvk_offset,
        shape=(HEAD_DIM, N_CTX),
        strides=(stride_kk, stride_kn),
        offsets=(0, 0),
        block_shape=(HEAD_DIM, BLOCK_N),
        order=(0, 1),
    )
    O_block_ptr = tl.make_block_ptr(
        base=Out + qvk_offset,
        shape=(N_CTX, HEAD_DIM),
        strides=(stride_om, stride_on),
        offsets=(start_m * BLOCK_M, 0),
        block_shape=(BLOCK_M, HEAD_DIM),
        order=(1, 0),
    )
    # initialize offsets
    # 计算M维度(seq维度)上每个线程应处理的元素的起始偏移量。
    offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
    # 计算N维度(batch*heads维度)上每个线程应处理的元素的偏移量。
    offs_n = tl.arange(0, BLOCK_N)
    # initialize pointer to m and l
    
    # 初始化m向量,m用于存储每个m维度上的最大logit,初始化为负无穷大。
    m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
    
    # 初始化l向量,l用于累计softmax的分母,初始化为1。
    l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + 1.0
    
    # 初始化累加器,用于累积注意力加权和。注意这里的shape是(BLOCK_M, BLOCK_DMODEL)
    acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32)
    
    # load scales
    qk_scale = sm_scale
    qk_scale *= 1.44269504  # 1/log(2)
    
    # load q: it will stay in SRAM throughout
    # 将Q矩阵的当前块加载到SRAM中,此数据在整个计算过程中保持不变。
    q = tl.load(Q_block_ptr)
    
    # stage 1: off-band
    # For causal = True, STAGE = 3 and _attn_fwd_inner gets 1 as its STAGE
    # For causal = False, STAGE = 1, and _attn_fwd_inner gets 3 as its STAGE
    if STAGE & 1:
        acc, l_i, m_i = _attn_fwd_inner(acc, l_i, m_i, q, K_block_ptr, V_block_ptr,  #
                                        start_m, qk_scale,  #
                                        BLOCK_M, HEAD_DIM, BLOCK_N,  #
                                        4 - STAGE, offs_m, offs_n, N_CTX, V.dtype.element_ty == tl.float8e5  #
                                        )
    # stage 2: on-band
    if STAGE & 2:
        # barrier makes it easier for compielr to schedule the
        # two loops independently
        acc, l_i, m_i = _attn_fwd_inner(acc, l_i, m_i, q, K_block_ptr, V_block_ptr,  #
                                        start_m, qk_scale,  #
                                        BLOCK_M, HEAD_DIM, BLOCK_N,  #
                                        2, offs_m, offs_n, N_CTX, V.dtype.element_ty == tl.float8e5  #
                                        )
    # epilogue
    m_i += tl.math.log2(l_i)
    acc = acc / l_i[:, None]
    m_ptrs = M + off_hz * N_CTX + offs_m
    tl.store(m_ptrs, m_i)
    tl.store(O_block_ptr, acc.to(Out.type.element_ty))

需要特别注意的是这段代码最后的epilogue部分就对应了FlashAttention V2伪代码中的12行以后的内容,根据softmax的分母部分较正输出。此外,Triton的实现里面考虑了一些paper里面没有的东西比如qk_scalecausal mask,对Q*K的结果S应用了减掉m,使得整个实现看起来要复杂不少,但整体的算法逻辑和并行设置和paper还是一致的。

最后在Attention中使用这个函数

class _attention(torch.autograd.Function):
    @staticmethod
    def forward(ctx, q, k, v, causal, sm_scale):
        # shape constraints
        HEAD_DIM_Q, HEAD_DIM_K = q.shape[-1], k.shape[-1]
        # when v is in float8_e5m2 it is transposed.
        HEAD_DIM_V = v.shape[-1]
        assert HEAD_DIM_Q == HEAD_DIM_K and HEAD_DIM_K == HEAD_DIM_V
        assert HEAD_DIM_K in {16, 32, 64, 128, 256}
        o = torch.empty_like(q)
        stage = 3 if causal else 1
        extra_kern_args = {}
        # Tuning for AMD target
        if is_hip():
            waves_per_eu = 3 if HEAD_DIM_K <= 64 else 2
            extra_kern_args = {"waves_per_eu": waves_per_eu, "allow_flush_denorm": True}

        # q.shape is [Batch, Head, Seq, Dim]
        grid = lambda args: (triton.cdiv(q.shape[2], args["BLOCK_M"]), q.shape[0] * q.shape[1], 1)
        M = torch.empty((q.shape[0], q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
        
        # Launch Kernel.
        _attn_fwd[grid](
            q, k, v, sm_scale, M, o,  #
            q.stride(0), q.stride(1), q.stride(2), q.stride(3),  #
            k.stride(0), k.stride(1), k.stride(2), k.stride(3),  #
            v.stride(0), v.stride(1), v.stride(2), v.stride(3),  #
            o.stride(0), o.stride(1), o.stride(2), o.stride(3),  #
            q.shape[0], q.shape[1],  #
            N_CTX=q.shape[2],  #
            HEAD_DIM=HEAD_DIM_K,  #
            STAGE=stage,  #
            **extra_kern_args)

        ctx.save_for_backward(q, k, v, o, M)
        ctx.grid = grid
        ctx.sm_scale = sm_scale
        ctx.HEAD_DIM = HEAD_DIM_K
        ctx.causal = causal
        return o

0x09 CUDA代码

0x0A FA 3

0x0B 思考

  • CPU上使用这个靠谱吗?CPU上并行度较低,用这个没有必要,但是可以考虑分块和Mask混合的MatMul来减少计算量,也就是Early Exit。