Unroll 32x32 matrix multiply, use all regs

This commit is contained in:
Pauli Oikkonen 2019-10-15 16:54:01 +03:00
parent a58608d0b8
commit ac4d710e23

View file

@ -872,7 +872,7 @@ static void matmul_32x32_a_bt(const __m256i *__restrict a,
// ... // ...
// 17 | 1f // 17 | 1f
__m256i rp1[16]; __m256i rp1[16];
for (i = 0; i < 16; i++) { for (i = 0; i < 16; i += 4) {
// Loop in order: // Loop in order:
// 0, 1, 2, 3, 4, 5, 6, 7, // 0, 1, 2, 3, 4, 5, 6, 7,
// 16, 17, 18, 19, 20, 21, 22, 23 // 16, 17, 18, 19, 20, 21, 22, 23
@ -881,22 +881,64 @@ static void matmul_32x32_a_bt(const __m256i *__restrict a,
__m256i bt0_lo = b_t[off + 0]; __m256i bt0_lo = b_t[off + 0];
__m256i bt0_hi = b_t[off + 1]; __m256i bt0_hi = b_t[off + 1];
__m256i bt1_lo = b_t[off + 2];
__m256i bt1_hi = b_t[off + 3];
__m256i bt2_lo = b_t[off + 4];
__m256i bt2_hi = b_t[off + 5];
__m256i bt3_lo = b_t[off + 6];
__m256i bt3_hi = b_t[off + 7];
__m256i bt8_lo = b_t[off + 16]; __m256i bt8_lo = b_t[off + 16];
__m256i bt8_hi = b_t[off + 17]; __m256i bt8_hi = b_t[off + 17];
__m256i bt9_lo = b_t[off + 18];
__m256i bt9_hi = b_t[off + 19];
__m256i bta_lo = b_t[off + 20];
__m256i bta_hi = b_t[off + 21];
__m256i btb_lo = b_t[off + 22];
__m256i btb_hi = b_t[off + 23];
__m256i pr0_lo = _mm256_madd_epi16(a_lo, bt0_lo); __m256i pr0_lo = _mm256_madd_epi16(a_lo, bt0_lo);
__m256i pr0_hi = _mm256_madd_epi16(a_hi, bt0_hi); __m256i pr0_hi = _mm256_madd_epi16(a_hi, bt0_hi);
__m256i pr8_lo = _mm256_madd_epi16(a_lo, bt8_lo); __m256i pr8_lo = _mm256_madd_epi16(a_lo, bt8_lo);
__m256i pr8_hi = _mm256_madd_epi16(a_hi, bt8_hi); __m256i pr8_hi = _mm256_madd_epi16(a_hi, bt8_hi);
__m256i sum0 = _mm256_add_epi32 (pr0_lo, pr0_hi); __m256i sum0 = _mm256_add_epi32 (pr0_lo, pr0_hi);
__m256i sum8 = _mm256_add_epi32 (pr8_lo, pr8_hi); __m256i sum8 = _mm256_add_epi32 (pr8_lo, pr8_hi);
__m256i pr1_lo = _mm256_madd_epi16(a_lo, bt1_lo);
__m256i pr1_hi = _mm256_madd_epi16(a_hi, bt1_hi);
__m256i pr9_lo = _mm256_madd_epi16(a_lo, bt9_lo);
__m256i pr9_hi = _mm256_madd_epi16(a_hi, bt9_hi);
__m256i sum1 = _mm256_add_epi32 (pr1_lo, pr1_hi);
__m256i sum9 = _mm256_add_epi32 (pr9_lo, pr9_hi);
__m256i pr2_lo = _mm256_madd_epi16(a_lo, bt2_lo);
__m256i pr2_hi = _mm256_madd_epi16(a_hi, bt2_hi);
__m256i pra_lo = _mm256_madd_epi16(a_lo, bta_lo);
__m256i pra_hi = _mm256_madd_epi16(a_hi, bta_hi);
__m256i sum2 = _mm256_add_epi32 (pr2_lo, pr2_hi);
__m256i suma = _mm256_add_epi32 (pra_lo, pra_hi);
__m256i pr3_lo = _mm256_madd_epi16(a_lo, bt3_lo);
__m256i pr3_hi = _mm256_madd_epi16(a_hi, bt3_hi);
__m256i prb_lo = _mm256_madd_epi16(a_lo, btb_lo);
__m256i prb_hi = _mm256_madd_epi16(a_hi, btb_hi);
__m256i sum3 = _mm256_add_epi32 (pr3_lo, pr3_hi);
__m256i sumb = _mm256_add_epi32 (prb_lo, prb_hi);
// Arrange all parts for one number to be inside one lane // Arrange all parts for one number to be inside one lane
__m256i s08_lo = _mm256_permute2x128_si256(sum0, sum8, 0x20); __m256i s08_lo = _mm256_permute2x128_si256(sum0, sum8, 0x20);
__m256i s08_hi = _mm256_permute2x128_si256(sum0, sum8, 0x31); __m256i s08_hi = _mm256_permute2x128_si256(sum0, sum8, 0x31);
__m256i s19_lo = _mm256_permute2x128_si256(sum1, sum9, 0x20);
__m256i s19_hi = _mm256_permute2x128_si256(sum1, sum9, 0x31);
__m256i s2a_lo = _mm256_permute2x128_si256(sum2, suma, 0x20);
__m256i s2a_hi = _mm256_permute2x128_si256(sum2, suma, 0x31);
__m256i s3b_lo = _mm256_permute2x128_si256(sum3, sumb, 0x20);
__m256i s3b_hi = _mm256_permute2x128_si256(sum3, sumb, 0x31);
rp1[i] = _mm256_add_epi32(s08_lo, s08_hi); rp1[i + 0] = _mm256_add_epi32(s08_lo, s08_hi);
rp1[i + 1] = _mm256_add_epi32(s19_lo, s19_hi);
rp1[i + 2] = _mm256_add_epi32(s2a_lo, s2a_hi);
rp1[i + 3] = _mm256_add_epi32(s3b_lo, s3b_hi);
} }
// 00 00 01 01 | 08 08 09 09 // 00 00 01 01 | 08 08 09 09
@ -906,22 +948,32 @@ static void matmul_32x32_a_bt(const __m256i *__restrict a,
// 10 10 11 11 | 18 18 19 19 // 10 10 11 11 | 18 18 19 19
// ... // ...
// 16 16 17 17 | 1e 1e 1f 1f // 16 16 17 17 | 1e 1e 1f 1f
__m256i rp2[8]; __m256i rp2[8] = {
for (i = 0; i < 8; i++) { _mm256_hadd_epi32(rp1[ 0], rp1[ 1]),
uint32_t id = i << 1; _mm256_hadd_epi32(rp1[ 2], rp1[ 3]),
rp2[i] = _mm256_hadd_epi32(rp1[id + 0], rp1[id + 1]); _mm256_hadd_epi32(rp1[ 4], rp1[ 5]),
} _mm256_hadd_epi32(rp1[ 6], rp1[ 7]),
_mm256_hadd_epi32(rp1[ 8], rp1[ 9]),
_mm256_hadd_epi32(rp1[10], rp1[11]),
_mm256_hadd_epi32(rp1[12], rp1[13]),
_mm256_hadd_epi32(rp1[14], rp1[15]),
};
// 00 01 02 03 | 08 09 0a 0b // 00 01 02 03 | 08 09 0a 0b
// 04 05 06 07 | 0c 0d 0e 0f // 04 05 06 07 | 0c 0d 0e 0f
// 10 11 12 13 | 18 19 1a 1b // 10 11 12 13 | 18 19 1a 1b
// 14 15 16 17 | 1c 1d 1e 1f // 14 15 16 17 | 1c 1d 1e 1f
__m256i rp3[4]; __m256i rp3[4] = {
for (i = 0; i < 4; i++) { _mm256_hadd_epi32(rp2[0], rp2[1]),
uint32_t id = i << 1; _mm256_hadd_epi32(rp2[2], rp2[3]),
__m256i finals = _mm256_hadd_epi32(rp2[id + 0], rp2[id + 1]); _mm256_hadd_epi32(rp2[4], rp2[5]),
rp3[i] = truncate(finals, debias, shift); _mm256_hadd_epi32(rp2[6], rp2[7]),
} };
rp3[0] = truncate(rp3[0], debias, shift);
rp3[1] = truncate(rp3[1], debias, shift);
rp3[2] = truncate(rp3[2], debias, shift);
rp3[3] = truncate(rp3[3], debias, shift);
dst[jd + 0] = _mm256_packs_epi32(rp3[0], rp3[1]); dst[jd + 0] = _mm256_packs_epi32(rp3[0], rp3[1]);
dst[jd + 1] = _mm256_packs_epi32(rp3[2], rp3[3]); dst[jd + 1] = _mm256_packs_epi32(rp3[2], rp3[3]);