diff --git a/src/strategies/avx2/dct-avx2.c b/src/strategies/avx2/dct-avx2.c index 82c4c086..3304c923 100644 --- a/src/strategies/avx2/dct-avx2.c +++ b/src/strategies/avx2/dct-avx2.c @@ -95,65 +95,6 @@ static void mul_clip_matrix_4x4_avx2(const int16_t *left, const int16_t *right, _mm256_storeu_si256((__m256i*)dst, result); } -// 8x8 matrix multiplication with value clipping. -// Parameters: Two 8x8 matrices containing 16-bit values in consecutive addresses, -// destination for the result and the shift value for clipping. -// -static void mul_clip_matrix_8x8_avx2(const int16_t *left, const int16_t *right, int16_t *dst, const int32_t shift) -{ - int i, j; - __m256i b[2], accu[8], even[2], odd[2]; - - const int32_t add = 1 << (shift - 1); - - b[0] = _mm256_loadu_si256((__m256i*) right); - - b[1] = _mm256_unpackhi_epi16(b[0], _mm256_castsi128_si256(_mm256_extracti128_si256(b[0], 1))); - b[0] = _mm256_unpacklo_epi16(b[0], _mm256_castsi128_si256(_mm256_extracti128_si256(b[0], 1))); - b[0] = _mm256_inserti128_si256(b[0], _mm256_castsi256_si128(b[1]), 1); - - for (i = 0; i < 8; i += 2) { - - even[0] = _mm256_set1_epi32(((int32_t*)left)[4 * i]); - even[0] = _mm256_madd_epi16(even[0], b[0]); - accu[i] = even[0]; - - odd[0] = _mm256_set1_epi32(((int32_t*)left)[4 * (i + 1)]); - odd[0] = _mm256_madd_epi16(odd[0], b[0]); - accu[i + 1] = odd[0]; - } - - for (j = 1; j < 4; ++j) { - - b[0] = _mm256_loadu_si256((__m256i*)right + j); - - b[1] = _mm256_unpackhi_epi16(b[0], _mm256_castsi128_si256(_mm256_extracti128_si256(b[0], 1))); - b[0] = _mm256_unpacklo_epi16(b[0], _mm256_castsi128_si256(_mm256_extracti128_si256(b[0], 1))); - b[0] = _mm256_inserti128_si256(b[0], _mm256_castsi256_si128(b[1]), 1); - - for (i = 0; i < 8; i += 2) { - - even[0] = _mm256_set1_epi32(((int32_t*)left)[4 * i + j]); - even[0] = _mm256_madd_epi16(even[0], b[0]); - accu[i] = _mm256_add_epi32(accu[i], even[0]); - - odd[0] = _mm256_set1_epi32(((int32_t*)left)[4 * (i + 1) + j]); - odd[0] = _mm256_madd_epi16(odd[0], b[0]); - accu[i + 1] = _mm256_add_epi32(accu[i + 1], odd[0]); - } - } - - for (i = 0; i < 8; i += 2) { - __m256i result, first_half, second_half; - - first_half = _mm256_srai_epi32(_mm256_add_epi32(accu[i], _mm256_set1_epi32(add)), shift); - second_half = _mm256_srai_epi32(_mm256_add_epi32(accu[i + 1], _mm256_set1_epi32(add)), shift); - result = _mm256_permute4x64_epi64(_mm256_packs_epi32(first_half, second_half), 0 + 8 + 16 + 192); - _mm256_storeu_si256((__m256i*)dst + i / 2, result); - - } -} - static INLINE __m256i swap_lanes(__m256i v) { return _mm256_permute4x64_epi64(v, _MM_SHUFFLE(1, 0, 3, 2)); @@ -165,18 +106,83 @@ static INLINE __m256i truncate(__m256i v, __m256i debias, int32_t shift) return _mm256_srai_epi32(truncable, shift); } -static void matrix_dct_8x8_avx2(int8_t bitdepth, const int16_t *input, int16_t *output) +static void mul_clip_matrix_8x8_avx2(const int16_t *left, const int16_t *right, int16_t *dst, const int32_t shift) { - int32_t shift_1st = kvz_g_convert_to_bit[8] + 1 + (bitdepth - 8); - int32_t shift_2nd = kvz_g_convert_to_bit[8] + 8; + const __m256i transp_mask = _mm256_broadcastsi128_si256(_mm_setr_epi8(0, 1, 8, 9, 2, 3, 10, 11, 4, 5, 12, 13, 6, 7, 14, 15)); - const int32_t add1 = 1 << (shift_1st - 1); - const __m256i debias1 = _mm256_set1_epi32(add1); + const int32_t add = 1 << (shift - 1); + const __m256i debias = _mm256_set1_epi32(add); - const int32_t add2 = 1 << (shift_2nd - 1); - const __m256i debias2 = _mm256_set1_epi32(add2); + __m256i left_dr[4] = { + _mm256_loadu_si256((const __m256i *)left + 0), + _mm256_loadu_si256((const __m256i *)left + 1), + _mm256_loadu_si256((const __m256i *)left + 2), + _mm256_loadu_si256((const __m256i *)left + 3), + }; + __m256i right_dr[4] = { + _mm256_loadu_si256((const __m256i *)right + 0), + _mm256_loadu_si256((const __m256i *)right + 1), + _mm256_loadu_si256((const __m256i *)right + 2), + _mm256_loadu_si256((const __m256i *)right + 3), + }; - const __m256i *dct = (__m256i *)&(kvz_g_dct_8[0][0]); + __m256i rdrs_rearr[8]; + + // Rearrange right matrix + for (int32_t dry = 0; dry < 4; dry++) { + __m256i rdr = right_dr[dry]; + __m256i rdr_los = _mm256_permute4x64_epi64(rdr, _MM_SHUFFLE(2, 0, 2, 0)); + __m256i rdr_his = _mm256_permute4x64_epi64(rdr, _MM_SHUFFLE(3, 1, 3, 1)); + + __m256i rdr_lo_rearr = _mm256_shuffle_epi8(rdr_los, transp_mask); + __m256i rdr_hi_rearr = _mm256_shuffle_epi8(rdr_his, transp_mask); + + rdrs_rearr[dry * 2 + 0] = rdr_lo_rearr; + rdrs_rearr[dry * 2 + 1] = rdr_hi_rearr; + } + + // Double-Row Y for destination matrix + for (int32_t dry = 0; dry < 4; dry++) { + __m256i ldr = left_dr[dry]; + + __m256i ldr_slice12 = _mm256_shuffle_epi32(ldr, _MM_SHUFFLE(0, 0, 0, 0)); + __m256i ldr_slice34 = _mm256_shuffle_epi32(ldr, _MM_SHUFFLE(1, 1, 1, 1)); + __m256i ldr_slice56 = _mm256_shuffle_epi32(ldr, _MM_SHUFFLE(2, 2, 2, 2)); + __m256i ldr_slice78 = _mm256_shuffle_epi32(ldr, _MM_SHUFFLE(3, 3, 3, 3)); + + __m256i prod1 = _mm256_madd_epi16(ldr_slice12, rdrs_rearr[0]); + __m256i prod2 = _mm256_madd_epi16(ldr_slice12, rdrs_rearr[1]); + __m256i prod3 = _mm256_madd_epi16(ldr_slice34, rdrs_rearr[2]); + __m256i prod4 = _mm256_madd_epi16(ldr_slice34, rdrs_rearr[3]); + __m256i prod5 = _mm256_madd_epi16(ldr_slice56, rdrs_rearr[4]); + __m256i prod6 = _mm256_madd_epi16(ldr_slice56, rdrs_rearr[5]); + __m256i prod7 = _mm256_madd_epi16(ldr_slice78, rdrs_rearr[6]); + __m256i prod8 = _mm256_madd_epi16(ldr_slice78, rdrs_rearr[7]); + + __m256i lo_1 = _mm256_add_epi32(prod1, prod3); + __m256i hi_1 = _mm256_add_epi32(prod2, prod4); + __m256i lo_2 = _mm256_add_epi32(prod5, prod7); + __m256i hi_2 = _mm256_add_epi32(prod6, prod8); + + __m256i lo = _mm256_add_epi32(lo_1, lo_2); + __m256i hi = _mm256_add_epi32(hi_1, hi_2); + + __m256i lo_tr = truncate(lo, debias, shift); + __m256i hi_tr = truncate(hi, debias, shift); + + __m256i final_dr = _mm256_packs_epi32(lo_tr, hi_tr); + + _mm256_storeu_si256((__m256i *)dst + dry, final_dr); + } +} + +// Multiplies A by B_T's transpose and stores result's transpose in output, +// which should be an array of 4 __m256i's +static void matmul_8x8_a_bt_t(const int16_t *a, const int16_t *b_t, + const int8_t shift, __m256i *output) +{ + const int32_t add = 1 << (shift - 1); + const __m256i debias = _mm256_set1_epi32(add); // Keep upper row intact and swap neighboring 16-bit words in lower row const __m256i shuf_lorow_mask = @@ -185,18 +191,113 @@ static void matrix_dct_8x8_avx2(int8_t bitdepth, const int16_t *input, int16_t * 18, 19, 16, 17, 22, 23, 20, 21, 26, 27, 24, 25, 30, 31, 28, 29); - __m256i tmpres[4]; + const __m256i *b_t_256 = (const __m256i *)b_t; // Dual Rows, because two 8x16b words fit in one YMM - __m256i i_dr_0 = _mm256_loadu_si256((__m256i *)input + 0); - __m256i i_dr_1 = _mm256_loadu_si256((__m256i *)input + 1); - __m256i i_dr_2 = _mm256_loadu_si256((__m256i *)input + 2); - __m256i i_dr_3 = _mm256_loadu_si256((__m256i *)input + 3); + __m256i a_dr_0 = _mm256_loadu_si256((__m256i *)a + 0); + __m256i a_dr_1 = _mm256_loadu_si256((__m256i *)a + 1); + __m256i a_dr_2 = _mm256_loadu_si256((__m256i *)a + 2); + __m256i a_dr_3 = _mm256_loadu_si256((__m256i *)a + 3); - __m256i i_dr_0_swp = swap_lanes(i_dr_0); - __m256i i_dr_1_swp = swap_lanes(i_dr_1); - __m256i i_dr_2_swp = swap_lanes(i_dr_2); - __m256i i_dr_3_swp = swap_lanes(i_dr_3); + __m256i a_dr_0_swp = swap_lanes(a_dr_0); + __m256i a_dr_1_swp = swap_lanes(a_dr_1); + __m256i a_dr_2_swp = swap_lanes(a_dr_2); + __m256i a_dr_3_swp = swap_lanes(a_dr_3); + + for (int dry = 0; dry < 4; dry++) { + + // Read dual columns of B matrix by reading rows of its transpose + __m256i b_dc = _mm256_loadu_si256(b_t_256 + dry); + + __m256i prod0 = _mm256_madd_epi16(b_dc, a_dr_0); + __m256i prod0_swp = _mm256_madd_epi16(b_dc, a_dr_0_swp); + __m256i prod1 = _mm256_madd_epi16(b_dc, a_dr_1); + __m256i prod1_swp = _mm256_madd_epi16(b_dc, a_dr_1_swp); + __m256i prod2 = _mm256_madd_epi16(b_dc, a_dr_2); + __m256i prod2_swp = _mm256_madd_epi16(b_dc, a_dr_2_swp); + __m256i prod3 = _mm256_madd_epi16(b_dc, a_dr_3); + __m256i prod3_swp = _mm256_madd_epi16(b_dc, a_dr_3_swp); + + __m256i hsum0 = _mm256_hadd_epi32(prod0, prod0_swp); + __m256i hsum1 = _mm256_hadd_epi32(prod1, prod1_swp); + __m256i hsum2 = _mm256_hadd_epi32(prod2, prod2_swp); + __m256i hsum3 = _mm256_hadd_epi32(prod3, prod3_swp); + + __m256i hsum2c_0 = _mm256_hadd_epi32(hsum0, hsum1); + __m256i hsum2c_1 = _mm256_hadd_epi32(hsum2, hsum3); + + __m256i hsum2c_0_tr = truncate(hsum2c_0, debias, shift); + __m256i hsum2c_1_tr = truncate(hsum2c_1, debias, shift); + + __m256i tmp_dc = _mm256_packs_epi32(hsum2c_0_tr, hsum2c_1_tr); + + output[dry] = _mm256_shuffle_epi8(tmp_dc, shuf_lorow_mask); + } +} + +// Multiplies A by B_T's transpose and stores result in output +// which should be an array of 4 __m256i's +static void matmul_8x8_a_bt(const int16_t *a, const __m256i *b_t, + const int8_t shift, int16_t *output) +{ + const int32_t add = 1 << (shift - 1); + const __m256i debias = _mm256_set1_epi32(add); + + const __m256i shuf_lorow_mask = + _mm256_setr_epi8(0, 1, 2, 3, 4, 5, 6, 7, + 8, 9, 10, 11, 12, 13, 14, 15, + 18, 19, 16, 17, 22, 23, 20, 21, + 26, 27, 24, 25, 30, 31, 28, 29); + + const __m256i *a_256 = (const __m256i *)a; + + __m256i b_dc_0 = b_t[0]; + __m256i b_dc_1 = b_t[1]; + __m256i b_dc_2 = b_t[2]; + __m256i b_dc_3 = b_t[3]; + + __m256i b_dc_0_swp = swap_lanes(b_dc_0); + __m256i b_dc_1_swp = swap_lanes(b_dc_1); + __m256i b_dc_2_swp = swap_lanes(b_dc_2); + __m256i b_dc_3_swp = swap_lanes(b_dc_3); + + for (int dry = 0; dry < 4; dry++) { + __m256i a_dr = _mm256_loadu_si256(a_256 + dry); + + __m256i prod0 = _mm256_madd_epi16(a_dr, b_dc_0); + __m256i prod0_swp = _mm256_madd_epi16(a_dr, b_dc_0_swp); + __m256i prod1 = _mm256_madd_epi16(a_dr, b_dc_1); + __m256i prod1_swp = _mm256_madd_epi16(a_dr, b_dc_1_swp); + __m256i prod2 = _mm256_madd_epi16(a_dr, b_dc_2); + __m256i prod2_swp = _mm256_madd_epi16(a_dr, b_dc_2_swp); + __m256i prod3 = _mm256_madd_epi16(a_dr, b_dc_3); + __m256i prod3_swp = _mm256_madd_epi16(a_dr, b_dc_3_swp); + + __m256i hsum0 = _mm256_hadd_epi32(prod0, prod0_swp); + __m256i hsum1 = _mm256_hadd_epi32(prod1, prod1_swp); + __m256i hsum2 = _mm256_hadd_epi32(prod2, prod2_swp); + __m256i hsum3 = _mm256_hadd_epi32(prod3, prod3_swp); + + __m256i hsum2c_0 = _mm256_hadd_epi32(hsum0, hsum1); + __m256i hsum2c_1 = _mm256_hadd_epi32(hsum2, hsum3); + + __m256i hsum2c_0_tr = truncate(hsum2c_0, debias, shift); + __m256i hsum2c_1_tr = truncate(hsum2c_1, debias, shift); + + __m256i tmp_dr = _mm256_packs_epi32(hsum2c_0_tr, hsum2c_1_tr); + + __m256i final_dr = _mm256_shuffle_epi8(tmp_dr, shuf_lorow_mask); + + _mm256_storeu_si256((__m256i *)output + dry, final_dr); + } +} + +static void matrix_dct_8x8_avx2(int8_t bitdepth, const int16_t *input, int16_t *output) +{ + int32_t shift_1st = kvz_g_convert_to_bit[8] + 1 + (bitdepth - 8); + int32_t shift_2nd = kvz_g_convert_to_bit[8] + 8; + + const int16_t *dct = &kvz_g_dct_8[0][0]; /* * Multiply input by the tranpose of DCT matrix into tmpres, and DCT matrix @@ -210,75 +311,36 @@ static void matrix_dct_8x8_avx2(int8_t bitdepth, const int16_t *input, int16_t * * instead, since it will be used in similar fashion as the right operand * in the second multiplication. */ - for (int dry = 0; dry < 4; dry++) { - // Read columns of DCT matrix's transpose by reading rows of DCT matrix - __m256i d_dr = _mm256_loadu_si256(dct + dry); + __m256i tmpres[4]; - __m256i prod0 = _mm256_madd_epi16(d_dr, i_dr_0); - __m256i prod0_swp = _mm256_madd_epi16(d_dr, i_dr_0_swp); - __m256i prod1 = _mm256_madd_epi16(d_dr, i_dr_1); - __m256i prod1_swp = _mm256_madd_epi16(d_dr, i_dr_1_swp); - __m256i prod2 = _mm256_madd_epi16(d_dr, i_dr_2); - __m256i prod2_swp = _mm256_madd_epi16(d_dr, i_dr_2_swp); - __m256i prod3 = _mm256_madd_epi16(d_dr, i_dr_3); - __m256i prod3_swp = _mm256_madd_epi16(d_dr, i_dr_3_swp); + matmul_8x8_a_bt_t(input, dct, shift_1st, tmpres); + matmul_8x8_a_bt (dct, tmpres, shift_2nd, output); +} - __m256i hsum0 = _mm256_hadd_epi32(prod0, prod0_swp); - __m256i hsum1 = _mm256_hadd_epi32(prod1, prod1_swp); - __m256i hsum2 = _mm256_hadd_epi32(prod2, prod2_swp); - __m256i hsum3 = _mm256_hadd_epi32(prod3, prod3_swp); +static void matrix_idct_8x8_avx2(int8_t bitdepth, const int16_t *input, int16_t *output) +{ + int32_t shift_1st = 7; + int32_t shift_2nd = 12 - (bitdepth - 8); + int16_t tmp[8 * 8]; - __m256i hsum2c_0 = _mm256_hadd_epi32(hsum0, hsum1); - __m256i hsum2c_1 = _mm256_hadd_epi32(hsum2, hsum3); + const int16_t *tdct = &kvz_g_dct_8_t[0][0]; + const int16_t *dct = &kvz_g_dct_8 [0][0]; - __m256i hsum2c_0_tr = truncate(hsum2c_0, debias1, shift_1st); - __m256i hsum2c_1_tr = truncate(hsum2c_1, debias1, shift_1st); + mul_clip_matrix_8x8_avx2(tdct, input, tmp, shift_1st); + mul_clip_matrix_8x8_avx2(tmp, dct, output, shift_2nd); - __m256i tmp_dr = _mm256_packs_epi32(hsum2c_0_tr, hsum2c_1_tr); - - tmpres[dry] = _mm256_shuffle_epi8(tmp_dr, shuf_lorow_mask); - } - - __m256i t_dr_0 = tmpres[0]; - __m256i t_dr_1 = tmpres[1]; - __m256i t_dr_2 = tmpres[2]; - __m256i t_dr_3 = tmpres[3]; - - __m256i t_dr_0_swp = swap_lanes(t_dr_0); - __m256i t_dr_1_swp = swap_lanes(t_dr_1); - __m256i t_dr_2_swp = swap_lanes(t_dr_2); - __m256i t_dr_3_swp = swap_lanes(t_dr_3); - - for (int dry = 0; dry < 4; dry++) { - __m256i d_dr = _mm256_loadu_si256(dct + dry); - - __m256i prod0 = _mm256_madd_epi16(d_dr, t_dr_0); - __m256i prod0_swp = _mm256_madd_epi16(d_dr, t_dr_0_swp); - __m256i prod1 = _mm256_madd_epi16(d_dr, t_dr_1); - __m256i prod1_swp = _mm256_madd_epi16(d_dr, t_dr_1_swp); - __m256i prod2 = _mm256_madd_epi16(d_dr, t_dr_2); - __m256i prod2_swp = _mm256_madd_epi16(d_dr, t_dr_2_swp); - __m256i prod3 = _mm256_madd_epi16(d_dr, t_dr_3); - __m256i prod3_swp = _mm256_madd_epi16(d_dr, t_dr_3_swp); - - __m256i hsum0 = _mm256_hadd_epi32(prod0, prod0_swp); - __m256i hsum1 = _mm256_hadd_epi32(prod1, prod1_swp); - __m256i hsum2 = _mm256_hadd_epi32(prod2, prod2_swp); - __m256i hsum3 = _mm256_hadd_epi32(prod3, prod3_swp); - - __m256i hsum2c_0 = _mm256_hadd_epi32(hsum0, hsum1); - __m256i hsum2c_1 = _mm256_hadd_epi32(hsum2, hsum3); - - __m256i hsum2c_0_tr = truncate(hsum2c_0, debias2, shift_2nd); - __m256i hsum2c_1_tr = truncate(hsum2c_1, debias2, shift_2nd); - - __m256i tmp_dr = _mm256_packs_epi32(hsum2c_0_tr, hsum2c_1_tr); - - __m256i final_dr = _mm256_shuffle_epi8(tmp_dr, shuf_lorow_mask); - - _mm256_storeu_si256(((__m256i *)output) + dry, final_dr); - } + /* + * Because: + * out = tdct * input * dct = tdct * (input * dct) = tdct * (input * transpose(tdct)) + * This could almost be done this way: + * + * matmul_8x8_a_bt_t(input, tdct, debias1, shift_1st, tmp); + * matmul_8x8_a_bt (tdct, tmp, debias2, shift_2nd, output); + * + * But not really, since it will fall victim to some very occasional + * rounding errors. Sadly. + */ } // 16x16 matrix multiplication with value clipping. @@ -468,15 +530,15 @@ static void matrix_i ## type ## _## n ## x ## n ## _avx2(int8_t bitdepth, const TRANSFORM(dst, 4); TRANSFORM(dct, 4); -// Ha, we've got a tailored implementation for this +// Ha, we've got a tailored implementation for these // TRANSFORM(dct, 8); +// ITRANSFORM(dct, 8); TRANSFORM(dct, 16); TRANSFORM(dct, 32); ITRANSFORM(dst, 4); ITRANSFORM(dct, 4); -ITRANSFORM(dct, 8); ITRANSFORM(dct, 16); ITRANSFORM(dct, 32);