diff --git a/src/strategies/avx2/dct-avx2.c b/src/strategies/avx2/dct-avx2.c index 905232af..70530531 100644 --- a/src/strategies/avx2/dct-avx2.c +++ b/src/strategies/avx2/dct-avx2.c @@ -872,7 +872,7 @@ static void matmul_32x32_a_bt(const __m256i *__restrict a, // ... // 17 | 1f __m256i rp1[16]; - for (i = 0; i < 16; i++) { + for (i = 0; i < 16; i += 4) { // Loop in order: // 0, 1, 2, 3, 4, 5, 6, 7, // 16, 17, 18, 19, 20, 21, 22, 23 @@ -881,22 +881,64 @@ static void matmul_32x32_a_bt(const __m256i *__restrict a, __m256i bt0_lo = b_t[off + 0]; __m256i bt0_hi = b_t[off + 1]; + __m256i bt1_lo = b_t[off + 2]; + __m256i bt1_hi = b_t[off + 3]; + __m256i bt2_lo = b_t[off + 4]; + __m256i bt2_hi = b_t[off + 5]; + __m256i bt3_lo = b_t[off + 6]; + __m256i bt3_hi = b_t[off + 7]; + __m256i bt8_lo = b_t[off + 16]; __m256i bt8_hi = b_t[off + 17]; + __m256i bt9_lo = b_t[off + 18]; + __m256i bt9_hi = b_t[off + 19]; + __m256i bta_lo = b_t[off + 20]; + __m256i bta_hi = b_t[off + 21]; + __m256i btb_lo = b_t[off + 22]; + __m256i btb_hi = b_t[off + 23]; __m256i pr0_lo = _mm256_madd_epi16(a_lo, bt0_lo); __m256i pr0_hi = _mm256_madd_epi16(a_hi, bt0_hi); __m256i pr8_lo = _mm256_madd_epi16(a_lo, bt8_lo); __m256i pr8_hi = _mm256_madd_epi16(a_hi, bt8_hi); - __m256i sum0 = _mm256_add_epi32 (pr0_lo, pr0_hi); __m256i sum8 = _mm256_add_epi32 (pr8_lo, pr8_hi); + __m256i pr1_lo = _mm256_madd_epi16(a_lo, bt1_lo); + __m256i pr1_hi = _mm256_madd_epi16(a_hi, bt1_hi); + __m256i pr9_lo = _mm256_madd_epi16(a_lo, bt9_lo); + __m256i pr9_hi = _mm256_madd_epi16(a_hi, bt9_hi); + __m256i sum1 = _mm256_add_epi32 (pr1_lo, pr1_hi); + __m256i sum9 = _mm256_add_epi32 (pr9_lo, pr9_hi); + + __m256i pr2_lo = _mm256_madd_epi16(a_lo, bt2_lo); + __m256i pr2_hi = _mm256_madd_epi16(a_hi, bt2_hi); + __m256i pra_lo = _mm256_madd_epi16(a_lo, bta_lo); + __m256i pra_hi = _mm256_madd_epi16(a_hi, bta_hi); + __m256i sum2 = _mm256_add_epi32 (pr2_lo, pr2_hi); + __m256i suma = _mm256_add_epi32 (pra_lo, pra_hi); + + __m256i pr3_lo = _mm256_madd_epi16(a_lo, bt3_lo); + __m256i pr3_hi = _mm256_madd_epi16(a_hi, bt3_hi); + __m256i prb_lo = _mm256_madd_epi16(a_lo, btb_lo); + __m256i prb_hi = _mm256_madd_epi16(a_hi, btb_hi); + __m256i sum3 = _mm256_add_epi32 (pr3_lo, pr3_hi); + __m256i sumb = _mm256_add_epi32 (prb_lo, prb_hi); + // Arrange all parts for one number to be inside one lane __m256i s08_lo = _mm256_permute2x128_si256(sum0, sum8, 0x20); __m256i s08_hi = _mm256_permute2x128_si256(sum0, sum8, 0x31); + __m256i s19_lo = _mm256_permute2x128_si256(sum1, sum9, 0x20); + __m256i s19_hi = _mm256_permute2x128_si256(sum1, sum9, 0x31); + __m256i s2a_lo = _mm256_permute2x128_si256(sum2, suma, 0x20); + __m256i s2a_hi = _mm256_permute2x128_si256(sum2, suma, 0x31); + __m256i s3b_lo = _mm256_permute2x128_si256(sum3, sumb, 0x20); + __m256i s3b_hi = _mm256_permute2x128_si256(sum3, sumb, 0x31); - rp1[i] = _mm256_add_epi32(s08_lo, s08_hi); + rp1[i + 0] = _mm256_add_epi32(s08_lo, s08_hi); + rp1[i + 1] = _mm256_add_epi32(s19_lo, s19_hi); + rp1[i + 2] = _mm256_add_epi32(s2a_lo, s2a_hi); + rp1[i + 3] = _mm256_add_epi32(s3b_lo, s3b_hi); } // 00 00 01 01 | 08 08 09 09 @@ -906,22 +948,32 @@ static void matmul_32x32_a_bt(const __m256i *__restrict a, // 10 10 11 11 | 18 18 19 19 // ... // 16 16 17 17 | 1e 1e 1f 1f - __m256i rp2[8]; - for (i = 0; i < 8; i++) { - uint32_t id = i << 1; - rp2[i] = _mm256_hadd_epi32(rp1[id + 0], rp1[id + 1]); - } + __m256i rp2[8] = { + _mm256_hadd_epi32(rp1[ 0], rp1[ 1]), + _mm256_hadd_epi32(rp1[ 2], rp1[ 3]), + _mm256_hadd_epi32(rp1[ 4], rp1[ 5]), + _mm256_hadd_epi32(rp1[ 6], rp1[ 7]), + _mm256_hadd_epi32(rp1[ 8], rp1[ 9]), + _mm256_hadd_epi32(rp1[10], rp1[11]), + _mm256_hadd_epi32(rp1[12], rp1[13]), + _mm256_hadd_epi32(rp1[14], rp1[15]), + }; // 00 01 02 03 | 08 09 0a 0b // 04 05 06 07 | 0c 0d 0e 0f // 10 11 12 13 | 18 19 1a 1b // 14 15 16 17 | 1c 1d 1e 1f - __m256i rp3[4]; - for (i = 0; i < 4; i++) { - uint32_t id = i << 1; - __m256i finals = _mm256_hadd_epi32(rp2[id + 0], rp2[id + 1]); - rp3[i] = truncate(finals, debias, shift); - } + __m256i rp3[4] = { + _mm256_hadd_epi32(rp2[0], rp2[1]), + _mm256_hadd_epi32(rp2[2], rp2[3]), + _mm256_hadd_epi32(rp2[4], rp2[5]), + _mm256_hadd_epi32(rp2[6], rp2[7]), + }; + + rp3[0] = truncate(rp3[0], debias, shift); + rp3[1] = truncate(rp3[1], debias, shift); + rp3[2] = truncate(rp3[2], debias, shift); + rp3[3] = truncate(rp3[3], debias, shift); dst[jd + 0] = _mm256_packs_epi32(rp3[0], rp3[1]); dst[jd + 1] = _mm256_packs_epi32(rp3[2], rp3[3]);