mirror of
https://github.com/ultravideo/uvg266.git
synced 2024-11-24 02:24:07 +00:00
Redo 4x4 matrix multiplication
This commit is contained in:
parent
07970ea82f
commit
30ce461d98
|
@ -47,54 +47,6 @@ extern const int16_t kvz_g_dct_32_t[32][32];
|
|||
* \brief AVX2 transformations.
|
||||
*/
|
||||
|
||||
// 4x4 matrix multiplication with value clipping.
|
||||
// Parameters: Two 4x4 matrices containing 16-bit values in consecutive addresses,
|
||||
// destination for the result and the shift value for clipping.
|
||||
static void mul_clip_matrix_4x4_avx2(const int16_t *left, const int16_t *right, int16_t *dst, int32_t shift)
|
||||
{
|
||||
__m256i b[2], a, result, even[2], odd[2];
|
||||
|
||||
const int32_t add = 1 << (shift - 1);
|
||||
|
||||
a = _mm256_loadu_si256((__m256i*) left);
|
||||
b[0] = _mm256_loadu_si256((__m256i*) right);
|
||||
|
||||
// Interleave values in both 128-bit lanes
|
||||
b[0] = _mm256_unpacklo_epi16(b[0], _mm256_srli_si256(b[0], 8));
|
||||
b[1] = _mm256_permute2x128_si256(b[0], b[0], 1 + 16);
|
||||
b[0] = _mm256_permute2x128_si256(b[0], b[0], 0);
|
||||
|
||||
// Fill both 128-lanes with the first pair of 16-bit factors in the lane.
|
||||
even[0] = _mm256_shuffle_epi32(a, 0);
|
||||
odd[0] = _mm256_shuffle_epi32(a, 1 + 4 + 16 + 64);
|
||||
|
||||
// Multiply packed elements and sum pairs. Input 16-bit output 32-bit.
|
||||
even[0] = _mm256_madd_epi16(even[0], b[0]);
|
||||
odd[0] = _mm256_madd_epi16(odd[0], b[1]);
|
||||
|
||||
// Add the halves of the dot product and
|
||||
// round.
|
||||
result = _mm256_add_epi32(even[0], odd[0]);
|
||||
result = _mm256_add_epi32(result, _mm256_set1_epi32(add));
|
||||
result = _mm256_srai_epi32(result, shift);
|
||||
|
||||
//Repeat for the remaining parts
|
||||
even[1] = _mm256_shuffle_epi32(a, 2 + 8 + 32 + 128);
|
||||
odd[1] = _mm256_shuffle_epi32(a, 3 + 12 + 48 + 192);
|
||||
|
||||
even[1] = _mm256_madd_epi16(even[1], b[0]);
|
||||
odd[1] = _mm256_madd_epi16(odd[1], b[1]);
|
||||
|
||||
odd[1] = _mm256_add_epi32(even[1], odd[1]);
|
||||
odd[1] = _mm256_add_epi32(odd[1], _mm256_set1_epi32(add));
|
||||
odd[1] = _mm256_srai_epi32(odd[1], shift);
|
||||
|
||||
// Truncate to 16-bit values
|
||||
result = _mm256_packs_epi32(result, odd[1]);
|
||||
|
||||
_mm256_storeu_si256((__m256i*)dst, result);
|
||||
}
|
||||
|
||||
static INLINE __m256i swap_lanes(__m256i v)
|
||||
{
|
||||
return _mm256_permute4x64_epi64(v, _MM_SHUFFLE(1, 0, 3, 2));
|
||||
|
@ -106,6 +58,110 @@ static INLINE __m256i truncate(__m256i v, __m256i debias, int32_t shift)
|
|||
return _mm256_srai_epi32(truncable, shift);
|
||||
}
|
||||
|
||||
// 4x4 matrix multiplication with value clipping.
|
||||
// Parameters: Two 4x4 matrices containing 16-bit values in consecutive addresses,
|
||||
// destination for the result and the shift value for clipping.
|
||||
static __m256i mul_clip_matrix_4x4_avx2(const __m256i left, const __m256i right, int shift)
|
||||
{
|
||||
const int32_t add = 1 << (shift - 1);
|
||||
const __m256i debias = _mm256_set1_epi32(add);
|
||||
|
||||
__m256i right_los = _mm256_permute4x64_epi64(right, _MM_SHUFFLE(2, 0, 2, 0));
|
||||
__m256i right_his = _mm256_permute4x64_epi64(right, _MM_SHUFFLE(3, 1, 3, 1));
|
||||
|
||||
__m256i right_cols_up = _mm256_unpacklo_epi16(right_los, right_his);
|
||||
__m256i right_cols_dn = _mm256_unpackhi_epi16(right_los, right_his);
|
||||
|
||||
__m256i left_slice1 = _mm256_shuffle_epi32(left, _MM_SHUFFLE(0, 0, 0, 0));
|
||||
__m256i left_slice2 = _mm256_shuffle_epi32(left, _MM_SHUFFLE(1, 1, 1, 1));
|
||||
__m256i left_slice3 = _mm256_shuffle_epi32(left, _MM_SHUFFLE(2, 2, 2, 2));
|
||||
__m256i left_slice4 = _mm256_shuffle_epi32(left, _MM_SHUFFLE(3, 3, 3, 3));
|
||||
|
||||
__m256i prod1 = _mm256_madd_epi16(left_slice1, right_cols_up);
|
||||
__m256i prod2 = _mm256_madd_epi16(left_slice2, right_cols_dn);
|
||||
__m256i prod3 = _mm256_madd_epi16(left_slice3, right_cols_up);
|
||||
__m256i prod4 = _mm256_madd_epi16(left_slice4, right_cols_dn);
|
||||
|
||||
__m256i rows_up = _mm256_add_epi32(prod1, prod2);
|
||||
__m256i rows_dn = _mm256_add_epi32(prod3, prod4);
|
||||
|
||||
__m256i rows_up_tr = truncate(rows_up, debias, shift);
|
||||
__m256i rows_dn_tr = truncate(rows_dn, debias, shift);
|
||||
|
||||
__m256i result = _mm256_packs_epi32(rows_up_tr, rows_dn_tr);
|
||||
return result;
|
||||
}
|
||||
|
||||
static void matrix_dst_4x4_avx2(int8_t bitdepth, const int16_t *input, int16_t *output)
|
||||
{
|
||||
int32_t shift_1st = kvz_g_convert_to_bit[4] + 1 + (bitdepth - 8);
|
||||
int32_t shift_2nd = kvz_g_convert_to_bit[4] + 8;
|
||||
const int16_t *tdst = &kvz_g_dst_4_t[0][0];
|
||||
const int16_t *dst = &kvz_g_dst_4 [0][0];
|
||||
|
||||
__m256i tdst_v = _mm256_loadu_si256((const __m256i *) tdst);
|
||||
__m256i dst_v = _mm256_loadu_si256((const __m256i *) dst);
|
||||
__m256i in_v = _mm256_loadu_si256((const __m256i *)input);
|
||||
|
||||
__m256i tmp = mul_clip_matrix_4x4_avx2(in_v, tdst_v, shift_1st);
|
||||
__m256i result = mul_clip_matrix_4x4_avx2(dst_v, tmp, shift_2nd);
|
||||
|
||||
_mm256_storeu_si256((__m256i *)output, result);
|
||||
}
|
||||
|
||||
static void matrix_idst_4x4_avx2(int8_t bitdepth, const int16_t *input, int16_t *output)
|
||||
{
|
||||
int32_t shift_1st = 7;
|
||||
int32_t shift_2nd = 12 - (bitdepth - 8);
|
||||
|
||||
const int16_t *tdst = &kvz_g_dst_4_t[0][0];
|
||||
const int16_t *dst = &kvz_g_dst_4 [0][0];
|
||||
|
||||
__m256i tdst_v = _mm256_loadu_si256((const __m256i *)tdst);
|
||||
__m256i dst_v = _mm256_loadu_si256((const __m256i *) dst);
|
||||
__m256i in_v = _mm256_loadu_si256((const __m256i *)input);
|
||||
|
||||
__m256i tmp = mul_clip_matrix_4x4_avx2(tdst_v, in_v, shift_1st);
|
||||
__m256i result = mul_clip_matrix_4x4_avx2(tmp, dst_v, shift_2nd);
|
||||
|
||||
_mm256_storeu_si256((__m256i *)output, result);
|
||||
}
|
||||
|
||||
static void matrix_dct_4x4_avx2(int8_t bitdepth, const int16_t *input, int16_t *output)
|
||||
{
|
||||
int32_t shift_1st = kvz_g_convert_to_bit[4] + 1 + (bitdepth - 8);
|
||||
int32_t shift_2nd = kvz_g_convert_to_bit[4] + 8;
|
||||
const int16_t *tdct = &kvz_g_dct_4_t[0][0];
|
||||
const int16_t *dct = &kvz_g_dct_4 [0][0];
|
||||
|
||||
__m256i tdct_v = _mm256_loadu_si256((const __m256i *) tdct);
|
||||
__m256i dct_v = _mm256_loadu_si256((const __m256i *) dct);
|
||||
__m256i in_v = _mm256_loadu_si256((const __m256i *)input);
|
||||
|
||||
__m256i tmp = mul_clip_matrix_4x4_avx2(in_v, tdct_v, shift_1st);
|
||||
__m256i result = mul_clip_matrix_4x4_avx2(dct_v, tmp, shift_2nd);
|
||||
|
||||
_mm256_storeu_si256((__m256i *)output, result);
|
||||
}
|
||||
|
||||
static void matrix_idct_4x4_avx2(int8_t bitdepth, const int16_t *input, int16_t *output)
|
||||
{
|
||||
int32_t shift_1st = 7;
|
||||
int32_t shift_2nd = 12 - (bitdepth - 8);
|
||||
|
||||
const int16_t *tdct = &kvz_g_dct_4_t[0][0];
|
||||
const int16_t *dct = &kvz_g_dct_4 [0][0];
|
||||
|
||||
__m256i tdct_v = _mm256_loadu_si256((const __m256i *)tdct);
|
||||
__m256i dct_v = _mm256_loadu_si256((const __m256i *) dct);
|
||||
__m256i in_v = _mm256_loadu_si256((const __m256i *)input);
|
||||
|
||||
__m256i tmp = mul_clip_matrix_4x4_avx2(tdct_v, in_v, shift_1st);
|
||||
__m256i result = mul_clip_matrix_4x4_avx2(tmp, dct_v, shift_2nd);
|
||||
|
||||
_mm256_storeu_si256((__m256i *)output, result);
|
||||
}
|
||||
|
||||
static void mul_clip_matrix_8x8_avx2(const int16_t *left, const int16_t *right, int16_t *dst, const int32_t shift)
|
||||
{
|
||||
const __m256i transp_mask = _mm256_broadcastsi128_si256(_mm_setr_epi8(0, 1, 8, 9, 2, 3, 10, 11, 4, 5, 12, 13, 6, 7, 14, 15));
|
||||
|
@ -526,19 +582,18 @@ static void matrix_i ## type ## _## n ## x ## n ## _avx2(int8_t bitdepth, const
|
|||
mul_clip_matrix_ ## n ## x ## n ## _avx2(tmp, dct, output, shift_2nd);\
|
||||
}\
|
||||
|
||||
// Generate all the transform functions
|
||||
TRANSFORM(dst, 4);
|
||||
TRANSFORM(dct, 4);
|
||||
|
||||
// Ha, we've got a tailored implementation for these
|
||||
// TRANSFORM(dst, 4);
|
||||
// ITRANSFORM(dst, 4);
|
||||
// TRANSFORM(dct, 4);
|
||||
// ITRANSFORM(dct, 4);
|
||||
// TRANSFORM(dct, 8);
|
||||
// ITRANSFORM(dct, 8);
|
||||
|
||||
// Generate all the transform functions
|
||||
TRANSFORM(dct, 16);
|
||||
TRANSFORM(dct, 32);
|
||||
|
||||
ITRANSFORM(dst, 4);
|
||||
ITRANSFORM(dct, 4);
|
||||
ITRANSFORM(dct, 16);
|
||||
ITRANSFORM(dct, 32);
|
||||
|
||||
|
|
Loading…
Reference in a new issue