void sgemm_micro_6x16_ac_br_cr(int m, int n, int k, float alpha, const float* A, const float* B,
float beta, float* C, int ldc) {
assert(
m == 6 && n == 16
&& "sgemm micro kernel expects A: 6xk(col major), B: kx16(row major) and C: 6x16(row major)");
uint64_t iters = k / 4;
uint64_t remaining = k % 4;
uint64_t ldc_ = ldc;
#if defined(__AVX__)
const float* a_ptr = A;
const float* b_ptr = B;
__m256 ymm4, ymm5, ymm6, ymm7, ymm8, ymm9, ymm10, ymm11, ymm12, ymm13, ymm14, ymm15;
// set all outputs ymm register to zeros.
ymm4 = _mm256_setzero_ps();
ymm5 = _mm256_setzero_ps();
ymm6 = _mm256_setzero_ps();
ymm7 = _mm256_setzero_ps();
ymm8 = _mm256_setzero_ps();
ymm9 = _mm256_setzero_ps();
ymm10 = _mm256_setzero_ps();
ymm11 = _mm256_setzero_ps();
ymm12 = _mm256_setzero_ps();
ymm13 = _mm256_setzero_ps();
ymm14 = _mm256_setzero_ps();
ymm15 = _mm256_setzero_ps();
// For C: 6 x 16
//
// line0(16xfp32): ymm4, ymm5
// line1(16xfp32): ymm6, ymm7
// line2(16xfp32): ymm8, ymm9
// line3(16xfp32): ymm10, ymm11
// line4(16xfp32): ymm12, ymm13
// line5(16xfp32): ymm14, ymm15
for (uint64_t k_index = 0; k_index < iters; ++k_index) {
__m256 ymm0, ymm1, ymm2, ymm3;
// performance issues?
__builtin_prefetch(b_ptr);
__builtin_prefetch(a_ptr);
__builtin_prefetch(b_ptr + 64);
__builtin_prefetch(a_ptr + 24);
// iteration 0
// get a line of matrix B -> 1 x 16
ymm0 = _mm256_load_ps(b_ptr); // -> 1 x 8
ymm1 = _mm256_load_ps(b_ptr + 8); // -> 1 x 8
ymm2 = _mm256_broadcast_ss(a_ptr);
ymm3 = _mm256_broadcast_ss(a_ptr + 1);
ymm4 = _mm256_fmadd_ps(ymm0, ymm2, ymm4);
ymm5 = _mm256_fmadd_ps(ymm1, ymm2, ymm5);
ymm6 = _mm256_fmadd_ps(ymm0, ymm3, ymm6);
ymm7 = _mm256_fmadd_ps(ymm1, ymm3, ymm7);
ymm2 = _mm256_broadcast_ss(a_ptr + 2);
ymm3 = _mm256_broadcast_ss(a_ptr + 3);
ymm8 = _mm256_fmadd_ps(ymm0, ymm2, ymm8);
ymm9 = _mm256_fmadd_ps(ymm1, ymm2, ymm9);
ymm10 = _mm256_fmadd_ps(ymm0, ymm3, ymm10);
ymm11 = _mm256_fmadd_ps(ymm1, ymm3, ymm11);
ymm2 = _mm256_broadcast_ss(a_ptr + 4);
ymm3 = _mm256_broadcast_ss(a_ptr + 5);
ymm12 = _mm256_fmadd_ps(ymm0, ymm2, ymm12);
ymm13 = _mm256_fmadd_ps(ymm1, ymm2, ymm13);
ymm14 = _mm256_fmadd_ps(ymm0, ymm3, ymm14);
ymm15 = _mm256_fmadd_ps(ymm1, ymm3, ymm15);
// iteration 1
// get a line of matrix B -> 1 x 16
ymm0 = _mm256_load_ps(b_ptr + 16); // -> 1 x 8
ymm1 = _mm256_load_ps(b_ptr + 24); // -> 1 x 8
ymm2 = _mm256_broadcast_ss(a_ptr + 6);
ymm3 = _mm256_broadcast_ss(a_ptr + 7);
ymm4 = _mm256_fmadd_ps(ymm0, ymm2, ymm4);
ymm5 = _mm256_fmadd_ps(ymm1, ymm2, ymm5);
ymm6 = _mm256_fmadd_ps(ymm0, ymm3, ymm6);
ymm7 = _mm256_fmadd_ps(ymm1, ymm3, ymm7);
ymm2 = _mm256_broadcast_ss(a_ptr + 8);
ymm3 = _mm256_broadcast_ss(a_ptr + 9);
ymm8 = _mm256_fmadd_ps(ymm0, ymm2, ymm8);
ymm9 = _mm256_fmadd_ps(ymm1, ymm2, ymm9);
ymm10 = _mm256_fmadd_ps(ymm0, ymm3, ymm10);
ymm11 = _mm256_fmadd_ps(ymm1, ymm3, ymm11);
ymm2 = _mm256_broadcast_ss(a_ptr + 10);
ymm3 = _mm256_broadcast_ss(a_ptr + 11);
ymm12 = _mm256_fmadd_ps(ymm0, ymm2, ymm12);
ymm13 = _mm256_fmadd_ps(ymm1, ymm2, ymm13);
ymm14 = _mm256_fmadd_ps(ymm0, ymm3, ymm14);
ymm15 = _mm256_fmadd_ps(ymm1, ymm3, ymm15);
// iteration 2
// get a line of matrix B -> 1 x 16
ymm0 = _mm256_load_ps(b_ptr + 32); // -> 1 x 8
ymm1 = _mm256_load_ps(b_ptr + 40); // -> 1 x 8
ymm2 = _mm256_broadcast_ss(a_ptr + 12);
ymm3 = _mm256_broadcast_ss(a_ptr + 13);
ymm4 = _mm256_fmadd_ps(ymm0, ymm2, ymm4);
ymm5 = _mm256_fmadd_ps(ymm1, ymm2, ymm5);
ymm6 = _mm256_fmadd_ps(ymm0, ymm3, ymm6);
ymm7 = _mm256_fmadd_ps(ymm1, ymm3, ymm7);
ymm2 = _mm256_broadcast_ss(a_ptr + 14);
ymm3 = _mm256_broadcast_ss(a_ptr + 15);
ymm8 = _mm256_fmadd_ps(ymm0, ymm2, ymm8);
ymm9 = _mm256_fmadd_ps(ymm1, ymm2, ymm9);
ymm10 = _mm256_fmadd_ps(ymm0, ymm3, ymm10);
ymm11 = _mm256_fmadd_ps(ymm1, ymm3, ymm11);
ymm2 = _mm256_broadcast_ss(a_ptr + 16);
ymm3 = _mm256_broadcast_ss(a_ptr + 17);
ymm12 = _mm256_fmadd_ps(ymm0, ymm2, ymm12);
ymm13 = _mm256_fmadd_ps(ymm1, ymm2, ymm13);
ymm14 = _mm256_fmadd_ps(ymm0, ymm3, ymm14);
ymm15 = _mm256_fmadd_ps(ymm1, ymm3, ymm15);
// iteration 3
// get a line of matrix B -> 1 x 16
ymm0 = _mm256_load_ps(b_ptr + 48); // -> 1 x 8
ymm1 = _mm256_load_ps(b_ptr + 56); // -> 1 x 8
ymm2 = _mm256_broadcast_ss(a_ptr + 18);
ymm3 = _mm256_broadcast_ss(a_ptr + 19);
ymm4 = _mm256_fmadd_ps(ymm0, ymm2, ymm4);
ymm5 = _mm256_fmadd_ps(ymm1, ymm2, ymm5);
ymm6 = _mm256_fmadd_ps(ymm0, ymm3, ymm6);
ymm7 = _mm256_fmadd_ps(ymm1, ymm3, ymm7);
ymm2 = _mm256_broadcast_ss(a_ptr + 20);
ymm3 = _mm256_broadcast_ss(a_ptr + 21);
ymm8 = _mm256_fmadd_ps(ymm0, ymm2, ymm8);
ymm9 = _mm256_fmadd_ps(ymm1, ymm2, ymm9);
ymm10 = _mm256_fmadd_ps(ymm0, ymm3, ymm10);
ymm11 = _mm256_fmadd_ps(ymm1, ymm3, ymm11);
ymm2 = _mm256_broadcast_ss(a_ptr + 22);
ymm3 = _mm256_broadcast_ss(a_ptr + 23);
ymm12 = _mm256_fmadd_ps(ymm0, ymm2, ymm12);
ymm13 = _mm256_fmadd_ps(ymm1, ymm2, ymm13);
ymm14 = _mm256_fmadd_ps(ymm0, ymm3, ymm14);
ymm15 = _mm256_fmadd_ps(ymm1, ymm3, ymm15);
a_ptr += 24;
b_ptr += 64;
}
for (uint64_t k_index = 0; k < remaining; ++k_index) {
__m256 ymm0, ymm1, ymm2, ymm3;
// get a line of matrix B -> 1 x 16
ymm0 = _mm256_load_ps(b_ptr); // -> 1 x 8
ymm1 = _mm256_load_ps(b_ptr + 8); // -> 1 x 8
ymm2 = _mm256_broadcast_ss(a_ptr);
ymm3 = _mm256_broadcast_ss(a_ptr + 1);
ymm4 = _mm256_fmadd_ps(ymm0, ymm2, ymm4);
ymm5 = _mm256_fmadd_ps(ymm1, ymm2, ymm5);
ymm6 = _mm256_fmadd_ps(ymm0, ymm3, ymm6);
ymm7 = _mm256_fmadd_ps(ymm1, ymm3, ymm7);
ymm2 = _mm256_broadcast_ss(a_ptr + 2);
ymm3 = _mm256_broadcast_ss(a_ptr + 3);
ymm8 = _mm256_fmadd_ps(ymm0, ymm2, ymm8);
ymm9 = _mm256_fmadd_ps(ymm1, ymm2, ymm9);
ymm10 = _mm256_fmadd_ps(ymm0, ymm3, ymm10);
ymm11 = _mm256_fmadd_ps(ymm1, ymm3, ymm11);
ymm2 = _mm256_broadcast_ss(a_ptr + 4);
ymm3 = _mm256_broadcast_ss(a_ptr + 5);
ymm12 = _mm256_fmadd_ps(ymm0, ymm2, ymm12);
ymm13 = _mm256_fmadd_ps(ymm1, ymm2, ymm13);
ymm14 = _mm256_fmadd_ps(ymm0, ymm3, ymm14);
ymm15 = _mm256_fmadd_ps(ymm1, ymm3, ymm15);
a_ptr += 6;
b_ptr += 16;
}
__m256 c0, c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11;
// For C: 6 x 16
//
// line0(16xfp32): c0, c1
// line1(16xfp32): c2, c3
// line2(16xfp32): c4, c5
// line3(16xfp32): c6, c7
// line4(16xfp32): c8, c9
// line5(16xfp32): c10, c11
float* c_ptr0 = C;
float* c_ptr1 = C + ldc_;
float* c_ptr2 = c_ptr1 + ldc_;
float* c_ptr3 = c_ptr2 + ldc_;
float* c_ptr4 = c_ptr3 + ldc_;
float* c_ptr5 = c_ptr4 + ldc_;
c0 = _mm256_load_ps(c_ptr0);
c1 = _mm256_load_ps(c_ptr0 + 8);
c2 = _mm256_load_ps(c_ptr1);
c3 = _mm256_load_ps(c_ptr1 + 8);
c4 = _mm256_load_ps(c_ptr2);
c5 = _mm256_load_ps(c_ptr2 + 8);
c6 = _mm256_load_ps(c_ptr3);
c7 = _mm256_load_ps(c_ptr3 + 8);
c8 = _mm256_load_ps(c_ptr4);
c9 = _mm256_load_ps(c_ptr4 + 8);
c10 = _mm256_load_ps(c_ptr5);
c11 = _mm256_load_ps(c_ptr5 + 8);
c0 = _mm256_add_ps(c0, ymm4);
c1 = _mm256_add_ps(c1, ymm5);
c2 = _mm256_add_ps(c2, ymm6);
c3 = _mm256_add_ps(c3, ymm7);
c4 = _mm256_add_ps(c4, ymm8);
c5 = _mm256_add_ps(c5, ymm9);
c6 = _mm256_add_ps(c6, ymm10);
c7 = _mm256_add_ps(c7, ymm11);
c8 = _mm256_add_ps(c8, ymm12);
c9 = _mm256_add_ps(c9, ymm13);
c10 = _mm256_add_ps(c10, ymm14);
c11 = _mm256_add_ps(c11, ymm15);
_mm256_store_ps(c_ptr0, c0);
_mm256_store_ps(c_ptr0 + 8, c1);
_mm256_store_ps(c_ptr1, c2);
_mm256_store_ps(c_ptr1 + 8, c3);
_mm256_store_ps(c_ptr2, c4);
_mm256_store_ps(c_ptr2 + 8, c5);
_mm256_store_ps(c_ptr3, c6);
_mm256_store_ps(c_ptr3 + 8, c7);
_mm256_store_ps(c_ptr4, c8);
_mm256_store_ps(c_ptr4 + 8, c9);
_mm256_store_ps(c_ptr5, c10);
_mm256_store_ps(c_ptr5 + 8, c11);
#endif //! defined(__AVX__)
}