0x00 参考资料

本文是对RoPE的学习笔记,其中有不少内容是摘自业内的前辈们的文章,在此一并感谢。所参考的资料、摘录的文章来源在下面列出:

  1. 苏老师的Blog,Transformer升级之路:2、博采众长的旋转式位置编码

  2. Reformer,Github repo

  3. 视屏讲解,【解密旋转位置编码:数学基础、代码实现与绝对编码一体化探索】

0x01 为什么需要位置编码

在自然语言处理中,Attention机制是一种用于建立输入序列中不同元素间关联的方法,它允许模型在处理序列数据时,能够关注到当前元素之外的其他元素。然而,原始的Attention机制存在一个问题:它没有考虑序列中元素的顺序信息。 在处理序列数据时,元素的顺序是非常重要的,因为不同的顺序可能会表达完全不同的意义。例如,在句子中,“I saw the man with the telescope” 和 “The man saw I with the telescope” 虽然使用了相同的词汇,但意义却大相径庭,这正是因为词序不同。

为了解决这个问题,引入了位置编码(Positional Encoding)。位置编码是一种向量,它与原始的输入序列中的每个元素相乘,为模型提供关于每个元素在序列中位置的信息。这样,即使在计算Attention权重时,模型也能够区分不同位置的元素,从而更好地理解序列数据中的顺序信息。

0x02 绝对位置编码

可学习的绝对位置编码

BERT模型中引入了可学习的绝对位置编码(position embeddings),其目的是在模型中为每个位置提供序列中单词的位置信息。这种位置编码是模型参数的一部分,可以在训练过程中自动学习得到最优的表示。具体来说,BERT使用了一个可训练的嵌入层来生成位置编码,这使得模型能够捕捉到序列中单词的顺序信息,从而更好地理解句子结构和语义。

但是,可学习的位置编码受长度限制,无法应用在长文本。

Sinusoidal位置编码

Sinusoidal位置编码的基本思想是利用正弦和余弦函数的周期性来编码位置信息。具体来说,对于序列中的每个元素,位置编码会生成一个与位置相关的向量:

$$PE_{(p, 2i)} = \sin\left(\frac{p}{10000^{(i/d)}}\right)$$

$$PE_{(p, 2i+1)} = \cos\left(\frac{p}{10000^{(i/d)}}\right)$$

其中,$p$表示位置,$i$表示维度信息。

0x03 RoPE

这里的公式借用苏老师文章中的用法。假设我们需要对$\boldsymbol{q},\boldsymbol{k}$添加上位置信息$m,n$,我们可以假设有$\boldsymbol{f}(\cdot, \text{pos})$这样的操作可以做到这一点:

$$\tilde{\boldsymbol{q}}_m=\boldsymbol{f}(\boldsymbol{q},m),\quad\tilde{\boldsymbol{k}}_n=\boldsymbol{f}(\boldsymbol{k},n)$$

在Attention操作中,$q,k^*$会做内积,我们希望在做内积的时候可以体现出相对位置关系,也就是下面的式子展示的:

$$\langle \boldsymbol{f}(\boldsymbol{q},m), \boldsymbol{f}(\boldsymbol{k},n) \rangle = \boldsymbol{g}(\boldsymbol{q},\boldsymbol{k}, m-n)$$

其中$m-n$就是相对位置关系。

$\langle \boldsymbol{x}, \boldsymbol{y} \rangle$表示希尔伯特空间的内积,即高维度欧几里得空间。

那么接下来的问题就是如何找到这样的$\boldsymbol{f}$以满足我们上面的假设。为了简化问题,我们不妨假设

$$\boldsymbol{f}(\boldsymbol{q},0) = q, \boldsymbol{f}(\boldsymbol{k},0)=k$$。

我们可以借用复数的概念来求解这个问题,在此,我们先考虑二维情况下。

在复数中有$\langle \boldsymbol{q} ,\boldsymbol{k} \rangle = \text{Re}[\boldsymbol{q} \boldsymbol{k}^*]$,即

$$\text{Re}[ \boldsymbol{f}(\boldsymbol{q},m)\boldsymbol{f}^*(\boldsymbol{k},n)]= \boldsymbol{g}(\boldsymbol{q},\boldsymbol{k}, m-n)$$

为了简化问题,我们假设存在复数$\boldsymbol{g}(\boldsymbol{q},\boldsymbol{k}, m-n)$使得:

$$\boldsymbol{f}(\boldsymbol{q},m)\boldsymbol{f}^*(\boldsymbol{k},n)= \boldsymbol{g}(\boldsymbol{q},\boldsymbol{k}, m-n)$$

我们使用复数的指数形式可以得到:

$$\boldsymbol{f}(\boldsymbol{q},m) = R_f(\boldsymbol{q},m)e^{i \Theta_f(\boldsymbol{q},m)}$$

$$\boldsymbol{f}(\boldsymbol{k},n) = R_f(\boldsymbol{k},n)e^{i \Theta_f(\boldsymbol{k},m)}$$

$$\boldsymbol{g}(\boldsymbol{q},\boldsymbol{k}, m-n) = R_g(\boldsymbol{q},\boldsymbol{k},m-n)e^{i \Theta_g(\boldsymbol{q},\boldsymbol{k},m-n)}$$

带入之前的方程后可以得到下述方程组:

$$R_f(\boldsymbol{q},m)R_f(\boldsymbol{k},n)=R_g(\boldsymbol{q},\boldsymbol{k},m-n)$$

$$\Theta_f(\boldsymbol{q},m) - \Theta_f(\boldsymbol{k},n) = \Theta_g(\boldsymbol{q},\boldsymbol{k},m-n)$$

对于上述两个方程,我们带入$m=n$,由于一开始的假设$\boldsymbol{f}(\boldsymbol{q},0) = q, \boldsymbol{f}(\boldsymbol{k},0)=k$,我们可以得到下面的方程:

$$R_f(\boldsymbol{q},m)R_f(\boldsymbol{k},m)=R_g(\boldsymbol{q},\boldsymbol{k},0)=R_f(\boldsymbol{q},0)R_f(\boldsymbol{k},0)=\Vert \boldsymbol{q} \Vert \Vert \boldsymbol{k} \Vert$$

$$\Theta_f(\boldsymbol{q},m) - \Theta_f(\boldsymbol{k},m) = \Theta_g(\boldsymbol{q},\boldsymbol{k},0)=\Theta_f(\boldsymbol{q},0) - \Theta_f(\boldsymbol{k},0)=\Theta_f(\boldsymbol{q}) - \Theta_f(\boldsymbol{k})$$

现在,对于第一个式子,我们简单的假设

$$R_f(\boldsymbol{q},m)=\Vert \boldsymbol{q} \Vert,R_f(\boldsymbol{k},n)=\Vert \boldsymbol{k} \Vert$$

对于第二个式子,经过变换得到:

$$\Theta_f(\boldsymbol{q},m) - \Theta_f(\boldsymbol{q}) = 0,\Theta_f(\boldsymbol{k},m) - \Theta_f(\boldsymbol{k}) = 0$$

所以,$\Theta_f(\boldsymbol{q},m) - \Theta_f(\boldsymbol{q})$是一个与$\boldsymbol{q}$无关,与$\boldsymbol{m}$相关的一个函数,我们设它为$\varphi(m)$,即:

$$\Theta_f(\boldsymbol{q},m) = \Theta_f(\boldsymbol{q}) + \varphi(m)$$

此时,

$$\varphi(m) - \varphi(m-1) = \Theta(\boldsymbol{q},\boldsymbol{k},1)+\Theta(\boldsymbol{k}) - \Theta(\boldsymbol{q})$$

即${\varphi(m)}$是等差数列,设右端是$\theta$,那么$\varphi(m) = \theta m$

此时就得到了二维情况下的RoPE:

$$\boldsymbol{f}(\boldsymbol{q}, m) = \boldsymbol{q}e^{im\theta}$$

这实际上就是对于$\boldsymbol{q}$的旋转公式:

$$\boldsymbol{f}(\boldsymbol{q}, m) = \begin{pmatrix} \cos m\theta & -\sin m \theta \\ \sin m \theta & \cos m \theta \end{pmatrix} \begin{pmatrix} q_0 \\ q_1 \end{pmatrix}$$

$\text{Re}$表示复数的Real部分。 $Z = \text{a} + \text{b}i=r(\cos(\Theta) + i\sin(\Theta)) = re^{i\Theta}$,复数形式变换

由于内积满足线性叠加性,因此任意偶数维的RoPE,我们都可以表示为二维情形的拼接

$$\underbrace{\begin{pmatrix}\cos m\theta_0&-\sin m\theta_0&0&0&\cdots&0&0\\\sin m\theta_0&\cos m\theta_0&0&0&\cdots&0&0\\0&0&\cos m\theta_1&-\sin m\theta_1&\cdots&0&0\\0&0&\sin m\theta_1&\cos m\theta_1&\cdots&0&0\\\vdots&\vdots&\vdots&\vdots&\ddots&\vdots&\vdots\\0&0&0&0&\cdots&\cos m\theta_{d/2-1}&-\sin m\theta_{d/2-1}\\0&0&0&0&\cdots&\sin m\theta_{d/2-1}&\cos m\theta_{d/2-1}\end{pmatrix}}_{\mathbf{W}_m}\begin{pmatrix}q_0\\q_1\\q_2\\q_3\\\vdots\\q_{d-2}\\q_{d-1}\end{pmatrix}$$

可以看到,RoPE形式上和Sinusoidal位置编码有点相似,只不过Sinusoidal位置编码是加性的,而RoPE可以视为乘性的。

对于远程衰减,在$\theta_i$的选择上,RoPEs同样沿用了Sinusoidal位置编码的方案,即:

$$\theta_i = 10000^{-\frac{2i}{d}}$$

远程衰减: 对于$q,k$,我们希望距离近的$q,k$有较大的相关性,距离远的$q,k$相关性小。

0x04 思考

  1. 既然是旋转,会出现旋转角度重复的现象吗?

在任意的第$k$个子空间上,只要$\theta_k$中不包含$\pi$,那么旋转角度序列${ i \theta_k }$就不会有角度重复。

  1. $\theta_i$中10000的设置会影响外推的性能吗?

是会的,较小的值会造成外推性能严重下降。可以把旋转角度的函数图画出来,或者设置一个$q,k$算一下就明了了。

使用$\text{base}=10000$,跑出下图的结果:

Fig 1. $\text{Base}=10000$

而$\text{base}=50000$和$\text{base}=10000$比,远程衰减的效果有降低:

Fig 2. $\text{Base}=50000$

目前很多大长度外推的模型都是通过调大base来提升模型的输入长度。

  1. 二维的子空间可以是任意的吗

二维的子空间可以是任意的,只要是成对的就可以了,无需按照文章中所说的来。像HF Llama中的那样,就是half分割的形式。

0x05 代码实现

参考的是HF的Llama RoPE实现。

在实现的过程中,我们需要避开旋转矩阵的相乘,因为旋转矩阵是非常稀疏的。在论文中,作者使用的是:

${0, 1, 2, 3}$ -> { 0, 1 }, { 2, 3 }
Fig 3. 原文子空间的选择方法

${0, 1, 2, 3}$ -> { 0, 1 }, { 2, 3 }

但是实际上,如HF Llama,其使用的是:

${0, 1, 2, 3}$ -> { 0, 2 }, { 1, 3 }
Fig 4. HF Llama 子空间的选择方法

${0, 1, 2, 3}$ -> { 0, 2 }, { 1, 3 }

两者都是合理的实现,只是子空间的划分不同。对于half划分,在HF llama的实现中是这样的:

def rotate_half(x):
    """Rotates half the hidden dims of the input."""
    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]
    return torch.cat((-x2, x1), dim=-1)

在乘上$\cos,\sin$后是:

def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
    """Applies Rotary Position Embedding to the query and key tensors.

    Args:
        q (`torch.Tensor`): The query tensor.
        k (`torch.Tensor`): The key tensor.
        cos (`torch.Tensor`): The cosine part of the rotary embedding.
        sin (`torch.Tensor`): The sine part of the rotary embedding.
        position_ids (`torch.Tensor`, *optional*):
            Deprecated and unused.
        unsqueeze_dim (`int`, *optional*, defaults to 1):
            The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
            sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
            that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
            k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
            cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
            the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
    Returns:
        `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
    """
    cos = cos.unsqueeze(unsqueeze_dim)
    sin = sin.unsqueeze(unsqueeze_dim)
    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)
    return q_embed, k_embed

在$\cos,\sin$的生成上就不使用Llama的代码了,它封装的太多了,我自己写了一个:

base = 1e5
d = D / 2
B = base ** (1/d)
theta_beta = 1.0 / (B ** torch.arrange(0, d))
theta_0 = q.outer(theta_beta)
theta = torch.concat([theta_0, theta_0], dim=-1)
cos = theta.cos()
sin = theta.sin()