Implement fast (butterfly) 32x32 DCT in AVX2

This commit is contained in:
Pauli Oikkonen 2019-10-03 23:00:31 +03:00
parent b5962dadac
commit e382339182

View file

@ -523,7 +523,12 @@ static void matmul_16x16_a_bt(const int16_t *a, const __m256i *b_t, int16_t *out
}
}
static void transpose_16x16(const int16_t *src, int16_t *dst)
// NOTE: The strides measured by s_stride_log2 and d_stride_log2 are in units
// of 16 coeffs, not 1!
static void transpose_16x16_stride(const int16_t *src,
int16_t *dst,
uint8_t s_stride_log2,
uint8_t d_stride_log2)
{
__m256i tmp_128[16];
for (uint32_t i = 0; i < 16; i += 8) {
@ -535,14 +540,14 @@ static void transpose_16x16(const int16_t *src, int16_t *dst)
__m256i tmp_64[8];
__m256i m[8] = {
_mm256_load_si256((const __m256i *)src + i + 0),
_mm256_load_si256((const __m256i *)src + i + 1),
_mm256_load_si256((const __m256i *)src + i + 2),
_mm256_load_si256((const __m256i *)src + i + 3),
_mm256_load_si256((const __m256i *)src + i + 4),
_mm256_load_si256((const __m256i *)src + i + 5),
_mm256_load_si256((const __m256i *)src + i + 6),
_mm256_load_si256((const __m256i *)src + i + 7),
_mm256_load_si256((const __m256i *)src + ((i + 0) << s_stride_log2)),
_mm256_load_si256((const __m256i *)src + ((i + 1) << s_stride_log2)),
_mm256_load_si256((const __m256i *)src + ((i + 2) << s_stride_log2)),
_mm256_load_si256((const __m256i *)src + ((i + 3) << s_stride_log2)),
_mm256_load_si256((const __m256i *)src + ((i + 4) << s_stride_log2)),
_mm256_load_si256((const __m256i *)src + ((i + 5) << s_stride_log2)),
_mm256_load_si256((const __m256i *)src + ((i + 6) << s_stride_log2)),
_mm256_load_si256((const __m256i *)src + ((i + 7) << s_stride_log2)),
};
tmp_32[0] = _mm256_unpacklo_epi16( m[0], m[1]);
@ -579,19 +584,35 @@ static void transpose_16x16(const int16_t *src, int16_t *dst)
}
for (uint32_t i = 0; i < 8; i++) {
uint32_t loid = i + 0;
uint32_t hiid = i + 8;
uint32_t loid = i + 0;
uint32_t hiid = i + 8;
uint32_t dst_loid = loid << d_stride_log2;
uint32_t dst_hiid = hiid << d_stride_log2;
__m256i lo = tmp_128[loid];
__m256i hi = tmp_128[hiid];
__m256i final_lo = _mm256_permute2x128_si256(lo, hi, 0x20);
__m256i final_hi = _mm256_permute2x128_si256(lo, hi, 0x31);
_mm256_store_si256((__m256i *)dst + loid, final_lo);
_mm256_store_si256((__m256i *)dst + hiid, final_hi);
_mm256_store_si256((__m256i *)dst + dst_loid, final_lo);
_mm256_store_si256((__m256i *)dst + dst_hiid, final_hi);
}
}
static void transpose_16x16(const int16_t *src, int16_t *dst)
{
transpose_16x16_stride(src, dst, 0, 0);
}
static void transpose_32x32(const int16_t *src, int16_t *dst)
{
transpose_16x16_stride(src + 0, dst + 0, 1, 1);
transpose_16x16_stride(src + 16, dst + 16 * 32, 1, 1);
transpose_16x16_stride(src + 16 * 32, dst + 16, 1, 1);
transpose_16x16_stride(src + 16 * 33, dst + 16 * 33, 1, 1);
}
static __m256i truncate_inv(__m256i v, int32_t shift)
{
int32_t add = 1 << (shift - 1);
@ -834,6 +855,310 @@ static void matrix_dct_16x16_avx2(int8_t bitdepth, const int16_t *input, int16_t
matmul_16x16_a_bt (dct, tmpres, output, shift_2nd);
}
static __m256i get_overflows(const __m256i a, const __m256i b, const __m256i res, const __m256i of_adjust_mask)
{
const __m256i ones = _mm256_set1_epi16(1);
__m256i src_signdiff = _mm256_xor_si256 (a, b);
__m256i a_r_signdiff = _mm256_xor_si256 (a, res);
__m256i of_possible = _mm256_xor_si256 (src_signdiff, of_adjust_mask);
__m256i overflows = _mm256_and_si256 (of_possible, a_r_signdiff);
overflows = _mm256_srai_epi16 (overflows, 15);
__m256i of_signs = _mm256_srai_epi16 (a, 15);
of_signs = _mm256_or_si256 (of_signs, ones);
return _mm256_and_si256 (overflows, of_signs);
}
/*
* You need more than 16 bits to store the result of signed 16b-16b operation,
* this one stores the low 16b in lo and high 16b in hi. The high 16 bits can
* only be -1, 0 or 1.
*
* of_possible_mask is either all zero bits for subtraction, or all ones for
* addition
*/
static void sub_16_16_hilo(const __m256i a, const __m256i b, __m256i *lo, __m256i *hi)
{
const __m256i zero = _mm256_setzero_si256();
*lo = _mm256_sub_epi16(a, b);
*hi = get_overflows(a, b, *lo, zero);
}
static void add_16_16_hilo(const __m256i a, const __m256i b, __m256i *lo, __m256i *hi)
{
const __m256i ff = _mm256_set1_epi8(-1);
*lo = _mm256_add_epi16(a, b);
*hi = get_overflows(a, b, *lo, ff);
}
static __m256i reverse_16x16b_in_lanes(const __m256i v)
{
const __m256i lanerev = _mm256_setr_epi16(0x0f0e, 0x0d0c, 0x0b0a, 0x0908,
0x0706, 0x0504, 0x0302, 0x0100,
0x0f0e, 0x0d0c, 0x0b0a, 0x0908,
0x0706, 0x0504, 0x0302, 0x0100);
return _mm256_shuffle_epi8(v, lanerev);
}
static __m256i reverse_16x16b(const __m256i v)
{
__m256i tmp = reverse_16x16b_in_lanes(v);
return _mm256_permute4x64_epi64(tmp, _MM_SHUFFLE(1, 0, 3, 2));
}
static __m256i m256_from_2xm128(const __m128i lo, const __m128i hi)
{
__m256i result = _mm256_castsi128_si256 (lo);
return _mm256_inserti128_si256(result, hi, 1);
}
// Get a vector consisting of the divisible-by-8 coeffs in DCT's first two
// columns, in order:
// DC00 DC00 DC00 DC00 DC10 DC10 DC10 DC10 | DC08 DC08 DC08 DC08 DC18 DC18 DC18 DC18
static __m256i get_dct_db8_vec(const int16_t *dct_t, uint32_t offset)
{
const __m256i reorder_mask = _mm256_setr_epi32(0x01000100, 0x01000100, 0x05040504, 0x05040504,
0x03020302, 0x03020302, 0x07060706, 0x07060706);
uint16_t coeff00 = (uint16_t)dct_t[offset + 0];
uint16_t coeff08 = (uint16_t)dct_t[offset + 8];
uint16_t coeff10 = (uint16_t)dct_t[offset + 32];
uint16_t coeff18 = (uint16_t)dct_t[offset + 40];
uint64_t col_db8_packed = (((uint64_t)coeff00) << 0) |
(((uint64_t)coeff08) << 16) |
(((uint64_t)coeff10) << 32) |
(((uint64_t)coeff18) << 48);
__m256i col_db8 = _mm256_set1_epi64x (col_db8_packed);
col_db8 = _mm256_shuffle_epi8(col_db8, reorder_mask);
__m256i col_db8_shifted = _mm256_slli_epi16 (col_db8, 8);
return _mm256_blend_epi32 (col_db8, col_db8_shifted, 0xaa);
}
// Get first 4 DCT coeffs from DCT rows (rowoff + 4) and (rowoff + 12),
// shifting copies of the coeffs 8 bits left to do half of the 65536-factor
// multiplication required for high words of EEO coeffs
static __m256i get_dct_db4_vec(const int16_t *dct, uint32_t rowoff)
{
ALIGNED(32) uint64_t buf[4];
uint64_t r4_0123 = *(uint64_t *)(dct + (rowoff + 0x04) * 32);
uint64_t rc_0123 = *(uint64_t *)(dct + (rowoff + 0x0c) * 32);
buf[0] = r4_0123;
buf[1] = r4_0123 << 8;
buf[2] = rc_0123;
buf[3] = rc_0123 << 8;
return _mm256_load_si256((const __m256i *)buf);
}
// Get first 8 coeffs from DCT rows (rowoff + 2) and (rowoff + 10) and return
// them in a single YMM
static __m256i get_dct_db2_vec(const int16_t *dct, uint32_t rowoff)
{
__m128i row2 = _mm_load_si128((const __m128i *)(dct + (rowoff + 2) * 32));
__m128i rowa = _mm_load_si128((const __m128i *)(dct + (rowoff + 10) * 32));
return m256_from_2xm128(row2, rowa);
}
static void partial_butterfly_32_avx2(const int16_t *src, int16_t *dst, int32_t shift)
{
const int32_t add = 1 << (shift - 1);
const __m256i debias = _mm256_set1_epi32(add);
const int16_t *dct = (const int16_t *)(&kvz_g_dct_32 [0][0]);
const int16_t *dct_t = (const int16_t *)(&kvz_g_dct_32_t[0][0]);
const __m256i ff = _mm256_set1_epi32 (-1);
const __m256i ones = _mm256_set1_epi16 ( 1);
const __m128i ff_128 = _mm256_castsi256_si128 (ff);
const __m256i lolane_smask = _mm256_inserti128_si256(ones, ff_128, 0);
const __m256i hilane_mask = _mm256_cmpeq_epi16 (ones, lolane_smask);
const __m256i eee_lohi_shuf = _mm256_setr_epi32(0x01000706, 0x09080f0e, 0x03020504, 0x0b0a0d0c,
0x01000706, 0x09080f0e, 0x03020504, 0x0b0a0d0c);
const __m256i eee_eee_sgnmask = _mm256_setr_epi32(0x00010001, 0x00010001, 0x00010001, 0x00010001,
0xffff0001, 0xffff0001, 0xffff0001, 0xffff0001);
const __m256i dct_c0_db8 = get_dct_db8_vec(dct_t, 0);
const __m256i dct_c1_db8 = get_dct_db8_vec(dct_t, 16);
const __m256i dct_r4c_db4[2] = {
get_dct_db4_vec(dct, 0x00),
get_dct_db4_vec(dct, 0x10),
};
const __m256i dct_r_db2[4] = {
get_dct_db2_vec(dct, 0),
get_dct_db2_vec(dct, 4),
get_dct_db2_vec(dct, 16),
get_dct_db2_vec(dct, 20),
};
__m256i res_tp[2 * 32];
for (uint32_t i = 0; i < 32; i++) {
__m256i lo = _mm256_load_si256((const __m256i *)src + 2 * i + 0);
__m256i hi = _mm256_load_si256((const __m256i *)src + 2 * i + 1);
__m256i hi_rev = reverse_16x16b(hi);
__m256i e_lo, e_hi, o_lo, o_hi;
add_16_16_hilo(lo, hi_rev, &e_lo, &e_hi);
sub_16_16_hilo(lo, hi_rev, &o_lo, &o_hi);
__m256i erev_lo = reverse_16x16b(e_lo);
__m256i erev_hi = reverse_16x16b(e_hi);
// Hack! Negate low lanes to do subtractions there, but retain non-negated
// low lanes for overflow detection because the original value is
// essentially what was subtracted there by adding its complement
__m256i erev_lo_n = _mm256_sign_epi16(erev_lo, lolane_smask);
__m256i erev_hi_n = _mm256_sign_epi16(erev_hi, lolane_smask);
// eo0 eo1 eo2 eo3 eo4 eo5 eo6 eo7 | ee7 ee6 ee5 ee4 ee3 ee2 ee1 ee0
__m256i eo_ee_lo = _mm256_add_epi16 (e_lo, erev_lo_n);
__m256i eo_ee_hi = _mm256_add_epi16 (e_hi, erev_hi_n);
__m256i eo_ee_hi2 = get_overflows (e_lo, erev_lo, eo_ee_lo, hilane_mask);
eo_ee_hi = _mm256_add_epi16 (eo_ee_hi, eo_ee_hi2);
__m256i eo_eo_lo = _mm256_permute4x64_epi64(eo_ee_lo, _MM_SHUFFLE(1, 0, 1, 0));
__m256i eo_eo_hi = _mm256_permute4x64_epi64(eo_ee_hi, _MM_SHUFFLE(1, 0, 1, 0));
// Rev: ee7 ee6 ee5 ee4 ee7 ee6 ee5 ee4 | ee3 ee2 ee1 ee0 ee3 ee2 ee1 ee0
// Fwd: ee0 ee1 ee2 ee3 ee0 ee1 ee2 ee3 | ee4 ee5 ee6 ee7 ee4 ee5 ee6 ee7
__m256i ee_ee_rev_lo = _mm256_permute4x64_epi64(eo_ee_lo, _MM_SHUFFLE(3, 3, 2, 2));
__m256i ee_ee_lo = _mm256_permute4x64_epi64(eo_ee_lo, _MM_SHUFFLE(2, 2, 3, 3));
ee_ee_lo = reverse_16x16b_in_lanes (ee_ee_lo);
__m256i ee_ee_rev_hi = _mm256_permute4x64_epi64(eo_ee_hi, _MM_SHUFFLE(3, 3, 2, 2));
__m256i ee_ee_hi = _mm256_permute4x64_epi64(eo_ee_hi, _MM_SHUFFLE(2, 2, 3, 3));
ee_ee_hi = reverse_16x16b_in_lanes (ee_ee_hi);
__m256i ee_ee_rev_lo_n = _mm256_sign_epi16(ee_ee_rev_lo, lolane_smask);
__m256i ee_ee_rev_hi_n = _mm256_sign_epi16(ee_ee_rev_hi, lolane_smask);
// eeo0 eeo1 eeo2 eeo3 eeo0 eeo1 eeo2 eeo3 | eee3 eee2 eee1 eee0 eee3 eee2 eee1 eee0
__m256i eeo_eee_lo = _mm256_add_epi16 (ee_ee_lo, ee_ee_rev_lo_n);
__m256i eeo_eee_hi = _mm256_add_epi16 (ee_ee_hi, ee_ee_rev_hi_n);
__m256i eeo_eee_hi2 = get_overflows(ee_ee_lo, ee_ee_rev_lo, eeo_eee_lo, hilane_mask);
eeo_eee_hi = _mm256_add_epi16(eeo_eee_hi, eeo_eee_hi2);
// Multiply these guys by 256 and also the corresponding DCT coefficients
// by 256, to multiply their product by 65536. Neither these values nor
// the coeffs will exceed 255, so this is overflow safe.
__m256i eeo_eee_hi_shed = _mm256_slli_epi16(eeo_eee_hi, 8);
// Discard eeo's (low lane), duplicate the high lane, and reorder eee's
// in-lane. Finally invert the eee2/eee3 components in the high lane.
__m256i eee_lohi = _mm256_blend_epi32(eeo_eee_lo, eeo_eee_hi_shed, 0xc0);
eee_lohi = _mm256_permute4x64_epi64(eee_lohi, _MM_SHUFFLE(3, 2, 3, 2));
__m256i eee_lh_ordered = _mm256_shuffle_epi8(eee_lohi, eee_lohi_shuf);
__m256i eee_lh_final = _mm256_sign_epi16 (eee_lh_ordered, eee_eee_sgnmask);
// D00, D08, D10 and D18 in four parts fit for hadding
__m256i d00_d08 = _mm256_madd_epi16(eee_lh_final, dct_c0_db8);
__m256i d10_d18 = _mm256_madd_epi16(eee_lh_final, dct_c1_db8);
__m256i eeo_lohi = _mm256_unpacklo_epi64 (eeo_eee_lo, eeo_eee_hi_shed);
eeo_lohi = _mm256_permute4x64_epi64(eeo_lohi, _MM_SHUFFLE(1, 0, 1, 0));
__m256i d_db4[2] = {
_mm256_madd_epi16(eeo_lohi, dct_r4c_db4[0]),
_mm256_madd_epi16(eeo_lohi, dct_r4c_db4[1]),
};
__m256i db2_parts[4];
for (uint32_t j = 0; j < 4; j++) {
const __m256i dr2a = dct_r_db2[j];
__m256i lo = _mm256_madd_epi16(dr2a, eo_eo_lo);
__m256i hi = _mm256_madd_epi16(dr2a, eo_eo_hi);
hi = _mm256_slli_epi32(hi, 16);
db2_parts[j] = _mm256_add_epi32 (lo, hi);
}
__m256i odd_parts[16];
for (uint32_t j = 0; j < 16; j++) {
__m256i drow_lo = _mm256_load_si256((const __m256i *)(dct + (j * 2 + 1) * 32));
__m256i odds_lo = _mm256_madd_epi16(o_lo, drow_lo);
__m256i odds_hi = _mm256_madd_epi16(o_hi, drow_lo);
odds_hi = _mm256_slli_epi32(odds_hi, 16);
odd_parts[j] = _mm256_add_epi32 (odds_lo, odds_hi);
}
// Rearrange odds so that parts belonging to any single one are all inside
// one lane - combine 01 | 09 ; 03 | 0b ; 05 | 0d ; 07 | 0f
// 11 | 19 ; 13 | 1b ; 15 | 1d ; 17 | 1f
__m256i odd_parts2[8];
for (uint32_t j = 0; j < 8; j++) {
// Turn 0, 1, 2, 3, 4, 5, 6, 7 into:
// 0, 1, 2, 3, 8, 9, a, b
uint32_t j_lo = j & 0x03;
uint32_t j_hi = j & 0x04;
uint32_t id_lo = j_lo | (j_hi << 1);
uint32_t id_hi = id_lo | 4;
__m256i odd_lo = _mm256_permute2x128_si256(odd_parts[id_lo],
odd_parts[id_hi],
0x20);
__m256i odd_hi = _mm256_permute2x128_si256(odd_parts[id_lo],
odd_parts[id_hi],
0x31);
odd_parts2[j] = _mm256_add_epi32(odd_lo, odd_hi);
}
// First stage HADDs...
__m256i d0001_0809 = _mm256_hadd_epi32(d00_d08, odd_parts2[0]);
__m256i d1011_1819 = _mm256_hadd_epi32(d10_d18, odd_parts2[4]);
__m256i d0405_0c0d = _mm256_hadd_epi32(d_db4[0], odd_parts2[2]);
__m256i d1415_1c1d = _mm256_hadd_epi32(d_db4[1], odd_parts2[6]);
__m256i d0203_0a0b = _mm256_hadd_epi32(db2_parts[0], odd_parts2[1]);
__m256i d0607_0e0f = _mm256_hadd_epi32(db2_parts[1], odd_parts2[3]);
__m256i d1213_1a1b = _mm256_hadd_epi32(db2_parts[2], odd_parts2[5]);
__m256i d1617_1e1f = _mm256_hadd_epi32(db2_parts[3], odd_parts2[7]);
// .. and second stage
__m256i d0123_89ab_lo = _mm256_hadd_epi32(d0001_0809, d0203_0a0b);
__m256i d0123_89ab_hi = _mm256_hadd_epi32(d1011_1819, d1213_1a1b);
__m256i d4567_cdef_lo = _mm256_hadd_epi32(d0405_0c0d, d0607_0e0f);
__m256i d4567_cdef_hi = _mm256_hadd_epi32(d1415_1c1d, d1617_1e1f);
d0123_89ab_lo = truncate(d0123_89ab_lo, debias, shift);
d0123_89ab_hi = truncate(d0123_89ab_hi, debias, shift);
d4567_cdef_lo = truncate(d4567_cdef_lo, debias, shift);
d4567_cdef_hi = truncate(d4567_cdef_hi, debias, shift);
__m256i final_lo = _mm256_packs_epi32(d0123_89ab_lo, d4567_cdef_lo);
__m256i final_hi = _mm256_packs_epi32(d0123_89ab_hi, d4567_cdef_hi);
_mm256_store_si256(res_tp + (i * 2) + 0, final_lo);
_mm256_store_si256(res_tp + (i * 2) + 1, final_hi);
}
transpose_32x32((const int16_t *)res_tp, dst);
}
// 32x32 matrix multiplication with value clipping.
// Parameters: Two 32x32 matrices containing 16-bit values in consecutive addresses,
// destination for the result and the shift value for clipping.
@ -923,6 +1248,16 @@ static void mul_clip_matrix_32x32_avx2(const int16_t *left, const int16_t *right
}
}
static void matrix_dct_32x32_avx2(int8_t bitdepth, const int16_t *input, int16_t *output)
{
int32_t shift_1st = kvz_g_convert_to_bit[32] + 1 + (bitdepth - 8);
int32_t shift_2nd = kvz_g_convert_to_bit[32] + 8;
ALIGNED(64) int16_t tmp[32 * 32];
partial_butterfly_32_avx2(input, tmp, shift_1st);
partial_butterfly_32_avx2(tmp, output, shift_2nd);
}
// Macro that generates 2D transform functions with clipping values.
// Sets correct shift values and matrices according to transform type and
// block size. Performs matrix multiplication horizontally and vertically.
@ -964,8 +1299,9 @@ static void matrix_i ## type ## _## n ## x ## n ## _avx2(int8_t bitdepth, const
// TRANSFORM(dct, 16);
// ITRANSFORM(dct, 16);
// TRANSFORM(dct, 32);
// Generate all the transform functions
TRANSFORM(dct, 32);
ITRANSFORM(dct, 32);