diff --git a/src/strategies/avx2/dct-avx2.c b/src/strategies/avx2/dct-avx2.c index fc35b903..bc12bb6f 100644 --- a/src/strategies/avx2/dct-avx2.c +++ b/src/strategies/avx2/dct-avx2.c @@ -399,6 +399,130 @@ static void matrix_idct_8x8_avx2(int8_t bitdepth, const int16_t *input, int16_t */ } +static void matmul_16x16_a_bt_t(const int16_t *a, const int16_t *b_t, int16_t *output, const int8_t shift) +{ + const int32_t add = 1 << (shift - 1); + const __m256i debias = _mm256_set1_epi32(add); + + for (int32_t x = 0; x < 16; x++) { + __m256i bt_c = _mm256_loadu_si256((const __m256i *)b_t + x); + + __m256i results_32[2]; + + // First Row Offset + for (int32_t fro = 0; fro < 2; fro++) { + // Read first rows 0, 1, 2, 3, 8, 9, 10, 11, and then next 4 + __m256i a_r0 = _mm256_loadu_si256((const __m256i *)a + fro * 4 + 0); + __m256i a_r1 = _mm256_loadu_si256((const __m256i *)a + fro * 4 + 1); + __m256i a_r2 = _mm256_loadu_si256((const __m256i *)a + fro * 4 + 2); + __m256i a_r3 = _mm256_loadu_si256((const __m256i *)a + fro * 4 + 3); + __m256i a_r8 = _mm256_loadu_si256((const __m256i *)a + fro * 4 + 8); + __m256i a_r9 = _mm256_loadu_si256((const __m256i *)a + fro * 4 + 9); + __m256i a_r10 = _mm256_loadu_si256((const __m256i *)a + fro * 4 + 10); + __m256i a_r11 = _mm256_loadu_si256((const __m256i *)a + fro * 4 + 11); + + __m256i p0 = _mm256_madd_epi16(bt_c, a_r0); + __m256i p1 = _mm256_madd_epi16(bt_c, a_r1); + __m256i p2 = _mm256_madd_epi16(bt_c, a_r2); + __m256i p3 = _mm256_madd_epi16(bt_c, a_r3); + __m256i p8 = _mm256_madd_epi16(bt_c, a_r8); + __m256i p9 = _mm256_madd_epi16(bt_c, a_r9); + __m256i p10 = _mm256_madd_epi16(bt_c, a_r10); + __m256i p11 = _mm256_madd_epi16(bt_c, a_r11); + + // Combine low lanes from P0 and P8, high lanes from them, and the same + // with P1:P9 and so on + __m256i p0l = _mm256_permute2x128_si256(p0, p8, 0x20); + __m256i p0h = _mm256_permute2x128_si256(p0, p8, 0x31); + __m256i p1l = _mm256_permute2x128_si256(p1, p9, 0x20); + __m256i p1h = _mm256_permute2x128_si256(p1, p9, 0x31); + __m256i p2l = _mm256_permute2x128_si256(p2, p10, 0x20); + __m256i p2h = _mm256_permute2x128_si256(p2, p10, 0x31); + __m256i p3l = _mm256_permute2x128_si256(p3, p11, 0x20); + __m256i p3h = _mm256_permute2x128_si256(p3, p11, 0x31); + + __m256i s0 = _mm256_add_epi32(p0l, p0h); + __m256i s1 = _mm256_add_epi32(p1l, p1h); + __m256i s2 = _mm256_add_epi32(p2l, p2h); + __m256i s3 = _mm256_add_epi32(p3l, p3h); + + __m256i s4 = _mm256_unpacklo_epi64(s0, s1); + __m256i s5 = _mm256_unpackhi_epi64(s0, s1); + __m256i s6 = _mm256_unpacklo_epi64(s2, s3); + __m256i s7 = _mm256_unpackhi_epi64(s2, s3); + + __m256i s8 = _mm256_add_epi32(s4, s5); + __m256i s9 = _mm256_add_epi32(s6, s7); + + __m256i res = _mm256_hadd_epi32(s8, s9); + results_32[fro] = truncate(res, debias, shift); + } + __m256i final_col = _mm256_packs_epi32(results_32[0], results_32[1]); + _mm256_storeu_si256((__m256i *)output + x, final_col); + } +} + +static void matmul_16x16_a_bt(const int16_t *a, const int16_t *b_t, int16_t *output, const int8_t shift) +{ + const int32_t add = 1 << (shift - 1); + const __m256i debias = _mm256_set1_epi32(add); + + for (int32_t y = 0; y < 16; y++) { + __m256i a_r = _mm256_loadu_si256((const __m256i *)a + y); + __m256i results_32[2]; + + for (int32_t fco = 0; fco < 2; fco++) { + // Read first cols 0, 1, 2, 3, 8, 9, 10, 11, and then next 4 + __m256i bt_c0 = _mm256_loadu_si256((const __m256i *)b_t + fco * 4 + 0); + __m256i bt_c1 = _mm256_loadu_si256((const __m256i *)b_t + fco * 4 + 1); + __m256i bt_c2 = _mm256_loadu_si256((const __m256i *)b_t + fco * 4 + 2); + __m256i bt_c3 = _mm256_loadu_si256((const __m256i *)b_t + fco * 4 + 3); + __m256i bt_c8 = _mm256_loadu_si256((const __m256i *)b_t + fco * 4 + 8); + __m256i bt_c9 = _mm256_loadu_si256((const __m256i *)b_t + fco * 4 + 9); + __m256i bt_c10 = _mm256_loadu_si256((const __m256i *)b_t + fco * 4 + 10); + __m256i bt_c11 = _mm256_loadu_si256((const __m256i *)b_t + fco * 4 + 11); + + __m256i p0 = _mm256_madd_epi16(a_r, bt_c0); + __m256i p1 = _mm256_madd_epi16(a_r, bt_c1); + __m256i p2 = _mm256_madd_epi16(a_r, bt_c2); + __m256i p3 = _mm256_madd_epi16(a_r, bt_c3); + __m256i p8 = _mm256_madd_epi16(a_r, bt_c8); + __m256i p9 = _mm256_madd_epi16(a_r, bt_c9); + __m256i p10 = _mm256_madd_epi16(a_r, bt_c10); + __m256i p11 = _mm256_madd_epi16(a_r, bt_c11); + + // Combine low lanes from P0 and P8, high lanes from them, and the same + // with P1:P9 and so on + __m256i p0l = _mm256_permute2x128_si256(p0, p8, 0x20); + __m256i p0h = _mm256_permute2x128_si256(p0, p8, 0x31); + __m256i p1l = _mm256_permute2x128_si256(p1, p9, 0x20); + __m256i p1h = _mm256_permute2x128_si256(p1, p9, 0x31); + __m256i p2l = _mm256_permute2x128_si256(p2, p10, 0x20); + __m256i p2h = _mm256_permute2x128_si256(p2, p10, 0x31); + __m256i p3l = _mm256_permute2x128_si256(p3, p11, 0x20); + __m256i p3h = _mm256_permute2x128_si256(p3, p11, 0x31); + + __m256i s0 = _mm256_add_epi32(p0l, p0h); + __m256i s1 = _mm256_add_epi32(p1l, p1h); + __m256i s2 = _mm256_add_epi32(p2l, p2h); + __m256i s3 = _mm256_add_epi32(p3l, p3h); + + __m256i s4 = _mm256_unpacklo_epi64(s0, s1); + __m256i s5 = _mm256_unpackhi_epi64(s0, s1); + __m256i s6 = _mm256_unpacklo_epi64(s2, s3); + __m256i s7 = _mm256_unpackhi_epi64(s2, s3); + + __m256i s8 = _mm256_add_epi32(s4, s5); + __m256i s9 = _mm256_add_epi32(s6, s7); + + __m256i res = _mm256_hadd_epi32(s8, s9); + results_32[fco] = truncate(res, debias, shift); + } + __m256i final_col = _mm256_packs_epi32(results_32[0], results_32[1]); + _mm256_storeu_si256((__m256i *)output + y, final_col); + } +} + // 16x16 matrix multiplication with value clipping. // Parameters: Two 16x16 matrices containing 16-bit values in consecutive addresses, // destination for the result and the shift value for clipping. @@ -462,6 +586,31 @@ static void mul_clip_matrix_16x16_avx2(const int16_t *left, const int16_t *right } } +static void matrix_dct_16x16_avx2(int8_t bitdepth, const int16_t *input, int16_t *output) +{ + int32_t shift_1st = kvz_g_convert_to_bit[16] + 1 + (bitdepth - 8); + int32_t shift_2nd = kvz_g_convert_to_bit[16] + 8; + + const int16_t *dct = &kvz_g_dct_16[0][0]; + + /* + * Multiply input by the tranpose of DCT matrix into tmpres, and DCT matrix + * by tmpres - this is then our output matrix + * + * It's easier to implement an AVX2 matrix multiplication if you can multiply + * the left term with the transpose of the right term. Here things are stored + * row-wise, not column-wise, so we can effectively read DCT_T column-wise + * into YMM registers by reading DCT row-wise. Also because of this, the + * first multiplication is hacked to produce the transpose of the result + * instead, since it will be used in similar fashion as the right operand + * in the second multiplication. + */ + + int16_t tmpres[16 * 16]; + matmul_16x16_a_bt_t(input, dct, tmpres, shift_1st); + matmul_16x16_a_bt (dct, tmpres, output, shift_2nd); +} + // 32x32 matrix multiplication with value clipping. // Parameters: Two 32x32 matrices containing 16-bit values in consecutive addresses, // destination for the result and the shift value for clipping. @@ -589,9 +738,9 @@ static void matrix_i ## type ## _## n ## x ## n ## _avx2(int8_t bitdepth, const // ITRANSFORM(dct, 4); // TRANSFORM(dct, 8); // ITRANSFORM(dct, 8); +// TRANSFORM(dct, 16); // Generate all the transform functions -TRANSFORM(dct, 16); TRANSFORM(dct, 32); ITRANSFORM(dct, 16);