MLLM Qnn AOT Support for Full Graph Execution on NPU

Translated by Gemini 3 Flash

github repo: https://github.com/UbiquitousLearning/mllm

website: https://ubiquitouslearning.github.io/mllm/index.html

1.1 Introduction

Deploying Large Language Models (LLMs) on Qualcomm NPUs has long been a challenge for the industry due to their static graph architecture and full-integer computation requirements. To facilitate more convenient on-device deployment, we have implemented support for the Qualcomm Qnn framework within the MLLM framework. Now, users can compile LLMs using MLLM’s Qnn AOT mode on X86 machines and subsequently run them efficiently on edge devices equipped with Qualcomm NPUs using the MLLM Qnn AOT Runtime. Through the meticulous design of the MLLM team, the framework supports the entire workflow of model quantization, AOT compilation, and actual execution in a lightweight manner via a unified intermediate representation (IR). Under our optimization, the Qwen3 1.7B model11 Using W4A16KV8 quantization configuration, chunk=32, pad=1024, running on Qualcomm SM8650 achieved a performance of 758.28 tokens/s for prefill and 25.84 tokens/s for decode in a scenario with 640 prefill tokens and 128 decode tokens. We plan to further accelerate prefill and decode speeds using Graph Selection methods in the future.

Edge devices are highly sensitive to power consumption and heat dissipation, while LLM inference itself is a task with extremely high latency requirements. Although current Arm CPUs are generally equipped with acceleration units such as SVE and SME, their performance still struggles to fully meet the computational demands of LLM inference. Furthermore, in edge devices, the CPU is typically required to handle other low-latency tasks; offloading computationally intensive LLM tasks to the CPU is neither reasonable nor realistic22 Due to considerations of power consumption, heat dissipation, and impact on other applications, CPUs are often restricted to single-threaded execution for LLMs in actual deployments. Similarly, while edge devices usually feature decent GPUs and executing models on GPUs via OpenCL is a viable direction, the GPU is also responsible for real-time UI rendering. Allocating its resources to LLM computation would significantly degrade the user experience. Therefore, it is particularly important to unify the execution of all LLM tasks (such as prefill, decode, and ViT) on the low-power NPU.

Inference on NPUs (hereafter, “NPU” refers to Qualcomm’s NPU unless otherwise specified) presents the following difficulties:

  1. Slow Decode Speed The prefill stage is typically fast, but the decode stage is relatively slow. This is primarily because the small computational load fails to fully utilize the NPU’s compute units, making LLM inference bound by memory bandwidth.

  2. Severe Accuracy Loss in Full-Integer Computation All operators on the NPU expect Per-Tensor Int quantization for inputs and outputs to achieve maximum performance. However, full-integer Per-Tensor quantization causes significant performance degradation for LLMs.33 Due to the nature of fixed-point computation within the NPU, operators like SiLU, RMSNorm, and RoPE require integer inputs and outputs. Specifically, RoPE needs to be decomposed into a series of element-wise operators, each requiring integer inputs and outputs. We have some recent work related to NPU-friendly quantization; please stay tuned.

  3. Difficulty Implementing LLM Dynamic Shapes on NPU Static Graphs Qnn computation graphs are constructed offline, but the self-regressive inference of LLMs involves dynamic shapes in the Attention mechanism, necessitating significant padding, which incurs performance overhead.

During the implementation of MLLM Qnn AOT, we encountered numerous technical challenges, the details of which are elaborated in the Appendix. The entire process of adapting the Qnn AOT mode was fraught with difficulties—the team lacked sufficient trust in Qnn’s documentation and implementation, the accuracy of quantization algorithms, and the implementation of the AOT Compile Stack in MLLM, leading to a “chain of suspicion.”

Figure 1: Who is the imposter? Is it Qnn, the quantization algorithm, or MLLM Qnn AOT?

The root cause lies in the closed nature of the Qnn software stack and its incomplete or erroneous documentation. The black-box nature of Qnn’s internal computation graphs makes debugging extremely difficult, often leading us to wonder if the issue lies within Qnn’s own implementation. We also discovered inconsistencies between the description and behavior of Qnn’s RMSNorm Op, as well as inconsistent descriptions of quantization methods. These issues are listed individually in the Appendix.

The MLLM Qnn AOT workflow primarily consists of three parts: 1. Implementation of modeling files with QDQ44 This requires both Python modeling files in the Hugging Face Transformers style and C++ modeling files in MLLM. The model construction methods for these two are quite similar, and AI can handle them effectively. 2. Exporting the quantized model to MLLM format and compiling the specific mllm-aot-compiler for the model. 3. Using the mllm-aot-compiler to compile the model and running it with the mllm-aot-runner.

Figure 2: MLLM Qnn AOT Workflow

In this MLLM Qnn AOT implementation, AI played a massive role, becoming an integral part of the workflow. We exposed many QDQ (Quantize, Dequantize) details at the modeling level rather than hiding them in the compiler’s annotation phase, as seen in ExecuTorch. The core philosophy is: If AI can do it, why hide it from the compiler? Exposing sufficient details at the frontend makes it easier for both humans and AI to read. Once we provide a well-implemented Qwen3 modeling script, AI can easily generate modeling files for other models like Llama or Qwen2.5.

Finally, we would like to thank the students who participated in the development of MLLM Qnn AOT. We also extend special thanks to the ExecuTorch community’s Qualcomm backend implementation. We drew significant inspiration from ExecuTorch, and its community members were very friendly, providing responses and assistance for the issues we encountered.

While Qualcomm’s NPU is currently top-tier, its weak software ecosystem limits open-source developers from doing further work on its chips. After being tormented by Qualcomm’s Qnn, I must borrow Torvalds’ words to summarize: XXXX you, Qualcomm!

In this article, I will first explain why the NPU expects both inputs and outputs to be of integer types. Then, following the MLLM Qnn AOT workflow, I will sequentially introduce quantization algorithm configurations, QDQ modeling file authorship, MLLM IR, Qnn AOT generation, and the Qnn AOT Runtime.

1.2 Why Are Integer Inputs and Outputs Necessary?

In mobile AI development (especially when using SNPE or the QNN SDK), developers often encounter a rigid requirement: the model must be quantized, and the NPU strongly prefers Int8/Uint8 formats for inputs and outputs. Given that modern mobile CPUs and GPUs have such strong floating-point performance, why does the Hexagon NPU insist on “returning to basics” with fixed-point computation? Why are even the inputs and outputs locked to integers? This is likely due to two reasons: 1. Historical inertia from convolutional networks 2. The NPU’s design goal of extreme energy efficiency (TOPS/Watt).

Let’s take a basic operator, Per-Tensor Uint8 Elementwise Multiplication, as an example to deconstruct the underlying fixed-point logic.

We want to perform an element-wise multiplication of two tensors and to get :

However, inside the NPU, we don’t have real floating-point values (, real), only quantized integers (, quantized). According to the Per-Tensor quantization formula (where all elements share a single Scale and Zero-point):

Where:

Substituting the formula into the multiplication operation:

Our goal is to find the integer that the NPU needs to output. Rearranging the terms gives the core calculation formula:

Note the term labeled :

are all floating-point numbers, so the calculated is also a floating-point number (e.g., ). If the NPU were to compute this formula directly, it would need a built-in floating-point multiplier to handle , which defeats the purpose of full fixed-point computation.

Therefore, the NPU performs Fixed-point Rescaling.

The NPU approximates the floating-point as “an integer multiplier + a right-shift operation”:

Where Multiplier is a large integer (e.g., int32) and Shift is the bitwise shift amount. Thus, the entire computation process becomes pure integer arithmetic:

  1. Input Preparation: The NPU reads Int8 and (this is why inputs must be integers; otherwise, it would fail at the first step).
  2. Integer Subtraction: .
  3. Integer Multiplication: (the result could be int16 or int32).
  4. Fixed-point Rescaling (Crucial Step): . Here, integer multiplication and shifting replace floating-point multiplication.
  5. Add Zero-point: .
  6. Saturation: Clamp the result to the range and output .

You might think is just another way of writing it, but in reality, it is a lossy transformation, not a mathematical equivalence.

Mathematically, the decimal places of a floating-point (e.g., ) can theoretically extend infinitely. However, in hardware, the "Multiplier" is typically restricted to a 32-bit integer (Int32).

The round operation here forces the discarding of the fractional part of the result. This introduces unavoidable quantization error.

Suppose the true ratio . Since the hardware requires the Multiplier to be an integer, let’s try representing it with different Shift values:

Attempt 1 (Shift=8):

After rounding, Multiplier = 26. Reverting back: . Error is approximately 1.5%.

Attempt 2 (Shift=31, the limit of a 32-bit integer):

After rounding, Multiplier = 214748365. Reverting back: . Error is extremely small, but it still exists.

This means that results from ordinary FakeQuant still cannot fully simulate the computational precision errors of Qnn! The more operators there are, the greater the loss!

1.3 Quantization Algorithm

In MLLM, the default quantization configuration we implement is W4A16KV8, as detailed in the table below:

Table 1: MLLM Default Quantization Configuration W4A16KV8 Description

Tensor Type Quantization Config Precision
Linear Weight per-block, block size(16/32), symmetric int4
RMSNorm Weight per-tensor, asymmetric uint16
RoPE Embedding Weight per-tensor, asymmetric uint16
Attention Sink Weight per-tensor, asymmetric uint16
Activation per-tensor, asymmetric uint16
KV Cache per-tensor, symmetric uint8

1.3.1 LPBQ

LPBQ stands for Low Power Block Quantization. It is a fine-grained quantization scheme designed by Qualcomm for edge hardware (specifically Hexagon NPU/DSP).

Traditional quantization granularities are typically of two types:

LPBQ introduces the concept of “blocks.” It divides the tensor into smaller blocks and calculates quantization parameters independently for each block.

Taking Linear as an example, for a Linear operator, weights are usually represented as an () or matrix. In the low-level implementation of DSP/NPU, weights often use the HWIO format to optimize vectorized loading.

For a Linear layer:

LPBQ performs standard per-block quantization on the I (Input Channel) dimension, yielding fp32 Scales and int32 Zero Points. Then, it performs second-level quantization on the Scales of all blocks along the O (Output Channel). After quantization, each block’s Scale is uint4, while the Scale on the O (Output Channel) remains at fp32 precision.

1.3.2 Special Observer Constraints

During the quantization process, the Observer’s role is to collect the numerical distribution of a tensor (Min/Max or Histogram) and use it to calculate the quantization parameters Scale () and Zero Point (). However, not all Observers can “freely” determine these parameters. To ensure numerical stability or meet hardware constraints, we must impose special restrictions in the following cases:

1. Epsilon Constraint to Prevent Division-by-Zero Errors

The formula for calculating the Scale is typically . When the observed numerical variation is extremely small (e.g., ), the denominator may approach zero, causing the calculated Scale to explode or become NaN.

For numerical stability, we must limit the minimum value of to an . For different quantization bit-widths, this threshold is usually set to:

This ensures that the Scale remains within a reasonable floating-point range even in flat regions.

2. Fixed ZP for Operators like Sigmoid/Tanh

For activation functions with fixed output ranges, such as Sigmoid or Tanh (Sigmoid output is always in ).

The quantized output typically requires a fixed Zero Point. For example, in asymmetric quantization (uint8), a constraint might be enforced:

Or it may be fixed to another specific integer value based on the hardware implementation, rather than changing with the dynamic distribution of the data.

3. Observer Sharing for Concat Operators (Shared Scale & ZP)

The Concat operation involves moving data from multiple tensors into a single contiguous block of physical memory.

If the multiple tensors input to Concat have different Scales and ZPs, they cannot be directly concatenated in their physical values (e.g., a value of 5 in Tensor A represents the real number 0.5, while a value of 5 in Tensor B represents the real number 2.0, leading to physical ambiguity when concatenated).

To avoid the computational overhead of inserting extra Requantize (re-alignment) operations before Concat, it is usually mandatory that all input tensors and the output tensor share the same Observer, i.e.:

1.3.3 SpinQuant

Figure 3: SpinQuant (Liu et al., 2024)

SpinQuant effectively reduces the impact of outliers in weights and activation layers by introducing orthogonal rotation matrices without requiring training. This method has been implemented in ExecuTorch’s Qualcomm backend and has become a de facto industry standard. However, the current MLLM AOT framework does not yet support rotation transform quantization, which will be a key focus of our future work.

1.4 QDQ Modeling

Given that NPU operators require integer inputs and outputs to achieve peak performance, we need to perform QDQ before and after every operator. This raises a question: how can we quickly insert the correct QDQ operators? ExecuTorch uses certain passes to annotate the quantization settings for each operator and inserts FakeQuant nodes or Observers based on these settings.

However, implementing a complete and complex pattern matching system is very time-consuming. Therefore, MLLM chose a different path: explicitly declaring QDQ nodes directly in the modeling file and explicitly declaring the types used for each operator’s inputs and outputs.

Let’s first look at how to insert QDQ nodes in Python55 pymllm provides a series of QDQ classes for assistance. QLinearLPBQ defaults to W4A16 precision:

from pymllm.backends.qualcomm.transformers.core.qlinear import (
QLinearLPBQ,
)
from pymllm.backends.qualcomm.transformers.core.qdq import (
ActivationQDQ,
FixedActivationQDQ,
)

class Qwen3MLP(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
self.gate_proj = QLinearLPBQ(
self.hidden_size, self.intermediate_size, bias=False, block_size=16
)
self.up_proj = QLinearLPBQ(
self.hidden_size, self.intermediate_size, bias=False, block_size=16
)
self.down_proj = QLinearLPBQ(
self.intermediate_size, self.hidden_size, bias=False, block_size=16
)
# QDQ
self.up_proj_input_qdq = ActivationQDQ(bits=16)
self.up_proj_output_qdq = ActivationQDQ(bits=16)
self.gate_proj_output_qdq = ActivationQDQ(bits=16)
self.act_output_qdq = ActivationQDQ(bits=16)
self.down_proj_input_qdq = ActivationQDQ(bits=16)
# For sigmoid output: scale = 1 / (q_max - q_min + 1), zp = 0
# For 16-bit: q_min = 0, q_max = 65535
sigmoid_scale = 1.0 / (65535 - 0 + 1) # 1 / 65536
self.sigmoid_output_qdq = FixedActivationQDQ(
scale=sigmoid_scale, zero_point=0, bits=16
)

def forward(self, x):
x = self.up_proj_input_qdq(x)
up_result = self.up_proj_output_qdq(self.up_proj(x))
gate_result = self.gate_proj_output_qdq(self.gate_proj(x))

# SiLU
gate_result = self.act_output_qdq(
gate_result * self.sigmoid_output_qdq(F.sigmoid(gate_result))
)

o = self.down_proj_input_qdq(gate_result * up_result)
o = self.down_proj(o)
return o

Now let’s look at how a standard modeling file (without QDQ) is written in MLLM. In MLLM, we represent the graph using modeling files implemented in C++. Don’t be alarmed; we have wrapped the C++ API extensively to provide an experience similar to Torch. For instance, a common MLP layer looks like this:

class Qwen3MLP final : public nn::Module {
nn::Linear gate_proj_;
nn::Linear up_proj_;
nn::Linear down_proj_;
nn::SiLU silu_;

public:
Qwen3MLP() = default;
Qwen3MLP(const std::string& name, const Qwen3Config& cfg) : nn::Module(name) {
gate_proj_ = reg<nn::Linear>("gate_proj", cfg.hidden_size, cfg.intermediate_size, false, cfg.linear_impl_type);
silu_ = reg<nn::SiLU>("act");
up_proj_ = reg<nn::Linear>("up_proj", cfg.hidden_size, cfg.intermediate_size, false, cfg.linear_impl_type);
down_proj_ = reg<nn::Linear>("down_proj", cfg.intermediate_size, cfg.hidden_size, false, cfg.linear_impl_type);
}

std::vector<Tensor> forward(const std::vector<Tensor>& inputs, const std::vector<AnyValue>& args) override {
auto x = gate_proj_(inputs[0]);
x = silu_(x);
auto y = up_proj_(inputs[0]);
x = x * y;
x = down_proj_(x);
return {x};
}
};

Users can build their own LLM structures using these simple APIs.

With QDQ added, the MLP layer implementation is as follows66 In this example, we replaced Linear with Conv2D to accelerate inference. The third parameter of ptq::QDQ is the instantiation name of the FakeQuant module in the Python modeling file:

class Qwen3MLP final : public nn::Module {
nn::Conv2D gate_proj_;
nn::Conv2D up_proj_;
nn::Conv2D down_proj_;
nn::SiLU silu_;
int hidden_size_;
int intermediate_size_;

public:
Qwen3MLP() = default;
Qwen3MLP(const std::string& name, const Qwen3Config& cfg) : nn::Module(name) {
gate_proj_ = reg<nn::Conv2D>("gate_proj", cfg.hidden_size, cfg.intermediate_size, CONV2D_PROPERTY);
silu_ = reg<nn::SiLU>("act");
up_proj_ = reg<nn::Conv2D>("up_proj", cfg.hidden_size, cfg.intermediate_size, CONV2D_PROPERTY);
down_proj_ = reg<nn::Conv2D>("down_proj", cfg.intermediate_size, cfg.hidden_size, CONV2D_PROPERTY);
hidden_size_ = cfg.hidden_size;
intermediate_size_ = cfg.intermediate_size;
}

std::vector<Tensor> forward(const std::vector<Tensor>& inputs, const std::vector<AnyValue>& args) override {
auto x = inputs[0];
x = ptq::QDQ(this, x, "up_proj_input_qdq");
x = x.view({1, 1, -1, hidden_size_}, true);
auto up_result = ptq::QDQ(this, up_proj_(x), "up_proj_output_qdq").view({1, -1, intermediate_size_}, true);
auto gate_result = ptq::QDQ(this, gate_proj_(x), "gate_proj_output_qdq").view({1, -1, intermediate_size_}, true);
// SiLU
gate_result = ptq::QDQ(this, (gate_result * ptq::QDQ(this, nn::functional::sigmoid(gate_result), "sigmoid_output_qdq")),
"act_output_qdq");
auto o = ptq::QDQ(this, gate_result * up_result, "down_proj_input_qdq");
o = o.view({1, 1, -1, intermediate_size_}, true);
o = down_proj_(o).view({1, -1, hidden_size_}, true);

return {o};
}
};

1.5 MLLM IR

MLLM’s IR is a faithful representation of the modeling files in MLLM. MLLM IR is a static graph, non-SSA (Static Single Assignment) IR. We chose not to use SSA because a computation graph IR is simpler and faster to implement from an engineering perspective. Furthermore, since computation graphs consist of interconnected nodes and all memory allocations are represented as tensors, SSA is less necessary. For an MLP, MLLM generates the following IR77 This example is from the Qwen2 VL model:

graph.SubGraphOp @model.layers.6.mlp <CPU> {
(%1454:tensor<[1, 192, 1536], Float32, CPU>) -> (%1459:tensor<[1, 192, 1536], Float32, CPU>) {
linalg.CPU.LinearOp(%1454:tensor<[1, 192, 1536], Float32, CPU>) -> (%1455:tensor<[1, 192, 8960], Float32, CPU>)
linalg.CPU.SiLUOp(%1455:tensor<[1, 192, 8960], Float32, CPU>) -> (%1456:tensor<[1, 192, 8960], Float32, CPU>)
linalg.CPU.LinearOp(%1454:tensor<[1, 192, 1536], Float32, CPU>) -> (%1457:tensor<[1, 192, 8960], Float32, CPU>)
linalg.CPU.MulOp(%1456:tensor<[1, 192, 8960], Float32, CPU>, %1457:tensor<[1, 192, 8960], Float32, CPU>) -> (%1458:tensor<[1, 192, 8960], Float32, CPU>)
linalg.CPU.LinearOp(%1458:tensor<[1, 192, 8960], Float32, CPU>) -> (%1459:tensor<[1, 192, 1536], Float32, CPU>)
cf.ReturnOp (%1459:tensor<[1, 192, 1536], Float32, CPU>) -> ()
}
}

The IR above is the simplest case obtained through tracing. To ensure that quantization settings and model structure comply with Qnn restrictions, MLLM Qnn AOT mode implements a series of passes to assist in IR conversion. These passes are as follows (the order in the table is the actual execution order):

Table 2: MLLM Qnn AOT Passes Description

Pass Name Function Constraints
MarkQnnGraphPass Marks Graph Ops and some Linalg Ops to be executed on the NPU based on the provided configuration file
OpNamingPass Assigns unique names to anonymous operators (e.g., from nn.functional) during tracing
MergeLLMHeadIntoMainGraphPass Moves the LM Head operator into the Model Graph, as it is typically independent of other Graph Ops Only applied to LLMs
LLMQuantRecipePass Deduces whether each operator’s QDQ can be reused based on the frontend quantization type (detailed in the quantization section)
PTQPass Maps Scales and Zero Points from parameters to each activation and weight via the Quant Recipe. Quantizes constant values if present
SplitLLMGraphPass Splits the graph into different contexts if an LLM exceeds 4GB. Also flattens all subgraphs into a single large graph Currently only supports single graphs; future work will improve this
MarkTensorIOPass Tags inputs and outputs as IO
LLM2QnnLoweringPass Traverses the IR and builds the graph using Qnn’s offline graph construction API

1.6 Qnn AOT Generation

We extended the original Qnn Backend88 The first version of MLLM’s Qnn Backend prepared the Qnn graph on the mobile device and wrapped a compilation environment usable on X86. Users can compile Qnn graphs for different devices by specifying hardware parameters for the target_machine in the AOTEnv during compilation.

{
"target_machine": {
"htp_arch": "V75",
"htp_chipset": "SM8650",
"htp_try_best_performance": "HtpBurst",
"htp_security_pd_session": "HtpSignedPd",
"htp_vtcm_capability_in_mb": 8
}
}

As shown in the JSON above, when compiling a Qnn graph on X86, hardware-related information such as the HTP chipset and architecture must be provided to MLLM’s AOTEnv.

1.7 Full NPU Execution

Executing the prefill and decode stages of an LLM on the NPU is a challenge. The current MLLM AOT solution is to compile two separate computation graphs: one with a chunk size of 32 and another with a chunk size of 1. these two graphs are used for the prefill and decode stages, respectively.

1.7.1 KV Cache Management

We learned from and implemented the KV Cache management method used in ExecuTorch.

Figure 4: KV Cache Management, Official ExecuTorch analysis

All KV Caches are of fixed length, with Key shapes as [B, H, D, S] and Value shapes as [B, H, S, D]. For the context length:

During each Attention computation, the newly computed chunk is directly appended to the end of the Key and Value to form a Key and Value of length "context_length". Consequently, the Causal Mask also needs adjustment, as shown below99 Citing ExecuTorch Python comments:

1. Full Attention

Step = 0

0| 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0
1| 0 0 0 0 0 0 0 0 0 0 1 1 0 0 0
2| 0 0 0 0 0 0 0 0 0 0 1 1 1 0 0
3| 0 0 0 0 0 0 0 0 0 0 1 1 1 1 0
4| 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1

Step = 1

0| 1 1 1 1 1 0 0 0 0 0 1 0 0 0 0
1| 1 1 1 1 1 0 0 0 0 0 1 1 0 0 0
2| 1 1 1 1 1 0 0 0 0 0 1 1 1 0 0
3| 1 1 1 1 1 0 0 0 0 0 1 1 1 1 0
4| 1 1 1 1 1 0 0 0 0 0 1 1 1 1 1

Step = 2

0| 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0
1| 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0
2| 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0
3| 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0
4| 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1

2. Sliding Window Attention

Step = 0

0| 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0
1| 0 0 0 0 0 0 0 0 0 0 1 1 0 0 0
2| 0 0 0 0 0 0 0 0 0 0 1 1 1 0 0
3| 0 0 0 0 0 0 0 0 0 0 0 1 1 1 0
4| 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1

Step = 1

0| 0 0 0 1 1 0 0 0 0 0 1 0 0 0 0
1| 0 0 0 0 1 0 0 0 0 0 1 1 0 0 0
2| 0 0 0 0 0 0 0 0 0 0 1 1 1 0 0
3| 0 0 0 0 0 0 0 0 0 0 0 1 1 1 0
4| 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1

Step = 2

0| 0 0 0 0 0 0 0 0 1 1 1 0 0 0 0
1| 0 0 0 0 0 0 0 0 0 1 1 1 0 0 0
2| 0 0 0 0 0 0 0 0 0 0 1 1 1 0 0
3| 0 0 0 0 0 0 0 0 0 0 0 1 1 1 0
4| 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1

1.7.2 Multiple Graph Selection

Since the NPU uses a static computation graph architecture, the computation process must be defined in its entirety beforehand. Taking the Attention mechanism as an example, each computation must be padded to the maximum length (e.g., 1024), which leads to significant redundant computation and increased time overhead. To address this, a compromise can be adopted: pre-preparing multiple Qnn graphs of different shapes (i.e., graphs for different input lengths) and dynamically selecting the graph that best matches the current input length (minimizing computational waste) during execution. This achieves a balance between performance and flexibility.

1.8 Qnn Model Runtime for LLM

The implementation of the Qnn Model Runtime is relatively straightforward, primarily involving the execution of graphs via the Qnn Graph Execute API. However, one point to note and improve is that switching from the prefill stage to the decode stage requires reordering all KV Caches, which involves significant data copying.

Reordering is necessary because the small chunks computed for Key and Value are appended to the end of the Past Key and Past Value. However, the Key in the prefill stage occupies chunk_size space, while the Key in the decode stage only requires chunk_size = 1. Since the Key format is [B, H, D, S], the Past Key needed for the decode stage cannot be accessed directly via pointer offsets, necessitating reordering during the transition from prefill to decode.

1.9 Performance

We conducted experiments on an SM8650 device.

Table 3: LLM Performance on SM8650, w4a16kv8. A: prefill 256 tokens, decode 128 tokens; B: prefill 512 tokens, decode 256 tokens; C: prefill 640 tokens, decode 128 tokens.

Model Prefill (token/s) Decode (token/s)
Qwen2.5 3B, A 512.75 18.29
Qwen2.5 3B, B 545.62 18.11
Qwen3 1.7B, C 758.28 25.84

1.10 Summary

The MLLM Qnn AOT framework successfully addresses the key challenges of efficiently deploying LLMs on Qualcomm NPUs through a unified intermediate representation, explicit QDQ modeling, and an innovative dual-graph execution strategy, achieving high-performance on-device inference. This workflow integrates quantization, compilation, and runtime, providing a convenient deployment path for developers. Future work will focus on supporting more advanced quantization methods (e.g., SpinQuant) and further optimizing multi-graph selection strategies.

1.11 Appendix

1.11.1 Quantize/Dequantize, Cast, and Convert

Qualcomm provides a Convert Op to assist in converting between quantization types, eliminating the need for a combination of Quantize and Dequantize. The Cast Op does not consider whether a tensor has a quantization type; it is a mandatory type conversion.

1.11.2 Zero Point Should Be Negative in QNN

Although many examples in Qualcomm’s documentation show positive offsets, Qualcomm’s quantization method should actually follow the formula below:

1.11.3 HMX Enable on Short Depth Convolution

Use the QNN_HTP_GRAPH_CONFIG_OPTION_SHORT_DEPTH_CONV_ON_HMX_OFF flag with caution.

p_custom_config = (QnnHtpGraph_CustomConfig_t*)malloc(sizeof(QnnHtpGraph_CustomConfig_t));
p_custom_config->option = QNN_HTP_GRAPH_CONFIG_OPTION_SHORT_DEPTH_CONV_ON_HMX_OFF;
p_custom_config->shortDepthConvOnHmxOff = true;
htp_graph_configs.push_back(static_cast<QnnGraph_CustomConfig_t>(p_custom_config));

Setting the Optimization Level to 3 may provide some performance gains:

p_custom_config = (QnnHtpGraph_CustomConfig_t*)malloc(sizeof(QnnHtpGraph_CustomConfig_t));
p_custom_config->option = QNN_HTP_GRAPH_CONFIG_OPTION_OPTIMIZATION;
p_custom_config->optimizationOption.type = QNN_HTP_GRAPH_OPTIMIZATION_TYPE_FINALIZE_OPTIMIZATION_FLAG;
p_custom_config->optimizationOption.floatValue = 3;
htp_graph_configs.push_back(static_cast<QnnGraph_CustomConfig_t>(p_custom_config));

1.11.4 LPBQ Packing

The LPBQ quantization method requires users to provide an int8 array to represent int4 weights. Qualcomm then packs this array. However, Qualcomm’s packing algorithm is implemented as follows:

byte = val1 | (val2 << 4)

In MLLM, we perform a bitwise AND with 0x0F on an int8 array to obtain clean high four bits.

Python code:

weight_int4 = quantized_weight.to(torch.int8)
mask = torch.full(
weight_int4.size(), 0x0F, dtype=torch.int8, device=weight_int4.device
)
weight_int4 = torch.bitwise_and(mask, weight_int4)

Instead of the following approach:

byte = (val1 & 0x0F) | ((val2 & 0x0F) << 4)

References

  • Liu, Z., Zhao, C., Fedorov, I., Soran, B., Choudhary, D., Krishnamoorthi, R., Chandra, V., Tian, Y., & Blankevoort, T. (2024). Spinquant: Llm quantization with learned rotations. Arxiv Preprint Arxiv:2405.16406.