Reorder parameters for 8x8 matrix multiplies

This commit is contained in:
Pauli Oikkonen 2019-06-06 12:25:35 +03:00
parent 292af62256
commit beb85ce9d6

View file

@ -235,7 +235,7 @@ static void mul_clip_matrix_8x8_avx2(const int16_t *left, const int16_t *right,
// Multiplies A by B_T's transpose and stores result's transpose in output, // Multiplies A by B_T's transpose and stores result's transpose in output,
// which should be an array of 4 __m256i's // which should be an array of 4 __m256i's
static void matmul_8x8_a_bt_t(const int16_t *a, const int16_t *b_t, static void matmul_8x8_a_bt_t(const int16_t *a, const int16_t *b_t,
const int8_t shift, __m256i *output) __m256i *output, const int8_t shift)
{ {
const int32_t add = 1 << (shift - 1); const int32_t add = 1 << (shift - 1);
const __m256i debias = _mm256_set1_epi32(add); const __m256i debias = _mm256_set1_epi32(add);
@ -294,7 +294,7 @@ static void matmul_8x8_a_bt_t(const int16_t *a, const int16_t *b_t,
// Multiplies A by B_T's transpose and stores result in output // Multiplies A by B_T's transpose and stores result in output
// which should be an array of 4 __m256i's // which should be an array of 4 __m256i's
static void matmul_8x8_a_bt(const int16_t *a, const __m256i *b_t, static void matmul_8x8_a_bt(const int16_t *a, const __m256i *b_t,
const int8_t shift, int16_t *output) int16_t *output, const int8_t shift)
{ {
const int32_t add = 1 << (shift - 1); const int32_t add = 1 << (shift - 1);
const __m256i debias = _mm256_set1_epi32(add); const __m256i debias = _mm256_set1_epi32(add);
@ -370,8 +370,8 @@ static void matrix_dct_8x8_avx2(int8_t bitdepth, const int16_t *input, int16_t *
__m256i tmpres[4]; __m256i tmpres[4];
matmul_8x8_a_bt_t(input, dct, shift_1st, tmpres); matmul_8x8_a_bt_t(input, dct, tmpres, shift_1st);
matmul_8x8_a_bt (dct, tmpres, shift_2nd, output); matmul_8x8_a_bt (dct, tmpres, output, shift_2nd);
} }
static void matrix_idct_8x8_avx2(int8_t bitdepth, const int16_t *input, int16_t *output) static void matrix_idct_8x8_avx2(int8_t bitdepth, const int16_t *input, int16_t *output)