Redo 4x4 matrix multiplication

This commit is contained in:
Pauli Oikkonen 2019-05-30 20:29:45 +03:00
parent 07970ea82f
commit 30ce461d98

View file

@ -47,54 +47,6 @@ extern const int16_t kvz_g_dct_32_t[32][32];
* \brief AVX2 transformations.
*/
// 4x4 matrix multiplication with value clipping.
// Parameters: Two 4x4 matrices containing 16-bit values in consecutive addresses,
// destination for the result and the shift value for clipping.
static void mul_clip_matrix_4x4_avx2(const int16_t *left, const int16_t *right, int16_t *dst, int32_t shift)
{
__m256i b[2], a, result, even[2], odd[2];
const int32_t add = 1 << (shift - 1);
a = _mm256_loadu_si256((__m256i*) left);
b[0] = _mm256_loadu_si256((__m256i*) right);
// Interleave values in both 128-bit lanes
b[0] = _mm256_unpacklo_epi16(b[0], _mm256_srli_si256(b[0], 8));
b[1] = _mm256_permute2x128_si256(b[0], b[0], 1 + 16);
b[0] = _mm256_permute2x128_si256(b[0], b[0], 0);
// Fill both 128-lanes with the first pair of 16-bit factors in the lane.
even[0] = _mm256_shuffle_epi32(a, 0);
odd[0] = _mm256_shuffle_epi32(a, 1 + 4 + 16 + 64);
// Multiply packed elements and sum pairs. Input 16-bit output 32-bit.
even[0] = _mm256_madd_epi16(even[0], b[0]);
odd[0] = _mm256_madd_epi16(odd[0], b[1]);
// Add the halves of the dot product and
// round.
result = _mm256_add_epi32(even[0], odd[0]);
result = _mm256_add_epi32(result, _mm256_set1_epi32(add));
result = _mm256_srai_epi32(result, shift);
//Repeat for the remaining parts
even[1] = _mm256_shuffle_epi32(a, 2 + 8 + 32 + 128);
odd[1] = _mm256_shuffle_epi32(a, 3 + 12 + 48 + 192);
even[1] = _mm256_madd_epi16(even[1], b[0]);
odd[1] = _mm256_madd_epi16(odd[1], b[1]);
odd[1] = _mm256_add_epi32(even[1], odd[1]);
odd[1] = _mm256_add_epi32(odd[1], _mm256_set1_epi32(add));
odd[1] = _mm256_srai_epi32(odd[1], shift);
// Truncate to 16-bit values
result = _mm256_packs_epi32(result, odd[1]);
_mm256_storeu_si256((__m256i*)dst, result);
}
static INLINE __m256i swap_lanes(__m256i v)
{
return _mm256_permute4x64_epi64(v, _MM_SHUFFLE(1, 0, 3, 2));
@ -106,6 +58,110 @@ static INLINE __m256i truncate(__m256i v, __m256i debias, int32_t shift)
return _mm256_srai_epi32(truncable, shift);
}
// 4x4 matrix multiplication with value clipping.
// Parameters: Two 4x4 matrices containing 16-bit values in consecutive addresses,
// destination for the result and the shift value for clipping.
static __m256i mul_clip_matrix_4x4_avx2(const __m256i left, const __m256i right, int shift)
{
const int32_t add = 1 << (shift - 1);
const __m256i debias = _mm256_set1_epi32(add);
__m256i right_los = _mm256_permute4x64_epi64(right, _MM_SHUFFLE(2, 0, 2, 0));
__m256i right_his = _mm256_permute4x64_epi64(right, _MM_SHUFFLE(3, 1, 3, 1));
__m256i right_cols_up = _mm256_unpacklo_epi16(right_los, right_his);
__m256i right_cols_dn = _mm256_unpackhi_epi16(right_los, right_his);
__m256i left_slice1 = _mm256_shuffle_epi32(left, _MM_SHUFFLE(0, 0, 0, 0));
__m256i left_slice2 = _mm256_shuffle_epi32(left, _MM_SHUFFLE(1, 1, 1, 1));
__m256i left_slice3 = _mm256_shuffle_epi32(left, _MM_SHUFFLE(2, 2, 2, 2));
__m256i left_slice4 = _mm256_shuffle_epi32(left, _MM_SHUFFLE(3, 3, 3, 3));
__m256i prod1 = _mm256_madd_epi16(left_slice1, right_cols_up);
__m256i prod2 = _mm256_madd_epi16(left_slice2, right_cols_dn);
__m256i prod3 = _mm256_madd_epi16(left_slice3, right_cols_up);
__m256i prod4 = _mm256_madd_epi16(left_slice4, right_cols_dn);
__m256i rows_up = _mm256_add_epi32(prod1, prod2);
__m256i rows_dn = _mm256_add_epi32(prod3, prod4);
__m256i rows_up_tr = truncate(rows_up, debias, shift);
__m256i rows_dn_tr = truncate(rows_dn, debias, shift);
__m256i result = _mm256_packs_epi32(rows_up_tr, rows_dn_tr);
return result;
}
static void matrix_dst_4x4_avx2(int8_t bitdepth, const int16_t *input, int16_t *output)
{
int32_t shift_1st = kvz_g_convert_to_bit[4] + 1 + (bitdepth - 8);
int32_t shift_2nd = kvz_g_convert_to_bit[4] + 8;
const int16_t *tdst = &kvz_g_dst_4_t[0][0];
const int16_t *dst = &kvz_g_dst_4 [0][0];
__m256i tdst_v = _mm256_loadu_si256((const __m256i *) tdst);
__m256i dst_v = _mm256_loadu_si256((const __m256i *) dst);
__m256i in_v = _mm256_loadu_si256((const __m256i *)input);
__m256i tmp = mul_clip_matrix_4x4_avx2(in_v, tdst_v, shift_1st);
__m256i result = mul_clip_matrix_4x4_avx2(dst_v, tmp, shift_2nd);
_mm256_storeu_si256((__m256i *)output, result);
}
static void matrix_idst_4x4_avx2(int8_t bitdepth, const int16_t *input, int16_t *output)
{
int32_t shift_1st = 7;
int32_t shift_2nd = 12 - (bitdepth - 8);
const int16_t *tdst = &kvz_g_dst_4_t[0][0];
const int16_t *dst = &kvz_g_dst_4 [0][0];
__m256i tdst_v = _mm256_loadu_si256((const __m256i *)tdst);
__m256i dst_v = _mm256_loadu_si256((const __m256i *) dst);
__m256i in_v = _mm256_loadu_si256((const __m256i *)input);
__m256i tmp = mul_clip_matrix_4x4_avx2(tdst_v, in_v, shift_1st);
__m256i result = mul_clip_matrix_4x4_avx2(tmp, dst_v, shift_2nd);
_mm256_storeu_si256((__m256i *)output, result);
}
static void matrix_dct_4x4_avx2(int8_t bitdepth, const int16_t *input, int16_t *output)
{
int32_t shift_1st = kvz_g_convert_to_bit[4] + 1 + (bitdepth - 8);
int32_t shift_2nd = kvz_g_convert_to_bit[4] + 8;
const int16_t *tdct = &kvz_g_dct_4_t[0][0];
const int16_t *dct = &kvz_g_dct_4 [0][0];
__m256i tdct_v = _mm256_loadu_si256((const __m256i *) tdct);
__m256i dct_v = _mm256_loadu_si256((const __m256i *) dct);
__m256i in_v = _mm256_loadu_si256((const __m256i *)input);
__m256i tmp = mul_clip_matrix_4x4_avx2(in_v, tdct_v, shift_1st);
__m256i result = mul_clip_matrix_4x4_avx2(dct_v, tmp, shift_2nd);
_mm256_storeu_si256((__m256i *)output, result);
}
static void matrix_idct_4x4_avx2(int8_t bitdepth, const int16_t *input, int16_t *output)
{
int32_t shift_1st = 7;
int32_t shift_2nd = 12 - (bitdepth - 8);
const int16_t *tdct = &kvz_g_dct_4_t[0][0];
const int16_t *dct = &kvz_g_dct_4 [0][0];
__m256i tdct_v = _mm256_loadu_si256((const __m256i *)tdct);
__m256i dct_v = _mm256_loadu_si256((const __m256i *) dct);
__m256i in_v = _mm256_loadu_si256((const __m256i *)input);
__m256i tmp = mul_clip_matrix_4x4_avx2(tdct_v, in_v, shift_1st);
__m256i result = mul_clip_matrix_4x4_avx2(tmp, dct_v, shift_2nd);
_mm256_storeu_si256((__m256i *)output, result);
}
static void mul_clip_matrix_8x8_avx2(const int16_t *left, const int16_t *right, int16_t *dst, const int32_t shift)
{
const __m256i transp_mask = _mm256_broadcastsi128_si256(_mm_setr_epi8(0, 1, 8, 9, 2, 3, 10, 11, 4, 5, 12, 13, 6, 7, 14, 15));
@ -526,19 +582,18 @@ static void matrix_i ## type ## _## n ## x ## n ## _avx2(int8_t bitdepth, const
mul_clip_matrix_ ## n ## x ## n ## _avx2(tmp, dct, output, shift_2nd);\
}\
// Generate all the transform functions
TRANSFORM(dst, 4);
TRANSFORM(dct, 4);
// Ha, we've got a tailored implementation for these
// TRANSFORM(dst, 4);
// ITRANSFORM(dst, 4);
// TRANSFORM(dct, 4);
// ITRANSFORM(dct, 4);
// TRANSFORM(dct, 8);
// ITRANSFORM(dct, 8);
// Generate all the transform functions
TRANSFORM(dct, 16);
TRANSFORM(dct, 32);
ITRANSFORM(dst, 4);
ITRANSFORM(dct, 4);
ITRANSFORM(dct, 16);
ITRANSFORM(dct, 32);