Implement a streamlined matrix-multiply 32x32 DCT

This commit is contained in:
Pauli Oikkonen 2019-10-10 18:32:39 +03:00
parent e9da2d851b
commit 043f53539f

View file

@ -605,14 +605,6 @@ static void transpose_16x16(const int16_t *src, int16_t *dst)
transpose_16x16_stride(src, dst, 0, 0);
}
static void transpose_32x32(const int16_t *src, int16_t *dst)
{
transpose_16x16_stride(src + 0, dst + 0, 1, 1);
transpose_16x16_stride(src + 16, dst + 16 * 32, 1, 1);
transpose_16x16_stride(src + 16 * 32, dst + 16, 1, 1);
transpose_16x16_stride(src + 16 * 33, dst + 16 * 33, 1, 1);
}
static __m256i truncate_inv(__m256i v, int32_t shift)
{
int32_t add = 1 << (shift - 1);
@ -855,313 +847,160 @@ static void matrix_dct_16x16_avx2(int8_t bitdepth, const int16_t *input, int16_t
matmul_16x16_a_bt (dct, tmpres, output, shift_2nd);
}
static __m256i get_overflows(const __m256i a,
const __m256i b,
const __m256i res,
const __m256i of_adjust_mask)
{
const __m256i ones = _mm256_set1_epi16(1);
__m256i src_signdiff = _mm256_xor_si256 (a, b);
__m256i a_r_signdiff = _mm256_xor_si256 (a, res);
__m256i of_possible = _mm256_xor_si256 (src_signdiff, of_adjust_mask);
__m256i overflows = _mm256_and_si256 (of_possible, a_r_signdiff);
overflows = _mm256_srai_epi16 (overflows, 15);
__m256i of_signs = _mm256_srai_epi16 (a, 15);
of_signs = _mm256_or_si256 (of_signs, ones);
return _mm256_and_si256 (overflows, of_signs);
}
/*
* You need more than 16 bits to store the result of signed 16b-16b operation,
* this one stores the low 16b in lo and high 16b in hi. The high 16 bits can
* only be -1, 0 or 1.
*
* of_possible_mask is either all zero bits for subtraction, or all ones for
* addition
*/
static void sub_16_16_hilo(const __m256i a, const __m256i b,
__m256i *lo, __m256i *hi)
{
const __m256i zero = _mm256_setzero_si256();
*lo = _mm256_sub_epi16(a, b);
*hi = get_overflows(a, b, *lo, zero);
}
static void add_16_16_hilo(const __m256i a, const __m256i b,
__m256i *lo, __m256i *hi)
{
const __m256i ff = _mm256_set1_epi8(-1);
*lo = _mm256_add_epi16(a, b);
*hi = get_overflows(a, b, *lo, ff);
}
static INLINE __m256i reverse_16x16b_in_lanes(const __m256i v)
{
const __m256i lanerev = _mm256_setr_epi16(0x0f0e, 0x0d0c, 0x0b0a, 0x0908,
0x0706, 0x0504, 0x0302, 0x0100,
0x0f0e, 0x0d0c, 0x0b0a, 0x0908,
0x0706, 0x0504, 0x0302, 0x0100);
return _mm256_shuffle_epi8(v, lanerev);
}
static INLINE __m256i reverse_16x16b(const __m256i v)
{
__m256i tmp = reverse_16x16b_in_lanes(v);
return _mm256_permute4x64_epi64(tmp, _MM_SHUFFLE(1, 0, 3, 2));
}
static INLINE __m256i m256_from_2xm128(const __m128i lo, const __m128i hi)
{
__m256i result = _mm256_castsi128_si256 (lo);
return _mm256_inserti128_si256(result, hi, 1);
}
// Get a vector consisting of the divisible-by-8 coeffs in DCT's first two
// columns, in order:
// DC00 DC00 DC00 DC00 DC10 DC10 DC10 DC10 | DC08 DC08 DC08 DC08 DC18 DC18 DC18 DC18
static __m256i get_dct_db8_vec(const int16_t *dct_t, uint32_t offset)
{
const __m256i reorder_mask = _mm256_setr_epi32(0x01000100, 0x01000100, 0x05040504, 0x05040504,
0x03020302, 0x03020302, 0x07060706, 0x07060706);
uint16_t coeff00 = (uint16_t)dct_t[offset + 0];
uint16_t coeff08 = (uint16_t)dct_t[offset + 8];
uint16_t coeff10 = (uint16_t)dct_t[offset + 32];
uint16_t coeff18 = (uint16_t)dct_t[offset + 40];
uint64_t col_db8_packed = (((uint64_t)coeff00) << 0) |
(((uint64_t)coeff08) << 16) |
(((uint64_t)coeff10) << 32) |
(((uint64_t)coeff18) << 48);
__m256i col_db8 = _mm256_set1_epi64x (col_db8_packed);
col_db8 = _mm256_shuffle_epi8(col_db8, reorder_mask);
__m256i col_db8_shifted = _mm256_slli_epi16 (col_db8, 8);
return _mm256_blend_epi32 (col_db8, col_db8_shifted, 0xaa);
}
// Get first 4 DCT coeffs from DCT rows (rowoff + 4) and (rowoff + 12),
// shifting copies of the coeffs 8 bits left to do half of the 65536-factor
// multiplication required for high words of EEO coeffs
static __m256i get_dct_db4_vec(const int16_t *dct, uint32_t rowoff)
{
ALIGNED(32) uint64_t buf[4];
uint64_t r4_0123 = *(uint64_t *)(dct + (rowoff + 0x04) * 32);
uint64_t rc_0123 = *(uint64_t *)(dct + (rowoff + 0x0c) * 32);
buf[0] = r4_0123;
buf[1] = r4_0123 << 8;
buf[2] = rc_0123;
buf[3] = rc_0123 << 8;
return _mm256_load_si256((const __m256i *)buf);
}
// Get first 8 coeffs from DCT rows (rowoff + 2) and (rowoff + 10) and return
// them in a single YMM
static __m256i get_dct_db2_vec(const int16_t *dct, uint32_t rowoff)
{
__m128i row2 = _mm_load_si128((const __m128i *)(dct + (rowoff + 2) * 32));
__m128i rowa = _mm_load_si128((const __m128i *)(dct + (rowoff + 10) * 32));
return m256_from_2xm128(row2, rowa);
}
static void partial_butterfly_32_avx2(const int16_t *src, int16_t *dst, int32_t shift)
void matmul_32x32_a_bt_t(const __m256i *a, const __m256i *b_t, __m256i *dst, const uint8_t shift)
{
const int32_t add = 1 << (shift - 1);
const __m256i debias = _mm256_set1_epi32(add);
const int16_t *dct = (const int16_t *)(&kvz_g_dct_32 [0][0]);
const int16_t *dct_t = (const int16_t *)(&kvz_g_dct_32_t[0][0]);
uint32_t i, j;
for (j = 0; j < 32; j++) {
uint32_t jd = j << 1;
__m256i bt_lo = b_t[jd + 0];
__m256i bt_hi = b_t[jd + 1];
const __m256i ff = _mm256_set1_epi32 (-1);
const __m256i ones = _mm256_set1_epi16 ( 1);
const __m128i ff_128 = _mm256_castsi256_si128 (ff);
const __m256i lolane_smask = _mm256_inserti128_si256(ones, ff_128, 0);
const __m256i hilane_mask = _mm256_cmpeq_epi16 (ones, lolane_smask);
// Each lane here is the 4 i32's that add up to one number in the result
// vector:
// 00 | 08
// 01 | 09
// ...
// 07 | 0f
// 10 | 18
// 11 | 19
// ...
// 17 | 1f
__m256i rp1[16];
for (i = 0; i < 16; i++) {
// Loop in order:
// 0, 1, 2, 3, 4, 5, 6, 7,
// 16, 17, 18, 19, 20, 21, 22, 23
uint32_t id = (i & 7) | ((i & ~7) << 1);
uint32_t off = id << 1;
const __m256i eee_lohi_shuf = _mm256_setr_epi32(0x01000706, 0x09080f0e, 0x03020504, 0x0b0a0d0c,
0x01000706, 0x09080f0e, 0x03020504, 0x0b0a0d0c);
__m256i a0_lo = a[off + 0];
__m256i a0_hi = a[off + 1];
__m256i a8_lo = a[off + 16];
__m256i a8_hi = a[off + 17];
const __m256i eee_eee_sgnmask = _mm256_setr_epi32(0x00010001, 0x00010001, 0x00010001, 0x00010001,
0xffff0001, 0xffff0001, 0xffff0001, 0xffff0001);
__m256i pr0_lo = _mm256_madd_epi16(bt_lo, a0_lo);
__m256i pr0_hi = _mm256_madd_epi16(bt_hi, a0_hi);
__m256i pr8_lo = _mm256_madd_epi16(bt_lo, a8_lo);
__m256i pr8_hi = _mm256_madd_epi16(bt_hi, a8_hi);
const __m256i dct_c0_db8 = get_dct_db8_vec(dct_t, 0);
const __m256i dct_c1_db8 = get_dct_db8_vec(dct_t, 16);
__m256i sum0 = _mm256_add_epi32 (pr0_lo, pr0_hi);
__m256i sum8 = _mm256_add_epi32 (pr8_lo, pr8_hi);
const __m256i dct_r4c_db4[2] = {
get_dct_db4_vec(dct, 0x00),
get_dct_db4_vec(dct, 0x10),
};
// Arrange all parts for one number to be inside one lane
__m256i s08_lo = _mm256_permute2x128_si256(sum0, sum8, 0x20);
__m256i s08_hi = _mm256_permute2x128_si256(sum0, sum8, 0x31);
const __m256i dct_r_db2[4] = {
get_dct_db2_vec(dct, 0),
get_dct_db2_vec(dct, 4),
get_dct_db2_vec(dct, 16),
get_dct_db2_vec(dct, 20),
};
__m256i res_tp[2 * 32];
for (uint32_t i = 0; i < 32; i++) {
__m256i lo = _mm256_load_si256((const __m256i *)src + 2 * i + 0);
__m256i hi = _mm256_load_si256((const __m256i *)src + 2 * i + 1);
__m256i hi_rev = reverse_16x16b(hi);
__m256i e_lo, e_hi, o_lo, o_hi;
add_16_16_hilo(lo, hi_rev, &e_lo, &e_hi);
sub_16_16_hilo(lo, hi_rev, &o_lo, &o_hi);
__m256i erev_lo = reverse_16x16b(e_lo);
__m256i erev_hi = reverse_16x16b(e_hi);
// Hack! Negate low lanes to do subtractions there, but retain non-negated
// low lanes for overflow detection because the original value is
// essentially what was subtracted there by adding its complement
__m256i erev_lo_n = _mm256_sign_epi16(erev_lo, lolane_smask);
__m256i erev_hi_n = _mm256_sign_epi16(erev_hi, lolane_smask);
// eo0 eo1 eo2 eo3 eo4 eo5 eo6 eo7 | ee7 ee6 ee5 ee4 ee3 ee2 ee1 ee0
__m256i eo_ee_lo = _mm256_add_epi16 (e_lo, erev_lo_n);
__m256i eo_ee_hi = _mm256_add_epi16 (e_hi, erev_hi_n);
__m256i eo_ee_hi2 = get_overflows (e_lo, erev_lo, eo_ee_lo, hilane_mask);
eo_ee_hi = _mm256_add_epi16 (eo_ee_hi, eo_ee_hi2);
__m256i eo_eo_lo = _mm256_permute4x64_epi64(eo_ee_lo, _MM_SHUFFLE(1, 0, 1, 0));
__m256i eo_eo_hi = _mm256_permute4x64_epi64(eo_ee_hi, _MM_SHUFFLE(1, 0, 1, 0));
// Rev: ee7 ee6 ee5 ee4 ee7 ee6 ee5 ee4 | ee3 ee2 ee1 ee0 ee3 ee2 ee1 ee0
// Fwd: ee0 ee1 ee2 ee3 ee0 ee1 ee2 ee3 | ee4 ee5 ee6 ee7 ee4 ee5 ee6 ee7
__m256i ee_ee_rev_lo = _mm256_permute4x64_epi64(eo_ee_lo, _MM_SHUFFLE(3, 3, 2, 2));
__m256i ee_ee_lo = _mm256_permute4x64_epi64(eo_ee_lo, _MM_SHUFFLE(2, 2, 3, 3));
ee_ee_lo = reverse_16x16b_in_lanes (ee_ee_lo);
__m256i ee_ee_rev_hi = _mm256_permute4x64_epi64(eo_ee_hi, _MM_SHUFFLE(3, 3, 2, 2));
__m256i ee_ee_hi = _mm256_permute4x64_epi64(eo_ee_hi, _MM_SHUFFLE(2, 2, 3, 3));
ee_ee_hi = reverse_16x16b_in_lanes (ee_ee_hi);
__m256i ee_ee_rev_lo_n = _mm256_sign_epi16(ee_ee_rev_lo, lolane_smask);
__m256i ee_ee_rev_hi_n = _mm256_sign_epi16(ee_ee_rev_hi, lolane_smask);
// eeo0 eeo1 eeo2 eeo3 eeo0 eeo1 eeo2 eeo3 | eee3 eee2 eee1 eee0 eee3 eee2 eee1 eee0
__m256i eeo_eee_lo = _mm256_add_epi16 (ee_ee_lo, ee_ee_rev_lo_n);
__m256i eeo_eee_hi = _mm256_add_epi16 (ee_ee_hi, ee_ee_rev_hi_n);
__m256i eeo_eee_hi2 = get_overflows(ee_ee_lo, ee_ee_rev_lo, eeo_eee_lo, hilane_mask);
eeo_eee_hi = _mm256_add_epi16(eeo_eee_hi, eeo_eee_hi2);
// Multiply these guys by 256 and also the corresponding DCT coefficients
// by 256, to multiply their product by 65536. Neither these values nor
// the coeffs will exceed 255, so this is overflow safe.
__m256i eeo_eee_hi_shed = _mm256_slli_epi16(eeo_eee_hi, 8);
// Discard eeo's (low lane), duplicate the high lane, and reorder eee's
// in-lane. Finally invert the eee2/eee3 components in the high lane.
__m256i eee_lohi = _mm256_blend_epi32(eeo_eee_lo, eeo_eee_hi_shed, 0xc0);
eee_lohi = _mm256_permute4x64_epi64(eee_lohi, _MM_SHUFFLE(3, 2, 3, 2));
__m256i eee_lh_ordered = _mm256_shuffle_epi8(eee_lohi, eee_lohi_shuf);
__m256i eee_lh_final = _mm256_sign_epi16 (eee_lh_ordered, eee_eee_sgnmask);
// D00, D08, D10 and D18 in four parts fit for hadding
__m256i d00_d08 = _mm256_madd_epi16(eee_lh_final, dct_c0_db8);
__m256i d10_d18 = _mm256_madd_epi16(eee_lh_final, dct_c1_db8);
__m256i eeo_lohi = _mm256_unpacklo_epi64 (eeo_eee_lo, eeo_eee_hi_shed);
eeo_lohi = _mm256_permute4x64_epi64(eeo_lohi, _MM_SHUFFLE(1, 0, 1, 0));
__m256i d_db4[2] = {
_mm256_madd_epi16(eeo_lohi, dct_r4c_db4[0]),
_mm256_madd_epi16(eeo_lohi, dct_r4c_db4[1]),
};
__m256i db2_parts[4];
for (uint32_t j = 0; j < 4; j++) {
const __m256i dr2a = dct_r_db2[j];
__m256i lo = _mm256_madd_epi16(dr2a, eo_eo_lo);
__m256i hi = _mm256_madd_epi16(dr2a, eo_eo_hi);
hi = _mm256_slli_epi32(hi, 16);
db2_parts[j] = _mm256_add_epi32 (lo, hi);
rp1[i] = _mm256_add_epi32(s08_lo, s08_hi);
}
__m256i odd_parts[16];
for (uint32_t j = 0; j < 16; j++) {
__m256i drow_lo = _mm256_load_si256((const __m256i *)(dct + (j * 2 + 1) * 32));
__m256i odds_lo = _mm256_madd_epi16(o_lo, drow_lo);
__m256i odds_hi = _mm256_madd_epi16(o_hi, drow_lo);
odds_hi = _mm256_slli_epi32(odds_hi, 16);
odd_parts[j] = _mm256_add_epi32 (odds_lo, odds_hi);
// 00 00 01 01 | 08 08 09 09
// 02 02 03 03 | 0a 0a 0b 0b
// 04 04 05 05 | 0c 0c 0d 0d
// 06 06 07 07 | 0e 0e 0f 0f
// 10 10 11 11 | 18 18 19 19
// ...
// 16 16 17 17 | 1e 1e 1f 1f
__m256i rp2[8];
for (i = 0; i < 8; i++) {
uint32_t id = i << 1;
rp2[i] = _mm256_hadd_epi32(rp1[id + 0], rp1[id + 1]);
}
// Rearrange odds so that parts belonging to any single one are all inside
// one lane - combine 01 | 09 ; 03 | 0b ; 05 | 0d ; 07 | 0f
// 11 | 19 ; 13 | 1b ; 15 | 1d ; 17 | 1f
__m256i odd_parts2[8];
for (uint32_t j = 0; j < 8; j++) {
// Turn 0, 1, 2, 3, 4, 5, 6, 7 into:
// 0, 1, 2, 3, 8, 9, a, b
uint32_t j_lo = j & 0x03;
uint32_t j_hi = j & 0x04;
uint32_t id_lo = j_lo | (j_hi << 1);
uint32_t id_hi = id_lo | 4;
__m256i odd_lo = _mm256_permute2x128_si256(odd_parts[id_lo],
odd_parts[id_hi],
0x20);
__m256i odd_hi = _mm256_permute2x128_si256(odd_parts[id_lo],
odd_parts[id_hi],
0x31);
odd_parts2[j] = _mm256_add_epi32(odd_lo, odd_hi);
// 00 01 02 03 | 08 09 0a 0b
// 04 05 06 07 | 0c 0d 0e 0f
// 10 11 12 13 | 18 19 1a 1b
// 14 15 16 17 | 1c 1d 1e 1f
__m256i rp3[4];
for (i = 0; i < 4; i++) {
uint32_t id = i << 1;
__m256i finals = _mm256_hadd_epi32(rp2[id + 0], rp2[id + 1]);
rp3[i] = truncate(finals, debias, shift);
}
// First stage HADDs...
__m256i d0001_0809 = _mm256_hadd_epi32(d00_d08, odd_parts2[0]);
__m256i d1011_1819 = _mm256_hadd_epi32(d10_d18, odd_parts2[4]);
__m256i d0405_0c0d = _mm256_hadd_epi32(d_db4[0], odd_parts2[2]);
__m256i d1415_1c1d = _mm256_hadd_epi32(d_db4[1], odd_parts2[6]);
__m256i d0203_0a0b = _mm256_hadd_epi32(db2_parts[0], odd_parts2[1]);
__m256i d0607_0e0f = _mm256_hadd_epi32(db2_parts[1], odd_parts2[3]);
__m256i d1213_1a1b = _mm256_hadd_epi32(db2_parts[2], odd_parts2[5]);
__m256i d1617_1e1f = _mm256_hadd_epi32(db2_parts[3], odd_parts2[7]);
// .. and second stage
__m256i d0123_89ab_lo = _mm256_hadd_epi32(d0001_0809, d0203_0a0b);
__m256i d0123_89ab_hi = _mm256_hadd_epi32(d1011_1819, d1213_1a1b);
__m256i d4567_cdef_lo = _mm256_hadd_epi32(d0405_0c0d, d0607_0e0f);
__m256i d4567_cdef_hi = _mm256_hadd_epi32(d1415_1c1d, d1617_1e1f);
d0123_89ab_lo = truncate(d0123_89ab_lo, debias, shift);
d0123_89ab_hi = truncate(d0123_89ab_hi, debias, shift);
d4567_cdef_lo = truncate(d4567_cdef_lo, debias, shift);
d4567_cdef_hi = truncate(d4567_cdef_hi, debias, shift);
__m256i final_lo = _mm256_packs_epi32(d0123_89ab_lo, d4567_cdef_lo);
__m256i final_hi = _mm256_packs_epi32(d0123_89ab_hi, d4567_cdef_hi);
_mm256_store_si256(res_tp + (i * 2) + 0, final_lo);
_mm256_store_si256(res_tp + (i * 2) + 1, final_hi);
dst[jd + 0] = _mm256_packs_epi32(rp3[0], rp3[1]);
dst[jd + 1] = _mm256_packs_epi32(rp3[2], rp3[3]);
}
}
void matmul_32x32_a_bt(const __m256i *a, const __m256i *b_t, __m256i *dst, const uint8_t shift)
{
const int32_t add = 1 << (shift - 1);
const __m256i debias = _mm256_set1_epi32(add);
uint32_t i, j;
for (j = 0; j < 32; j++) {
uint32_t jd = j << 1;
__m256i a_lo = a[jd + 0];
__m256i a_hi = a[jd + 1];
// Each lane here is the 4 i32's that add up to one number in the result
// vector:
// 00 | 08
// 01 | 09
// ...
// 07 | 0f
// 10 | 18
// 11 | 19
// ...
// 17 | 1f
__m256i rp1[16];
for (i = 0; i < 16; i++) {
// Loop in order:
// 0, 1, 2, 3, 4, 5, 6, 7,
// 16, 17, 18, 19, 20, 21, 22, 23
uint32_t id = (i & 7) | ((i & ~7) << 1);
uint32_t off = id << 1;
__m256i bt0_lo = b_t[off + 0];
__m256i bt0_hi = b_t[off + 1];
__m256i bt8_lo = b_t[off + 16];
__m256i bt8_hi = b_t[off + 17];
__m256i pr0_lo = _mm256_madd_epi16(a_lo, bt0_lo);
__m256i pr0_hi = _mm256_madd_epi16(a_hi, bt0_hi);
__m256i pr8_lo = _mm256_madd_epi16(a_lo, bt8_lo);
__m256i pr8_hi = _mm256_madd_epi16(a_hi, bt8_hi);
__m256i sum0 = _mm256_add_epi32 (pr0_lo, pr0_hi);
__m256i sum8 = _mm256_add_epi32 (pr8_lo, pr8_hi);
// Arrange all parts for one number to be inside one lane
__m256i s08_lo = _mm256_permute2x128_si256(sum0, sum8, 0x20);
__m256i s08_hi = _mm256_permute2x128_si256(sum0, sum8, 0x31);
rp1[i] = _mm256_add_epi32(s08_lo, s08_hi);
}
// 00 00 01 01 | 08 08 09 09
// 02 02 03 03 | 0a 0a 0b 0b
// 04 04 05 05 | 0c 0c 0d 0d
// 06 06 07 07 | 0e 0e 0f 0f
// 10 10 11 11 | 18 18 19 19
// ...
// 16 16 17 17 | 1e 1e 1f 1f
__m256i rp2[8];
for (i = 0; i < 8; i++) {
uint32_t id = i << 1;
rp2[i] = _mm256_hadd_epi32(rp1[id + 0], rp1[id + 1]);
}
// 00 01 02 03 | 08 09 0a 0b
// 04 05 06 07 | 0c 0d 0e 0f
// 10 11 12 13 | 18 19 1a 1b
// 14 15 16 17 | 1c 1d 1e 1f
__m256i rp3[4];
for (i = 0; i < 4; i++) {
uint32_t id = i << 1;
__m256i finals = _mm256_hadd_epi32(rp2[id + 0], rp2[id + 1]);
rp3[i] = truncate(finals, debias, shift);
}
dst[jd + 0] = _mm256_packs_epi32(rp3[0], rp3[1]);
dst[jd + 1] = _mm256_packs_epi32(rp3[2], rp3[3]);
}
transpose_32x32((const int16_t *)res_tp, dst);
}
// 32x32 matrix multiplication with value clipping.
@ -1257,27 +1096,29 @@ static void matrix_dct_32x32_avx2(int8_t bitdepth, const int16_t *input, int16_t
{
int32_t shift_1st = kvz_g_convert_to_bit[32] + 1 + (bitdepth - 8);
int32_t shift_2nd = kvz_g_convert_to_bit[32] + 8;
ALIGNED(64) int16_t tmp[32 * 32];
partial_butterfly_32_avx2(input, tmp, shift_1st);
partial_butterfly_32_avx2(tmp, output, shift_2nd);
const __m256i *dct = (const __m256i *)(&kvz_g_dct_32[0][0]);
const __m256i *inp = (const __m256i *)input;
__m256i *out = ( __m256i *)output;
/*
* Multiply input by the tranpose of DCT matrix into tmpres, and DCT matrix
* by tmpres - this is then our output matrix
*
* It's easier to implement an AVX2 matrix multiplication if you can multiply
* the left term with the transpose of the right term. Here things are stored
* row-wise, not column-wise, so we can effectively read DCT_T column-wise
* into YMM registers by reading DCT row-wise. Also because of this, the
* first multiplication is hacked to produce the transpose of the result
* instead, since it will be used in similar fashion as the right operand
* in the second multiplication.
*/
__m256i tmp[2 * 32];
matmul_32x32_a_bt_t(inp, dct, tmp, shift_1st);
matmul_32x32_a_bt (dct, tmp, out, shift_2nd);
}
// Macro that generates 2D transform functions with clipping values.
// Sets correct shift values and matrices according to transform type and
// block size. Performs matrix multiplication horizontally and vertically.
#define TRANSFORM(type, n) static void matrix_ ## type ## _ ## n ## x ## n ## _avx2(int8_t bitdepth, const int16_t *input, int16_t *output)\
{\
int32_t shift_1st = kvz_g_convert_to_bit[n] + 1 + (bitdepth - 8); \
int32_t shift_2nd = kvz_g_convert_to_bit[n] + 8; \
ALIGNED(64) int16_t tmp[n * n];\
const int16_t *tdct = &kvz_g_ ## type ## _ ## n ## _t[0][0];\
const int16_t *dct = &kvz_g_ ## type ## _ ## n [0][0];\
\
mul_clip_matrix_ ## n ## x ## n ## _avx2(input, tdct, tmp, shift_1st);\
mul_clip_matrix_ ## n ## x ## n ## _avx2(dct, tmp, output, shift_2nd);\
}\
// Macro that generates 2D inverse transform functions with clipping values.
// Sets correct shift values and matrices according to transform type and
// block size. Performs matrix multiplication horizontally and vertically.
@ -1303,7 +1144,6 @@ static void matrix_i ## type ## _## n ## x ## n ## _avx2(int8_t bitdepth, const
// ITRANSFORM(dct, 8);
// TRANSFORM(dct, 16);
// ITRANSFORM(dct, 16);
// TRANSFORM(dct, 32);
// Generate all the transform functions