From e382339182de9b40d7fe952ab43d72867359cbbf Mon Sep 17 00:00:00 2001 From: Pauli Oikkonen Date: Thu, 3 Oct 2019 23:00:31 +0300 Subject: [PATCH] Implement fast (butterfly) 32x32 DCT in AVX2 --- src/strategies/avx2/dct-avx2.c | 364 +++++++++++++++++++++++++++++++-- 1 file changed, 350 insertions(+), 14 deletions(-) diff --git a/src/strategies/avx2/dct-avx2.c b/src/strategies/avx2/dct-avx2.c index 497a6b1f..2ef0f63f 100644 --- a/src/strategies/avx2/dct-avx2.c +++ b/src/strategies/avx2/dct-avx2.c @@ -523,7 +523,12 @@ static void matmul_16x16_a_bt(const int16_t *a, const __m256i *b_t, int16_t *out } } -static void transpose_16x16(const int16_t *src, int16_t *dst) +// NOTE: The strides measured by s_stride_log2 and d_stride_log2 are in units +// of 16 coeffs, not 1! +static void transpose_16x16_stride(const int16_t *src, + int16_t *dst, + uint8_t s_stride_log2, + uint8_t d_stride_log2) { __m256i tmp_128[16]; for (uint32_t i = 0; i < 16; i += 8) { @@ -535,14 +540,14 @@ static void transpose_16x16(const int16_t *src, int16_t *dst) __m256i tmp_64[8]; __m256i m[8] = { - _mm256_load_si256((const __m256i *)src + i + 0), - _mm256_load_si256((const __m256i *)src + i + 1), - _mm256_load_si256((const __m256i *)src + i + 2), - _mm256_load_si256((const __m256i *)src + i + 3), - _mm256_load_si256((const __m256i *)src + i + 4), - _mm256_load_si256((const __m256i *)src + i + 5), - _mm256_load_si256((const __m256i *)src + i + 6), - _mm256_load_si256((const __m256i *)src + i + 7), + _mm256_load_si256((const __m256i *)src + ((i + 0) << s_stride_log2)), + _mm256_load_si256((const __m256i *)src + ((i + 1) << s_stride_log2)), + _mm256_load_si256((const __m256i *)src + ((i + 2) << s_stride_log2)), + _mm256_load_si256((const __m256i *)src + ((i + 3) << s_stride_log2)), + _mm256_load_si256((const __m256i *)src + ((i + 4) << s_stride_log2)), + _mm256_load_si256((const __m256i *)src + ((i + 5) << s_stride_log2)), + _mm256_load_si256((const __m256i *)src + ((i + 6) << s_stride_log2)), + _mm256_load_si256((const __m256i *)src + ((i + 7) << s_stride_log2)), }; tmp_32[0] = _mm256_unpacklo_epi16( m[0], m[1]); @@ -579,19 +584,35 @@ static void transpose_16x16(const int16_t *src, int16_t *dst) } for (uint32_t i = 0; i < 8; i++) { - uint32_t loid = i + 0; - uint32_t hiid = i + 8; + uint32_t loid = i + 0; + uint32_t hiid = i + 8; + + uint32_t dst_loid = loid << d_stride_log2; + uint32_t dst_hiid = hiid << d_stride_log2; __m256i lo = tmp_128[loid]; __m256i hi = tmp_128[hiid]; __m256i final_lo = _mm256_permute2x128_si256(lo, hi, 0x20); __m256i final_hi = _mm256_permute2x128_si256(lo, hi, 0x31); - _mm256_store_si256((__m256i *)dst + loid, final_lo); - _mm256_store_si256((__m256i *)dst + hiid, final_hi); + _mm256_store_si256((__m256i *)dst + dst_loid, final_lo); + _mm256_store_si256((__m256i *)dst + dst_hiid, final_hi); } } +static void transpose_16x16(const int16_t *src, int16_t *dst) +{ + transpose_16x16_stride(src, dst, 0, 0); +} + +static void transpose_32x32(const int16_t *src, int16_t *dst) +{ + transpose_16x16_stride(src + 0, dst + 0, 1, 1); + transpose_16x16_stride(src + 16, dst + 16 * 32, 1, 1); + transpose_16x16_stride(src + 16 * 32, dst + 16, 1, 1); + transpose_16x16_stride(src + 16 * 33, dst + 16 * 33, 1, 1); +} + static __m256i truncate_inv(__m256i v, int32_t shift) { int32_t add = 1 << (shift - 1); @@ -834,6 +855,310 @@ static void matrix_dct_16x16_avx2(int8_t bitdepth, const int16_t *input, int16_t matmul_16x16_a_bt (dct, tmpres, output, shift_2nd); } +static __m256i get_overflows(const __m256i a, const __m256i b, const __m256i res, const __m256i of_adjust_mask) +{ + const __m256i ones = _mm256_set1_epi16(1); + + __m256i src_signdiff = _mm256_xor_si256 (a, b); + __m256i a_r_signdiff = _mm256_xor_si256 (a, res); + + __m256i of_possible = _mm256_xor_si256 (src_signdiff, of_adjust_mask); + + __m256i overflows = _mm256_and_si256 (of_possible, a_r_signdiff); + overflows = _mm256_srai_epi16 (overflows, 15); + + __m256i of_signs = _mm256_srai_epi16 (a, 15); + of_signs = _mm256_or_si256 (of_signs, ones); + return _mm256_and_si256 (overflows, of_signs); +} + +/* + * You need more than 16 bits to store the result of signed 16b-16b operation, + * this one stores the low 16b in lo and high 16b in hi. The high 16 bits can + * only be -1, 0 or 1. + * + * of_possible_mask is either all zero bits for subtraction, or all ones for + * addition + */ +static void sub_16_16_hilo(const __m256i a, const __m256i b, __m256i *lo, __m256i *hi) +{ + const __m256i zero = _mm256_setzero_si256(); + + *lo = _mm256_sub_epi16(a, b); + *hi = get_overflows(a, b, *lo, zero); +} + +static void add_16_16_hilo(const __m256i a, const __m256i b, __m256i *lo, __m256i *hi) +{ + const __m256i ff = _mm256_set1_epi8(-1); + + *lo = _mm256_add_epi16(a, b); + *hi = get_overflows(a, b, *lo, ff); +} + +static __m256i reverse_16x16b_in_lanes(const __m256i v) +{ + const __m256i lanerev = _mm256_setr_epi16(0x0f0e, 0x0d0c, 0x0b0a, 0x0908, + 0x0706, 0x0504, 0x0302, 0x0100, + 0x0f0e, 0x0d0c, 0x0b0a, 0x0908, + 0x0706, 0x0504, 0x0302, 0x0100); + return _mm256_shuffle_epi8(v, lanerev); +} + +static __m256i reverse_16x16b(const __m256i v) +{ + __m256i tmp = reverse_16x16b_in_lanes(v); + return _mm256_permute4x64_epi64(tmp, _MM_SHUFFLE(1, 0, 3, 2)); +} + +static __m256i m256_from_2xm128(const __m128i lo, const __m128i hi) +{ + __m256i result = _mm256_castsi128_si256 (lo); + return _mm256_inserti128_si256(result, hi, 1); +} + +// Get a vector consisting of the divisible-by-8 coeffs in DCT's first two +// columns, in order: +// DC00 DC00 DC00 DC00 DC10 DC10 DC10 DC10 | DC08 DC08 DC08 DC08 DC18 DC18 DC18 DC18 +static __m256i get_dct_db8_vec(const int16_t *dct_t, uint32_t offset) +{ + const __m256i reorder_mask = _mm256_setr_epi32(0x01000100, 0x01000100, 0x05040504, 0x05040504, + 0x03020302, 0x03020302, 0x07060706, 0x07060706); + + uint16_t coeff00 = (uint16_t)dct_t[offset + 0]; + uint16_t coeff08 = (uint16_t)dct_t[offset + 8]; + uint16_t coeff10 = (uint16_t)dct_t[offset + 32]; + uint16_t coeff18 = (uint16_t)dct_t[offset + 40]; + + uint64_t col_db8_packed = (((uint64_t)coeff00) << 0) | + (((uint64_t)coeff08) << 16) | + (((uint64_t)coeff10) << 32) | + (((uint64_t)coeff18) << 48); + + __m256i col_db8 = _mm256_set1_epi64x (col_db8_packed); + col_db8 = _mm256_shuffle_epi8(col_db8, reorder_mask); + __m256i col_db8_shifted = _mm256_slli_epi16 (col_db8, 8); + + return _mm256_blend_epi32 (col_db8, col_db8_shifted, 0xaa); +} + +// Get first 4 DCT coeffs from DCT rows (rowoff + 4) and (rowoff + 12), +// shifting copies of the coeffs 8 bits left to do half of the 65536-factor +// multiplication required for high words of EEO coeffs +static __m256i get_dct_db4_vec(const int16_t *dct, uint32_t rowoff) +{ + ALIGNED(32) uint64_t buf[4]; + + uint64_t r4_0123 = *(uint64_t *)(dct + (rowoff + 0x04) * 32); + uint64_t rc_0123 = *(uint64_t *)(dct + (rowoff + 0x0c) * 32); + + buf[0] = r4_0123; + buf[1] = r4_0123 << 8; + buf[2] = rc_0123; + buf[3] = rc_0123 << 8; + + return _mm256_load_si256((const __m256i *)buf); +} + +// Get first 8 coeffs from DCT rows (rowoff + 2) and (rowoff + 10) and return +// them in a single YMM +static __m256i get_dct_db2_vec(const int16_t *dct, uint32_t rowoff) +{ + __m128i row2 = _mm_load_si128((const __m128i *)(dct + (rowoff + 2) * 32)); + __m128i rowa = _mm_load_si128((const __m128i *)(dct + (rowoff + 10) * 32)); + + return m256_from_2xm128(row2, rowa); +} + +static void partial_butterfly_32_avx2(const int16_t *src, int16_t *dst, int32_t shift) +{ + const int32_t add = 1 << (shift - 1); + const __m256i debias = _mm256_set1_epi32(add); + + const int16_t *dct = (const int16_t *)(&kvz_g_dct_32 [0][0]); + const int16_t *dct_t = (const int16_t *)(&kvz_g_dct_32_t[0][0]); + + const __m256i ff = _mm256_set1_epi32 (-1); + const __m256i ones = _mm256_set1_epi16 ( 1); + const __m128i ff_128 = _mm256_castsi256_si128 (ff); + const __m256i lolane_smask = _mm256_inserti128_si256(ones, ff_128, 0); + const __m256i hilane_mask = _mm256_cmpeq_epi16 (ones, lolane_smask); + + const __m256i eee_lohi_shuf = _mm256_setr_epi32(0x01000706, 0x09080f0e, 0x03020504, 0x0b0a0d0c, + 0x01000706, 0x09080f0e, 0x03020504, 0x0b0a0d0c); + + const __m256i eee_eee_sgnmask = _mm256_setr_epi32(0x00010001, 0x00010001, 0x00010001, 0x00010001, + 0xffff0001, 0xffff0001, 0xffff0001, 0xffff0001); + + const __m256i dct_c0_db8 = get_dct_db8_vec(dct_t, 0); + const __m256i dct_c1_db8 = get_dct_db8_vec(dct_t, 16); + + const __m256i dct_r4c_db4[2] = { + get_dct_db4_vec(dct, 0x00), + get_dct_db4_vec(dct, 0x10), + }; + + const __m256i dct_r_db2[4] = { + get_dct_db2_vec(dct, 0), + get_dct_db2_vec(dct, 4), + get_dct_db2_vec(dct, 16), + get_dct_db2_vec(dct, 20), + }; + + __m256i res_tp[2 * 32]; + for (uint32_t i = 0; i < 32; i++) { + __m256i lo = _mm256_load_si256((const __m256i *)src + 2 * i + 0); + __m256i hi = _mm256_load_si256((const __m256i *)src + 2 * i + 1); + + __m256i hi_rev = reverse_16x16b(hi); + + __m256i e_lo, e_hi, o_lo, o_hi; + add_16_16_hilo(lo, hi_rev, &e_lo, &e_hi); + sub_16_16_hilo(lo, hi_rev, &o_lo, &o_hi); + + __m256i erev_lo = reverse_16x16b(e_lo); + __m256i erev_hi = reverse_16x16b(e_hi); + + // Hack! Negate low lanes to do subtractions there, but retain non-negated + // low lanes for overflow detection because the original value is + // essentially what was subtracted there by adding its complement + __m256i erev_lo_n = _mm256_sign_epi16(erev_lo, lolane_smask); + __m256i erev_hi_n = _mm256_sign_epi16(erev_hi, lolane_smask); + + // eo0 eo1 eo2 eo3 eo4 eo5 eo6 eo7 | ee7 ee6 ee5 ee4 ee3 ee2 ee1 ee0 + __m256i eo_ee_lo = _mm256_add_epi16 (e_lo, erev_lo_n); + __m256i eo_ee_hi = _mm256_add_epi16 (e_hi, erev_hi_n); + + __m256i eo_ee_hi2 = get_overflows (e_lo, erev_lo, eo_ee_lo, hilane_mask); + eo_ee_hi = _mm256_add_epi16 (eo_ee_hi, eo_ee_hi2); + + __m256i eo_eo_lo = _mm256_permute4x64_epi64(eo_ee_lo, _MM_SHUFFLE(1, 0, 1, 0)); + __m256i eo_eo_hi = _mm256_permute4x64_epi64(eo_ee_hi, _MM_SHUFFLE(1, 0, 1, 0)); + + // Rev: ee7 ee6 ee5 ee4 ee7 ee6 ee5 ee4 | ee3 ee2 ee1 ee0 ee3 ee2 ee1 ee0 + // Fwd: ee0 ee1 ee2 ee3 ee0 ee1 ee2 ee3 | ee4 ee5 ee6 ee7 ee4 ee5 ee6 ee7 + __m256i ee_ee_rev_lo = _mm256_permute4x64_epi64(eo_ee_lo, _MM_SHUFFLE(3, 3, 2, 2)); + __m256i ee_ee_lo = _mm256_permute4x64_epi64(eo_ee_lo, _MM_SHUFFLE(2, 2, 3, 3)); + ee_ee_lo = reverse_16x16b_in_lanes (ee_ee_lo); + + __m256i ee_ee_rev_hi = _mm256_permute4x64_epi64(eo_ee_hi, _MM_SHUFFLE(3, 3, 2, 2)); + __m256i ee_ee_hi = _mm256_permute4x64_epi64(eo_ee_hi, _MM_SHUFFLE(2, 2, 3, 3)); + ee_ee_hi = reverse_16x16b_in_lanes (ee_ee_hi); + + __m256i ee_ee_rev_lo_n = _mm256_sign_epi16(ee_ee_rev_lo, lolane_smask); + __m256i ee_ee_rev_hi_n = _mm256_sign_epi16(ee_ee_rev_hi, lolane_smask); + + // eeo0 eeo1 eeo2 eeo3 eeo0 eeo1 eeo2 eeo3 | eee3 eee2 eee1 eee0 eee3 eee2 eee1 eee0 + __m256i eeo_eee_lo = _mm256_add_epi16 (ee_ee_lo, ee_ee_rev_lo_n); + __m256i eeo_eee_hi = _mm256_add_epi16 (ee_ee_hi, ee_ee_rev_hi_n); + + __m256i eeo_eee_hi2 = get_overflows(ee_ee_lo, ee_ee_rev_lo, eeo_eee_lo, hilane_mask); + eeo_eee_hi = _mm256_add_epi16(eeo_eee_hi, eeo_eee_hi2); + + // Multiply these guys by 256 and also the corresponding DCT coefficients + // by 256, to multiply their product by 65536. Neither these values nor + // the coeffs will exceed 255, so this is overflow safe. + __m256i eeo_eee_hi_shed = _mm256_slli_epi16(eeo_eee_hi, 8); + + // Discard eeo's (low lane), duplicate the high lane, and reorder eee's + // in-lane. Finally invert the eee2/eee3 components in the high lane. + __m256i eee_lohi = _mm256_blend_epi32(eeo_eee_lo, eeo_eee_hi_shed, 0xc0); + eee_lohi = _mm256_permute4x64_epi64(eee_lohi, _MM_SHUFFLE(3, 2, 3, 2)); + + __m256i eee_lh_ordered = _mm256_shuffle_epi8(eee_lohi, eee_lohi_shuf); + __m256i eee_lh_final = _mm256_sign_epi16 (eee_lh_ordered, eee_eee_sgnmask); + + // D00, D08, D10 and D18 in four parts fit for hadding + __m256i d00_d08 = _mm256_madd_epi16(eee_lh_final, dct_c0_db8); + __m256i d10_d18 = _mm256_madd_epi16(eee_lh_final, dct_c1_db8); + + __m256i eeo_lohi = _mm256_unpacklo_epi64 (eeo_eee_lo, eeo_eee_hi_shed); + eeo_lohi = _mm256_permute4x64_epi64(eeo_lohi, _MM_SHUFFLE(1, 0, 1, 0)); + + __m256i d_db4[2] = { + _mm256_madd_epi16(eeo_lohi, dct_r4c_db4[0]), + _mm256_madd_epi16(eeo_lohi, dct_r4c_db4[1]), + }; + + __m256i db2_parts[4]; + for (uint32_t j = 0; j < 4; j++) { + const __m256i dr2a = dct_r_db2[j]; + __m256i lo = _mm256_madd_epi16(dr2a, eo_eo_lo); + __m256i hi = _mm256_madd_epi16(dr2a, eo_eo_hi); + hi = _mm256_slli_epi32(hi, 16); + + db2_parts[j] = _mm256_add_epi32 (lo, hi); + } + + __m256i odd_parts[16]; + for (uint32_t j = 0; j < 16; j++) { + __m256i drow_lo = _mm256_load_si256((const __m256i *)(dct + (j * 2 + 1) * 32)); + __m256i odds_lo = _mm256_madd_epi16(o_lo, drow_lo); + __m256i odds_hi = _mm256_madd_epi16(o_hi, drow_lo); + odds_hi = _mm256_slli_epi32(odds_hi, 16); + + odd_parts[j] = _mm256_add_epi32 (odds_lo, odds_hi); + } + + // Rearrange odds so that parts belonging to any single one are all inside + // one lane - combine 01 | 09 ; 03 | 0b ; 05 | 0d ; 07 | 0f + // 11 | 19 ; 13 | 1b ; 15 | 1d ; 17 | 1f + __m256i odd_parts2[8]; + for (uint32_t j = 0; j < 8; j++) { + + // Turn 0, 1, 2, 3, 4, 5, 6, 7 into: + // 0, 1, 2, 3, 8, 9, a, b + uint32_t j_lo = j & 0x03; + uint32_t j_hi = j & 0x04; + uint32_t id_lo = j_lo | (j_hi << 1); + uint32_t id_hi = id_lo | 4; + + __m256i odd_lo = _mm256_permute2x128_si256(odd_parts[id_lo], + odd_parts[id_hi], + 0x20); + + __m256i odd_hi = _mm256_permute2x128_si256(odd_parts[id_lo], + odd_parts[id_hi], + 0x31); + + odd_parts2[j] = _mm256_add_epi32(odd_lo, odd_hi); + } + + // First stage HADDs... + __m256i d0001_0809 = _mm256_hadd_epi32(d00_d08, odd_parts2[0]); + __m256i d1011_1819 = _mm256_hadd_epi32(d10_d18, odd_parts2[4]); + + __m256i d0405_0c0d = _mm256_hadd_epi32(d_db4[0], odd_parts2[2]); + __m256i d1415_1c1d = _mm256_hadd_epi32(d_db4[1], odd_parts2[6]); + + __m256i d0203_0a0b = _mm256_hadd_epi32(db2_parts[0], odd_parts2[1]); + __m256i d0607_0e0f = _mm256_hadd_epi32(db2_parts[1], odd_parts2[3]); + + __m256i d1213_1a1b = _mm256_hadd_epi32(db2_parts[2], odd_parts2[5]); + __m256i d1617_1e1f = _mm256_hadd_epi32(db2_parts[3], odd_parts2[7]); + + // .. and second stage + __m256i d0123_89ab_lo = _mm256_hadd_epi32(d0001_0809, d0203_0a0b); + __m256i d0123_89ab_hi = _mm256_hadd_epi32(d1011_1819, d1213_1a1b); + + __m256i d4567_cdef_lo = _mm256_hadd_epi32(d0405_0c0d, d0607_0e0f); + __m256i d4567_cdef_hi = _mm256_hadd_epi32(d1415_1c1d, d1617_1e1f); + + d0123_89ab_lo = truncate(d0123_89ab_lo, debias, shift); + d0123_89ab_hi = truncate(d0123_89ab_hi, debias, shift); + + d4567_cdef_lo = truncate(d4567_cdef_lo, debias, shift); + d4567_cdef_hi = truncate(d4567_cdef_hi, debias, shift); + + __m256i final_lo = _mm256_packs_epi32(d0123_89ab_lo, d4567_cdef_lo); + __m256i final_hi = _mm256_packs_epi32(d0123_89ab_hi, d4567_cdef_hi); + + _mm256_store_si256(res_tp + (i * 2) + 0, final_lo); + _mm256_store_si256(res_tp + (i * 2) + 1, final_hi); + } + transpose_32x32((const int16_t *)res_tp, dst); +} + // 32x32 matrix multiplication with value clipping. // Parameters: Two 32x32 matrices containing 16-bit values in consecutive addresses, // destination for the result and the shift value for clipping. @@ -923,6 +1248,16 @@ static void mul_clip_matrix_32x32_avx2(const int16_t *left, const int16_t *right } } +static void matrix_dct_32x32_avx2(int8_t bitdepth, const int16_t *input, int16_t *output) +{ + int32_t shift_1st = kvz_g_convert_to_bit[32] + 1 + (bitdepth - 8); + int32_t shift_2nd = kvz_g_convert_to_bit[32] + 8; + ALIGNED(64) int16_t tmp[32 * 32]; + + partial_butterfly_32_avx2(input, tmp, shift_1st); + partial_butterfly_32_avx2(tmp, output, shift_2nd); +} + // Macro that generates 2D transform functions with clipping values. // Sets correct shift values and matrices according to transform type and // block size. Performs matrix multiplication horizontally and vertically. @@ -964,8 +1299,9 @@ static void matrix_i ## type ## _## n ## x ## n ## _avx2(int8_t bitdepth, const // TRANSFORM(dct, 16); // ITRANSFORM(dct, 16); +// TRANSFORM(dct, 32); + // Generate all the transform functions -TRANSFORM(dct, 32); ITRANSFORM(dct, 32);