From 043f53539f7cc20439f1e9bdd82a1afee355986b Mon Sep 17 00:00:00 2001 From: Pauli Oikkonen Date: Thu, 10 Oct 2019 18:32:39 +0300 Subject: [PATCH] Implement a streamlined matrix-multiply 32x32 DCT --- src/strategies/avx2/dct-avx2.c | 474 +++++++++++---------------------- 1 file changed, 157 insertions(+), 317 deletions(-) diff --git a/src/strategies/avx2/dct-avx2.c b/src/strategies/avx2/dct-avx2.c index bb02ec99..4693371c 100644 --- a/src/strategies/avx2/dct-avx2.c +++ b/src/strategies/avx2/dct-avx2.c @@ -605,14 +605,6 @@ static void transpose_16x16(const int16_t *src, int16_t *dst) transpose_16x16_stride(src, dst, 0, 0); } -static void transpose_32x32(const int16_t *src, int16_t *dst) -{ - transpose_16x16_stride(src + 0, dst + 0, 1, 1); - transpose_16x16_stride(src + 16, dst + 16 * 32, 1, 1); - transpose_16x16_stride(src + 16 * 32, dst + 16, 1, 1); - transpose_16x16_stride(src + 16 * 33, dst + 16 * 33, 1, 1); -} - static __m256i truncate_inv(__m256i v, int32_t shift) { int32_t add = 1 << (shift - 1); @@ -855,313 +847,160 @@ static void matrix_dct_16x16_avx2(int8_t bitdepth, const int16_t *input, int16_t matmul_16x16_a_bt (dct, tmpres, output, shift_2nd); } -static __m256i get_overflows(const __m256i a, - const __m256i b, - const __m256i res, - const __m256i of_adjust_mask) -{ - const __m256i ones = _mm256_set1_epi16(1); - - __m256i src_signdiff = _mm256_xor_si256 (a, b); - __m256i a_r_signdiff = _mm256_xor_si256 (a, res); - - __m256i of_possible = _mm256_xor_si256 (src_signdiff, of_adjust_mask); - - __m256i overflows = _mm256_and_si256 (of_possible, a_r_signdiff); - overflows = _mm256_srai_epi16 (overflows, 15); - - __m256i of_signs = _mm256_srai_epi16 (a, 15); - of_signs = _mm256_or_si256 (of_signs, ones); - return _mm256_and_si256 (overflows, of_signs); -} - -/* - * You need more than 16 bits to store the result of signed 16b-16b operation, - * this one stores the low 16b in lo and high 16b in hi. The high 16 bits can - * only be -1, 0 or 1. - * - * of_possible_mask is either all zero bits for subtraction, or all ones for - * addition - */ -static void sub_16_16_hilo(const __m256i a, const __m256i b, - __m256i *lo, __m256i *hi) -{ - const __m256i zero = _mm256_setzero_si256(); - - *lo = _mm256_sub_epi16(a, b); - *hi = get_overflows(a, b, *lo, zero); -} - -static void add_16_16_hilo(const __m256i a, const __m256i b, - __m256i *lo, __m256i *hi) -{ - const __m256i ff = _mm256_set1_epi8(-1); - - *lo = _mm256_add_epi16(a, b); - *hi = get_overflows(a, b, *lo, ff); -} - -static INLINE __m256i reverse_16x16b_in_lanes(const __m256i v) -{ - const __m256i lanerev = _mm256_setr_epi16(0x0f0e, 0x0d0c, 0x0b0a, 0x0908, - 0x0706, 0x0504, 0x0302, 0x0100, - 0x0f0e, 0x0d0c, 0x0b0a, 0x0908, - 0x0706, 0x0504, 0x0302, 0x0100); - return _mm256_shuffle_epi8(v, lanerev); -} - -static INLINE __m256i reverse_16x16b(const __m256i v) -{ - __m256i tmp = reverse_16x16b_in_lanes(v); - return _mm256_permute4x64_epi64(tmp, _MM_SHUFFLE(1, 0, 3, 2)); -} - -static INLINE __m256i m256_from_2xm128(const __m128i lo, const __m128i hi) -{ - __m256i result = _mm256_castsi128_si256 (lo); - return _mm256_inserti128_si256(result, hi, 1); -} - -// Get a vector consisting of the divisible-by-8 coeffs in DCT's first two -// columns, in order: -// DC00 DC00 DC00 DC00 DC10 DC10 DC10 DC10 | DC08 DC08 DC08 DC08 DC18 DC18 DC18 DC18 -static __m256i get_dct_db8_vec(const int16_t *dct_t, uint32_t offset) -{ - const __m256i reorder_mask = _mm256_setr_epi32(0x01000100, 0x01000100, 0x05040504, 0x05040504, - 0x03020302, 0x03020302, 0x07060706, 0x07060706); - - uint16_t coeff00 = (uint16_t)dct_t[offset + 0]; - uint16_t coeff08 = (uint16_t)dct_t[offset + 8]; - uint16_t coeff10 = (uint16_t)dct_t[offset + 32]; - uint16_t coeff18 = (uint16_t)dct_t[offset + 40]; - - uint64_t col_db8_packed = (((uint64_t)coeff00) << 0) | - (((uint64_t)coeff08) << 16) | - (((uint64_t)coeff10) << 32) | - (((uint64_t)coeff18) << 48); - - __m256i col_db8 = _mm256_set1_epi64x (col_db8_packed); - col_db8 = _mm256_shuffle_epi8(col_db8, reorder_mask); - __m256i col_db8_shifted = _mm256_slli_epi16 (col_db8, 8); - - return _mm256_blend_epi32 (col_db8, col_db8_shifted, 0xaa); -} - -// Get first 4 DCT coeffs from DCT rows (rowoff + 4) and (rowoff + 12), -// shifting copies of the coeffs 8 bits left to do half of the 65536-factor -// multiplication required for high words of EEO coeffs -static __m256i get_dct_db4_vec(const int16_t *dct, uint32_t rowoff) -{ - ALIGNED(32) uint64_t buf[4]; - - uint64_t r4_0123 = *(uint64_t *)(dct + (rowoff + 0x04) * 32); - uint64_t rc_0123 = *(uint64_t *)(dct + (rowoff + 0x0c) * 32); - - buf[0] = r4_0123; - buf[1] = r4_0123 << 8; - buf[2] = rc_0123; - buf[3] = rc_0123 << 8; - - return _mm256_load_si256((const __m256i *)buf); -} - -// Get first 8 coeffs from DCT rows (rowoff + 2) and (rowoff + 10) and return -// them in a single YMM -static __m256i get_dct_db2_vec(const int16_t *dct, uint32_t rowoff) -{ - __m128i row2 = _mm_load_si128((const __m128i *)(dct + (rowoff + 2) * 32)); - __m128i rowa = _mm_load_si128((const __m128i *)(dct + (rowoff + 10) * 32)); - - return m256_from_2xm128(row2, rowa); -} - -static void partial_butterfly_32_avx2(const int16_t *src, int16_t *dst, int32_t shift) +void matmul_32x32_a_bt_t(const __m256i *a, const __m256i *b_t, __m256i *dst, const uint8_t shift) { const int32_t add = 1 << (shift - 1); const __m256i debias = _mm256_set1_epi32(add); - const int16_t *dct = (const int16_t *)(&kvz_g_dct_32 [0][0]); - const int16_t *dct_t = (const int16_t *)(&kvz_g_dct_32_t[0][0]); + uint32_t i, j; + for (j = 0; j < 32; j++) { + uint32_t jd = j << 1; + __m256i bt_lo = b_t[jd + 0]; + __m256i bt_hi = b_t[jd + 1]; - const __m256i ff = _mm256_set1_epi32 (-1); - const __m256i ones = _mm256_set1_epi16 ( 1); - const __m128i ff_128 = _mm256_castsi256_si128 (ff); - const __m256i lolane_smask = _mm256_inserti128_si256(ones, ff_128, 0); - const __m256i hilane_mask = _mm256_cmpeq_epi16 (ones, lolane_smask); + // Each lane here is the 4 i32's that add up to one number in the result + // vector: + // 00 | 08 + // 01 | 09 + // ... + // 07 | 0f + // 10 | 18 + // 11 | 19 + // ... + // 17 | 1f + __m256i rp1[16]; + for (i = 0; i < 16; i++) { + // Loop in order: + // 0, 1, 2, 3, 4, 5, 6, 7, + // 16, 17, 18, 19, 20, 21, 22, 23 + uint32_t id = (i & 7) | ((i & ~7) << 1); + uint32_t off = id << 1; - const __m256i eee_lohi_shuf = _mm256_setr_epi32(0x01000706, 0x09080f0e, 0x03020504, 0x0b0a0d0c, - 0x01000706, 0x09080f0e, 0x03020504, 0x0b0a0d0c); + __m256i a0_lo = a[off + 0]; + __m256i a0_hi = a[off + 1]; + __m256i a8_lo = a[off + 16]; + __m256i a8_hi = a[off + 17]; - const __m256i eee_eee_sgnmask = _mm256_setr_epi32(0x00010001, 0x00010001, 0x00010001, 0x00010001, - 0xffff0001, 0xffff0001, 0xffff0001, 0xffff0001); + __m256i pr0_lo = _mm256_madd_epi16(bt_lo, a0_lo); + __m256i pr0_hi = _mm256_madd_epi16(bt_hi, a0_hi); + __m256i pr8_lo = _mm256_madd_epi16(bt_lo, a8_lo); + __m256i pr8_hi = _mm256_madd_epi16(bt_hi, a8_hi); - const __m256i dct_c0_db8 = get_dct_db8_vec(dct_t, 0); - const __m256i dct_c1_db8 = get_dct_db8_vec(dct_t, 16); + __m256i sum0 = _mm256_add_epi32 (pr0_lo, pr0_hi); + __m256i sum8 = _mm256_add_epi32 (pr8_lo, pr8_hi); - const __m256i dct_r4c_db4[2] = { - get_dct_db4_vec(dct, 0x00), - get_dct_db4_vec(dct, 0x10), - }; + // Arrange all parts for one number to be inside one lane + __m256i s08_lo = _mm256_permute2x128_si256(sum0, sum8, 0x20); + __m256i s08_hi = _mm256_permute2x128_si256(sum0, sum8, 0x31); - const __m256i dct_r_db2[4] = { - get_dct_db2_vec(dct, 0), - get_dct_db2_vec(dct, 4), - get_dct_db2_vec(dct, 16), - get_dct_db2_vec(dct, 20), - }; - - __m256i res_tp[2 * 32]; - for (uint32_t i = 0; i < 32; i++) { - __m256i lo = _mm256_load_si256((const __m256i *)src + 2 * i + 0); - __m256i hi = _mm256_load_si256((const __m256i *)src + 2 * i + 1); - - __m256i hi_rev = reverse_16x16b(hi); - - __m256i e_lo, e_hi, o_lo, o_hi; - add_16_16_hilo(lo, hi_rev, &e_lo, &e_hi); - sub_16_16_hilo(lo, hi_rev, &o_lo, &o_hi); - - __m256i erev_lo = reverse_16x16b(e_lo); - __m256i erev_hi = reverse_16x16b(e_hi); - - // Hack! Negate low lanes to do subtractions there, but retain non-negated - // low lanes for overflow detection because the original value is - // essentially what was subtracted there by adding its complement - __m256i erev_lo_n = _mm256_sign_epi16(erev_lo, lolane_smask); - __m256i erev_hi_n = _mm256_sign_epi16(erev_hi, lolane_smask); - - // eo0 eo1 eo2 eo3 eo4 eo5 eo6 eo7 | ee7 ee6 ee5 ee4 ee3 ee2 ee1 ee0 - __m256i eo_ee_lo = _mm256_add_epi16 (e_lo, erev_lo_n); - __m256i eo_ee_hi = _mm256_add_epi16 (e_hi, erev_hi_n); - - __m256i eo_ee_hi2 = get_overflows (e_lo, erev_lo, eo_ee_lo, hilane_mask); - eo_ee_hi = _mm256_add_epi16 (eo_ee_hi, eo_ee_hi2); - - __m256i eo_eo_lo = _mm256_permute4x64_epi64(eo_ee_lo, _MM_SHUFFLE(1, 0, 1, 0)); - __m256i eo_eo_hi = _mm256_permute4x64_epi64(eo_ee_hi, _MM_SHUFFLE(1, 0, 1, 0)); - - // Rev: ee7 ee6 ee5 ee4 ee7 ee6 ee5 ee4 | ee3 ee2 ee1 ee0 ee3 ee2 ee1 ee0 - // Fwd: ee0 ee1 ee2 ee3 ee0 ee1 ee2 ee3 | ee4 ee5 ee6 ee7 ee4 ee5 ee6 ee7 - __m256i ee_ee_rev_lo = _mm256_permute4x64_epi64(eo_ee_lo, _MM_SHUFFLE(3, 3, 2, 2)); - __m256i ee_ee_lo = _mm256_permute4x64_epi64(eo_ee_lo, _MM_SHUFFLE(2, 2, 3, 3)); - ee_ee_lo = reverse_16x16b_in_lanes (ee_ee_lo); - - __m256i ee_ee_rev_hi = _mm256_permute4x64_epi64(eo_ee_hi, _MM_SHUFFLE(3, 3, 2, 2)); - __m256i ee_ee_hi = _mm256_permute4x64_epi64(eo_ee_hi, _MM_SHUFFLE(2, 2, 3, 3)); - ee_ee_hi = reverse_16x16b_in_lanes (ee_ee_hi); - - __m256i ee_ee_rev_lo_n = _mm256_sign_epi16(ee_ee_rev_lo, lolane_smask); - __m256i ee_ee_rev_hi_n = _mm256_sign_epi16(ee_ee_rev_hi, lolane_smask); - - // eeo0 eeo1 eeo2 eeo3 eeo0 eeo1 eeo2 eeo3 | eee3 eee2 eee1 eee0 eee3 eee2 eee1 eee0 - __m256i eeo_eee_lo = _mm256_add_epi16 (ee_ee_lo, ee_ee_rev_lo_n); - __m256i eeo_eee_hi = _mm256_add_epi16 (ee_ee_hi, ee_ee_rev_hi_n); - - __m256i eeo_eee_hi2 = get_overflows(ee_ee_lo, ee_ee_rev_lo, eeo_eee_lo, hilane_mask); - eeo_eee_hi = _mm256_add_epi16(eeo_eee_hi, eeo_eee_hi2); - - // Multiply these guys by 256 and also the corresponding DCT coefficients - // by 256, to multiply their product by 65536. Neither these values nor - // the coeffs will exceed 255, so this is overflow safe. - __m256i eeo_eee_hi_shed = _mm256_slli_epi16(eeo_eee_hi, 8); - - // Discard eeo's (low lane), duplicate the high lane, and reorder eee's - // in-lane. Finally invert the eee2/eee3 components in the high lane. - __m256i eee_lohi = _mm256_blend_epi32(eeo_eee_lo, eeo_eee_hi_shed, 0xc0); - eee_lohi = _mm256_permute4x64_epi64(eee_lohi, _MM_SHUFFLE(3, 2, 3, 2)); - - __m256i eee_lh_ordered = _mm256_shuffle_epi8(eee_lohi, eee_lohi_shuf); - __m256i eee_lh_final = _mm256_sign_epi16 (eee_lh_ordered, eee_eee_sgnmask); - - // D00, D08, D10 and D18 in four parts fit for hadding - __m256i d00_d08 = _mm256_madd_epi16(eee_lh_final, dct_c0_db8); - __m256i d10_d18 = _mm256_madd_epi16(eee_lh_final, dct_c1_db8); - - __m256i eeo_lohi = _mm256_unpacklo_epi64 (eeo_eee_lo, eeo_eee_hi_shed); - eeo_lohi = _mm256_permute4x64_epi64(eeo_lohi, _MM_SHUFFLE(1, 0, 1, 0)); - - __m256i d_db4[2] = { - _mm256_madd_epi16(eeo_lohi, dct_r4c_db4[0]), - _mm256_madd_epi16(eeo_lohi, dct_r4c_db4[1]), - }; - - __m256i db2_parts[4]; - for (uint32_t j = 0; j < 4; j++) { - const __m256i dr2a = dct_r_db2[j]; - __m256i lo = _mm256_madd_epi16(dr2a, eo_eo_lo); - __m256i hi = _mm256_madd_epi16(dr2a, eo_eo_hi); - hi = _mm256_slli_epi32(hi, 16); - - db2_parts[j] = _mm256_add_epi32 (lo, hi); + rp1[i] = _mm256_add_epi32(s08_lo, s08_hi); } - __m256i odd_parts[16]; - for (uint32_t j = 0; j < 16; j++) { - __m256i drow_lo = _mm256_load_si256((const __m256i *)(dct + (j * 2 + 1) * 32)); - __m256i odds_lo = _mm256_madd_epi16(o_lo, drow_lo); - __m256i odds_hi = _mm256_madd_epi16(o_hi, drow_lo); - odds_hi = _mm256_slli_epi32(odds_hi, 16); - - odd_parts[j] = _mm256_add_epi32 (odds_lo, odds_hi); + // 00 00 01 01 | 08 08 09 09 + // 02 02 03 03 | 0a 0a 0b 0b + // 04 04 05 05 | 0c 0c 0d 0d + // 06 06 07 07 | 0e 0e 0f 0f + // 10 10 11 11 | 18 18 19 19 + // ... + // 16 16 17 17 | 1e 1e 1f 1f + __m256i rp2[8]; + for (i = 0; i < 8; i++) { + uint32_t id = i << 1; + rp2[i] = _mm256_hadd_epi32(rp1[id + 0], rp1[id + 1]); } - // Rearrange odds so that parts belonging to any single one are all inside - // one lane - combine 01 | 09 ; 03 | 0b ; 05 | 0d ; 07 | 0f - // 11 | 19 ; 13 | 1b ; 15 | 1d ; 17 | 1f - __m256i odd_parts2[8]; - for (uint32_t j = 0; j < 8; j++) { - - // Turn 0, 1, 2, 3, 4, 5, 6, 7 into: - // 0, 1, 2, 3, 8, 9, a, b - uint32_t j_lo = j & 0x03; - uint32_t j_hi = j & 0x04; - uint32_t id_lo = j_lo | (j_hi << 1); - uint32_t id_hi = id_lo | 4; - - __m256i odd_lo = _mm256_permute2x128_si256(odd_parts[id_lo], - odd_parts[id_hi], - 0x20); - - __m256i odd_hi = _mm256_permute2x128_si256(odd_parts[id_lo], - odd_parts[id_hi], - 0x31); - - odd_parts2[j] = _mm256_add_epi32(odd_lo, odd_hi); + // 00 01 02 03 | 08 09 0a 0b + // 04 05 06 07 | 0c 0d 0e 0f + // 10 11 12 13 | 18 19 1a 1b + // 14 15 16 17 | 1c 1d 1e 1f + __m256i rp3[4]; + for (i = 0; i < 4; i++) { + uint32_t id = i << 1; + __m256i finals = _mm256_hadd_epi32(rp2[id + 0], rp2[id + 1]); + rp3[i] = truncate(finals, debias, shift); } - // First stage HADDs... - __m256i d0001_0809 = _mm256_hadd_epi32(d00_d08, odd_parts2[0]); - __m256i d1011_1819 = _mm256_hadd_epi32(d10_d18, odd_parts2[4]); - - __m256i d0405_0c0d = _mm256_hadd_epi32(d_db4[0], odd_parts2[2]); - __m256i d1415_1c1d = _mm256_hadd_epi32(d_db4[1], odd_parts2[6]); - - __m256i d0203_0a0b = _mm256_hadd_epi32(db2_parts[0], odd_parts2[1]); - __m256i d0607_0e0f = _mm256_hadd_epi32(db2_parts[1], odd_parts2[3]); - - __m256i d1213_1a1b = _mm256_hadd_epi32(db2_parts[2], odd_parts2[5]); - __m256i d1617_1e1f = _mm256_hadd_epi32(db2_parts[3], odd_parts2[7]); - - // .. and second stage - __m256i d0123_89ab_lo = _mm256_hadd_epi32(d0001_0809, d0203_0a0b); - __m256i d0123_89ab_hi = _mm256_hadd_epi32(d1011_1819, d1213_1a1b); - - __m256i d4567_cdef_lo = _mm256_hadd_epi32(d0405_0c0d, d0607_0e0f); - __m256i d4567_cdef_hi = _mm256_hadd_epi32(d1415_1c1d, d1617_1e1f); - - d0123_89ab_lo = truncate(d0123_89ab_lo, debias, shift); - d0123_89ab_hi = truncate(d0123_89ab_hi, debias, shift); - - d4567_cdef_lo = truncate(d4567_cdef_lo, debias, shift); - d4567_cdef_hi = truncate(d4567_cdef_hi, debias, shift); - - __m256i final_lo = _mm256_packs_epi32(d0123_89ab_lo, d4567_cdef_lo); - __m256i final_hi = _mm256_packs_epi32(d0123_89ab_hi, d4567_cdef_hi); - - _mm256_store_si256(res_tp + (i * 2) + 0, final_lo); - _mm256_store_si256(res_tp + (i * 2) + 1, final_hi); + dst[jd + 0] = _mm256_packs_epi32(rp3[0], rp3[1]); + dst[jd + 1] = _mm256_packs_epi32(rp3[2], rp3[3]); + } +} + +void matmul_32x32_a_bt(const __m256i *a, const __m256i *b_t, __m256i *dst, const uint8_t shift) +{ + const int32_t add = 1 << (shift - 1); + const __m256i debias = _mm256_set1_epi32(add); + + uint32_t i, j; + for (j = 0; j < 32; j++) { + uint32_t jd = j << 1; + __m256i a_lo = a[jd + 0]; + __m256i a_hi = a[jd + 1]; + + // Each lane here is the 4 i32's that add up to one number in the result + // vector: + // 00 | 08 + // 01 | 09 + // ... + // 07 | 0f + // 10 | 18 + // 11 | 19 + // ... + // 17 | 1f + __m256i rp1[16]; + for (i = 0; i < 16; i++) { + // Loop in order: + // 0, 1, 2, 3, 4, 5, 6, 7, + // 16, 17, 18, 19, 20, 21, 22, 23 + uint32_t id = (i & 7) | ((i & ~7) << 1); + uint32_t off = id << 1; + + __m256i bt0_lo = b_t[off + 0]; + __m256i bt0_hi = b_t[off + 1]; + __m256i bt8_lo = b_t[off + 16]; + __m256i bt8_hi = b_t[off + 17]; + + __m256i pr0_lo = _mm256_madd_epi16(a_lo, bt0_lo); + __m256i pr0_hi = _mm256_madd_epi16(a_hi, bt0_hi); + __m256i pr8_lo = _mm256_madd_epi16(a_lo, bt8_lo); + __m256i pr8_hi = _mm256_madd_epi16(a_hi, bt8_hi); + + __m256i sum0 = _mm256_add_epi32 (pr0_lo, pr0_hi); + __m256i sum8 = _mm256_add_epi32 (pr8_lo, pr8_hi); + + // Arrange all parts for one number to be inside one lane + __m256i s08_lo = _mm256_permute2x128_si256(sum0, sum8, 0x20); + __m256i s08_hi = _mm256_permute2x128_si256(sum0, sum8, 0x31); + + rp1[i] = _mm256_add_epi32(s08_lo, s08_hi); + } + + // 00 00 01 01 | 08 08 09 09 + // 02 02 03 03 | 0a 0a 0b 0b + // 04 04 05 05 | 0c 0c 0d 0d + // 06 06 07 07 | 0e 0e 0f 0f + // 10 10 11 11 | 18 18 19 19 + // ... + // 16 16 17 17 | 1e 1e 1f 1f + __m256i rp2[8]; + for (i = 0; i < 8; i++) { + uint32_t id = i << 1; + rp2[i] = _mm256_hadd_epi32(rp1[id + 0], rp1[id + 1]); + } + + // 00 01 02 03 | 08 09 0a 0b + // 04 05 06 07 | 0c 0d 0e 0f + // 10 11 12 13 | 18 19 1a 1b + // 14 15 16 17 | 1c 1d 1e 1f + __m256i rp3[4]; + for (i = 0; i < 4; i++) { + uint32_t id = i << 1; + __m256i finals = _mm256_hadd_epi32(rp2[id + 0], rp2[id + 1]); + rp3[i] = truncate(finals, debias, shift); + } + + dst[jd + 0] = _mm256_packs_epi32(rp3[0], rp3[1]); + dst[jd + 1] = _mm256_packs_epi32(rp3[2], rp3[3]); } - transpose_32x32((const int16_t *)res_tp, dst); } // 32x32 matrix multiplication with value clipping. @@ -1257,27 +1096,29 @@ static void matrix_dct_32x32_avx2(int8_t bitdepth, const int16_t *input, int16_t { int32_t shift_1st = kvz_g_convert_to_bit[32] + 1 + (bitdepth - 8); int32_t shift_2nd = kvz_g_convert_to_bit[32] + 8; - ALIGNED(64) int16_t tmp[32 * 32]; - partial_butterfly_32_avx2(input, tmp, shift_1st); - partial_butterfly_32_avx2(tmp, output, shift_2nd); + const __m256i *dct = (const __m256i *)(&kvz_g_dct_32[0][0]); + const __m256i *inp = (const __m256i *)input; + __m256i *out = ( __m256i *)output; + + /* + * Multiply input by the tranpose of DCT matrix into tmpres, and DCT matrix + * by tmpres - this is then our output matrix + * + * It's easier to implement an AVX2 matrix multiplication if you can multiply + * the left term with the transpose of the right term. Here things are stored + * row-wise, not column-wise, so we can effectively read DCT_T column-wise + * into YMM registers by reading DCT row-wise. Also because of this, the + * first multiplication is hacked to produce the transpose of the result + * instead, since it will be used in similar fashion as the right operand + * in the second multiplication. + */ + + __m256i tmp[2 * 32]; + matmul_32x32_a_bt_t(inp, dct, tmp, shift_1st); + matmul_32x32_a_bt (dct, tmp, out, shift_2nd); } -// Macro that generates 2D transform functions with clipping values. -// Sets correct shift values and matrices according to transform type and -// block size. Performs matrix multiplication horizontally and vertically. -#define TRANSFORM(type, n) static void matrix_ ## type ## _ ## n ## x ## n ## _avx2(int8_t bitdepth, const int16_t *input, int16_t *output)\ -{\ - int32_t shift_1st = kvz_g_convert_to_bit[n] + 1 + (bitdepth - 8); \ - int32_t shift_2nd = kvz_g_convert_to_bit[n] + 8; \ - ALIGNED(64) int16_t tmp[n * n];\ - const int16_t *tdct = &kvz_g_ ## type ## _ ## n ## _t[0][0];\ - const int16_t *dct = &kvz_g_ ## type ## _ ## n [0][0];\ -\ - mul_clip_matrix_ ## n ## x ## n ## _avx2(input, tdct, tmp, shift_1st);\ - mul_clip_matrix_ ## n ## x ## n ## _avx2(dct, tmp, output, shift_2nd);\ -}\ - // Macro that generates 2D inverse transform functions with clipping values. // Sets correct shift values and matrices according to transform type and // block size. Performs matrix multiplication horizontally and vertically. @@ -1303,7 +1144,6 @@ static void matrix_i ## type ## _## n ## x ## n ## _avx2(int8_t bitdepth, const // ITRANSFORM(dct, 8); // TRANSFORM(dct, 16); // ITRANSFORM(dct, 16); - // TRANSFORM(dct, 32); // Generate all the transform functions