MLLM Qnn AOT Support for Full Graph Execution on NPU
github repo: https://github.com/UbiquitousLearning/mllm
website: https://ubiquitouslearning.github.io/mllm/index.html
1.1 Introduction
由于高通 NPU 的静态图,全 Int 计算的特性,如何快速的、准确的在高通的 NPU 上部署 LLM 一直是业界面临的挑战。为了实现更便捷的端侧部署,我们在 MLLM 框架中完成了对高通 Qnn 框架的支持。现在,用户可以在 X86 机器上,通过 MLLM 的 Qnn AOT 模式编译 LLM,随后利用 MLLM Qnn AOT Runtime,在搭载高通 NPU 的端侧设备上高效运行。经过 MLLM 团队成员的精心设计,MLLM 框架通过一套统一的中间表示,以轻量的方式完整支撑了模型量化、AOT 编译及实际运行的全流程。在我们的优化下,Qwen3 1.7B 模型11 采用 W4A16KV8 量化配置,chunk=32, pad=1024,在 Qualcomm SM8650 上运行在 prefill 640 tokens、decode 128 tokens 的场景下,实现了 prefill 758.28 tokens/s 与 decode 25.84 tokens/s 的性能。我们之后会通过 Graph Selection 的方法继续加速 prefill 和 decode 的速度。
端侧设备对功耗与发热高度敏感,而 LLM 的推理本身又是一项对延迟要求极高的任务。尽管当前Arm CPU普遍搭载了 SVE、SME 等加速单元,其性能仍难以充分满足 LLM 推理的计算需求。此外,在端侧设备中,CPU 通常还需处理其他低延迟任务,将计算密集型的 LLM 任务交由 CPU 执行既不合理也不现实22 出于功耗、散热以及对其他应用影响的考量,实际部署中 CPU 往往仅使用单线程运行 LLM。同理,虽然端侧设备通常配备性能尚可的 GPU,且采用 OpenCL 在 GPU 上执行模型也是可行的方向,但 GPU 同样需负责用户界面的实时渲染,若将其资源用于 LLM 计算,将显著影响用户体验。因此,在低功耗的 NPU 上统一执行 LLM 的全部任务(如 prefill,decode,ViT)显得尤为重要。
在 NPU(下文中,NPU 若无说明,均指代高通的 NPU) 上做推理有如下的难点:
-
Decode 速度慢
Prefill 阶段通常较快,但 Decode 阶段相对较慢。这主要是因为较小的计算量无法充分利用 NPU 的计算单元,使得 LLM 推理受限于内存带宽瓶颈。
-
全 Int 计算的精度损失严重
NPU 上所有算子的输入和输出都期望是 Per Tensor Int 量化的,这样 NPU 才能发挥出最大的性能。但是全 Int 的 Per Tensor 量化会给 LLM 带来不小的性能损失。33 由于NPU 内部的定点计算的特性,SiLU、RMSNorm、RoPE 都是需要 Int 输入和输出的。特别是 RoPE,需要拆分出一系列的 elementwise 算子,这里的每一个 elementwise 算子都是需要 Int 输入和输出的。 我们近期有一些工作是和 NPU 友好的量化相关的,大家可以关注下
-
LLM 的动态形状特性难在 NPU 的静态图上实现
Qnn 的计算图是离线构建的,但 LLM 的自回归推理在 Attention 部分存在动态形状问题,需要在 Attention 中进行大量 Padding,这会对性能造成损耗。
在实现 MLLM Qnn AOT 的过程中,我们遇到了诸多技术挑战,这些细节将在附录中详细说明。整个适配 Qnn AOT 模式的过程充满困难——团队对 Qnn 的文档和实现、对量化算法的精度,以及对 MLLM 中 AOT Compile Stack 的实现都缺乏足够信任,由此逐渐形成了一条猜疑链。
图 1 谁是卧底?是 Qnn 的锅,还是量化算法和 MLLM Qnn AOT?
究其原因,是 Qnn 软件栈的封闭特性和不完善、错误的文档导致的。Qnn 因为内部计算图的黑盒特性,使得调试起来非常的困难,我们时常怀疑是不是 Qnn 本身的实现有问题。我们还发现 Qnn 的 RMSNorm Op 的描述和行为不一致;量化方法描述不一致的问题等等。这些问题我们都会在附录中一一列出。
MLLM 的 Qnn AOT 工作流主要由三部分组成:1. 带有 QDQ 的 modeling 文件实现44 这里需要同时包含 huggingface transformer 形式的 python modeling 文件;还需要 mllm 中的 cpp modeling 文件。这两个文件的模型搭建方式颇为类似,AI 可以很好的做出来 2. 将量化模型导出到 mllm 格式,编译具体模型的对应的 mllm-aot-compiler 3. 使用 mllm-aot-compiler 来编译模型,并使用 mllm-aot-runner 来运行。
图 2 MLLM Qnn AOT 工作流
在此次的 MLLM Qnn AOT 实现中,AI 发挥了巨大的作用,其成为了 MLLM Qnn AOT 工作流的一部分。我们把许多的 QDQ(Quantize, Dequantize)细节全都暴露在了 Modeling 层面,而不是和 executorch 一样全都要隐藏在了 compile 的 annotate 阶段。核心思想是:能让 AI 做的,为什么要隐藏给 Compiler 呢?,暴露出足够的细节在前端,人阅读也方便,AI 阅读也方便。当我们给出一个实现好的 Qwen3 modeling 脚本,AI 可以很容易的写出一个 llama or Qwen2.5 等模型的 modeling 文件。
最后,感谢参与了 MLLM Qnn AOT 开发的同学们。我们也特别感谢 executorch 社区的高通后端实现,我们从 executorch 的实现中汲取了很多的养分,executorch 社区的成员们也非常的友善,对我们碰到的问题给予了回复和帮助。
尽管高通的 NPU 是目前的第一梯队,但是其羸弱的软件生态限制了开源社区的开发者在其芯片上面做进一步的工作。在被高通的 Qnn 折磨后,我不得不借用 torvalds 的话来总结:XXXX you, Qualcomm!
在本文中,我将首先阐述为何 NPU 期望输入与输出均为整数类型,随后依据 MLLM Qnn AOT 的工作流程,依次介绍量化算法配置、QDQ建模文件编写、MLLM IR、Qnn AOT 生成以及 Qnn AOT 运行时。
1.2 为什么需要 Int 输入和输出?
在移动端 AI 开发中(特别是使用 SNPE 或 QNN SDK 时),开发者常会遇到一个强硬的要求:模型必须量化,且 NPU 极度偏好 Int8/Uint8 格式的输入和输出。现在的手机 CPU 和 GPU 浮点性能这么强,为什么 Hexagon NPU 非要在这个时候“返璞归真”去做定点数计算?甚至连输入输出都要卡死在 Int 上?可能是因为两点:1. 卷积网络的历史惯性 2. NPU 的设计初衷是 极致的能效比(TOPS/Watt)。
我们就以一个最基础的算子 Per-Tensor Uint8 Elementwise Multiplication 为例,拆解其背后的定点数计算逻辑。
我们要执行的操作是两个张量 和 的逐元素相乘,得到 。即:
但在 NPU 内部,我们没有真实浮点值 (, real),只有量化后的整数 (, quantized)。根据 Per-Tensor 量化公式(所有元素共享一组 Scale 和 Zero-point):
其中:
- : 真实浮点值 (Real value)
- : 缩放因子 (Scale, usually float32)
- : 量化后的整数 (Quantized int/uint8)
- : 零点 (Zero-point, int/uint8)
将公式代入乘法操作:
我们的目标是求出 NPU 需要输出的整数 。经过移项整理,得到核心计算公式:
请注意公式中被标记为 的这一项:
都是浮点数,算出来的 也是一个浮点数(例如 )。如果 NPU 直接算这个公式,它还得内置一个浮点乘法器来处理 ,这就破坏了全定点计算的初衷。
所以在 NPU 内部会做定点化缩放 (Fixed-point Rescaling)
NPU 会将浮点数 近似表示为 “一个整数乘数 + 一个右移操作”:
其中 Multiplier 是一个大的整数(比如 int32),Shift 是位移量。于是,整个计算流程变成了纯整数运算:
- 输入准备:NPU 读取 Int8 的 和 (这就是为什么输入必须是 Int,否则第一步就卡住了)。
- 整数减法:。
- 整数乘法: (结果可能是 int16 或 int32)。
- 定点缩放 (关键一步):。这里用整数乘法和位移代替了浮点乘法。
- 加零点:。
- 饱和截断 (Saturate):将结果限制在 范围内,输出 。
你可能会认为 只是换了一种写法,但事实上,这是一个有损变换,并非数学上的等价。
在数学上,浮点数 (例如 )的小数位理论上可以无限延伸。但在硬件中,"Multiplier" 通常被限制为 32位整数 (Int32)。
这里的 round 操作强制丢弃了 结果中的小数部分。这就引入了不可避免的量化误差。
假设真实比例 ,硬件限制 Multiplier 必须是整数,我们尝试用不同的位移量 Shift 来表示它:
尝试 1 (Shift=8):
四舍五入后,Multiplier = 26。 还原回去:。 误差约为 1.5%。
尝试 2 (Shift=31, 32位整数的极限):
四舍五入后,Multiplier = 214748365。 还原回去:。 误差极小,但依然存在。
这就意味着,寻常的 FakeQuant 以后的结果仍然无法完全模拟 Qnn 的计算精度误差! 算子越多,损失越大!
1.3 量化算法
在 MLLM 中,我们默认实现的量化配置是 W4A16KV8,具体见下表:
表 1 MLLM 默认量化配置 W4A16KV8 说明
| Tensor 类型 | 量化配置 | 精度 |
| Linear Weight | per-block, block size(16/32), 对称 | int4 |
| RMSNorm Weight | per-tensor,非对称 | uint16 |
| RoPE Embedding Weight | per-tensor,非对称 | uint16 |
| Attention Sink Weight | per-tensor,非对称 | uint16 |
| Activation | per-tensor,非对称 | uint16 |
| KV Cache | per-tensor,对称 | uint8 |
1.3.1 LPBQ
LPBQ 全称为 Low Power Block Quantization。它是 Qualcomm 针对端侧硬件(特别是 Hexagon NPU/DSP)设计的一种细粒度量化方案。
传统的量化粒度通常分为两种:
- Per-Tensor: 整个层共享一个缩放因子 (Scale),计算最快,但精度最差。
- Per-Channel: 每个输出通道有一个缩放因子,精度较好,是 CNN 的标配。
LPBQ 引入了“块”的概念。它将张量划分为更小的块,并为每个块独立计算量化参数。
我们以 Linear 为例,对于 Linear 算子,权重通常表示为矩阵 () 或 。在 DSP/NPU 的底层实现中,为了优化向量化加载,权重常采用 HWIO 格式。
对于 Linear 层:
- H (Height) = 1
- W (Width) = 1
- I (Input Channel): 累加求和维度 (Reduction Dimension)。
- O (Output Channel): 独立计算维度。
LPBQ 对 I(Input Channel)维度做标准的 per block 量化,得到 fp32 的 Scale 和 int32 的 Zero Point。然后对 O(Output Channel)上的所有 Blocks 的 Scale 做二级量化,量化后,每个 Block 的 Scale 为 uint4,O(Output Channel)上的 Scale 为 fp32 精度。
1.3.2 特殊的 Observer 约束
在量化过程中,Observer 的职责是统计 Tensor 的数值分布(Min/Max 或 Histogram)并由此计算量化参数 Scale () 和 Zero Point (). 然而,并非所有的 Observer 都能完全“自由”地确定参数。为了保证数值稳定性或满足硬件约束,我们必须对以下几种情况施加特殊的限制:
1. 防止除零错误的 Epsilon 约束
量化参数 Scale 的计算公式通常为 。当统计到的数值波动极小(例如 )时,分母可能趋近于零,导致计算出的 Scale 爆炸或出现 NaN。
为了数值稳定性,我们需要限制 的最小值为一个 。针对不同的量化位宽,这个阈值通常设定为:
- INT8:
- INT16:
这确保了即使在平坦区域,Scale 也能保持在一个合理的浮点数范围内。
2. Sigmoid/Tanh 等算子的固定 ZP
对于 Sigmoid 或 Tanh 这类输出范围固定的激活函数(Sigmoid 输出恒为 )。
量化后的输出通常要求固定的 Zero Point。例如,在非对称量化(uint8)中,可能会强制约束:
或者根据具体硬件实现固定为其他特定整数值,不再随数据的动态分布而改变。
3. Concat 算子的 Observer 共享 (Share Scale & ZP)
Concat 操作在物理内存上是将多个 Tensor 的数据搬运到一块连续的内存中。
如果输入 Concat 的多个 Tensor 拥有不同的 Scale 和 ZP,它们在物理数值上就无法直接拼接(例如:Tensor A 的数值 5 代表实数 0.5,而 Tensor B 的数值 5 代表实数 2.0,拼接在一起后物理含义混乱)。
为了避免在 Concat 之前插入额外的 Requantize(重对齐)操作从而增加计算开销,通常强制要求 所有输入 Tensor 以及输出 Tensor 共享同一个 Observer,即:
1.3.3 SpinQuant
图 3 SpinQuant (Liu 等, 2024)
SpinQuant 通过引入正交旋转矩阵,在无需训练的情况下有效降低了权重与激活层中的离群值影响。该方法已在 executorch 的 Qualcomm 后端中实现,并已成为事实上的行业标准。然而,在当前的 MLLM AOT 框架中,我们尚未支持旋转变换量化,这将是我们后续的重点工作方向。
1.4 QDQ Modeling
考虑到 NPU 的算子需要 Int 输入和 Int 输出来获得最快的性能。我们需要对每一个算子的前后都做 QDQ,这就带来了一个问题:如何快速地插入正确的 QDQ 算子?在 executorch 中,其使用了一些 pass 来 annotate 每个算子的量化设置,并根据这些量化设置来插入 FakeQuant 节点或者 Observer。
但是完整实现非常复杂的模式匹配是非常耗时的,所以 MLLM 选择了另外一条路:直接在 modeling 文件中显示的声明出来 QDQ 的节点,同时,也是显示的声明出来每个算子的输入和输出使用的是什么类型。
我们先来看一下 Python 中如何插入 QDQ 节点55 pymllm 中提供了一系列的 QDQ class 来辅助。QLinearLPBQ 默认是 W4A16 的精度:
from pymllm.backends.qualcomm.transformers.core.qlinear import (
QLinearLPBQ,
)
from pymllm.backends.qualcomm.transformers.core.qdq import (
ActivationQDQ,
FixedActivationQDQ,
)
class Qwen3MLP(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
self.gate_proj = QLinearLPBQ(
self.hidden_size, self.intermediate_size, bias=False, block_size=16
)
self.up_proj = QLinearLPBQ(
self.hidden_size, self.intermediate_size, bias=False, block_size=16
)
self.down_proj = QLinearLPBQ(
self.intermediate_size, self.hidden_size, bias=False, block_size=16
)
# QDQ
self.up_proj_input_qdq = ActivationQDQ(bits=16)
self.up_proj_output_qdq = ActivationQDQ(bits=16)
self.gate_proj_output_qdq = ActivationQDQ(bits=16)
self.act_output_qdq = ActivationQDQ(bits=16)
self.down_proj_input_qdq = ActivationQDQ(bits=16)
# For sigmoid output: scale = 1 / (q_max - q_min + 1), zp = 0
# For 16-bit: q_min = 0, q_max = 65535
sigmoid_scale = 1.0 / (65535 - 0 + 1) # 1 / 65536
self.sigmoid_output_qdq = FixedActivationQDQ(
scale=sigmoid_scale, zero_point=0, bits=16
)
def forward(self, x):
x = self.up_proj_input_qdq(x)
up_result = self.up_proj_output_qdq(self.up_proj(x))
gate_result = self.gate_proj_output_qdq(self.gate_proj(x))
# SiLU
gate_result = self.act_output_qdq(
gate_result * self.sigmoid_output_qdq(F.sigmoid(gate_result))
)
o = self.down_proj_input_qdq(gate_result * up_result)
o = self.down_proj(o)
return o
我们再看一下 MLLM 中的普通 modeling 文件是怎么写的(不带有 QDQ)。在 MLLM 中,我们通过 C++ 实现的 modeling 文件来表示图,但也不必惊慌,我们对 C++ API 做了许多的封装让其使用起来具有类似于 torch 的体验。比如常见的 MLP Layer 就如下所示:
class Qwen3MLP final : public nn::Module {
nn::Linear gate_proj_;
nn::Linear up_proj_;
nn::Linear down_proj_;
nn::SiLU silu_;
public:
Qwen3MLP() = default;
Qwen3MLP(const std::string& name, const Qwen3Config& cfg) : nn::Module(name) {
gate_proj_ = reg<nn::Linear>("gate_proj", cfg.hidden_size, cfg.intermediate_size, false, cfg.linear_impl_type);
silu_ = reg<nn::SiLU>("act");
up_proj_ = reg<nn::Linear>("up_proj", cfg.hidden_size, cfg.intermediate_size, false, cfg.linear_impl_type);
down_proj_ = reg<nn::Linear>("down_proj", cfg.intermediate_size, cfg.hidden_size, false, cfg.linear_impl_type);
}
std::vector<Tensor> forward(const std::vector<Tensor>& inputs, const std::vector<AnyValue>& args) override {
auto x = gate_proj_(inputs[0]);
x = silu_(x);
auto y = up_proj_(inputs[0]);
x = x * y;
x = down_proj_(x);
return {x};
}
};
用户可以通过这些简单的 API 来搭建出自己的 LLM 结构。
在带上 QDQ 后,MLP Layer 的实现如下66 在这个例子中,我们把 Linear 换成了 Conv2D,这是为了加速推理。其中 ptq::QDQ 的第三个参数就是 Python Modeling 文件中 QDQ 的 FakeQuant Module 的实例化名称:
class Qwen3MLP final : public nn::Module {
nn::Conv2D gate_proj_;
nn::Conv2D up_proj_;
nn::Conv2D down_proj_;
nn::SiLU silu_;
int hidden_size_;
int intermediate_size_;
public:
Qwen3MLP() = default;
Qwen3MLP(const std::string& name, const Qwen3Config& cfg) : nn::Module(name) {
gate_proj_ = reg<nn::Conv2D>("gate_proj", cfg.hidden_size, cfg.intermediate_size, CONV2D_PROPERTY);
silu_ = reg<nn::SiLU>("act");
up_proj_ = reg<nn::Conv2D>("up_proj", cfg.hidden_size, cfg.intermediate_size, CONV2D_PROPERTY);
down_proj_ = reg<nn::Conv2D>("down_proj", cfg.intermediate_size, cfg.hidden_size, CONV2D_PROPERTY);
hidden_size_ = cfg.hidden_size;
intermediate_size_ = cfg.intermediate_size;
}
std::vector<Tensor> forward(const std::vector<Tensor>& inputs, const std::vector<AnyValue>& args) override {
auto x = inputs[0];
x = ptq::QDQ(this, x, "up_proj_input_qdq");
x = x.view({1, 1, -1, hidden_size_}, true);
auto up_result = ptq::QDQ(this, up_proj_(x), "up_proj_output_qdq").view({1, -1, intermediate_size_}, true);
auto gate_result = ptq::QDQ(this, gate_proj_(x), "gate_proj_output_qdq").view({1, -1, intermediate_size_}, true);
// SiLU
gate_result = ptq::QDQ(this, (gate_result * ptq::QDQ(this, nn::functional::sigmoid(gate_result), "sigmoid_output_qdq")),
"act_output_qdq");
auto o = ptq::QDQ(this, gate_result * up_result, "down_proj_input_qdq");
o = o.view({1, 1, -1, intermediate_size_}, true);
o = down_proj_(o).view({1, -1, hidden_size_}, true);
return {o};
}
};
1.5 MLLM IR
MLLM 的 IR 是对 MLLM 中 modeling 文件的忠实表示。MLLM IR 是一个静态图的,非 SSA(Static Single Assignment) 形式的 IR。之所以不选择 SSA,是因为计算图的 IR 实现比较简单,从工程角度来看实现起来比较快;再者,由于计算图都是节点和节点之间的互联,所有的内存申请都以 Tensor 的形式来表现,SSA 就显得没有那么必要了。对于一个 MLP,MLLM 会生成如下的 IR77 这个例子来源于 Qwen2 VL 模型:
graph.SubGraphOp @model.layers.6.mlp <CPU> {
(%1454:tensor<[1, 192, 1536], Float32, CPU>) -> (%1459:tensor<[1, 192, 1536], Float32, CPU>) {
linalg.CPU.LinearOp(%1454:tensor<[1, 192, 1536], Float32, CPU>) -> (%1455:tensor<[1, 192, 8960], Float32, CPU>)
linalg.CPU.SiLUOp(%1455:tensor<[1, 192, 8960], Float32, CPU>) -> (%1456:tensor<[1, 192, 8960], Float32, CPU>)
linalg.CPU.LinearOp(%1454:tensor<[1, 192, 1536], Float32, CPU>) -> (%1457:tensor<[1, 192, 8960], Float32, CPU>)
linalg.CPU.MulOp(%1456:tensor<[1, 192, 8960], Float32, CPU>, %1457:tensor<[1, 192, 8960], Float32, CPU>) -> (%1458:tensor<[1, 192, 8960], Float32, CPU>)
linalg.CPU.LinearOp(%1458:tensor<[1, 192, 8960], Float32, CPU>) -> (%1459:tensor<[1, 192, 1536], Float32, CPU>)
cf.ReturnOp (%1459:tensor<[1, 192, 1536], Float32, CPU>) -> ()
}
}
上述的 IR 是 Trace 出来的最简单的情况。为了能够让量化的设置和模型的结构能符合 Qnn 的限制,MLLM Qnn AOT 模式实现了一系列的 Passes 来辅助 IR 转换。这些 Passes 如下(表格中的顺序就是实际的执行顺序):
表 2 MLLM Qnn AOT Passes 说明
| Pass Name | 作用 | 限制 |
| MarkQnnGraphPass | 通过传入的配置文件来将 Graph Op 和一些 Linalg Op 标记为需要在 NPU 上执行 | |
| OpNamingPass | 一些 nn.functional 的算子在 trace 的时候是匿名的,在这个 Pass 中,给这些匿名算子赋值独立的名称 | |
| MergeLLMHeadIntoMainGraphPass | 通常,LLM Head 是独立于所有的 Graph Op 的,这个 Pass 将 LM Head 算子给移动到 Model Graph 内部 | 这个 Pattern 只作用在 LLM 中 |
| LLMQuantRecipePass | 通过前端出入的量化类型,推导出来每个算子的 QDQ 能不能被复用,会在量化章节细讲 | |
| PTQPass | 通过 Quant Recipe 将 params 中的 scale 和 zero point 对应到每一个 activation 和 weight 上。 若有 constant value,那么对 constant value 做量化 | |
| SplitLLMGraphPass | 如果一个 LLM 在 4GB 内放不下,那么就需要拆图到不同的 Context。除了拆图,这个 Pass 还会把所有的图展平为一张大图 | 目前只支持一个图,是未来需要完善的工作 |
| MarkTensorIOPass | 将输入输出都打上 IO tag | |
| LLM2QnnLoweringPass | 遍历 IR,通过 Qnn 的离线建图 API 建图 |
1.6 Qnn AOT Generation
我们在原有的 Qnn Backend88 MLLM 的 Qnn Backend 的第一版是在手机上 Prepare Qnn Graph的 的基础上做了拓展,wrap 出了 X86 上可以使用的编译环境。用户可以通过在编译的时候指定 AOTEnv 中的 target machine 的硬件参数来编译适配不同设备的 Qnn Graph。
{
"target_machine": {
"htp_arch": "V75",
"htp_chipset": "SM8650",
"htp_try_best_performance": "HtpBurst",
"htp_security_pd_session": "HtpSignedPd",
"htp_vtcm_capability_in_mb": 8
}
}
如上述的 json 文件,在 X86 上编译 Qnn Graph 的时候,需要向 MLLM 的 AOTEnv 提供 htp 的 chipset、arch 等硬件相关信息。
1.7 Full NPU Execution
在 NPU 上执行 LLM 的 Prefill 和 Decode 两个阶段是一个挑战。目前 MLLM AOT 的解决方案是:分别编译两张计算图,一张对应 chunk size 为 32 的图,另一张对应 chunk size 为 1 的图。这两张图分别用于 Prefill 和 Decode 阶段。
1.7.1 KV Cache 管理
我们学习并实现了 executorch 的 KV Cache 管理方法。
图 4 KV Cache 管理, executorch 的官方解析
所有的 KV Cache 都是定长的大小,其中 Key 的形状是 [B, H, D, S],Value 的形状是 [B, H, S, D]。对于 Context Length 的长度,
每次做 Attention 计算的时候,就直接把新算好的 chunk 拼接到 Key 和 Value 的末尾,组成长度为 的 Key 和 Value。为此,Causal Mask 也需要做一定的更改,如下99 Cite Executorch 的 py 注释:
1. Full Attention
Step = 0
0| 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0
1| 0 0 0 0 0 0 0 0 0 0 1 1 0 0 0
2| 0 0 0 0 0 0 0 0 0 0 1 1 1 0 0
3| 0 0 0 0 0 0 0 0 0 0 1 1 1 1 0
4| 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1
Step = 1
0| 1 1 1 1 1 0 0 0 0 0 1 0 0 0 0
1| 1 1 1 1 1 0 0 0 0 0 1 1 0 0 0
2| 1 1 1 1 1 0 0 0 0 0 1 1 1 0 0
3| 1 1 1 1 1 0 0 0 0 0 1 1 1 1 0
4| 1 1 1 1 1 0 0 0 0 0 1 1 1 1 1
Step = 2
0| 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0
1| 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0
2| 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0
3| 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0
4| 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
2. Sliding Window Attention
Step = 0
0| 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0
1| 0 0 0 0 0 0 0 0 0 0 1 1 0 0 0
2| 0 0 0 0 0 0 0 0 0 0 1 1 1 0 0
3| 0 0 0 0 0 0 0 0 0 0 0 1 1 1 0
4| 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1
Step = 1
0| 0 0 0 1 1 0 0 0 0 0 1 0 0 0 0
1| 0 0 0 0 1 0 0 0 0 0 1 1 0 0 0
2| 0 0 0 0 0 0 0 0 0 0 1 1 1 0 0
3| 0 0 0 0 0 0 0 0 0 0 0 1 1 1 0
4| 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1
Step = 2
0| 0 0 0 0 0 0 0 0 1 1 1 0 0 0 0
1| 0 0 0 0 0 0 0 0 0 1 1 1 0 0 0
2| 0 0 0 0 0 0 0 0 0 0 1 1 1 0 0
3| 0 0 0 0 0 0 0 0 0 0 0 1 1 1 0
4| 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1
1.7.2 Multiple Graph Selection
由于 NPU 采用静态计算图的架构,其运算过程需要预先定义好完整的计算流程。以 Attention 机制为例,每次计算都需要按照最大长度(例如 1024)进行填充处理,这会导致大量冗余计算,显著增加时间开销。为了解决这一问题,可以采取一种折中方案:预先准备多组不同形状的 Qnn 图(即针对不同输入长度的计算图),并在每次执行时动态选择与当前输入最匹配(计算损失最小)的图来执行,从而在性能与灵活性之间取得平衡。
1.8 Qnn Model Runtime for LLM
Qnn Model Runtime 的实现相对简单,主要通过调用 Qnn Graph Execute API 执行图即可。但需要注意并可以改进的一点是:在从 Prefill 阶段切换到 Decode 阶段时,需要对所有 KV Cache 进行重排,这涉及大量数据拷贝。
之所以需要重排,是因为 Key 和 Value 每次计算出的小块都会被追加到 Past Key 和 Past Value 的末尾。然而,Prefill 阶段的 Key 占用了 chunk size 的空间,而 Decode 阶段的 Key 仅需 chunk size = 1 的空间。由于 Key 的格式为 [B, H, D, S],无法通过指针偏移直接获取 Decode 阶段所需的 Past Key,因此在 Prefill 到 Decode 的切换过程中必须进行重排。
1.9 Performance
我们在 SM8650 设备上做了实验
表 3 SM8650 上 LLM 性能, w4a16kv8. A: prefill 256 tokens, decode 128 tokens; B: prefill 512 tokens, decode 256 tokens; C: prefill 640 tokens, decode 128 tokens.
| 模型 | Prefill(token/s) | Decode(token/s) |
| Qwen2.5 3B, A | 512.75 | 18.29 |
| Qwen2.5 3B, B | 545.62 | 18.11 |
| Qwen3 1.7B, C | 758.28 | 25.84 |
1.10 总结
MLLM Qnn AOT 框架通过统一的中间表示、显式的 QDQ 建模以及创新的双图执行策略,成功解决了在高通 NPU 上高效部署 LLM 的关键难题,实现了高性能的端侧推理。该工作流将量化、编译和运行时集成一体,为开发者提供了便捷的部署路径。未来工作将聚焦于支持更先进的量化方法(如 SpinQuant)以及进一步优化多图选择策略。
1.11 Appendix
1.11.1 Quantize/Dequantize, Cast And Convert
高通提供了一个 Convert Op 来辅助量化类型之间的 Convert,不需要使用 Quantize 和 Dequantize 的组合。Cast Op 并不会考虑 Tensor 是否具有量化类型,Cast 是强制的类型转换。
1.11.2 zero point should be negative in QNN
尽管高通在文档中给出的很多例子的 offset 是正的。但是实际上高通的量化方法应当如下面的公式一样:
1.11.3 HMX Enable on short depth Convolution
慎用 QNN_HTP_GRAPH_CONFIG_OPTION_SHORT_DEPTH_CONV_ON_HMX_OFF flag。
p_custom_config = (QnnHtpGraph_CustomConfig_t*)malloc(sizeof(QnnHtpGraph_CustomConfig_t));
p_custom_config->option = QNN_HTP_GRAPH_CONFIG_OPTION_SHORT_DEPTH_CONV_ON_HMX_OFF;
p_custom_config->shortDepthConvOnHmxOff = true;
htp_graph_configs.push_back(static_cast<QnnGraph_CustomConfig_t>(p_custom_config));
将 Optimize Level 设置为 3 可能可以带来一定的性能提升:
p_custom_config = (QnnHtpGraph_CustomConfig_t*)malloc(sizeof(QnnHtpGraph_CustomConfig_t));
p_custom_config->option = QNN_HTP_GRAPH_CONFIG_OPTION_OPTIMIZATION;
p_custom_config->optimizationOption.type = QNN_HTP_GRAPH_OPTIMIZATION_TYPE_FINALIZE_OPTIMIZATION_FLAG;
p_custom_config->optimizationOption.floatValue = 3;
htp_graph_configs.push_back(static_cast<QnnGraph_CustomConfig_t>(p_custom_config));
1.11.4 LPBQ Packing
LPBQ的量化方法要求用户提供一个int8数组来表示int4权重。高通会对这个数组进行打包处理。然而,高通的打包算法实现方式如下:
byte = val1 | (val2 << 4)
在 MLLM 中,我们会在一个 int8 的数组上做 0x0F 的与来获得干净的高四位。 Python 代码如下:weight_int4 = quantized_weight.to(torch.int8)
mask = torch.full(
weight_int4.size(), 0x0F, dtype=torch.int8, device=weight_int4.device
)
weight_int4 = torch.bitwise_and(mask, weight_int4)
而不是下面的这个做法:
byte = (val1 & 0x0F) | ((val2 & 0x0F) << 4)
References
- Liu, Z., Zhao, C., Fedorov, I., Soran, B., Choudhary, D., Krishnamoorthi, R., Chandra, V., Tian, Y., & Blankevoort, T. (2024). Spinquant: Llm quantization with learned rotations. Arxiv Preprint Arxiv:2405.16406.