1. Algorithm

2. Code

void sgemm_micro_6x16_ac_br_cr(int m, int n, int k, float alpha, const float* A, const float* B,
                               float beta, float* C, int ldc) {
  assert(
      m == 6 && n == 16
      && "sgemm micro kernel expects A: 6xk(col major), B: kx16(row major) and C: 6x16(row major)");
  uint64_t iters = k / 4;
  uint64_t remaining = k % 4;
  uint64_t ldc_ = ldc;
#if defined(__AVX__)
  const float* a_ptr = A;
  const float* b_ptr = B;

  __m256 ymm4, ymm5, ymm6, ymm7, ymm8, ymm9, ymm10, ymm11, ymm12, ymm13, ymm14, ymm15;

  // set all outputs ymm register to zeros.
  ymm4 = _mm256_setzero_ps();
  ymm5 = _mm256_setzero_ps();
  ymm6 = _mm256_setzero_ps();
  ymm7 = _mm256_setzero_ps();
  ymm8 = _mm256_setzero_ps();
  ymm9 = _mm256_setzero_ps();
  ymm10 = _mm256_setzero_ps();
  ymm11 = _mm256_setzero_ps();
  ymm12 = _mm256_setzero_ps();
  ymm13 = _mm256_setzero_ps();
  ymm14 = _mm256_setzero_ps();
  ymm15 = _mm256_setzero_ps();

  // For C: 6 x 16
  //
  // line0(16xfp32): ymm4,   ymm5
  // line1(16xfp32): ymm6,   ymm7
  // line2(16xfp32): ymm8,   ymm9
  // line3(16xfp32): ymm10,  ymm11
  // line4(16xfp32): ymm12,  ymm13
  // line5(16xfp32): ymm14,  ymm15
  for (uint64_t k_index = 0; k_index < iters; ++k_index) {
    __m256 ymm0, ymm1, ymm2, ymm3;

    // performance issues?
    __builtin_prefetch(b_ptr);
    __builtin_prefetch(a_ptr);
    __builtin_prefetch(b_ptr + 64);
    __builtin_prefetch(a_ptr + 24);

    // iteration 0
    // get a line of matrix B -> 1 x 16
    ymm0 = _mm256_load_ps(b_ptr);      // -> 1 x 8
    ymm1 = _mm256_load_ps(b_ptr + 8);  // -> 1 x 8

    ymm2 = _mm256_broadcast_ss(a_ptr);
    ymm3 = _mm256_broadcast_ss(a_ptr + 1);
    ymm4 = _mm256_fmadd_ps(ymm0, ymm2, ymm4);
    ymm5 = _mm256_fmadd_ps(ymm1, ymm2, ymm5);
    ymm6 = _mm256_fmadd_ps(ymm0, ymm3, ymm6);
    ymm7 = _mm256_fmadd_ps(ymm1, ymm3, ymm7);

    ymm2 = _mm256_broadcast_ss(a_ptr + 2);
    ymm3 = _mm256_broadcast_ss(a_ptr + 3);
    ymm8 = _mm256_fmadd_ps(ymm0, ymm2, ymm8);
    ymm9 = _mm256_fmadd_ps(ymm1, ymm2, ymm9);
    ymm10 = _mm256_fmadd_ps(ymm0, ymm3, ymm10);
    ymm11 = _mm256_fmadd_ps(ymm1, ymm3, ymm11);

    ymm2 = _mm256_broadcast_ss(a_ptr + 4);
    ymm3 = _mm256_broadcast_ss(a_ptr + 5);
    ymm12 = _mm256_fmadd_ps(ymm0, ymm2, ymm12);
    ymm13 = _mm256_fmadd_ps(ymm1, ymm2, ymm13);
    ymm14 = _mm256_fmadd_ps(ymm0, ymm3, ymm14);
    ymm15 = _mm256_fmadd_ps(ymm1, ymm3, ymm15);

    // iteration 1
    // get a line of matrix B -> 1 x 16
    ymm0 = _mm256_load_ps(b_ptr + 16);  // -> 1 x 8
    ymm1 = _mm256_load_ps(b_ptr + 24);  // -> 1 x 8

    ymm2 = _mm256_broadcast_ss(a_ptr + 6);
    ymm3 = _mm256_broadcast_ss(a_ptr + 7);
    ymm4 = _mm256_fmadd_ps(ymm0, ymm2, ymm4);
    ymm5 = _mm256_fmadd_ps(ymm1, ymm2, ymm5);
    ymm6 = _mm256_fmadd_ps(ymm0, ymm3, ymm6);
    ymm7 = _mm256_fmadd_ps(ymm1, ymm3, ymm7);

    ymm2 = _mm256_broadcast_ss(a_ptr + 8);
    ymm3 = _mm256_broadcast_ss(a_ptr + 9);
    ymm8 = _mm256_fmadd_ps(ymm0, ymm2, ymm8);
    ymm9 = _mm256_fmadd_ps(ymm1, ymm2, ymm9);
    ymm10 = _mm256_fmadd_ps(ymm0, ymm3, ymm10);
    ymm11 = _mm256_fmadd_ps(ymm1, ymm3, ymm11);

    ymm2 = _mm256_broadcast_ss(a_ptr + 10);
    ymm3 = _mm256_broadcast_ss(a_ptr + 11);
    ymm12 = _mm256_fmadd_ps(ymm0, ymm2, ymm12);
    ymm13 = _mm256_fmadd_ps(ymm1, ymm2, ymm13);
    ymm14 = _mm256_fmadd_ps(ymm0, ymm3, ymm14);
    ymm15 = _mm256_fmadd_ps(ymm1, ymm3, ymm15);

    // iteration 2
    // get a line of matrix B -> 1 x 16
    ymm0 = _mm256_load_ps(b_ptr + 32);  // -> 1 x 8
    ymm1 = _mm256_load_ps(b_ptr + 40);  // -> 1 x 8

    ymm2 = _mm256_broadcast_ss(a_ptr + 12);
    ymm3 = _mm256_broadcast_ss(a_ptr + 13);
    ymm4 = _mm256_fmadd_ps(ymm0, ymm2, ymm4);
    ymm5 = _mm256_fmadd_ps(ymm1, ymm2, ymm5);
    ymm6 = _mm256_fmadd_ps(ymm0, ymm3, ymm6);
    ymm7 = _mm256_fmadd_ps(ymm1, ymm3, ymm7);

    ymm2 = _mm256_broadcast_ss(a_ptr + 14);
    ymm3 = _mm256_broadcast_ss(a_ptr + 15);
    ymm8 = _mm256_fmadd_ps(ymm0, ymm2, ymm8);
    ymm9 = _mm256_fmadd_ps(ymm1, ymm2, ymm9);
    ymm10 = _mm256_fmadd_ps(ymm0, ymm3, ymm10);
    ymm11 = _mm256_fmadd_ps(ymm1, ymm3, ymm11);

    ymm2 = _mm256_broadcast_ss(a_ptr + 16);
    ymm3 = _mm256_broadcast_ss(a_ptr + 17);
    ymm12 = _mm256_fmadd_ps(ymm0, ymm2, ymm12);
    ymm13 = _mm256_fmadd_ps(ymm1, ymm2, ymm13);
    ymm14 = _mm256_fmadd_ps(ymm0, ymm3, ymm14);
    ymm15 = _mm256_fmadd_ps(ymm1, ymm3, ymm15);

    // iteration 3
    // get a line of matrix B -> 1 x 16
    ymm0 = _mm256_load_ps(b_ptr + 48);  // -> 1 x 8
    ymm1 = _mm256_load_ps(b_ptr + 56);  // -> 1 x 8

    ymm2 = _mm256_broadcast_ss(a_ptr + 18);
    ymm3 = _mm256_broadcast_ss(a_ptr + 19);
    ymm4 = _mm256_fmadd_ps(ymm0, ymm2, ymm4);
    ymm5 = _mm256_fmadd_ps(ymm1, ymm2, ymm5);
    ymm6 = _mm256_fmadd_ps(ymm0, ymm3, ymm6);
    ymm7 = _mm256_fmadd_ps(ymm1, ymm3, ymm7);

    ymm2 = _mm256_broadcast_ss(a_ptr + 20);
    ymm3 = _mm256_broadcast_ss(a_ptr + 21);
    ymm8 = _mm256_fmadd_ps(ymm0, ymm2, ymm8);
    ymm9 = _mm256_fmadd_ps(ymm1, ymm2, ymm9);
    ymm10 = _mm256_fmadd_ps(ymm0, ymm3, ymm10);
    ymm11 = _mm256_fmadd_ps(ymm1, ymm3, ymm11);

    ymm2 = _mm256_broadcast_ss(a_ptr + 22);
    ymm3 = _mm256_broadcast_ss(a_ptr + 23);
    ymm12 = _mm256_fmadd_ps(ymm0, ymm2, ymm12);
    ymm13 = _mm256_fmadd_ps(ymm1, ymm2, ymm13);
    ymm14 = _mm256_fmadd_ps(ymm0, ymm3, ymm14);
    ymm15 = _mm256_fmadd_ps(ymm1, ymm3, ymm15);

    a_ptr += 24;
    b_ptr += 64;
  }

  for (uint64_t k_index = 0; k < remaining; ++k_index) {
    __m256 ymm0, ymm1, ymm2, ymm3;

    // get a line of matrix B -> 1 x 16
    ymm0 = _mm256_load_ps(b_ptr);      // -> 1 x 8
    ymm1 = _mm256_load_ps(b_ptr + 8);  // -> 1 x 8

    ymm2 = _mm256_broadcast_ss(a_ptr);
    ymm3 = _mm256_broadcast_ss(a_ptr + 1);
    ymm4 = _mm256_fmadd_ps(ymm0, ymm2, ymm4);
    ymm5 = _mm256_fmadd_ps(ymm1, ymm2, ymm5);
    ymm6 = _mm256_fmadd_ps(ymm0, ymm3, ymm6);
    ymm7 = _mm256_fmadd_ps(ymm1, ymm3, ymm7);

    ymm2 = _mm256_broadcast_ss(a_ptr + 2);
    ymm3 = _mm256_broadcast_ss(a_ptr + 3);
    ymm8 = _mm256_fmadd_ps(ymm0, ymm2, ymm8);
    ymm9 = _mm256_fmadd_ps(ymm1, ymm2, ymm9);
    ymm10 = _mm256_fmadd_ps(ymm0, ymm3, ymm10);
    ymm11 = _mm256_fmadd_ps(ymm1, ymm3, ymm11);

    ymm2 = _mm256_broadcast_ss(a_ptr + 4);
    ymm3 = _mm256_broadcast_ss(a_ptr + 5);
    ymm12 = _mm256_fmadd_ps(ymm0, ymm2, ymm12);
    ymm13 = _mm256_fmadd_ps(ymm1, ymm2, ymm13);
    ymm14 = _mm256_fmadd_ps(ymm0, ymm3, ymm14);
    ymm15 = _mm256_fmadd_ps(ymm1, ymm3, ymm15);

    a_ptr += 6;
    b_ptr += 16;
  }

  __m256 c0, c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11;

  // For C: 6 x 16
  //
  // line0(16xfp32): c0,   c1
  // line1(16xfp32): c2,   c3
  // line2(16xfp32): c4,   c5
  // line3(16xfp32): c6,   c7
  // line4(16xfp32): c8,   c9
  // line5(16xfp32): c10,  c11
  float* c_ptr0 = C;
  float* c_ptr1 = C + ldc_;
  float* c_ptr2 = c_ptr1 + ldc_;
  float* c_ptr3 = c_ptr2 + ldc_;
  float* c_ptr4 = c_ptr3 + ldc_;
  float* c_ptr5 = c_ptr4 + ldc_;

  c0 = _mm256_load_ps(c_ptr0);
  c1 = _mm256_load_ps(c_ptr0 + 8);
  c2 = _mm256_load_ps(c_ptr1);
  c3 = _mm256_load_ps(c_ptr1 + 8);
  c4 = _mm256_load_ps(c_ptr2);
  c5 = _mm256_load_ps(c_ptr2 + 8);
  c6 = _mm256_load_ps(c_ptr3);
  c7 = _mm256_load_ps(c_ptr3 + 8);
  c8 = _mm256_load_ps(c_ptr4);
  c9 = _mm256_load_ps(c_ptr4 + 8);
  c10 = _mm256_load_ps(c_ptr5);
  c11 = _mm256_load_ps(c_ptr5 + 8);

  c0 = _mm256_add_ps(c0, ymm4);
  c1 = _mm256_add_ps(c1, ymm5);
  c2 = _mm256_add_ps(c2, ymm6);
  c3 = _mm256_add_ps(c3, ymm7);
  c4 = _mm256_add_ps(c4, ymm8);
  c5 = _mm256_add_ps(c5, ymm9);
  c6 = _mm256_add_ps(c6, ymm10);
  c7 = _mm256_add_ps(c7, ymm11);
  c8 = _mm256_add_ps(c8, ymm12);
  c9 = _mm256_add_ps(c9, ymm13);
  c10 = _mm256_add_ps(c10, ymm14);
  c11 = _mm256_add_ps(c11, ymm15);

  _mm256_store_ps(c_ptr0, c0);
  _mm256_store_ps(c_ptr0 + 8, c1);
  _mm256_store_ps(c_ptr1, c2);
  _mm256_store_ps(c_ptr1 + 8, c3);
  _mm256_store_ps(c_ptr2, c4);
  _mm256_store_ps(c_ptr2 + 8, c5);
  _mm256_store_ps(c_ptr3, c6);
  _mm256_store_ps(c_ptr3 + 8, c7);
  _mm256_store_ps(c_ptr4, c8);
  _mm256_store_ps(c_ptr4 + 8, c9);
  _mm256_store_ps(c_ptr5, c10);
  _mm256_store_ps(c_ptr5 + 8, c11);

#endif  //! defined(__AVX__)
}