[Fundamental] FlashDecoding Series
0x00 前沿和阅读材料 FlashDecoding系列的文章是对FA在推理场景下的改进,目前包含两篇文章: Flash-Decoding for long-context inference, Torch团队的Blog FlashDecoding++: Faster Large Language Model Inference on GPUs 我们知道,在FA2中特别对Seq方向做了并行化,但是在推理的时候Seq=1。此时,并不能占用满GPU的全部的SM,导致性能损失,FlashDecoding就是对此的优化。 0x01 FlashDecoding 在解码过程中,生成的每个新Token都需要考虑之前的所有Token,以便进行注意力计算。 在训练的时候,Attention这一算子已使用FlashAttentionV2算法进行了优化,其瓶颈在于读写中间结果(例如 Q @ K^T)的内存带宽不足。但是,这些优化并不能直接应用于推理,因为推理的时候不在是内存带宽的瓶颈。在训练过程中,FlashAttention 会在Batch和Seq两个维度上进行并行处理。在推理过程中,Seq=1,这意味着,如果Batch大小小于 GPU 上的SM数量(A100 为 108),则这个Attention操作只会使用 GPU 的一小部分!在使用长上下文时尤其如此。当Batch大小为 1 时,FlashAttention 将使用不到 1%的 GPU! Fig 1. Regular AttentionFlash-Decoding for long-context inference 为此我们不难想到可以实用Split-K的方法来使得Attention在推理的时候也有很好的并行性。如下图所示: Fig 1. Split-K AttentionFlash-Decoding for long-context inference 非常的好理解,但是这里需要注意的是,在最后的Reduce Op这里还是要做Online Softmax的,所以在SRAM里面保存的东西是比原来多的,除了Output,还有exp Sum和Max。 0x02 FlashDecoding++ flashdecoding++不是meta官方出品的。 FA中,求解Max需要遍历迭代,之后的子块依赖于之前的子块。Safe-softmax的计算公式中,需要先求每行x的最大值,然后减去这个max(x)之后,再做softmax以防止数值溢出。FlashDecoding++提出的创新点就是,我们可以实用一个先验的$\phi$来作为max值,只要它能让数值稳定就可以了。从Safe Softmax的公式上来看,无论是$\phi$还是max(x),他们的结果是一致的,我们需要追求的是数值上的稳定与否。 FlashDecoding++认为一个合理的先验值 $\phi$,可以直接从数据集中进行统计获得。对于不同的模型,这个先验值也是不一样的。在实现的时候,FlashDecoding++还使用了Fallback的思路,当出现数值溢出的时候,使用传统的FlashDecoding。 那为什么FalshDecoding++能异步,而FlashDecoding不行呢? 在FalshDecoding Split-K分成的几个区间内,还是使用的FA2的方法来计算,但是FA2的一次迭代是依赖于上一次迭代的结果的,也就是需要rescale。但是FlashDecoding++不需要,它大致上是这样的: $$\begin{aligned} &\ell^{(1)}=\mathrm{rowsum}\Big(e^{\mathbf{S}^{(1)}-\phi}\Big)\in\mathbb{R}^{B_{r}} \\ &\tilde{\mathbf{O}}^{(1)}=e^{\mathbf{S}^{(1)}-\phi}\mathbf{V}^{(1)}\in\mathbb{R}^{B_{r}\times d} \\ &\ell^{(2)}=\mathrm{rowsum}\left(e^{\mathbf{S}^{(2)}-\phi}\right) \\ &\tilde{\mathbf{O}}^{(2)}=e^{s^{(2)}-\phi}\mathbf{V}^{(2)} \\ &\mathbf{O}^{(2)}=\mathrm{diag}\left(\ell^{(1)}+\ell^{(2)}\right)^{-1}(\tilde{\mathbf{O}}^{(1)}+\tilde{\mathbf{O}}^{(2)})=\mathbf{O} \end{aligned}$$