0. 前言
本问分析的两篇文章是:
-
2024-05, SpinQuant: LLM quantization with learned rotations from meta,一作是 Zechun Liu
-
2024-04, QuaRot: Outlier-Free 4-Bit Inference in Rotated LLMs from ETH Zurich、EPFL、Microsoft Research、IST Austria & NeuralMagic
这两篇文章以相同的视角去解决问题,并且量化后的模型保持了相当好的性能,应该就是未来模型量化的一个主要的跟进方向。QuaRot和SpinQuant可以算作是同一系列的工作,SpinQuant在QuaRot的基础上做了可学习旋转矩阵的改进。
1. Computational Invariance
Computational Invariance是Quarot和SpinQuant的基础。
Computational Invariance定理[1]指出,可以使用正交矩阵对Transformer 中的权重和块间激活进行变换,而模型输出不变。这个定理是说,如果$W_{in}$是一个在transformer block(i.e. $W_k,W_q,W_v$等)左边的权重,我们可以左乘上一个正交的矩阵$Q$,为了在最后的结果里消除这个影响,我们可以在输出矩阵(i.e. $W_{out}, W_{down}$)右边乘上$Q_T$。
尽管 RMSNorm 被应用于两个数据块之间,但只要 RMSNorm 数据块中没有重新缩放(在实际操作中,我们会先将任何重新缩放吸收到相邻的权重矩阵中),这一点还是适用的。从概念上讲,这是因为 RMSNorm 将激活值除以它们的模长,而对激活值应用旋转矩阵$Q$不会影响模长(因为正交矩阵的特性):
$$\text{RMSNorm}(\boldsymbol{X}) = \text{RMSNorm}(\boldsymbol{X\boldsymbol{Q}^T})\boldsymbol{Q}$$
我们这里假设RMSNorm对激活值$\boldsymbol{X}$的每一行做的操作都是$x_{i} \larr x_i/ \Vert x_I \Vert$。这意味着,将输出矩阵乘以 $Q^T$ 会使线性层输出$XQ^T$,$XQ^T$被归一化,然后传入下一个区块,其输入权重矩阵现在是$QW$,因此该线性层不做任何修改就会输出原始激活。
2. 🥕Quarot
Quarot有两个阶段,一个阶段是在全精度下操作模型权重,并在模型的前向传递中插入两个额外的乘Hadamard矩阵操作;在第二个阶段使用现在的量化方法来量化夹在Hadamard矩阵中的模型权重,因为这些权重被削减了峰度,outliers减少,可以使量化的压力小很多。
Quarot文中的图片已经非常清晰的描绘了什么时候做旋转以及何时做量化和量化的数据流。SpinQuant和Quarot比较了实验结果,Quarot的实验结果请看第三章节的SpinQuant的表格。
3. SpinQuant
spin把成对的旋转和反旋转操作用到了整个模型中去,包括残差的地方。对于我来说,SpinQuant更加的直接一点。SpinQuant的动机和Quarot相似都是为了解决Outlier问题, Outliers会拉伸量化范围,导致大部分值的在对应量化网格上的有效映射减少;或者需要裁剪Outliers。通过对激活或权重矩阵进行旋转可以帮助消除Outliers。
作者的文章写的很好,展示了旋转之后的数据分布情况:
作者使用的旋转矩阵不同于Quarot,她使用了Cayley SGD 优化器来学到一个旋转矩阵。学习得到的旋转矩阵比Quarot的Hadamard矩阵有更好的效果。作者也探究了不同的旋转矩阵对模型性能的影响:
作者在文中和Quarot一样使用了两个Stage的方式,第一个Stage是训练学习到旋转矩阵,第二个Stage是将旋转矩阵和权重合并之后再做量化。如下面的流程图所示:
需要说明的是,图中$q$分支的的$R_3$应该是标注错误了,$R_3^{-1}$才是准确的。
旋转矩阵是学习出来的,其流程是 WikiText2 校准数据集上的 800 个样本使用 Cayley 优化来更新旋转 100 次迭代。
结果是非常惊艳的,在W4A4KV4上SpinQuant算法的表现非常的优秀!!!在W4KV4的情况下不损失过多的性能,看得出SpinQuant在Outlier的剔除上有非常大的作用。
Ref:
- Saleh Ashkboos, Maximilian L Croci, Marcelo Gennari do Nascimento, Torsten Hoefler, and James Hensman. Slicegpt: Compress large language models by deleting rows and columns. arXiv preprint arXiv:2401.15024, 2024.