mirror of
https://github.com/ultravideo/uvg266.git
synced 2024-11-27 19:24:06 +00:00
Unroll 32x32 matrix multiply, use all regs
This commit is contained in:
parent
a58608d0b8
commit
ac4d710e23
|
@ -872,7 +872,7 @@ static void matmul_32x32_a_bt(const __m256i *__restrict a,
|
||||||
// ...
|
// ...
|
||||||
// 17 | 1f
|
// 17 | 1f
|
||||||
__m256i rp1[16];
|
__m256i rp1[16];
|
||||||
for (i = 0; i < 16; i++) {
|
for (i = 0; i < 16; i += 4) {
|
||||||
// Loop in order:
|
// Loop in order:
|
||||||
// 0, 1, 2, 3, 4, 5, 6, 7,
|
// 0, 1, 2, 3, 4, 5, 6, 7,
|
||||||
// 16, 17, 18, 19, 20, 21, 22, 23
|
// 16, 17, 18, 19, 20, 21, 22, 23
|
||||||
|
@ -881,22 +881,64 @@ static void matmul_32x32_a_bt(const __m256i *__restrict a,
|
||||||
|
|
||||||
__m256i bt0_lo = b_t[off + 0];
|
__m256i bt0_lo = b_t[off + 0];
|
||||||
__m256i bt0_hi = b_t[off + 1];
|
__m256i bt0_hi = b_t[off + 1];
|
||||||
|
__m256i bt1_lo = b_t[off + 2];
|
||||||
|
__m256i bt1_hi = b_t[off + 3];
|
||||||
|
__m256i bt2_lo = b_t[off + 4];
|
||||||
|
__m256i bt2_hi = b_t[off + 5];
|
||||||
|
__m256i bt3_lo = b_t[off + 6];
|
||||||
|
__m256i bt3_hi = b_t[off + 7];
|
||||||
|
|
||||||
__m256i bt8_lo = b_t[off + 16];
|
__m256i bt8_lo = b_t[off + 16];
|
||||||
__m256i bt8_hi = b_t[off + 17];
|
__m256i bt8_hi = b_t[off + 17];
|
||||||
|
__m256i bt9_lo = b_t[off + 18];
|
||||||
|
__m256i bt9_hi = b_t[off + 19];
|
||||||
|
__m256i bta_lo = b_t[off + 20];
|
||||||
|
__m256i bta_hi = b_t[off + 21];
|
||||||
|
__m256i btb_lo = b_t[off + 22];
|
||||||
|
__m256i btb_hi = b_t[off + 23];
|
||||||
|
|
||||||
__m256i pr0_lo = _mm256_madd_epi16(a_lo, bt0_lo);
|
__m256i pr0_lo = _mm256_madd_epi16(a_lo, bt0_lo);
|
||||||
__m256i pr0_hi = _mm256_madd_epi16(a_hi, bt0_hi);
|
__m256i pr0_hi = _mm256_madd_epi16(a_hi, bt0_hi);
|
||||||
__m256i pr8_lo = _mm256_madd_epi16(a_lo, bt8_lo);
|
__m256i pr8_lo = _mm256_madd_epi16(a_lo, bt8_lo);
|
||||||
__m256i pr8_hi = _mm256_madd_epi16(a_hi, bt8_hi);
|
__m256i pr8_hi = _mm256_madd_epi16(a_hi, bt8_hi);
|
||||||
|
|
||||||
__m256i sum0 = _mm256_add_epi32 (pr0_lo, pr0_hi);
|
__m256i sum0 = _mm256_add_epi32 (pr0_lo, pr0_hi);
|
||||||
__m256i sum8 = _mm256_add_epi32 (pr8_lo, pr8_hi);
|
__m256i sum8 = _mm256_add_epi32 (pr8_lo, pr8_hi);
|
||||||
|
|
||||||
|
__m256i pr1_lo = _mm256_madd_epi16(a_lo, bt1_lo);
|
||||||
|
__m256i pr1_hi = _mm256_madd_epi16(a_hi, bt1_hi);
|
||||||
|
__m256i pr9_lo = _mm256_madd_epi16(a_lo, bt9_lo);
|
||||||
|
__m256i pr9_hi = _mm256_madd_epi16(a_hi, bt9_hi);
|
||||||
|
__m256i sum1 = _mm256_add_epi32 (pr1_lo, pr1_hi);
|
||||||
|
__m256i sum9 = _mm256_add_epi32 (pr9_lo, pr9_hi);
|
||||||
|
|
||||||
|
__m256i pr2_lo = _mm256_madd_epi16(a_lo, bt2_lo);
|
||||||
|
__m256i pr2_hi = _mm256_madd_epi16(a_hi, bt2_hi);
|
||||||
|
__m256i pra_lo = _mm256_madd_epi16(a_lo, bta_lo);
|
||||||
|
__m256i pra_hi = _mm256_madd_epi16(a_hi, bta_hi);
|
||||||
|
__m256i sum2 = _mm256_add_epi32 (pr2_lo, pr2_hi);
|
||||||
|
__m256i suma = _mm256_add_epi32 (pra_lo, pra_hi);
|
||||||
|
|
||||||
|
__m256i pr3_lo = _mm256_madd_epi16(a_lo, bt3_lo);
|
||||||
|
__m256i pr3_hi = _mm256_madd_epi16(a_hi, bt3_hi);
|
||||||
|
__m256i prb_lo = _mm256_madd_epi16(a_lo, btb_lo);
|
||||||
|
__m256i prb_hi = _mm256_madd_epi16(a_hi, btb_hi);
|
||||||
|
__m256i sum3 = _mm256_add_epi32 (pr3_lo, pr3_hi);
|
||||||
|
__m256i sumb = _mm256_add_epi32 (prb_lo, prb_hi);
|
||||||
|
|
||||||
// Arrange all parts for one number to be inside one lane
|
// Arrange all parts for one number to be inside one lane
|
||||||
__m256i s08_lo = _mm256_permute2x128_si256(sum0, sum8, 0x20);
|
__m256i s08_lo = _mm256_permute2x128_si256(sum0, sum8, 0x20);
|
||||||
__m256i s08_hi = _mm256_permute2x128_si256(sum0, sum8, 0x31);
|
__m256i s08_hi = _mm256_permute2x128_si256(sum0, sum8, 0x31);
|
||||||
|
__m256i s19_lo = _mm256_permute2x128_si256(sum1, sum9, 0x20);
|
||||||
|
__m256i s19_hi = _mm256_permute2x128_si256(sum1, sum9, 0x31);
|
||||||
|
__m256i s2a_lo = _mm256_permute2x128_si256(sum2, suma, 0x20);
|
||||||
|
__m256i s2a_hi = _mm256_permute2x128_si256(sum2, suma, 0x31);
|
||||||
|
__m256i s3b_lo = _mm256_permute2x128_si256(sum3, sumb, 0x20);
|
||||||
|
__m256i s3b_hi = _mm256_permute2x128_si256(sum3, sumb, 0x31);
|
||||||
|
|
||||||
rp1[i] = _mm256_add_epi32(s08_lo, s08_hi);
|
rp1[i + 0] = _mm256_add_epi32(s08_lo, s08_hi);
|
||||||
|
rp1[i + 1] = _mm256_add_epi32(s19_lo, s19_hi);
|
||||||
|
rp1[i + 2] = _mm256_add_epi32(s2a_lo, s2a_hi);
|
||||||
|
rp1[i + 3] = _mm256_add_epi32(s3b_lo, s3b_hi);
|
||||||
}
|
}
|
||||||
|
|
||||||
// 00 00 01 01 | 08 08 09 09
|
// 00 00 01 01 | 08 08 09 09
|
||||||
|
@ -906,22 +948,32 @@ static void matmul_32x32_a_bt(const __m256i *__restrict a,
|
||||||
// 10 10 11 11 | 18 18 19 19
|
// 10 10 11 11 | 18 18 19 19
|
||||||
// ...
|
// ...
|
||||||
// 16 16 17 17 | 1e 1e 1f 1f
|
// 16 16 17 17 | 1e 1e 1f 1f
|
||||||
__m256i rp2[8];
|
__m256i rp2[8] = {
|
||||||
for (i = 0; i < 8; i++) {
|
_mm256_hadd_epi32(rp1[ 0], rp1[ 1]),
|
||||||
uint32_t id = i << 1;
|
_mm256_hadd_epi32(rp1[ 2], rp1[ 3]),
|
||||||
rp2[i] = _mm256_hadd_epi32(rp1[id + 0], rp1[id + 1]);
|
_mm256_hadd_epi32(rp1[ 4], rp1[ 5]),
|
||||||
}
|
_mm256_hadd_epi32(rp1[ 6], rp1[ 7]),
|
||||||
|
_mm256_hadd_epi32(rp1[ 8], rp1[ 9]),
|
||||||
|
_mm256_hadd_epi32(rp1[10], rp1[11]),
|
||||||
|
_mm256_hadd_epi32(rp1[12], rp1[13]),
|
||||||
|
_mm256_hadd_epi32(rp1[14], rp1[15]),
|
||||||
|
};
|
||||||
|
|
||||||
// 00 01 02 03 | 08 09 0a 0b
|
// 00 01 02 03 | 08 09 0a 0b
|
||||||
// 04 05 06 07 | 0c 0d 0e 0f
|
// 04 05 06 07 | 0c 0d 0e 0f
|
||||||
// 10 11 12 13 | 18 19 1a 1b
|
// 10 11 12 13 | 18 19 1a 1b
|
||||||
// 14 15 16 17 | 1c 1d 1e 1f
|
// 14 15 16 17 | 1c 1d 1e 1f
|
||||||
__m256i rp3[4];
|
__m256i rp3[4] = {
|
||||||
for (i = 0; i < 4; i++) {
|
_mm256_hadd_epi32(rp2[0], rp2[1]),
|
||||||
uint32_t id = i << 1;
|
_mm256_hadd_epi32(rp2[2], rp2[3]),
|
||||||
__m256i finals = _mm256_hadd_epi32(rp2[id + 0], rp2[id + 1]);
|
_mm256_hadd_epi32(rp2[4], rp2[5]),
|
||||||
rp3[i] = truncate(finals, debias, shift);
|
_mm256_hadd_epi32(rp2[6], rp2[7]),
|
||||||
}
|
};
|
||||||
|
|
||||||
|
rp3[0] = truncate(rp3[0], debias, shift);
|
||||||
|
rp3[1] = truncate(rp3[1], debias, shift);
|
||||||
|
rp3[2] = truncate(rp3[2], debias, shift);
|
||||||
|
rp3[3] = truncate(rp3[3], debias, shift);
|
||||||
|
|
||||||
dst[jd + 0] = _mm256_packs_epi32(rp3[0], rp3[1]);
|
dst[jd + 0] = _mm256_packs_epi32(rp3[0], rp3[1]);
|
||||||
dst[jd + 1] = _mm256_packs_epi32(rp3[2], rp3[3]);
|
dst[jd + 1] = _mm256_packs_epi32(rp3[2], rp3[3]);
|
||||||
|
|
Loading…
Reference in a new issue