✅[April 2024] Prompt Cache: Modular Attention Reuse for Low-Latency Inference
背景和动机 以KV Cache为启发,探索了对time-to-first-token (TTFT) Latency的优化。类似于KV Cache,Prompt Cache(PC)推理加速的核心思想是复用注意力的中间状态(Attention States)。然而与KV Cache不同的是,PC是在不同的prompt之间进行复用。 在大部分的LLM任务中,prompt有重叠(overlapping)的现象,这些重叠的prompt可以被存储起来,进而在接下来的LLM处理阶段可以像KV Cache一样,提取出来直接使用。在TTFT的推理过程中,免去计算不同prompt中重叠部分的注意力状态,从而缩短TTFT的生成时间。 与KV Cache不同的点是: 相同的文本段可能出现在不同prompt的不同位置,如何对它们的Attention States进行复用。因为不同位置的文本段的Position Encoding进去的值是不一样的。在KV Cache中不需要考虑这一点,因为cache是从前往后线性增长的,但Prompt所在的位置是不确定的。 如何从不同的prompt中识别出已经缓存过的文本。 算法 实验经验 一段prompt的Position值不连续没有关系。只要这一段prompt本身的Position值是连续的就行。意思是部分连续对于LLM就够了,不一定要完全连续。请注意:这是一个实验性验证的结论。 Prompt Schema Fig 1. Prompt Schema 作者团队定义了一个Prompt Markup Language(PML)。上图中的例子有:可以复用的module和不能复用的填充部分,填充部分需要用Param指出,并给出长度。Prompt Attention States中的红色部分是可以被复用的区域。 Fig 2. 原始LLM/KV Cache/Prompt Cache 我们来对比下普通的自回归LLM、使用了KV Cache的LLM和使用了Prompt Cache的LLM。普通的LLM每次都要通过输入的Prompt来预测出下一个Token,Prompt是全量的计算。使用了KV Cache的LLM,每次Token预测不用全量计算了,可以使用上次Attention的中间结果。而使用了Prompt Cache的LLM,在后期预测Token的过程和原来的KV Cache没有什么区别。主要区别是在一开始的Prompt输入的阶段,Prompt Cache中常用的Prompt Attention States可以被利用起来,这会极大的缩减第一个Token输出的时间。 Prompt Schema有很多的细节,这里只讲大致的思路,具体的请看文章和代码仓库。 我对module怎么复用不是很理解,应该是通过将文本内容进行sha256编码来对其进行识别。 本文主要是对首Token输出时间的优化,对于用户来说可以有更好的体验。要是能做个全局的Prompt Cache数据库,应该可以给大规模的LLM Infer系统带来不少的好处。