Streamline by-the-book 8x8 matrix multiplication

Also chop up the forward transform into two tailored multiply functions
This commit is contained in:
Pauli Oikkonen 2019-05-29 19:29:07 +03:00
parent 7ec7ab3361
commit 07970ea82f

View file

@ -95,65 +95,6 @@ static void mul_clip_matrix_4x4_avx2(const int16_t *left, const int16_t *right,
_mm256_storeu_si256((__m256i*)dst, result);
}
// 8x8 matrix multiplication with value clipping.
// Parameters: Two 8x8 matrices containing 16-bit values in consecutive addresses,
// destination for the result and the shift value for clipping.
//
static void mul_clip_matrix_8x8_avx2(const int16_t *left, const int16_t *right, int16_t *dst, const int32_t shift)
{
int i, j;
__m256i b[2], accu[8], even[2], odd[2];
const int32_t add = 1 << (shift - 1);
b[0] = _mm256_loadu_si256((__m256i*) right);
b[1] = _mm256_unpackhi_epi16(b[0], _mm256_castsi128_si256(_mm256_extracti128_si256(b[0], 1)));
b[0] = _mm256_unpacklo_epi16(b[0], _mm256_castsi128_si256(_mm256_extracti128_si256(b[0], 1)));
b[0] = _mm256_inserti128_si256(b[0], _mm256_castsi256_si128(b[1]), 1);
for (i = 0; i < 8; i += 2) {
even[0] = _mm256_set1_epi32(((int32_t*)left)[4 * i]);
even[0] = _mm256_madd_epi16(even[0], b[0]);
accu[i] = even[0];
odd[0] = _mm256_set1_epi32(((int32_t*)left)[4 * (i + 1)]);
odd[0] = _mm256_madd_epi16(odd[0], b[0]);
accu[i + 1] = odd[0];
}
for (j = 1; j < 4; ++j) {
b[0] = _mm256_loadu_si256((__m256i*)right + j);
b[1] = _mm256_unpackhi_epi16(b[0], _mm256_castsi128_si256(_mm256_extracti128_si256(b[0], 1)));
b[0] = _mm256_unpacklo_epi16(b[0], _mm256_castsi128_si256(_mm256_extracti128_si256(b[0], 1)));
b[0] = _mm256_inserti128_si256(b[0], _mm256_castsi256_si128(b[1]), 1);
for (i = 0; i < 8; i += 2) {
even[0] = _mm256_set1_epi32(((int32_t*)left)[4 * i + j]);
even[0] = _mm256_madd_epi16(even[0], b[0]);
accu[i] = _mm256_add_epi32(accu[i], even[0]);
odd[0] = _mm256_set1_epi32(((int32_t*)left)[4 * (i + 1) + j]);
odd[0] = _mm256_madd_epi16(odd[0], b[0]);
accu[i + 1] = _mm256_add_epi32(accu[i + 1], odd[0]);
}
}
for (i = 0; i < 8; i += 2) {
__m256i result, first_half, second_half;
first_half = _mm256_srai_epi32(_mm256_add_epi32(accu[i], _mm256_set1_epi32(add)), shift);
second_half = _mm256_srai_epi32(_mm256_add_epi32(accu[i + 1], _mm256_set1_epi32(add)), shift);
result = _mm256_permute4x64_epi64(_mm256_packs_epi32(first_half, second_half), 0 + 8 + 16 + 192);
_mm256_storeu_si256((__m256i*)dst + i / 2, result);
}
}
static INLINE __m256i swap_lanes(__m256i v)
{
return _mm256_permute4x64_epi64(v, _MM_SHUFFLE(1, 0, 3, 2));
@ -165,18 +106,83 @@ static INLINE __m256i truncate(__m256i v, __m256i debias, int32_t shift)
return _mm256_srai_epi32(truncable, shift);
}
static void matrix_dct_8x8_avx2(int8_t bitdepth, const int16_t *input, int16_t *output)
static void mul_clip_matrix_8x8_avx2(const int16_t *left, const int16_t *right, int16_t *dst, const int32_t shift)
{
int32_t shift_1st = kvz_g_convert_to_bit[8] + 1 + (bitdepth - 8);
int32_t shift_2nd = kvz_g_convert_to_bit[8] + 8;
const __m256i transp_mask = _mm256_broadcastsi128_si256(_mm_setr_epi8(0, 1, 8, 9, 2, 3, 10, 11, 4, 5, 12, 13, 6, 7, 14, 15));
const int32_t add1 = 1 << (shift_1st - 1);
const __m256i debias1 = _mm256_set1_epi32(add1);
const int32_t add = 1 << (shift - 1);
const __m256i debias = _mm256_set1_epi32(add);
const int32_t add2 = 1 << (shift_2nd - 1);
const __m256i debias2 = _mm256_set1_epi32(add2);
__m256i left_dr[4] = {
_mm256_loadu_si256((const __m256i *)left + 0),
_mm256_loadu_si256((const __m256i *)left + 1),
_mm256_loadu_si256((const __m256i *)left + 2),
_mm256_loadu_si256((const __m256i *)left + 3),
};
__m256i right_dr[4] = {
_mm256_loadu_si256((const __m256i *)right + 0),
_mm256_loadu_si256((const __m256i *)right + 1),
_mm256_loadu_si256((const __m256i *)right + 2),
_mm256_loadu_si256((const __m256i *)right + 3),
};
const __m256i *dct = (__m256i *)&(kvz_g_dct_8[0][0]);
__m256i rdrs_rearr[8];
// Rearrange right matrix
for (int32_t dry = 0; dry < 4; dry++) {
__m256i rdr = right_dr[dry];
__m256i rdr_los = _mm256_permute4x64_epi64(rdr, _MM_SHUFFLE(2, 0, 2, 0));
__m256i rdr_his = _mm256_permute4x64_epi64(rdr, _MM_SHUFFLE(3, 1, 3, 1));
__m256i rdr_lo_rearr = _mm256_shuffle_epi8(rdr_los, transp_mask);
__m256i rdr_hi_rearr = _mm256_shuffle_epi8(rdr_his, transp_mask);
rdrs_rearr[dry * 2 + 0] = rdr_lo_rearr;
rdrs_rearr[dry * 2 + 1] = rdr_hi_rearr;
}
// Double-Row Y for destination matrix
for (int32_t dry = 0; dry < 4; dry++) {
__m256i ldr = left_dr[dry];
__m256i ldr_slice12 = _mm256_shuffle_epi32(ldr, _MM_SHUFFLE(0, 0, 0, 0));
__m256i ldr_slice34 = _mm256_shuffle_epi32(ldr, _MM_SHUFFLE(1, 1, 1, 1));
__m256i ldr_slice56 = _mm256_shuffle_epi32(ldr, _MM_SHUFFLE(2, 2, 2, 2));
__m256i ldr_slice78 = _mm256_shuffle_epi32(ldr, _MM_SHUFFLE(3, 3, 3, 3));
__m256i prod1 = _mm256_madd_epi16(ldr_slice12, rdrs_rearr[0]);
__m256i prod2 = _mm256_madd_epi16(ldr_slice12, rdrs_rearr[1]);
__m256i prod3 = _mm256_madd_epi16(ldr_slice34, rdrs_rearr[2]);
__m256i prod4 = _mm256_madd_epi16(ldr_slice34, rdrs_rearr[3]);
__m256i prod5 = _mm256_madd_epi16(ldr_slice56, rdrs_rearr[4]);
__m256i prod6 = _mm256_madd_epi16(ldr_slice56, rdrs_rearr[5]);
__m256i prod7 = _mm256_madd_epi16(ldr_slice78, rdrs_rearr[6]);
__m256i prod8 = _mm256_madd_epi16(ldr_slice78, rdrs_rearr[7]);
__m256i lo_1 = _mm256_add_epi32(prod1, prod3);
__m256i hi_1 = _mm256_add_epi32(prod2, prod4);
__m256i lo_2 = _mm256_add_epi32(prod5, prod7);
__m256i hi_2 = _mm256_add_epi32(prod6, prod8);
__m256i lo = _mm256_add_epi32(lo_1, lo_2);
__m256i hi = _mm256_add_epi32(hi_1, hi_2);
__m256i lo_tr = truncate(lo, debias, shift);
__m256i hi_tr = truncate(hi, debias, shift);
__m256i final_dr = _mm256_packs_epi32(lo_tr, hi_tr);
_mm256_storeu_si256((__m256i *)dst + dry, final_dr);
}
}
// Multiplies A by B_T's transpose and stores result's transpose in output,
// which should be an array of 4 __m256i's
static void matmul_8x8_a_bt_t(const int16_t *a, const int16_t *b_t,
const int8_t shift, __m256i *output)
{
const int32_t add = 1 << (shift - 1);
const __m256i debias = _mm256_set1_epi32(add);
// Keep upper row intact and swap neighboring 16-bit words in lower row
const __m256i shuf_lorow_mask =
@ -185,18 +191,113 @@ static void matrix_dct_8x8_avx2(int8_t bitdepth, const int16_t *input, int16_t *
18, 19, 16, 17, 22, 23, 20, 21,
26, 27, 24, 25, 30, 31, 28, 29);
__m256i tmpres[4];
const __m256i *b_t_256 = (const __m256i *)b_t;
// Dual Rows, because two 8x16b words fit in one YMM
__m256i i_dr_0 = _mm256_loadu_si256((__m256i *)input + 0);
__m256i i_dr_1 = _mm256_loadu_si256((__m256i *)input + 1);
__m256i i_dr_2 = _mm256_loadu_si256((__m256i *)input + 2);
__m256i i_dr_3 = _mm256_loadu_si256((__m256i *)input + 3);
__m256i a_dr_0 = _mm256_loadu_si256((__m256i *)a + 0);
__m256i a_dr_1 = _mm256_loadu_si256((__m256i *)a + 1);
__m256i a_dr_2 = _mm256_loadu_si256((__m256i *)a + 2);
__m256i a_dr_3 = _mm256_loadu_si256((__m256i *)a + 3);
__m256i i_dr_0_swp = swap_lanes(i_dr_0);
__m256i i_dr_1_swp = swap_lanes(i_dr_1);
__m256i i_dr_2_swp = swap_lanes(i_dr_2);
__m256i i_dr_3_swp = swap_lanes(i_dr_3);
__m256i a_dr_0_swp = swap_lanes(a_dr_0);
__m256i a_dr_1_swp = swap_lanes(a_dr_1);
__m256i a_dr_2_swp = swap_lanes(a_dr_2);
__m256i a_dr_3_swp = swap_lanes(a_dr_3);
for (int dry = 0; dry < 4; dry++) {
// Read dual columns of B matrix by reading rows of its transpose
__m256i b_dc = _mm256_loadu_si256(b_t_256 + dry);
__m256i prod0 = _mm256_madd_epi16(b_dc, a_dr_0);
__m256i prod0_swp = _mm256_madd_epi16(b_dc, a_dr_0_swp);
__m256i prod1 = _mm256_madd_epi16(b_dc, a_dr_1);
__m256i prod1_swp = _mm256_madd_epi16(b_dc, a_dr_1_swp);
__m256i prod2 = _mm256_madd_epi16(b_dc, a_dr_2);
__m256i prod2_swp = _mm256_madd_epi16(b_dc, a_dr_2_swp);
__m256i prod3 = _mm256_madd_epi16(b_dc, a_dr_3);
__m256i prod3_swp = _mm256_madd_epi16(b_dc, a_dr_3_swp);
__m256i hsum0 = _mm256_hadd_epi32(prod0, prod0_swp);
__m256i hsum1 = _mm256_hadd_epi32(prod1, prod1_swp);
__m256i hsum2 = _mm256_hadd_epi32(prod2, prod2_swp);
__m256i hsum3 = _mm256_hadd_epi32(prod3, prod3_swp);
__m256i hsum2c_0 = _mm256_hadd_epi32(hsum0, hsum1);
__m256i hsum2c_1 = _mm256_hadd_epi32(hsum2, hsum3);
__m256i hsum2c_0_tr = truncate(hsum2c_0, debias, shift);
__m256i hsum2c_1_tr = truncate(hsum2c_1, debias, shift);
__m256i tmp_dc = _mm256_packs_epi32(hsum2c_0_tr, hsum2c_1_tr);
output[dry] = _mm256_shuffle_epi8(tmp_dc, shuf_lorow_mask);
}
}
// Multiplies A by B_T's transpose and stores result in output
// which should be an array of 4 __m256i's
static void matmul_8x8_a_bt(const int16_t *a, const __m256i *b_t,
const int8_t shift, int16_t *output)
{
const int32_t add = 1 << (shift - 1);
const __m256i debias = _mm256_set1_epi32(add);
const __m256i shuf_lorow_mask =
_mm256_setr_epi8(0, 1, 2, 3, 4, 5, 6, 7,
8, 9, 10, 11, 12, 13, 14, 15,
18, 19, 16, 17, 22, 23, 20, 21,
26, 27, 24, 25, 30, 31, 28, 29);
const __m256i *a_256 = (const __m256i *)a;
__m256i b_dc_0 = b_t[0];
__m256i b_dc_1 = b_t[1];
__m256i b_dc_2 = b_t[2];
__m256i b_dc_3 = b_t[3];
__m256i b_dc_0_swp = swap_lanes(b_dc_0);
__m256i b_dc_1_swp = swap_lanes(b_dc_1);
__m256i b_dc_2_swp = swap_lanes(b_dc_2);
__m256i b_dc_3_swp = swap_lanes(b_dc_3);
for (int dry = 0; dry < 4; dry++) {
__m256i a_dr = _mm256_loadu_si256(a_256 + dry);
__m256i prod0 = _mm256_madd_epi16(a_dr, b_dc_0);
__m256i prod0_swp = _mm256_madd_epi16(a_dr, b_dc_0_swp);
__m256i prod1 = _mm256_madd_epi16(a_dr, b_dc_1);
__m256i prod1_swp = _mm256_madd_epi16(a_dr, b_dc_1_swp);
__m256i prod2 = _mm256_madd_epi16(a_dr, b_dc_2);
__m256i prod2_swp = _mm256_madd_epi16(a_dr, b_dc_2_swp);
__m256i prod3 = _mm256_madd_epi16(a_dr, b_dc_3);
__m256i prod3_swp = _mm256_madd_epi16(a_dr, b_dc_3_swp);
__m256i hsum0 = _mm256_hadd_epi32(prod0, prod0_swp);
__m256i hsum1 = _mm256_hadd_epi32(prod1, prod1_swp);
__m256i hsum2 = _mm256_hadd_epi32(prod2, prod2_swp);
__m256i hsum3 = _mm256_hadd_epi32(prod3, prod3_swp);
__m256i hsum2c_0 = _mm256_hadd_epi32(hsum0, hsum1);
__m256i hsum2c_1 = _mm256_hadd_epi32(hsum2, hsum3);
__m256i hsum2c_0_tr = truncate(hsum2c_0, debias, shift);
__m256i hsum2c_1_tr = truncate(hsum2c_1, debias, shift);
__m256i tmp_dr = _mm256_packs_epi32(hsum2c_0_tr, hsum2c_1_tr);
__m256i final_dr = _mm256_shuffle_epi8(tmp_dr, shuf_lorow_mask);
_mm256_storeu_si256((__m256i *)output + dry, final_dr);
}
}
static void matrix_dct_8x8_avx2(int8_t bitdepth, const int16_t *input, int16_t *output)
{
int32_t shift_1st = kvz_g_convert_to_bit[8] + 1 + (bitdepth - 8);
int32_t shift_2nd = kvz_g_convert_to_bit[8] + 8;
const int16_t *dct = &kvz_g_dct_8[0][0];
/*
* Multiply input by the tranpose of DCT matrix into tmpres, and DCT matrix
@ -210,75 +311,36 @@ static void matrix_dct_8x8_avx2(int8_t bitdepth, const int16_t *input, int16_t *
* instead, since it will be used in similar fashion as the right operand
* in the second multiplication.
*/
for (int dry = 0; dry < 4; dry++) {
// Read columns of DCT matrix's transpose by reading rows of DCT matrix
__m256i d_dr = _mm256_loadu_si256(dct + dry);
__m256i tmpres[4];
__m256i prod0 = _mm256_madd_epi16(d_dr, i_dr_0);
__m256i prod0_swp = _mm256_madd_epi16(d_dr, i_dr_0_swp);
__m256i prod1 = _mm256_madd_epi16(d_dr, i_dr_1);
__m256i prod1_swp = _mm256_madd_epi16(d_dr, i_dr_1_swp);
__m256i prod2 = _mm256_madd_epi16(d_dr, i_dr_2);
__m256i prod2_swp = _mm256_madd_epi16(d_dr, i_dr_2_swp);
__m256i prod3 = _mm256_madd_epi16(d_dr, i_dr_3);
__m256i prod3_swp = _mm256_madd_epi16(d_dr, i_dr_3_swp);
matmul_8x8_a_bt_t(input, dct, shift_1st, tmpres);
matmul_8x8_a_bt (dct, tmpres, shift_2nd, output);
}
__m256i hsum0 = _mm256_hadd_epi32(prod0, prod0_swp);
__m256i hsum1 = _mm256_hadd_epi32(prod1, prod1_swp);
__m256i hsum2 = _mm256_hadd_epi32(prod2, prod2_swp);
__m256i hsum3 = _mm256_hadd_epi32(prod3, prod3_swp);
static void matrix_idct_8x8_avx2(int8_t bitdepth, const int16_t *input, int16_t *output)
{
int32_t shift_1st = 7;
int32_t shift_2nd = 12 - (bitdepth - 8);
int16_t tmp[8 * 8];
__m256i hsum2c_0 = _mm256_hadd_epi32(hsum0, hsum1);
__m256i hsum2c_1 = _mm256_hadd_epi32(hsum2, hsum3);
const int16_t *tdct = &kvz_g_dct_8_t[0][0];
const int16_t *dct = &kvz_g_dct_8 [0][0];
__m256i hsum2c_0_tr = truncate(hsum2c_0, debias1, shift_1st);
__m256i hsum2c_1_tr = truncate(hsum2c_1, debias1, shift_1st);
mul_clip_matrix_8x8_avx2(tdct, input, tmp, shift_1st);
mul_clip_matrix_8x8_avx2(tmp, dct, output, shift_2nd);
__m256i tmp_dr = _mm256_packs_epi32(hsum2c_0_tr, hsum2c_1_tr);
tmpres[dry] = _mm256_shuffle_epi8(tmp_dr, shuf_lorow_mask);
}
__m256i t_dr_0 = tmpres[0];
__m256i t_dr_1 = tmpres[1];
__m256i t_dr_2 = tmpres[2];
__m256i t_dr_3 = tmpres[3];
__m256i t_dr_0_swp = swap_lanes(t_dr_0);
__m256i t_dr_1_swp = swap_lanes(t_dr_1);
__m256i t_dr_2_swp = swap_lanes(t_dr_2);
__m256i t_dr_3_swp = swap_lanes(t_dr_3);
for (int dry = 0; dry < 4; dry++) {
__m256i d_dr = _mm256_loadu_si256(dct + dry);
__m256i prod0 = _mm256_madd_epi16(d_dr, t_dr_0);
__m256i prod0_swp = _mm256_madd_epi16(d_dr, t_dr_0_swp);
__m256i prod1 = _mm256_madd_epi16(d_dr, t_dr_1);
__m256i prod1_swp = _mm256_madd_epi16(d_dr, t_dr_1_swp);
__m256i prod2 = _mm256_madd_epi16(d_dr, t_dr_2);
__m256i prod2_swp = _mm256_madd_epi16(d_dr, t_dr_2_swp);
__m256i prod3 = _mm256_madd_epi16(d_dr, t_dr_3);
__m256i prod3_swp = _mm256_madd_epi16(d_dr, t_dr_3_swp);
__m256i hsum0 = _mm256_hadd_epi32(prod0, prod0_swp);
__m256i hsum1 = _mm256_hadd_epi32(prod1, prod1_swp);
__m256i hsum2 = _mm256_hadd_epi32(prod2, prod2_swp);
__m256i hsum3 = _mm256_hadd_epi32(prod3, prod3_swp);
__m256i hsum2c_0 = _mm256_hadd_epi32(hsum0, hsum1);
__m256i hsum2c_1 = _mm256_hadd_epi32(hsum2, hsum3);
__m256i hsum2c_0_tr = truncate(hsum2c_0, debias2, shift_2nd);
__m256i hsum2c_1_tr = truncate(hsum2c_1, debias2, shift_2nd);
__m256i tmp_dr = _mm256_packs_epi32(hsum2c_0_tr, hsum2c_1_tr);
__m256i final_dr = _mm256_shuffle_epi8(tmp_dr, shuf_lorow_mask);
_mm256_storeu_si256(((__m256i *)output) + dry, final_dr);
}
/*
* Because:
* out = tdct * input * dct = tdct * (input * dct) = tdct * (input * transpose(tdct))
* This could almost be done this way:
*
* matmul_8x8_a_bt_t(input, tdct, debias1, shift_1st, tmp);
* matmul_8x8_a_bt (tdct, tmp, debias2, shift_2nd, output);
*
* But not really, since it will fall victim to some very occasional
* rounding errors. Sadly.
*/
}
// 16x16 matrix multiplication with value clipping.
@ -468,15 +530,15 @@ static void matrix_i ## type ## _## n ## x ## n ## _avx2(int8_t bitdepth, const
TRANSFORM(dst, 4);
TRANSFORM(dct, 4);
// Ha, we've got a tailored implementation for this
// Ha, we've got a tailored implementation for these
// TRANSFORM(dct, 8);
// ITRANSFORM(dct, 8);
TRANSFORM(dct, 16);
TRANSFORM(dct, 32);
ITRANSFORM(dst, 4);
ITRANSFORM(dct, 4);
ITRANSFORM(dct, 8);
ITRANSFORM(dct, 16);
ITRANSFORM(dct, 32);