在本文中我以mllm的实现为例。mllm中的大部分混合精度的矩阵乘法是从llama.cpp中更改过来的。我们先来看下Q8_0和Q4_0代表什么。Huggingface的Doc中给出了一张表,大家可以去看一下:GGUF Quantization Type,我在这里也截图给出来
对于量化操作不是很熟悉的读者可以看下我之前的blog: [Fundamental] 模型量化
在mllm中,Q8_0和Q4_0的实现是这样的:
typedef struct {
mllm_fp16_t d; // delta
int8_t qs[QK8_0]; // quants QK8_0 = 32
} block_q8_0;
// QK4_0 = 32
typedef struct {
mllm_fp16_t d; // delta
uint8_t qs[QK4_0 / 2]; // nibbles / quants
} block_q4_0;
而Q4_0x4实际上就是将4个Q4_0打包成一组,这样在GEMM的时候可以利用起指令并行性。
我们首先来看下GEMV问题定义,然后再推广到GEMM上。我们有两个矩阵,分别是A($1 \times nr$), B($nc \times nr$),矩阵乘法后的结果是C$1 \times nc$。一个不是非常恰当的图例如下图所示:
为了更好的理解怎么分块,我们先来看下Q4_0x4的数据排布是怎么样的:Q4_0x4实际上是在$nc$的方向上以4分块,在$nr$的方向上以32分块,最终得到的block形状如下图所示:
我们在$nc$的方向上以4分块,在$nr$的方向上以32分块,将gemv拆解成一个更小的子问题。
如图所示,Tiled A在每一词迭代,是取出x, x+16的位置的数据和Q4_0x4一行中的元素做点乘。在遍历完成后再reduce。Q4_0x4的每一行都是采用了优化的存储方法,这种存储方式在很多的量化算法中都有体现。该方式存储的数据可以很容易的使用左移和Mask等速度更快的操作来获得我们想要的信息。
现在我们来看下汇编,在看汇编中,主要注意vector寄存器的排布,我在下文中给出了排布示意图
"movi v31.16b, #0x4\n" // for sshl. to get high bits.
"movi v30.16b, #0xf0\n" // for mask. to get low bits.
"add %x[b_ptr], %x[b_ptr], #0x8\n" // to qs
"1:" // Column loop
"add x22, %x[a_ptr], #0x2\n" // to qs
"movi v29.16b, #0x0\n" // acc is on register v29(16x8bits). Set to 0.
"mov x21, %x[nb]\n" // move num of blocks to register x21
"2:" // Block loop
"ldr q28, [%x[b_ptr], #0x0]\n" // load 128 bits from b matrix
"ldr q27, [x22, #0x0]\n" // load 128 bits from a matrix
"movi v26.4s, #0x0\n" // acc is on register v26(4x32bits). Set to 0.
"sub x20, x22, #0x2\n" // to get scalar
"ldr q25, [x22, #0x10]\n" // load 128 bits to q25. offsets is 16B
"ldr q24, [%x[b_ptr], #0x10]\n" // load 128 bits to q24. offsets is 16B
"sub x21, x21, #0x1\n" // nb = nb - 1
"add x22, x22, #0x22\n" // a_ptr = aptr + 34B
"ldr q23, [%x[b_ptr], #0x20]\n" // load 128 bits to q23. offset is 32
"ldr q22, [%x[b_ptr], #0x30]\n" // load 128 bits to q22. offset is 48
"ld1r { v21.8h }, [x20]\n" // scalar 4x16bit
"ldr q20, [%x[b_ptr], #-0x8]\n" // scalar 1x16bit
"sshl v16.16b, v28.16b, v31.16b\n" // get high bits in q4_0x4
"and v28.16b, v28.16b, v30.16b\n" // get low bits in q4_0x4
"sshl v19.16b, v24.16b, v31.16b\n" // get high bits in q4_0x4
"and v24.16b, v24.16b, v30.16b\n" // get low bits in q4_0x4
"add %x[b_ptr], %x[b_ptr], #0x48\n" // b_ptr = b_ptr + 72
"sshl v18.16b, v23.16b, v31.16b\n" // get high bits in q4_0x4
"and v23.16b, v23.16b, v30.16b\n" // get low bits in q4_0x4
".inst 0x4f9be21a // sdot v26.4s, v16.16b, v27.4b[0]\n"
"sshl v17.16b, v22.16b, v31.16b\n" // get high bits in q4_0x4
"and v22.16b, v22.16b, v30.16b\n" // get low bits in q4_0x4
"fcvtl v21.4s, v21.4h\n" // cvt 8x16b to 4x32b, scalar of a matrix
"fcvtl v16.4s, v20.4h\n" // cvt 8x16b to 4x32b, scalar of b matrix. reuse v16 register
".inst 0x4f99e39a // sdot v26.4s, v28.16b, v25.4b[0]\n"
"fmul v16.4s, v16.4s, v21.4s\n" // v16 = v16 * v21, scalar a * scalar b
".inst 0x4fbbe27a // sdot v26.4s, v19.16b, v27.4b[1]\n" // v19(8 bits) + v27(32bit, 1B) to v26(32bit)
".inst 0x4fb9e31a // sdot v26.4s, v24.16b, v25.4b[1]\n"
".inst 0x4f9bea5a // sdot v26.4s, v18.16b, v27.4b[2]\n"
".inst 0x4f99eafa // sdot v26.4s, v23.16b, v25.4b[2]\n"
".inst 0x4fbbea3a // sdot v26.4s, v17.16b, v27.4b[3]\n"
".inst 0x4fb9eada // sdot v26.4s, v22.16b, v25.4b[3]\n"
"scvtf v26.4s, v26.4s, #0x4\n" // cvt int to float. the #0x4 is scale factor
"fmla v29.4s, v26.4s, v16.4s\n" // v29 = v26 * v16 + v29
"cbnz x21, 2b\n" // is x21 is not zero, jmp to label 2. num block loop.
"sub %x[nc], %x[nc], #0x4\n" // sub col by 4
"str q29, [%x[res_ptr], #0x0]\n" // store value to res_ptr
"add %x[res_ptr], %x[res_ptr], #0x10\n" // res_ptr move 16B. 4xf32.
"cbnz %x[nc], 1b\n" // if nc is not zero, jump to label 1. num col loop.
有了GEMV,GEMM的实现就相对轻松的多了,GEMM在GEMV的分块基础上,还在$n$维度上进行了大小为4的分块,然后对每个row为4的分块做GEMV,并且通过strip-mining来获得指令并行。