Tidy the old AVX2 32x32 matrix multiply

It was actually a very good algorithm, just looked messy!
This commit is contained in:
Pauli Oikkonen 2019-10-25 16:19:35 +03:00
parent 4a921cbdb5
commit 98ad78b333

View file

@ -847,244 +847,111 @@ static void matrix_dct_16x16_avx2(int8_t bitdepth, const int16_t *input, int16_t
matmul_16x16_a_bt (dct, tmpres, output, shift_2nd);
}
static void matmul_32x32_a_bt(const __m256i *__restrict a,
const __m256i *__restrict b_t,
__m256i *__restrict dst,
const uint8_t shift)
{
const int32_t add = 1 << (shift - 1);
const __m256i debias = _mm256_set1_epi32(add);
uint32_t i, j;
for (j = 0; j < 32; j++) {
uint32_t jd = j << 1;
__m256i a_lo = a[jd + 0];
__m256i a_hi = a[jd + 1];
__m256i res_32[4];
for (i = 0; i < 16; i += 4) {
// Loop in order:
// 0, 1, 2, 3, 4, 5, 6, 7,
// 16, 17, 18, 19, 20, 21, 22, 23
uint32_t id = (i & 7) | ((i & ~7) << 1);
uint32_t off = id << 1;
__m256i bt0_lo = b_t[off + 0];
__m256i bt0_hi = b_t[off + 1];
__m256i bt1_lo = b_t[off + 2];
__m256i bt1_hi = b_t[off + 3];
__m256i bt2_lo = b_t[off + 4];
__m256i bt2_hi = b_t[off + 5];
__m256i bt3_lo = b_t[off + 6];
__m256i bt3_hi = b_t[off + 7];
__m256i bt8_lo = b_t[off + 16];
__m256i bt8_hi = b_t[off + 17];
__m256i bt9_lo = b_t[off + 18];
__m256i bt9_hi = b_t[off + 19];
__m256i bta_lo = b_t[off + 20];
__m256i bta_hi = b_t[off + 21];
__m256i btb_lo = b_t[off + 22];
__m256i btb_hi = b_t[off + 23];
__m256i pr0_lo = _mm256_madd_epi16(a_lo, bt0_lo);
__m256i pr0_hi = _mm256_madd_epi16(a_hi, bt0_hi);
__m256i pr8_lo = _mm256_madd_epi16(a_lo, bt8_lo);
__m256i pr8_hi = _mm256_madd_epi16(a_hi, bt8_hi);
__m256i sum0 = _mm256_add_epi32 (pr0_lo, pr0_hi);
__m256i sum8 = _mm256_add_epi32 (pr8_lo, pr8_hi);
__m256i pr1_lo = _mm256_madd_epi16(a_lo, bt1_lo);
__m256i pr1_hi = _mm256_madd_epi16(a_hi, bt1_hi);
__m256i pr9_lo = _mm256_madd_epi16(a_lo, bt9_lo);
__m256i pr9_hi = _mm256_madd_epi16(a_hi, bt9_hi);
__m256i sum1 = _mm256_add_epi32 (pr1_lo, pr1_hi);
__m256i sum9 = _mm256_add_epi32 (pr9_lo, pr9_hi);
__m256i pr2_lo = _mm256_madd_epi16(a_lo, bt2_lo);
__m256i pr2_hi = _mm256_madd_epi16(a_hi, bt2_hi);
__m256i pra_lo = _mm256_madd_epi16(a_lo, bta_lo);
__m256i pra_hi = _mm256_madd_epi16(a_hi, bta_hi);
__m256i sum2 = _mm256_add_epi32 (pr2_lo, pr2_hi);
__m256i suma = _mm256_add_epi32 (pra_lo, pra_hi);
__m256i pr3_lo = _mm256_madd_epi16(a_lo, bt3_lo);
__m256i pr3_hi = _mm256_madd_epi16(a_hi, bt3_hi);
__m256i prb_lo = _mm256_madd_epi16(a_lo, btb_lo);
__m256i prb_hi = _mm256_madd_epi16(a_hi, btb_hi);
__m256i sum3 = _mm256_add_epi32 (pr3_lo, pr3_hi);
__m256i sumb = _mm256_add_epi32 (prb_lo, prb_hi);
// Arrange all parts for one number to be inside one lane
__m256i s08_lo = _mm256_permute2x128_si256(sum0, sum8, 0x20);
__m256i s08_hi = _mm256_permute2x128_si256(sum0, sum8, 0x31);
__m256i s19_lo = _mm256_permute2x128_si256(sum1, sum9, 0x20);
__m256i s19_hi = _mm256_permute2x128_si256(sum1, sum9, 0x31);
__m256i s2a_lo = _mm256_permute2x128_si256(sum2, suma, 0x20);
__m256i s2a_hi = _mm256_permute2x128_si256(sum2, suma, 0x31);
__m256i s3b_lo = _mm256_permute2x128_si256(sum3, sumb, 0x20);
__m256i s3b_hi = _mm256_permute2x128_si256(sum3, sumb, 0x31);
// Each lane here is the 4 i32's that add up to one number in the result
// vector:
// 00 | 08
// 01 | 09
// ...
// 07 | 0f
// 10 | 18
// 11 | 19
// ...
// 17 | 1f
__m256i rp1_0 = _mm256_add_epi32(s08_lo, s08_hi);
__m256i rp1_1 = _mm256_add_epi32(s19_lo, s19_hi);
__m256i rp1_2 = _mm256_add_epi32(s2a_lo, s2a_hi);
__m256i rp1_3 = _mm256_add_epi32(s3b_lo, s3b_hi);
// rp2_0 and rp2_1 values on different iterations
// 00 00 01 01 | 08 08 09 09
// 02 02 03 03 | 0a 0a 0b 0b
//
// 04 04 05 05 | 0c 0c 0d 0d
// 06 06 07 07 | 0e 0e 0f 0f
//
// 10 10 11 11 | 18 18 19 19
// 12 12 13 13 | 1a 1a 1b 1b
//
// 16 16 17 17 | 1e 1e 1f 1f
// 16 16 17 17 | 1e 1e 1f 1f
__m256i rp2_0 = _mm256_hadd_epi32(rp1_0, rp1_1);
__m256i rp2_1 = _mm256_hadd_epi32(rp1_2, rp1_3);
// rp3 values on different iters
// 00 01 02 03 | 08 09 0a 0b
// 04 05 06 07 | 0c 0d 0e 0f
// 10 11 12 13 | 18 19 1a 1b
// 14 15 16 17 | 1c 1d 1e 1f
__m256i rp3 = _mm256_hadd_epi32(rp2_0, rp2_1);
res_32[i >> 2] = truncate(rp3, debias, shift);
}
dst[jd + 0] = _mm256_packs_epi32(res_32[0], res_32[1]);
dst[jd + 1] = _mm256_packs_epi32(res_32[2], res_32[3]);
}
}
// 32x32 matrix multiplication with value clipping.
// Parameters: Two 32x32 matrices containing 16-bit values in consecutive addresses,
// destination for the result and the shift value for clipping.
static void mul_clip_matrix_32x32_avx2(const int16_t *left, const int16_t *right, int16_t *dst, const int32_t shift)
static void mul_clip_matrix_32x32_avx2(const int16_t *left,
const int16_t *right,
int16_t *dst,
const int32_t shift)
{
int i, j;
__m256i row[4], tmp[2], accu[32][4], even, odd;
const __m256i zero = _mm256_setzero_si256();
const int32_t stride = 16;
const int32_t add = 1 << (shift - 1);
const __m256i debias = _mm256_set1_epi32(add);
const int32_t add = 1 << (shift - 1);
const uint32_t *l_32 = (const uint32_t *)left;
const __m256i *r_v = (const __m256i *)right;
__m256i *dst_v = ( __m256i *)dst;
row[0] = _mm256_loadu_si256((__m256i*) right);
row[1] = _mm256_loadu_si256((__m256i*) right + 2);
tmp[0] = _mm256_unpacklo_epi16(row[0], row[1]);
tmp[1] = _mm256_unpackhi_epi16(row[0], row[1]);
row[0] = _mm256_permute2x128_si256(tmp[0], tmp[1], 0 + 32);
row[1] = _mm256_permute2x128_si256(tmp[0], tmp[1], 1 + 48);
size_t i, j;
row[2] = _mm256_loadu_si256((__m256i*) right + 1);
row[3] = _mm256_loadu_si256((__m256i*) right + 3);
tmp[0] = _mm256_unpacklo_epi16(row[2], row[3]);
tmp[1] = _mm256_unpackhi_epi16(row[2], row[3]);
row[2] = _mm256_permute2x128_si256(tmp[0], tmp[1], 0 + 32);
row[3] = _mm256_permute2x128_si256(tmp[0], tmp[1], 1 + 48);
for (i = 0; i < 32; i += 2) {
even = _mm256_set1_epi32(((int32_t*)left)[stride * i]);
accu[i][0] = _mm256_madd_epi16(even, row[0]);
accu[i][1] = _mm256_madd_epi16(even, row[1]);
accu[i][2] = _mm256_madd_epi16(even, row[2]);
accu[i][3] = _mm256_madd_epi16(even, row[3]);
odd = _mm256_set1_epi32(((int32_t*)left)[stride * (i + 1)]);
accu[i + 1][0] = _mm256_madd_epi16(odd, row[0]);
accu[i + 1][1] = _mm256_madd_epi16(odd, row[1]);
accu[i + 1][2] = _mm256_madd_epi16(odd, row[2]);
accu[i + 1][3] = _mm256_madd_epi16(odd, row[3]);
__m256i accu[128];
for (i = 0; i < 128; i++) {
accu[i] = zero;
}
for (j = 0; j < 64; j += 4) {
const __m256i r0 = r_v[j + 0];
const __m256i r1 = r_v[j + 1];
const __m256i r2 = r_v[j + 2];
const __m256i r3 = r_v[j + 3];
for (j = 4; j < 64; j += 4) {
__m256i r02l = _mm256_unpacklo_epi16(r0, r2);
__m256i r02h = _mm256_unpackhi_epi16(r0, r2);
__m256i r13l = _mm256_unpacklo_epi16(r1, r3);
__m256i r13h = _mm256_unpackhi_epi16(r1, r3);
row[0] = _mm256_loadu_si256((__m256i*)right + j);
row[1] = _mm256_loadu_si256((__m256i*)right + j + 2);
tmp[0] = _mm256_unpacklo_epi16(row[0], row[1]);
tmp[1] = _mm256_unpackhi_epi16(row[0], row[1]);
row[0] = _mm256_permute2x128_si256(tmp[0], tmp[1], 0 + 32);
row[1] = _mm256_permute2x128_si256(tmp[0], tmp[1], 1 + 48);
__m256i r0s = _mm256_permute2x128_si256(r02l, r02h, 0x20);
__m256i r1s = _mm256_permute2x128_si256(r02l, r02h, 0x31);
row[2] = _mm256_loadu_si256((__m256i*) right + j + 1);
row[3] = _mm256_loadu_si256((__m256i*) right + j + 3);
tmp[0] = _mm256_unpacklo_epi16(row[2], row[3]);
tmp[1] = _mm256_unpackhi_epi16(row[2], row[3]);
row[2] = _mm256_permute2x128_si256(tmp[0], tmp[1], 0 + 32);
row[3] = _mm256_permute2x128_si256(tmp[0], tmp[1], 1 + 48);
__m256i r2s = _mm256_permute2x128_si256(r13l, r13h, 0x20);
__m256i r3s = _mm256_permute2x128_si256(r13l, r13h, 0x31);
for (i = 0; i < 32; i += 2) {
size_t acc_base = i << 2;
even = _mm256_set1_epi32(((int32_t*)left)[stride * i + j / 4]);
accu[i][0] = _mm256_add_epi32(accu[i][0], _mm256_madd_epi16(even, row[0]));
accu[i][1] = _mm256_add_epi32(accu[i][1], _mm256_madd_epi16(even, row[1]));
accu[i][2] = _mm256_add_epi32(accu[i][2], _mm256_madd_epi16(even, row[2]));
accu[i][3] = _mm256_add_epi32(accu[i][3], _mm256_madd_epi16(even, row[3]));
uint32_t curr_e = l_32[(i + 0) * (32 / 2) + (j >> 2)];
uint32_t curr_o = l_32[(i + 1) * (32 / 2) + (j >> 2)];
odd = _mm256_set1_epi32(((int32_t*)left)[stride * (i + 1) + j / 4]);
accu[i + 1][0] = _mm256_add_epi32(accu[i + 1][0], _mm256_madd_epi16(odd, row[0]));
accu[i + 1][1] = _mm256_add_epi32(accu[i + 1][1], _mm256_madd_epi16(odd, row[1]));
accu[i + 1][2] = _mm256_add_epi32(accu[i + 1][2], _mm256_madd_epi16(odd, row[2]));
accu[i + 1][3] = _mm256_add_epi32(accu[i + 1][3], _mm256_madd_epi16(odd, row[3]));
__m256i even = _mm256_set1_epi32(curr_e);
__m256i odd = _mm256_set1_epi32(curr_o);
__m256i p_e0 = _mm256_madd_epi16(even, r0s);
__m256i p_e1 = _mm256_madd_epi16(even, r1s);
__m256i p_e2 = _mm256_madd_epi16(even, r2s);
__m256i p_e3 = _mm256_madd_epi16(even, r3s);
__m256i p_o0 = _mm256_madd_epi16(odd, r0s);
__m256i p_o1 = _mm256_madd_epi16(odd, r1s);
__m256i p_o2 = _mm256_madd_epi16(odd, r2s);
__m256i p_o3 = _mm256_madd_epi16(odd, r3s);
accu[acc_base + 0] = _mm256_add_epi32(p_e0, accu[acc_base + 0]);
accu[acc_base + 1] = _mm256_add_epi32(p_e1, accu[acc_base + 1]);
accu[acc_base + 2] = _mm256_add_epi32(p_e2, accu[acc_base + 2]);
accu[acc_base + 3] = _mm256_add_epi32(p_e3, accu[acc_base + 3]);
accu[acc_base + 4] = _mm256_add_epi32(p_o0, accu[acc_base + 4]);
accu[acc_base + 5] = _mm256_add_epi32(p_o1, accu[acc_base + 5]);
accu[acc_base + 6] = _mm256_add_epi32(p_o2, accu[acc_base + 6]);
accu[acc_base + 7] = _mm256_add_epi32(p_o3, accu[acc_base + 7]);
}
}
for (i = 0; i < 32; ++i) {
__m256i result, first_quarter, second_quarter, third_quarter, fourth_quarter;
for (i = 0; i < 32; i++) {
size_t acc_base = i << 2;
size_t dst_base = i << 1;
first_quarter = _mm256_srai_epi32(_mm256_add_epi32(accu[i][0], _mm256_set1_epi32(add)), shift);
second_quarter = _mm256_srai_epi32(_mm256_add_epi32(accu[i][1], _mm256_set1_epi32(add)), shift);
third_quarter = _mm256_srai_epi32(_mm256_add_epi32(accu[i][2], _mm256_set1_epi32(add)), shift);
fourth_quarter = _mm256_srai_epi32(_mm256_add_epi32(accu[i][3], _mm256_set1_epi32(add)), shift);
result = _mm256_permute4x64_epi64(_mm256_packs_epi32(first_quarter, second_quarter), 0 + 8 + 16 + 192);
_mm256_storeu_si256((__m256i*)dst + 2 * i, result);
result = _mm256_permute4x64_epi64(_mm256_packs_epi32(third_quarter, fourth_quarter), 0 + 8 + 16 + 192);
_mm256_storeu_si256((__m256i*)dst + 2 * i + 1, result);
__m256i q0 = truncate(accu[acc_base + 0], debias, shift);
__m256i q1 = truncate(accu[acc_base + 1], debias, shift);
__m256i q2 = truncate(accu[acc_base + 2], debias, shift);
__m256i q3 = truncate(accu[acc_base + 3], debias, shift);
__m256i h01 = _mm256_packs_epi32(q0, q1);
__m256i h23 = _mm256_packs_epi32(q2, q3);
h01 = _mm256_permute4x64_epi64(h01, _MM_SHUFFLE(3, 1, 2, 0));
h23 = _mm256_permute4x64_epi64(h23, _MM_SHUFFLE(3, 1, 2, 0));
_mm256_store_si256(dst_v + dst_base + 0, h01);
_mm256_store_si256(dst_v + dst_base + 1, h23);
}
}
static void matrix_dct_32x32_avx2(int8_t bitdepth, const int16_t *input, int16_t *output)
{
int32_t shift_1st = kvz_g_convert_to_bit[32] + 1 + (bitdepth - 8);
int32_t shift_2nd = kvz_g_convert_to_bit[32] + 8;
const __m256i *dct = (const __m256i *)(&kvz_g_dct_32[0][0]);
const __m256i *inp = (const __m256i *)input;
__m256i *out = ( __m256i *)output;
/*
* Multiply input by the tranpose of DCT matrix into tmpres, and DCT matrix
* by tmpres - this is then our output matrix
*
* It's easier to implement an AVX2 matrix multiplication if you can multiply
* the left term with the transpose of the right term. Here things are stored
* row-wise, not column-wise, so we can effectively read DCT_T column-wise
* into YMM registers by reading DCT row-wise. Also because of this, the
* first multiplication is hacked to produce the transpose of the result
* instead, since it will be used in similar fashion as the right operand
* in the second multiplication.
*/
__m256i tmp[2 * 32];
// Hack! (A * B^T)^T = B * A^T, so we can dispatch the transpose-produciong
// multiply completely
matmul_32x32_a_bt(dct, inp, tmp, shift_1st);
matmul_32x32_a_bt(dct, tmp, out, shift_2nd);
}
// Macro that generates 2D transform functions with clipping values.
// Sets correct shift values and matrices according to transform type and
// block size. Performs matrix multiplication horizontally and vertically.
#define TRANSFORM(type, n) static void matrix_ ## type ## _ ## n ## x ## n ## _avx2(int8_t bitdepth, const int16_t *input, int16_t *output)\
{\
int32_t shift_1st = kvz_g_convert_to_bit[n] + 1 + (bitdepth - 8); \
int32_t shift_2nd = kvz_g_convert_to_bit[n] + 8; \
ALIGNED(64) int16_t tmp[n * n];\
const int16_t *tdct = &kvz_g_ ## type ## _ ## n ## _t[0][0];\
const int16_t *dct = &kvz_g_ ## type ## _ ## n [0][0];\
\
mul_clip_matrix_ ## n ## x ## n ## _avx2(input, tdct, tmp, shift_1st);\
mul_clip_matrix_ ## n ## x ## n ## _avx2(dct, tmp, output, shift_2nd);\
}\
// Macro that generates 2D inverse transform functions with clipping values.
// Sets correct shift values and matrices according to transform type and
@ -1111,10 +978,10 @@ static void matrix_i ## type ## _## n ## x ## n ## _avx2(int8_t bitdepth, const
// ITRANSFORM(dct, 8);
// TRANSFORM(dct, 16);
// ITRANSFORM(dct, 16);
// TRANSFORM(dct, 32);
// Generate all the transform functions
TRANSFORM(dct, 32);
ITRANSFORM(dct, 32);
#endif //COMPILE_INTEL_AVX2