Implement 16x16 DCT as butterfly algorithm in AVX2

This commit is contained in:
Pauli Oikkonen 2019-07-08 19:47:48 +03:00
parent 7c69a26717
commit ca9409de2b

View file

@ -523,62 +523,290 @@ static void matmul_16x16_a_bt(const int16_t *a, const __m256i *b_t, int16_t *out
}
}
// 16x16 matrix multiplication with value clipping.
// Parameters: Two 16x16 matrices containing 16-bit values in consecutive addresses,
// destination for the result and the shift value for clipping.
static void mul_clip_matrix_16x16_avx2(const int16_t *left, const int16_t *right, int16_t *dst, const int32_t shift)
static void transpose_16x16(const int16_t *src, int16_t *dst)
{
const int32_t add = 1 << (shift - 1);
const __m256i debias = _mm256_set1_epi32(add);
__m256i tmp_128[16];
for (uint32_t i = 0; i < 16; i += 8) {
__m256i sliced_right[16];
for (int32_t dry = 0; dry < 16; dry += 2) {
__m256i right_up = _mm256_load_si256((const __m256i *)right + dry + 0);
__m256i right_dn = _mm256_load_si256((const __m256i *)right + dry + 1);
// After every n-bit unpack, 2n-bit units in the vectors will be in
// correct order. Pair words first, then dwords, then qwords. After that,
// whole lanes will be correct.
__m256i tmp_32[8];
__m256i tmp_64[8];
__m256i right_slices_lo = _mm256_unpacklo_epi16(right_up, right_dn);
__m256i right_slices_hi = _mm256_unpackhi_epi16(right_up, right_dn);
__m256i m[8] = {
_mm256_load_si256((const __m256i *)src + i + 0),
_mm256_load_si256((const __m256i *)src + i + 1),
_mm256_load_si256((const __m256i *)src + i + 2),
_mm256_load_si256((const __m256i *)src + i + 3),
_mm256_load_si256((const __m256i *)src + i + 4),
_mm256_load_si256((const __m256i *)src + i + 5),
_mm256_load_si256((const __m256i *)src + i + 6),
_mm256_load_si256((const __m256i *)src + i + 7),
};
sliced_right[dry + 0] = right_slices_lo;
sliced_right[dry + 1] = right_slices_hi;
tmp_32[0] = _mm256_unpacklo_epi16( m[0], m[1]);
tmp_32[1] = _mm256_unpacklo_epi16( m[2], m[3]);
tmp_32[2] = _mm256_unpackhi_epi16( m[0], m[1]);
tmp_32[3] = _mm256_unpackhi_epi16( m[2], m[3]);
tmp_32[4] = _mm256_unpacklo_epi16( m[4], m[5]);
tmp_32[5] = _mm256_unpacklo_epi16( m[6], m[7]);
tmp_32[6] = _mm256_unpackhi_epi16( m[4], m[5]);
tmp_32[7] = _mm256_unpackhi_epi16( m[6], m[7]);
tmp_64[0] = _mm256_unpacklo_epi32(tmp_32[0], tmp_32[1]);
tmp_64[1] = _mm256_unpacklo_epi32(tmp_32[2], tmp_32[3]);
tmp_64[2] = _mm256_unpackhi_epi32(tmp_32[0], tmp_32[1]);
tmp_64[3] = _mm256_unpackhi_epi32(tmp_32[2], tmp_32[3]);
tmp_64[4] = _mm256_unpacklo_epi32(tmp_32[4], tmp_32[5]);
tmp_64[5] = _mm256_unpacklo_epi32(tmp_32[6], tmp_32[7]);
tmp_64[6] = _mm256_unpackhi_epi32(tmp_32[4], tmp_32[5]);
tmp_64[7] = _mm256_unpackhi_epi32(tmp_32[6], tmp_32[7]);
tmp_128[i + 0] = _mm256_unpacklo_epi64(tmp_64[0], tmp_64[4]);
tmp_128[i + 1] = _mm256_unpackhi_epi64(tmp_64[0], tmp_64[4]);
tmp_128[i + 2] = _mm256_unpacklo_epi64(tmp_64[2], tmp_64[6]);
tmp_128[i + 3] = _mm256_unpackhi_epi64(tmp_64[2], tmp_64[6]);
tmp_128[i + 4] = _mm256_unpacklo_epi64(tmp_64[1], tmp_64[5]);
tmp_128[i + 5] = _mm256_unpackhi_epi64(tmp_64[1], tmp_64[5]);
tmp_128[i + 6] = _mm256_unpacklo_epi64(tmp_64[3], tmp_64[7]);
tmp_128[i + 7] = _mm256_unpackhi_epi64(tmp_64[3], tmp_64[7]);
}
for (int32_t dry = 0; dry < 16; dry += 2) {
__m256i accum1 = _mm256_setzero_si256();
__m256i accum2 = _mm256_setzero_si256();
__m256i accum3 = _mm256_setzero_si256();
__m256i accum4 = _mm256_setzero_si256();
for (int32_t lx = 0; lx < 16; lx += 2) {
const int32_t *curr_left_up = (const int32_t *)(left + (dry + 0) * 16 + lx);
const int32_t *curr_left_dn = (const int32_t *)(left + (dry + 1) * 16 + lx);
for (uint32_t i = 0; i < 8; i++) {
uint32_t loid = i + 0;
uint32_t hiid = i + 8;
__m256i left_slice_lo = _mm256_set1_epi32(*curr_left_up);
__m256i left_slice_hi = _mm256_set1_epi32(*curr_left_dn);
__m256i lo = tmp_128[loid];
__m256i hi = tmp_128[hiid];
__m256i final_lo = _mm256_permute2x128_si256(lo, hi, 0x20);
__m256i final_hi = _mm256_permute2x128_si256(lo, hi, 0x31);
__m256i right_slices_lo = sliced_right[lx + 0];
__m256i right_slices_hi = sliced_right[lx + 1];
__m256i prod1 = _mm256_madd_epi16(left_slice_lo, right_slices_lo);
__m256i prod2 = _mm256_madd_epi16(left_slice_hi, right_slices_lo);
__m256i prod3 = _mm256_madd_epi16(left_slice_lo, right_slices_hi);
__m256i prod4 = _mm256_madd_epi16(left_slice_hi, right_slices_hi);
accum1 = _mm256_add_epi32(accum1, prod1);
accum2 = _mm256_add_epi32(accum2, prod2);
accum3 = _mm256_add_epi32(accum3, prod3);
accum4 = _mm256_add_epi32(accum4, prod4);
_mm256_store_si256((__m256i *)dst + loid, final_lo);
_mm256_store_si256((__m256i *)dst + hiid, final_hi);
}
__m256i accum1_tr = truncate(accum1, debias, shift);
__m256i accum2_tr = truncate(accum2, debias, shift);
__m256i accum3_tr = truncate(accum3, debias, shift);
__m256i accum4_tr = truncate(accum4, debias, shift);
}
__m256i out_up = _mm256_packs_epi32(accum1_tr, accum3_tr);
__m256i out_dn = _mm256_packs_epi32(accum2_tr, accum4_tr);
static __m256i truncate_inv(__m256i v, int32_t shift)
{
int32_t add = 1 << (shift - 1);
_mm256_store_si256((__m256i *)dst + dry + 0, out_up);
_mm256_store_si256((__m256i *)dst + dry + 1, out_dn);
__m256i debias = _mm256_set1_epi32(add);
__m256i v2 = _mm256_add_epi32 (v, debias);
__m256i trunced = _mm256_srai_epi32(v2, shift);
return trunced;
}
static __m256i extract_odds(__m256i v)
{
// 0 1 2 3 4 5 6 7 | 8 9 a b c d e f => 1 3 5 7 1 3 5 7 | 9 b d f 9 b d f
const __m256i oddmask = _mm256_setr_epi8( 2, 3, 6, 7, 10, 11, 14, 15,
2, 3, 6, 7, 10, 11, 14, 15,
2, 3, 6, 7, 10, 11, 14, 15,
2, 3, 6, 7, 10, 11, 14, 15);
__m256i tmp = _mm256_shuffle_epi8 (v, oddmask);
return _mm256_permute4x64_epi64 (tmp, _MM_SHUFFLE(3, 1, 2, 0));
}
static __m256i extract_combine_odds(__m256i v0, __m256i v1)
{
// 0 1 2 3 4 5 6 7 | 8 9 a b c d e f => 1 3 5 7 1 3 5 7 | 9 b d f 9 b d f
const __m256i oddmask = _mm256_setr_epi8( 2, 3, 6, 7, 10, 11, 14, 15,
2, 3, 6, 7, 10, 11, 14, 15,
2, 3, 6, 7, 10, 11, 14, 15,
2, 3, 6, 7, 10, 11, 14, 15);
__m256i tmp0 = _mm256_shuffle_epi8(v0, oddmask);
__m256i tmp1 = _mm256_shuffle_epi8(v1, oddmask);
__m256i tmp2 = _mm256_blend_epi32 (tmp0, tmp1, 0xcc); // 1100 1100
return _mm256_permute4x64_epi64 (tmp2, _MM_SHUFFLE(3, 1, 2, 0));
}
// Extract items 2, 6, A and E from first four columns of DCT, order them as
// follows:
// D0,2 D0,6 D1,2 D1,6 D1,a D1,e D0,a D0,e | D2,2 D2,6 D3,2 D3,6 D3,a D3,e D2,a D2,e
static __m256i extract_26ae(const __m256i *tdct)
{
// 02 03 22 23 06 07 26 27 | 0a 0b 2a 2b 02 0f 2e 2f
// =>
// 02 06 22 26 02 06 22 26 | 2a 2e 0a 0e 2a 2e 0a 0e
const __m256i evens_mask = _mm256_setr_epi8( 0, 1, 8, 9, 4, 5, 12, 13,
0, 1, 8, 9, 4, 5, 12, 13,
4, 5, 12, 13, 0, 1, 8, 9,
4, 5, 12, 13, 0, 1, 8, 9);
__m256i shufd_0 = _mm256_shuffle_epi32(tdct[0], _MM_SHUFFLE(2, 3, 0, 1));
__m256i shufd_2 = _mm256_shuffle_epi32(tdct[2], _MM_SHUFFLE(2, 3, 0, 1));
__m256i cmbd_01 = _mm256_blend_epi32(shufd_0, tdct[1], 0xaa); // 1010 1010
__m256i cmbd_23 = _mm256_blend_epi32(shufd_2, tdct[3], 0xaa); // 1010 1010
__m256i evens_01 = _mm256_shuffle_epi8(cmbd_01, evens_mask);
__m256i evens_23 = _mm256_shuffle_epi8(cmbd_23, evens_mask);
__m256i evens_0123 = _mm256_unpacklo_epi64(evens_01, evens_23);
return _mm256_permute4x64_epi64(evens_0123, _MM_SHUFFLE(3, 1, 2, 0));
}
// 2 6 2 6 a e a e | 2 6 2 6 a e a e
static __m256i extract_26ae_vec(__m256i col)
{
const __m256i mask_26ae = _mm256_set1_epi32(0x0d0c0504);
// 2 6 2 6 2 6 2 6 | a e a e a e a e
__m256i reord = _mm256_shuffle_epi8 (col, mask_26ae);
__m256i final = _mm256_permute4x64_epi64(reord, _MM_SHUFFLE(3, 1, 2, 0));
return final;
}
// D00 D80 D01 D81 D41 Dc1 D40 Dc0 | D40 Dc0 D41 Dc1 D01 D81 D00 D80
static __m256i extract_d048c(const __m256i *tdct)
{
const __m256i final_shuf = _mm256_setr_epi8( 0, 1, 8, 9, 2, 3, 10, 11,
6, 7, 14, 15, 4, 5, 12, 13,
4, 5, 12, 13, 6, 7, 14, 15,
2, 3, 10, 11, 0, 1, 8, 9);
__m256i c0 = tdct[0];
__m256i c1 = tdct[1];
__m256i c1_2 = _mm256_slli_epi32 (c1, 16);
__m256i cmbd = _mm256_blend_epi16 (c0, c1_2, 0x22); // 0010 0010
__m256i cmbd2 = _mm256_shuffle_epi32 (cmbd, _MM_SHUFFLE(2, 0, 2, 0));
__m256i cmbd3 = _mm256_permute4x64_epi64(cmbd2, _MM_SHUFFLE(3, 1, 2, 0));
__m256i final = _mm256_shuffle_epi8 (cmbd3, final_shuf);
return final;
}
// 0 8 0 8 4 c 4 c | 4 c 4 c 0 8 0 8
static __m256i extract_d048c_vec(__m256i col)
{
const __m256i shufmask = _mm256_setr_epi8( 0, 1, 0, 1, 8, 9, 8, 9,
8, 9, 8, 9, 0, 1, 0, 1,
0, 1, 0, 1, 8, 9, 8, 9,
8, 9, 8, 9, 0, 1, 0, 1);
__m256i col_db4s = _mm256_shuffle_epi8 (col, shufmask);
__m256i col_los = _mm256_permute4x64_epi64(col_db4s, _MM_SHUFFLE(1, 1, 0, 0));
__m256i col_his = _mm256_permute4x64_epi64(col_db4s, _MM_SHUFFLE(3, 3, 2, 2));
__m256i final = _mm256_unpacklo_epi16 (col_los, col_his);
return final;
}
static void partial_butterfly_inverse_16_avx2(const int16_t *src, int16_t *dst, int32_t shift)
{
__m256i tsrc[16];
const uint32_t width = 16;
const int16_t *tdct = &kvz_g_dct_16_t[0][0];
const __m256i eo_signmask = _mm256_setr_epi32( 1, 1, 1, 1, -1, -1, -1, -1);
const __m256i eeo_signmask = _mm256_setr_epi32( 1, 1, -1, -1, -1, -1, 1, 1);
const __m256i o_signmask = _mm256_set1_epi32(-1);
const __m256i final_shufmask = _mm256_setr_epi8( 0, 1, 2, 3, 4, 5, 6, 7,
8, 9, 10, 11, 12, 13, 14, 15,
6, 7, 4, 5, 2, 3, 0, 1,
14, 15, 12, 13, 10, 11, 8, 9);
// TODO: this seems stoopid, fix 16x16 tp? :D
for (uint32_t i = 0; i < width; i++) {
__m256i v = _mm256_load_si256((const __m256i *)src + i);
_mm256_store_si256((__m256i *)tsrc + i, v);
}
transpose_16x16(src, (int16_t *)tsrc);
__m256i dct_cols[8];
for (uint32_t j = 0; j < 8; j++) {
dct_cols[j] = _mm256_load_si256((const __m256i *)tdct + j);
}
// These contain: D1,0 D3,0 D5,0 D7,0 D9,0 Db,0 Dd,0 Df,0 | D1,4 D3,4 D5,4 D7,4 D9,4 Db,4 Dd,4 Df,4
// D1,1 D3,1 D5,1 D7,1 D9,1 Db,1 Dd,1 Df,1 | D1,5 D3,5 D5,5 D7,5 D9,5 Db,5 Dd,5 Df,5
// D1,2 D3,2 D5,2 D7,2 D9,2 Db,2 Dd,2 Df,2 | D1,6 D3,6 D5,6 D7,6 D9,6 Db,6 Dd,6 Df,6
// D1,3 D3,3 D5,3 D7,3 D9,3 Db,3 Dd,3 Df,3 | D1,7 D3,7 D5,7 D7,7 D9,7 Db,7 Dd,7 Df,7
__m256i dct_col_odds[4];
for (uint32_t j = 0; j < 4; j++) {
dct_col_odds[j] = extract_combine_odds(dct_cols[j + 0], dct_cols[j + 4]);
}
for (uint32_t j = 0; j < width; j++) {
__m256i col = tsrc[j];
__m256i odds = extract_odds(col);
__m256i o04 = _mm256_madd_epi16(odds, dct_col_odds[0]);
__m256i o15 = _mm256_madd_epi16(odds, dct_col_odds[1]);
__m256i o26 = _mm256_madd_epi16(odds, dct_col_odds[2]);
__m256i o37 = _mm256_madd_epi16(odds, dct_col_odds[3]);
__m256i o0145 = _mm256_hadd_epi32(o04, o15);
__m256i o2367 = _mm256_hadd_epi32(o26, o37);
__m256i o = _mm256_hadd_epi32(o0145, o2367);
// D0,2 D0,6 D1,2 D1,6 D1,a D1,e D0,a D0,e | D2,2 D2,6 D3,2 D3,6 D3,a D3,e D2,a D2,e
__m256i d_db2 = extract_26ae(dct_cols);
// 2 6 2 6 a e a e | 2 6 2 6 a e a e
__m256i t_db2 = extract_26ae_vec (col);
__m256i eo_parts = _mm256_madd_epi16 (d_db2, t_db2);
__m256i eo_parts2 = _mm256_shuffle_epi32(eo_parts, _MM_SHUFFLE(0, 1, 2, 3));
// EO0 EO1 EO1 EO0 | EO2 EO3 EO3 EO2
__m256i eo = _mm256_add_epi32 (eo_parts, eo_parts2);
__m256i eo2 = _mm256_permute4x64_epi64(eo, _MM_SHUFFLE(1, 3, 2, 0));
__m256i eo3 = _mm256_sign_epi32 (eo2, eo_signmask);
__m256i d_db4 = extract_d048c (dct_cols);
__m256i t_db4 = extract_d048c_vec (col);
__m256i eee_eeo = _mm256_madd_epi16 (d_db4, t_db4);
__m256i eee_eee = _mm256_permute4x64_epi64(eee_eeo, _MM_SHUFFLE(3, 0, 3, 0));
__m256i eeo_eeo1 = _mm256_permute4x64_epi64(eee_eeo, _MM_SHUFFLE(1, 2, 1, 2));
__m256i eeo_eeo2 = _mm256_sign_epi32 (eeo_eeo1, eeo_signmask);
// EE0 EE1 EE2 EE3 | EE3 EE2 EE1 EE0
__m256i ee = _mm256_add_epi32 (eee_eee, eeo_eeo2);
__m256i e = _mm256_add_epi32 (ee, eo3);
__m256i o_neg = _mm256_sign_epi32 (o, o_signmask);
__m256i o_lo = _mm256_blend_epi32 (o, o_neg, 0xf0); // 1111 0000
__m256i o_hi = _mm256_blend_epi32 (o, o_neg, 0x0f); // 0000 1111
__m256i res_lo = _mm256_add_epi32 (e, o_lo);
__m256i res_hi = _mm256_add_epi32 (e, o_hi);
__m256i res_hi2 = _mm256_permute4x64_epi64(res_hi, _MM_SHUFFLE(1, 0, 3, 2));
__m256i res_lo_t = truncate_inv(res_lo, shift);
__m256i res_hi_t = truncate_inv(res_hi2, shift);
__m256i res_16_1 = _mm256_packs_epi32 (res_lo_t, res_hi_t);
__m256i final = _mm256_shuffle_epi8 (res_16_1, final_shufmask);
_mm256_store_si256((__m256i *)dst + j, final);
}
}
static void matrix_idct_16x16_avx2(int8_t bitdepth, const int16_t *input, int16_t *output)
{
int32_t shift_1st = 7;
int32_t shift_2nd = 12 - (bitdepth - 8);
ALIGNED(64) int16_t tmp[16 * 16];
partial_butterfly_inverse_16_avx2(input, tmp, shift_1st);
partial_butterfly_inverse_16_avx2(tmp, output, shift_2nd);
}
static void matrix_dct_16x16_avx2(int8_t bitdepth, const int16_t *input, int16_t *output)
@ -734,11 +962,11 @@ static void matrix_i ## type ## _## n ## x ## n ## _avx2(int8_t bitdepth, const
// TRANSFORM(dct, 8);
// ITRANSFORM(dct, 8);
// TRANSFORM(dct, 16);
// ITRANSFORM(dct, 16);
// Generate all the transform functions
TRANSFORM(dct, 32);
ITRANSFORM(dct, 16);
ITRANSFORM(dct, 32);
#endif //COMPILE_INTEL_AVX2