Remove totally unnecessary (A * B^T)^T 32x32 multiply

This commit is contained in:
Pauli Oikkonen 2019-10-15 15:18:42 +03:00
parent 043f53539f
commit a58608d0b8

View file

@ -847,85 +847,10 @@ static void matrix_dct_16x16_avx2(int8_t bitdepth, const int16_t *input, int16_t
matmul_16x16_a_bt (dct, tmpres, output, shift_2nd);
}
void matmul_32x32_a_bt_t(const __m256i *a, const __m256i *b_t, __m256i *dst, const uint8_t shift)
{
const int32_t add = 1 << (shift - 1);
const __m256i debias = _mm256_set1_epi32(add);
uint32_t i, j;
for (j = 0; j < 32; j++) {
uint32_t jd = j << 1;
__m256i bt_lo = b_t[jd + 0];
__m256i bt_hi = b_t[jd + 1];
// Each lane here is the 4 i32's that add up to one number in the result
// vector:
// 00 | 08
// 01 | 09
// ...
// 07 | 0f
// 10 | 18
// 11 | 19
// ...
// 17 | 1f
__m256i rp1[16];
for (i = 0; i < 16; i++) {
// Loop in order:
// 0, 1, 2, 3, 4, 5, 6, 7,
// 16, 17, 18, 19, 20, 21, 22, 23
uint32_t id = (i & 7) | ((i & ~7) << 1);
uint32_t off = id << 1;
__m256i a0_lo = a[off + 0];
__m256i a0_hi = a[off + 1];
__m256i a8_lo = a[off + 16];
__m256i a8_hi = a[off + 17];
__m256i pr0_lo = _mm256_madd_epi16(bt_lo, a0_lo);
__m256i pr0_hi = _mm256_madd_epi16(bt_hi, a0_hi);
__m256i pr8_lo = _mm256_madd_epi16(bt_lo, a8_lo);
__m256i pr8_hi = _mm256_madd_epi16(bt_hi, a8_hi);
__m256i sum0 = _mm256_add_epi32 (pr0_lo, pr0_hi);
__m256i sum8 = _mm256_add_epi32 (pr8_lo, pr8_hi);
// Arrange all parts for one number to be inside one lane
__m256i s08_lo = _mm256_permute2x128_si256(sum0, sum8, 0x20);
__m256i s08_hi = _mm256_permute2x128_si256(sum0, sum8, 0x31);
rp1[i] = _mm256_add_epi32(s08_lo, s08_hi);
}
// 00 00 01 01 | 08 08 09 09
// 02 02 03 03 | 0a 0a 0b 0b
// 04 04 05 05 | 0c 0c 0d 0d
// 06 06 07 07 | 0e 0e 0f 0f
// 10 10 11 11 | 18 18 19 19
// ...
// 16 16 17 17 | 1e 1e 1f 1f
__m256i rp2[8];
for (i = 0; i < 8; i++) {
uint32_t id = i << 1;
rp2[i] = _mm256_hadd_epi32(rp1[id + 0], rp1[id + 1]);
}
// 00 01 02 03 | 08 09 0a 0b
// 04 05 06 07 | 0c 0d 0e 0f
// 10 11 12 13 | 18 19 1a 1b
// 14 15 16 17 | 1c 1d 1e 1f
__m256i rp3[4];
for (i = 0; i < 4; i++) {
uint32_t id = i << 1;
__m256i finals = _mm256_hadd_epi32(rp2[id + 0], rp2[id + 1]);
rp3[i] = truncate(finals, debias, shift);
}
dst[jd + 0] = _mm256_packs_epi32(rp3[0], rp3[1]);
dst[jd + 1] = _mm256_packs_epi32(rp3[2], rp3[3]);
}
}
void matmul_32x32_a_bt(const __m256i *a, const __m256i *b_t, __m256i *dst, const uint8_t shift)
static void matmul_32x32_a_bt(const __m256i *__restrict a,
const __m256i *__restrict b_t,
__m256i *__restrict dst,
const uint8_t shift)
{
const int32_t add = 1 << (shift - 1);
const __m256i debias = _mm256_set1_epi32(add);
@ -1115,8 +1040,10 @@ static void matrix_dct_32x32_avx2(int8_t bitdepth, const int16_t *input, int16_t
*/
__m256i tmp[2 * 32];
matmul_32x32_a_bt_t(inp, dct, tmp, shift_1st);
matmul_32x32_a_bt (dct, tmp, out, shift_2nd);
// Hack! (A * B^T)^T = B * A^T, so we can dispatch the transpose-produciong
// multiply completely
matmul_32x32_a_bt(dct, inp, tmp, shift_1st);
matmul_32x32_a_bt(dct, tmp, out, shift_2nd);
}
// Macro that generates 2D inverse transform functions with clipping values.