Implement a tailored AVX2 8x8 DCT

This commit is contained in:
Pauli Oikkonen 2019-05-23 16:05:14 +03:00
parent ad7c8d40bc
commit 7ec7ab3361

View file

@ -154,6 +154,133 @@ static void mul_clip_matrix_8x8_avx2(const int16_t *left, const int16_t *right,
}
}
static INLINE __m256i swap_lanes(__m256i v)
{
return _mm256_permute4x64_epi64(v, _MM_SHUFFLE(1, 0, 3, 2));
}
static INLINE __m256i truncate(__m256i v, __m256i debias, int32_t shift)
{
__m256i truncable = _mm256_add_epi32 (v, debias);
return _mm256_srai_epi32(truncable, shift);
}
static void matrix_dct_8x8_avx2(int8_t bitdepth, const int16_t *input, int16_t *output)
{
int32_t shift_1st = kvz_g_convert_to_bit[8] + 1 + (bitdepth - 8);
int32_t shift_2nd = kvz_g_convert_to_bit[8] + 8;
const int32_t add1 = 1 << (shift_1st - 1);
const __m256i debias1 = _mm256_set1_epi32(add1);
const int32_t add2 = 1 << (shift_2nd - 1);
const __m256i debias2 = _mm256_set1_epi32(add2);
const __m256i *dct = (__m256i *)&(kvz_g_dct_8[0][0]);
// Keep upper row intact and swap neighboring 16-bit words in lower row
const __m256i shuf_lorow_mask =
_mm256_setr_epi8(0, 1, 2, 3, 4, 5, 6, 7,
8, 9, 10, 11, 12, 13, 14, 15,
18, 19, 16, 17, 22, 23, 20, 21,
26, 27, 24, 25, 30, 31, 28, 29);
__m256i tmpres[4];
// Dual Rows, because two 8x16b words fit in one YMM
__m256i i_dr_0 = _mm256_loadu_si256((__m256i *)input + 0);
__m256i i_dr_1 = _mm256_loadu_si256((__m256i *)input + 1);
__m256i i_dr_2 = _mm256_loadu_si256((__m256i *)input + 2);
__m256i i_dr_3 = _mm256_loadu_si256((__m256i *)input + 3);
__m256i i_dr_0_swp = swap_lanes(i_dr_0);
__m256i i_dr_1_swp = swap_lanes(i_dr_1);
__m256i i_dr_2_swp = swap_lanes(i_dr_2);
__m256i i_dr_3_swp = swap_lanes(i_dr_3);
/*
* Multiply input by the tranpose of DCT matrix into tmpres, and DCT matrix
* by tmpres - this is then our output matrix
*
* It's easier to implement an AVX2 matrix multiplication if you can multiply
* the left term with the transpose of the right term. Here things are stored
* row-wise, not column-wise, so we can effectively read DCT_T column-wise
* into YMM registers by reading DCT row-wise. Also because of this, the
* first multiplication is hacked to produce the transpose of the result
* instead, since it will be used in similar fashion as the right operand
* in the second multiplication.
*/
for (int dry = 0; dry < 4; dry++) {
// Read columns of DCT matrix's transpose by reading rows of DCT matrix
__m256i d_dr = _mm256_loadu_si256(dct + dry);
__m256i prod0 = _mm256_madd_epi16(d_dr, i_dr_0);
__m256i prod0_swp = _mm256_madd_epi16(d_dr, i_dr_0_swp);
__m256i prod1 = _mm256_madd_epi16(d_dr, i_dr_1);
__m256i prod1_swp = _mm256_madd_epi16(d_dr, i_dr_1_swp);
__m256i prod2 = _mm256_madd_epi16(d_dr, i_dr_2);
__m256i prod2_swp = _mm256_madd_epi16(d_dr, i_dr_2_swp);
__m256i prod3 = _mm256_madd_epi16(d_dr, i_dr_3);
__m256i prod3_swp = _mm256_madd_epi16(d_dr, i_dr_3_swp);
__m256i hsum0 = _mm256_hadd_epi32(prod0, prod0_swp);
__m256i hsum1 = _mm256_hadd_epi32(prod1, prod1_swp);
__m256i hsum2 = _mm256_hadd_epi32(prod2, prod2_swp);
__m256i hsum3 = _mm256_hadd_epi32(prod3, prod3_swp);
__m256i hsum2c_0 = _mm256_hadd_epi32(hsum0, hsum1);
__m256i hsum2c_1 = _mm256_hadd_epi32(hsum2, hsum3);
__m256i hsum2c_0_tr = truncate(hsum2c_0, debias1, shift_1st);
__m256i hsum2c_1_tr = truncate(hsum2c_1, debias1, shift_1st);
__m256i tmp_dr = _mm256_packs_epi32(hsum2c_0_tr, hsum2c_1_tr);
tmpres[dry] = _mm256_shuffle_epi8(tmp_dr, shuf_lorow_mask);
}
__m256i t_dr_0 = tmpres[0];
__m256i t_dr_1 = tmpres[1];
__m256i t_dr_2 = tmpres[2];
__m256i t_dr_3 = tmpres[3];
__m256i t_dr_0_swp = swap_lanes(t_dr_0);
__m256i t_dr_1_swp = swap_lanes(t_dr_1);
__m256i t_dr_2_swp = swap_lanes(t_dr_2);
__m256i t_dr_3_swp = swap_lanes(t_dr_3);
for (int dry = 0; dry < 4; dry++) {
__m256i d_dr = _mm256_loadu_si256(dct + dry);
__m256i prod0 = _mm256_madd_epi16(d_dr, t_dr_0);
__m256i prod0_swp = _mm256_madd_epi16(d_dr, t_dr_0_swp);
__m256i prod1 = _mm256_madd_epi16(d_dr, t_dr_1);
__m256i prod1_swp = _mm256_madd_epi16(d_dr, t_dr_1_swp);
__m256i prod2 = _mm256_madd_epi16(d_dr, t_dr_2);
__m256i prod2_swp = _mm256_madd_epi16(d_dr, t_dr_2_swp);
__m256i prod3 = _mm256_madd_epi16(d_dr, t_dr_3);
__m256i prod3_swp = _mm256_madd_epi16(d_dr, t_dr_3_swp);
__m256i hsum0 = _mm256_hadd_epi32(prod0, prod0_swp);
__m256i hsum1 = _mm256_hadd_epi32(prod1, prod1_swp);
__m256i hsum2 = _mm256_hadd_epi32(prod2, prod2_swp);
__m256i hsum3 = _mm256_hadd_epi32(prod3, prod3_swp);
__m256i hsum2c_0 = _mm256_hadd_epi32(hsum0, hsum1);
__m256i hsum2c_1 = _mm256_hadd_epi32(hsum2, hsum3);
__m256i hsum2c_0_tr = truncate(hsum2c_0, debias2, shift_2nd);
__m256i hsum2c_1_tr = truncate(hsum2c_1, debias2, shift_2nd);
__m256i tmp_dr = _mm256_packs_epi32(hsum2c_0_tr, hsum2c_1_tr);
__m256i final_dr = _mm256_shuffle_epi8(tmp_dr, shuf_lorow_mask);
_mm256_storeu_si256(((__m256i *)output) + dry, final_dr);
}
}
// 16x16 matrix multiplication with value clipping.
// Parameters: Two 16x16 matrices containing 16-bit values in consecutive addresses,
// destination for the result and the shift value for clipping.
@ -340,7 +467,10 @@ static void matrix_i ## type ## _## n ## x ## n ## _avx2(int8_t bitdepth, const
// Generate all the transform functions
TRANSFORM(dst, 4);
TRANSFORM(dct, 4);
TRANSFORM(dct, 8);
// Ha, we've got a tailored implementation for this
// TRANSFORM(dct, 8);
TRANSFORM(dct, 16);
TRANSFORM(dct, 32);