diff --git a/src/strategies/avx2/dct-avx2.c b/src/strategies/avx2/dct-avx2.c index 70530531..86127f57 100644 --- a/src/strategies/avx2/dct-avx2.c +++ b/src/strategies/avx2/dct-avx2.c @@ -861,17 +861,7 @@ static void matmul_32x32_a_bt(const __m256i *__restrict a, __m256i a_lo = a[jd + 0]; __m256i a_hi = a[jd + 1]; - // Each lane here is the 4 i32's that add up to one number in the result - // vector: - // 00 | 08 - // 01 | 09 - // ... - // 07 | 0f - // 10 | 18 - // 11 | 19 - // ... - // 17 | 1f - __m256i rp1[16]; + __m256i res_32[4]; for (i = 0; i < 16; i += 4) { // Loop in order: // 0, 1, 2, 3, 4, 5, 6, 7, @@ -935,48 +925,46 @@ static void matmul_32x32_a_bt(const __m256i *__restrict a, __m256i s3b_lo = _mm256_permute2x128_si256(sum3, sumb, 0x20); __m256i s3b_hi = _mm256_permute2x128_si256(sum3, sumb, 0x31); - rp1[i + 0] = _mm256_add_epi32(s08_lo, s08_hi); - rp1[i + 1] = _mm256_add_epi32(s19_lo, s19_hi); - rp1[i + 2] = _mm256_add_epi32(s2a_lo, s2a_hi); - rp1[i + 3] = _mm256_add_epi32(s3b_lo, s3b_hi); + // Each lane here is the 4 i32's that add up to one number in the result + // vector: + // 00 | 08 + // 01 | 09 + // ... + // 07 | 0f + // 10 | 18 + // 11 | 19 + // ... + // 17 | 1f + __m256i rp1_0 = _mm256_add_epi32(s08_lo, s08_hi); + __m256i rp1_1 = _mm256_add_epi32(s19_lo, s19_hi); + __m256i rp1_2 = _mm256_add_epi32(s2a_lo, s2a_hi); + __m256i rp1_3 = _mm256_add_epi32(s3b_lo, s3b_hi); + + // rp2_0 and rp2_1 values on different iterations + // 00 00 01 01 | 08 08 09 09 + // 02 02 03 03 | 0a 0a 0b 0b + // + // 04 04 05 05 | 0c 0c 0d 0d + // 06 06 07 07 | 0e 0e 0f 0f + // + // 10 10 11 11 | 18 18 19 19 + // 12 12 13 13 | 1a 1a 1b 1b + // + // 16 16 17 17 | 1e 1e 1f 1f + // 16 16 17 17 | 1e 1e 1f 1f + __m256i rp2_0 = _mm256_hadd_epi32(rp1_0, rp1_1); + __m256i rp2_1 = _mm256_hadd_epi32(rp1_2, rp1_3); + + // rp3 values on different iters + // 00 01 02 03 | 08 09 0a 0b + // 04 05 06 07 | 0c 0d 0e 0f + // 10 11 12 13 | 18 19 1a 1b + // 14 15 16 17 | 1c 1d 1e 1f + __m256i rp3 = _mm256_hadd_epi32(rp2_0, rp2_1); + res_32[i >> 2] = truncate(rp3, debias, shift); } - - // 00 00 01 01 | 08 08 09 09 - // 02 02 03 03 | 0a 0a 0b 0b - // 04 04 05 05 | 0c 0c 0d 0d - // 06 06 07 07 | 0e 0e 0f 0f - // 10 10 11 11 | 18 18 19 19 - // ... - // 16 16 17 17 | 1e 1e 1f 1f - __m256i rp2[8] = { - _mm256_hadd_epi32(rp1[ 0], rp1[ 1]), - _mm256_hadd_epi32(rp1[ 2], rp1[ 3]), - _mm256_hadd_epi32(rp1[ 4], rp1[ 5]), - _mm256_hadd_epi32(rp1[ 6], rp1[ 7]), - _mm256_hadd_epi32(rp1[ 8], rp1[ 9]), - _mm256_hadd_epi32(rp1[10], rp1[11]), - _mm256_hadd_epi32(rp1[12], rp1[13]), - _mm256_hadd_epi32(rp1[14], rp1[15]), - }; - - // 00 01 02 03 | 08 09 0a 0b - // 04 05 06 07 | 0c 0d 0e 0f - // 10 11 12 13 | 18 19 1a 1b - // 14 15 16 17 | 1c 1d 1e 1f - __m256i rp3[4] = { - _mm256_hadd_epi32(rp2[0], rp2[1]), - _mm256_hadd_epi32(rp2[2], rp2[3]), - _mm256_hadd_epi32(rp2[4], rp2[5]), - _mm256_hadd_epi32(rp2[6], rp2[7]), - }; - - rp3[0] = truncate(rp3[0], debias, shift); - rp3[1] = truncate(rp3[1], debias, shift); - rp3[2] = truncate(rp3[2], debias, shift); - rp3[3] = truncate(rp3[3], debias, shift); - - dst[jd + 0] = _mm256_packs_epi32(rp3[0], rp3[1]); - dst[jd + 1] = _mm256_packs_epi32(rp3[2], rp3[3]); + dst[jd + 0] = _mm256_packs_epi32(res_32[0], res_32[1]); + dst[jd + 1] = _mm256_packs_epi32(res_32[2], res_32[3]); } }