diff --git a/src/strategies/avx2/intra-avx2.c b/src/strategies/avx2/intra-avx2.c index 3b9543ec..e935e8a2 100644 --- a/src/strategies/avx2/intra-avx2.c +++ b/src/strategies/avx2/intra-avx2.c @@ -48,6 +48,77 @@ #include "strategyselector.h" #include "strategies/missing-intel-intrinsics.h" +// Y coord tables +ALIGNED(32) static const int8_t planar_avx2_ver_w4ys[1024] = { + 63, 1, 63, 1, 63, 1, 63, 1, 62, 2, 62, 2, 62, 2, 62, 2, 61, 3, 61, 3, 61, 3, 61, 3, 60, 4, 60, 4, 60, 4, 60, 4, // offset 0, line == 64 + 59, 5, 59, 5, 59, 5, 59, 5, 58, 6, 58, 6, 58, 6, 58, 6, 57, 7, 57, 7, 57, 7, 57, 7, 56, 8, 56, 8, 56, 8, 56, 8, + 55, 9, 55, 9, 55, 9, 55, 9, 54, 10, 54, 10, 54, 10, 54, 10, 53, 11, 53, 11, 53, 11, 53, 11, 52, 12, 52, 12, 52, 12, 52, 12, + 51, 13, 51, 13, 51, 13, 51, 13, 50, 14, 50, 14, 50, 14, 50, 14, 49, 15, 49, 15, 49, 15, 49, 15, 48, 16, 48, 16, 48, 16, 48, 16, + 47, 17, 47, 17, 47, 17, 47, 17, 46, 18, 46, 18, 46, 18, 46, 18, 45, 19, 45, 19, 45, 19, 45, 19, 44, 20, 44, 20, 44, 20, 44, 20, + 43, 21, 43, 21, 43, 21, 43, 21, 42, 22, 42, 22, 42, 22, 42, 22, 41, 23, 41, 23, 41, 23, 41, 23, 40, 24, 40, 24, 40, 24, 40, 24, + 39, 25, 39, 25, 39, 25, 39, 25, 38, 26, 38, 26, 38, 26, 38, 26, 37, 27, 37, 27, 37, 27, 37, 27, 36, 28, 36, 28, 36, 28, 36, 28, + 35, 29, 35, 29, 35, 29, 35, 29, 34, 30, 34, 30, 34, 30, 34, 30, 33, 31, 33, 31, 33, 31, 33, 31, 32, 32, 32, 32, 32, 32, 32, 32, + 31, 33, 31, 33, 31, 33, 31, 33, 30, 34, 30, 34, 30, 34, 30, 34, 29, 35, 29, 35, 29, 35, 29, 35, 28, 36, 28, 36, 28, 36, 28, 36, + 27, 37, 27, 37, 27, 37, 27, 37, 26, 38, 26, 38, 26, 38, 26, 38, 25, 39, 25, 39, 25, 39, 25, 39, 24, 40, 24, 40, 24, 40, 24, 40, + 23, 41, 23, 41, 23, 41, 23, 41, 22, 42, 22, 42, 22, 42, 22, 42, 21, 43, 21, 43, 21, 43, 21, 43, 20, 44, 20, 44, 20, 44, 20, 44, + 19, 45, 19, 45, 19, 45, 19, 45, 18, 46, 18, 46, 18, 46, 18, 46, 17, 47, 17, 47, 17, 47, 17, 47, 16, 48, 16, 48, 16, 48, 16, 48, + 15, 49, 15, 49, 15, 49, 15, 49, 14, 50, 14, 50, 14, 50, 14, 50, 13, 51, 13, 51, 13, 51, 13, 51, 12, 52, 12, 52, 12, 52, 12, 52, + 11, 53, 11, 53, 11, 53, 11, 53, 10, 54, 10, 54, 10, 54, 10, 54, 9, 55, 9, 55, 9, 55, 9, 55, 8, 56, 8, 56, 8, 56, 8, 56, + 7, 57, 7, 57, 7, 57, 7, 57, 6, 58, 6, 58, 6, 58, 6, 58, 5, 59, 5, 59, 5, 59, 5, 59, 4, 60, 4, 60, 4, 60, 4, 60, + 3, 61, 3, 61, 3, 61, 3, 61, 2, 62, 2, 62, 2, 62, 2, 62, 1, 63, 1, 63, 1, 63, 1, 63, 0, 64, 0, 64, 0, 64, 0, 64, + 31, 1, 31, 1, 31, 1, 31, 1, 30, 2, 30, 2, 30, 2, 30, 2, 29, 3, 29, 3, 29, 3, 29, 3, 28, 4, 28, 4, 28, 4, 28, 4, // offset 16, line == 32 + 27, 5, 27, 5, 27, 5, 27, 5, 26, 6, 26, 6, 26, 6, 26, 6, 25, 7, 25, 7, 25, 7, 25, 7, 24, 8, 24, 8, 24, 8, 24, 8, + 23, 9, 23, 9, 23, 9, 23, 9, 22, 10, 22, 10, 22, 10, 22, 10, 21, 11, 21, 11, 21, 11, 21, 11, 20, 12, 20, 12, 20, 12, 20, 12, + 19, 13, 19, 13, 19, 13, 19, 13, 18, 14, 18, 14, 18, 14, 18, 14, 17, 15, 17, 15, 17, 15, 17, 15, 16, 16, 16, 16, 16, 16, 16, 16, + 15, 17, 15, 17, 15, 17, 15, 17, 14, 18, 14, 18, 14, 18, 14, 18, 13, 19, 13, 19, 13, 19, 13, 19, 12, 20, 12, 20, 12, 20, 12, 20, + 11, 21, 11, 21, 11, 21, 11, 21, 10, 22, 10, 22, 10, 22, 10, 22, 9, 23, 9, 23, 9, 23, 9, 23, 8, 24, 8, 24, 8, 24, 8, 24, + 7, 25, 7, 25, 7, 25, 7, 25, 6, 26, 6, 26, 6, 26, 6, 26, 5, 27, 5, 27, 5, 27, 5, 27, 4, 28, 4, 28, 4, 28, 4, 28, + 3, 29, 3, 29, 3, 29, 3, 29, 2, 30, 2, 30, 2, 30, 2, 30, 1, 31, 1, 31, 1, 31, 1, 31, 0, 32, 0, 32, 0, 32, 0, 32, + 15, 1, 15, 1, 15, 1, 15, 1, 14, 2, 14, 2, 14, 2, 14, 2, 13, 3, 13, 3, 13, 3, 13, 3, 12, 4, 12, 4, 12, 4, 12, 4, // offset 24, line == 16 + 11, 5, 11, 5, 11, 5, 11, 5, 10, 6, 10, 6, 10, 6, 10, 6, 9, 7, 9, 7, 9, 7, 9, 7, 8, 8, 8, 8, 8, 8, 8, 8, + 7, 9, 7, 9, 7, 9, 7, 9, 6, 10, 6, 10, 6, 10, 6, 10, 5, 11, 5, 11, 5, 11, 5, 11, 4, 12, 4, 12, 4, 12, 4, 12, + 3, 13, 3, 13, 3, 13, 3, 13, 2, 14, 2, 14, 2, 14, 2, 14, 1, 15, 1, 15, 1, 15, 1, 15, 0, 16, 0, 16, 0, 16, 0, 16, + 7, 1, 7, 1, 7, 1, 7, 1, 6, 2, 6, 2, 6, 2, 6, 2, 5, 3, 5, 3, 5, 3, 5, 3, 4, 4, 4, 4, 4, 4, 4, 4, // offset 28, line == 8 + 3, 5, 3, 5, 3, 5, 3, 5, 2, 6, 2, 6, 2, 6, 2, 6, 1, 7, 1, 7, 1, 7, 1, 7, 0, 8, 0, 8, 0, 8, 0, 8, + 3, 1, 3, 1, 3, 1, 3, 1, 2, 2, 2, 2, 2, 2, 2, 2, 1, 3, 1, 3, 1, 3, 1, 3, 0, 4, 0, 4, 0, 4, 0, 4, // offset 30, line == 4 + 1, 1, 1, 1, 1, 1, 1, 1, 0, 2, 0, 2, 0, 2, 0, 2, 1, 1, 1, 1, 1, 1, 1, 1, 0, 2, 0, 2, 0, 2, 0, 2, // offset 31. line == 2 +}; + +ALIGNED(32) static const int8_t planar_avx2_ver_w8ys[2048] = { + 63, 1, 63, 1, 63, 1, 63, 1, 63, 1, 63, 1, 63, 1, 63, 1, 62, 2, 62, 2, 62, 2, 62, 2, 62, 2, 62, 2, 62, 2, 62, 2, + 61, 3, 61, 3, 61, 3, 61, 3, 61, 3, 61, 3, 61, 3, 61, 3, 60, 4, 60, 4, 60, 4, 60, 4, 60, 4, 60, 4, 60, 4, 60, 4, + 59, 5, 59, 5, 59, 5, 59, 5, 59, 5, 59, 5, 59, 5, 59, 5, 58, 6, 58, 6, 58, 6, 58, 6, 58, 6, 58, 6, 58, 6, 58, 6, + 57, 7, 57, 7, 57, 7, 57, 7, 57, 7, 57, 7, 57, 7, 57, 7, 56, 8, 56, 8, 56, 8, 56, 8, 56, 8, 56, 8, 56, 8, 56, 8, + 55, 9, 55, 9, 55, 9, 55, 9, 55, 9, 55, 9, 55, 9, 55, 9, 54, 10, 54, 10, 54, 10, 54, 10, 54, 10, 54, 10, 54, 10, 54, 10, + 53, 11, 53, 11, 53, 11, 53, 11, 53, 11, 53, 11, 53, 11, 53, 11, 52, 12, 52, 12, 52, 12, 52, 12, 52, 12, 52, 12, 52, 12, 52, 12, + 51, 13, 51, 13, 51, 13, 51, 13, 51, 13, 51, 13, 51, 13, 51, 13, 50, 14, 50, 14, 50, 14, 50, 14, 50, 14, 50, 14, 50, 14, 50, 14, + 49, 15, 49, 15, 49, 15, 49, 15, 49, 15, 49, 15, 49, 15, 49, 15, 48, 16, 48, 16, 48, 16, 48, 16, 48, 16, 48, 16, 48, 16, 48, 16, + 47, 17, 47, 17, 47, 17, 47, 17, 47, 17, 47, 17, 47, 17, 47, 17, 46, 18, 46, 18, 46, 18, 46, 18, 46, 18, 46, 18, 46, 18, 46, 18, + 45, 19, 45, 19, 45, 19, 45, 19, 45, 19, 45, 19, 45, 19, 45, 19, 44, 20, 44, 20, 44, 20, 44, 20, 44, 20, 44, 20, 44, 20, 44, 20, + 43, 21, 43, 21, 43, 21, 43, 21, 43, 21, 43, 21, 43, 21, 43, 21, 42, 22, 42, 22, 42, 22, 42, 22, 42, 22, 42, 22, 42, 22, 42, 22, + 41, 23, 41, 23, 41, 23, 41, 23, 41, 23, 41, 23, 41, 23, 41, 23, 40, 24, 40, 24, 40, 24, 40, 24, 40, 24, 40, 24, 40, 24, 40, 24, + 39, 25, 39, 25, 39, 25, 39, 25, 39, 25, 39, 25, 39, 25, 39, 25, 38, 26, 38, 26, 38, 26, 38, 26, 38, 26, 38, 26, 38, 26, 38, 26, + 37, 27, 37, 27, 37, 27, 37, 27, 37, 27, 37, 27, 37, 27, 37, 27, 36, 28, 36, 28, 36, 28, 36, 28, 36, 28, 36, 28, 36, 28, 36, 28, + 35, 29, 35, 29, 35, 29, 35, 29, 35, 29, 35, 29, 35, 29, 35, 29, 34, 30, 34, 30, 34, 30, 34, 30, 34, 30, 34, 30, 34, 30, 34, 30, + 33, 31, 33, 31, 33, 31, 33, 31, 33, 31, 33, 31, 33, 31, 33, 31, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, + 31, 33, 31, 33, 31, 33, 31, 33, 31, 33, 31, 33, 31, 33, 31, 33, 30, 34, 30, 34, 30, 34, 30, 34, 30, 34, 30, 34, 30, 34, 30, 34, + 29, 35, 29, 35, 29, 35, 29, 35, 29, 35, 29, 35, 29, 35, 29, 35, 28, 36, 28, 36, 28, 36, 28, 36, 28, 36, 28, 36, 28, 36, 28, 36, + 27, 37, 27, 37, 27, 37, 27, 37, 27, 37, 27, 37, 27, 37, 27, 37, 26, 38, 26, 38, 26, 38, 26, 38, 26, 38, 26, 38, 26, 38, 26, 38, + 25, 39, 25, 39, 25, 39, 25, 39, 25, 39, 25, 39, 25, 39, 25, 39, 24, 40, 24, 40, 24, 40, 24, 40, 24, 40, 24, 40, 24, 40, 24, 40, + 23, 41, 23, 41, 23, 41, 23, 41, 23, 41, 23, 41, 23, 41, 23, 41, 22, 42, 22, 42, 22, 42, 22, 42, 22, 42, 22, 42, 22, 42, 22, 42, + 21, 43, 21, 43, 21, 43, 21, 43, 21, 43, 21, 43, 21, 43, 21, 43, 20, 44, 20, 44, 20, 44, 20, 44, 20, 44, 20, 44, 20, 44, 20, 44, + 19, 45, 19, 45, 19, 45, 19, 45, 19, 45, 19, 45, 19, 45, 19, 45, 18, 46, 18, 46, 18, 46, 18, 46, 18, 46, 18, 46, 18, 46, 18, 46, + 17, 47, 17, 47, 17, 47, 17, 47, 17, 47, 17, 47, 17, 47, 17, 47, 16, 48, 16, 48, 16, 48, 16, 48, 16, 48, 16, 48, 16, 48, 16, 48, + 15, 49, 15, 49, 15, 49, 15, 49, 15, 49, 15, 49, 15, 49, 15, 49, 14, 50, 14, 50, 14, 50, 14, 50, 14, 50, 14, 50, 14, 50, 14, 50, + 13, 51, 13, 51, 13, 51, 13, 51, 13, 51, 13, 51, 13, 51, 13, 51, 12, 52, 12, 52, 12, 52, 12, 52, 12, 52, 12, 52, 12, 52, 12, 52, + 11, 53, 11, 53, 11, 53, 11, 53, 11, 53, 11, 53, 11, 53, 11, 53, 10, 54, 10, 54, 10, 54, 10, 54, 10, 54, 10, 54, 10, 54, 10, 54, + 9, 55, 9, 55, 9, 55, 9, 55, 9, 55, 9, 55, 9, 55, 9, 55, 8, 56, 8, 56, 8, 56, 8, 56, 8, 56, 8, 56, 8, 56, 8, 56, + 7, 57, 7, 57, 7, 57, 7, 57, 7, 57, 7, 57, 7, 57, 7, 57, 6, 58, 6, 58, 6, 58, 6, 58, 6, 58, 6, 58, 6, 58, 6, 58, + 5, 59, 5, 59, 5, 59, 5, 59, 5, 59, 5, 59, 5, 59, 5, 59, 4, 60, 4, 60, 4, 60, 4, 60, 4, 60, 4, 60, 4, 60, 4, 60, + 3, 61, 3, 61, 3, 61, 3, 61, 3, 61, 3, 61, 3, 61, 3, 61, 2, 62, 2, 62, 2, 62, 2, 62, 2, 62, 2, 62, 2, 62, 2, 62, + 1, 63, 1, 63, 1, 63, 1, 63, 1, 63, 1, 63, 1, 63, 1, 63, 0, 64, 0, 64, 0, 64, 0, 64, 0, 64, 0, 64, 0, 64, 0, 64, +}; + /** * \brief Generate angular predictions. * \param cu_loc CU locationand size data. @@ -633,13 +704,13 @@ static void intra_pred_planar_hor_w4(const uvg_pixel* ref, const int line, const for (int i = 0, d = 0; i < line; i += 4, ++d) { // Handle 4 lines at a time + // TODO: setr is VERY SLOW, replace this __m256i v_ref = _mm256_setr_epi16(ref[i + 1], ref[i + 1], ref[i + 1], ref[i + 1], ref[i + 2], ref[i + 2], ref[i + 2], ref[i + 2], ref[i + 3], ref[i + 3], ref[i + 3], ref[i + 3], ref[i + 4], ref[i + 4], ref[i + 4], ref[i + 4]); __m256i v_tmp = _mm256_mullo_epi16(v_ref, v_ref_coeff); - v_tmp = _mm256_add_epi16(v_last_ref_mul, v_tmp); - dst[d] = _mm256_slli_epi16(v_tmp, shift); + dst[d] = _mm256_add_epi16(v_last_ref_mul, v_tmp); } } static void intra_pred_planar_hor_w8(const uvg_pixel* ref, const int line, const int shift, __m256i* dst) @@ -661,8 +732,7 @@ static void intra_pred_planar_hor_w8(const uvg_pixel* ref, const int line, const __m256i v_tmp = _mm256_mullo_epi16(v_ref, v_ref_coeff); - v_tmp = _mm256_add_epi16(v_last_ref_mul, v_tmp); - dst[d] = _mm256_slli_epi16(v_tmp, shift); + dst[d] = _mm256_add_epi16(v_last_ref_mul, v_tmp); } } static void intra_pred_planar_hor_w16(const uvg_pixel* ref, const int line, const int shift, __m256i* dst) @@ -679,8 +749,7 @@ static void intra_pred_planar_hor_w16(const uvg_pixel* ref, const int line, cons __m256i v_tmp = _mm256_mullo_epi16(v_ref, v_ref_coeff); // TODO: the result is needed immediately after this. This leads to NOPs, consider doing multiple lines at a time - v_tmp = _mm256_add_epi16(v_last_ref_mul, v_tmp); - dst[d] = _mm256_slli_epi16(v_tmp, shift); + dst[d] = _mm256_add_epi16(v_last_ref_mul, v_tmp); } } static void intra_pred_planar_hor_w32(const uvg_pixel* ref, const int line, const int shift, __m256i* dst) {} @@ -691,30 +760,53 @@ static void intra_pred_planar_ver_w4(const uvg_pixel* ref, const int line, const { const __m256i v_last_ref = _mm256_set1_epi8(ref[line + 1]); + // Overflow possible for this width if line > 32 + const bool overflow = line > 32; + // Got four 8-bit references, or 32 bits of data. Duplicate to fill a whole 256-bit vector. const uint32_t* tmp = (const uint32_t*)&ref[1]; // Cast to 32 bit int to load 4 refs at the same time const __m256i v_ref = _mm256_set1_epi32(*tmp); - // Handle 4 lines at a time - for (int y = 0, d = 0; y < line; y += 4, ++d) { - const int a1 = line - 1 - (y + 0); - const int a2 = line - 1 - (y + 1); - const int a3 = line - 1 - (y + 2); - const int a4 = line - 1 - (y + 3); - const int b1 = (y + 0) + 1; - const int b2 = (y + 1) + 1; - const int b3 = (y + 2) + 1; - const int b4 = (y + 3) + 1; + const __m256i* v_ys = (const __m256i*)planar_avx2_ver_w4ys; - __m256i v_ys = _mm256_setr_epi8(a1, b1, a1, b1, a1, b1, a1, b1, - a2, b2, a2, b2, a2, b2, a2, b2, - a3, b3, a3, b3, a3, b3, a3, b3, - a4, b4, a4, b4, a4, b4, a4, b4); // TODO: these could be loaded from a table - __m256i v_lo = _mm256_unpacklo_epi8(v_ref, v_last_ref); - - __m256i v_madd_lo = _mm256_maddubs_epi16(v_lo, v_ys); - dst[d] = _mm256_slli_epi16(v_madd_lo, shift); + // Table offset + int offset; + if (line == 64) { + offset = 0; } + else if (line == 32) { + offset = 16; + } + else if (line == 16) { + offset = 24; + } + else if (line == 8) { + offset = 28; + } + else { // Do not care about lines < 4 since they are illegal + offset = 30; + } + + // Handle 4 lines at a time + #define UNROLL_LOOP(num) \ + for (int y = 0, s = offset, d = 0; y < (num); y += 4, ++s, ++d) { \ + __m256i v_lo = _mm256_unpacklo_epi8(v_ref, v_last_ref); \ + dst[d] = _mm256_maddubs_epi16(v_lo, v_ys[s]); \ + } + + switch (line) { + case 1: UNROLL_LOOP(1); break; + case 2: UNROLL_LOOP(2); break; + case 4: UNROLL_LOOP(4); break; + case 8: UNROLL_LOOP(8); break; + case 16: UNROLL_LOOP(16); break; + case 32: UNROLL_LOOP(32); break; + case 64: UNROLL_LOOP(64); break; + default: + assert(false && "Invalid dimension."); + break; + } + #undef UNROLL_LOOP } static void intra_pred_planar_ver_w8(const uvg_pixel* ref, const int line, const int shift, __m256i* dst) { @@ -726,33 +818,57 @@ static void intra_pred_planar_ver_w8(const uvg_pixel* ref, const int line, const v_ref = _mm256_inserti128_si256(v_ref, v_ref_raw, 1); v_ref = _mm256_shuffle_epi32(v_ref, _MM_SHUFFLE(1, 1, 0, 0)); - // Handle 4 lines at a time, unless line == 2 - for (int y = 0, d = 0; y < line; y += 4, d += 2) { - const int a1 = line - 1 - (y + 0); - const int b1 = (y + 0) + 1; - const int a2 = line - 1 - (y + 1); - const int b2 = (y + 1) + 1; - const int a3 = line - 1 - (y + 2); - const int b3 = (y + 2) + 1; - const int a4 = line - 1 - (y + 3); - const int b4 = (y + 3) + 1; - __m256i v_ys = _mm256_setr_epi8(a1, b1, a1, b1, a1, b1, a1, b1, - a2, b2, a2, b2, a2, b2, a2, b2, - a3, b3, a3, b3, a3, b3, a3, b3, - a4, b4, a4, b4, a4, b4, a4, b4); // TODO: these could be loaded from a table - __m256i v_lo = _mm256_unpacklo_epi8(v_ref, v_last_ref); - __m256i v_hi = _mm256_unpackhi_epi8(v_ref, v_last_ref); + const __m256i* v_ys = (const __m256i*)planar_avx2_ver_w4ys; - __m256i v_madd_lo = _mm256_maddubs_epi16(v_lo, v_ys); - __m256i v_madd_hi = _mm256_maddubs_epi16(v_hi, v_ys); - v_madd_lo = _mm256_slli_epi16(v_madd_lo, shift); - v_madd_hi = _mm256_slli_epi16(v_madd_hi, shift); - __m256i v_tmp0 = _mm256_permute2x128_si256(v_madd_lo, v_madd_hi, 0x20); - __m256i v_tmp1 = _mm256_permute2x128_si256(v_madd_lo, v_madd_hi, 0x31); - - dst[d + 0] = _mm256_permute4x64_epi64(v_tmp0, _MM_SHUFFLE(3, 1, 2, 0)); - dst[d + 1] = _mm256_permute4x64_epi64(v_tmp1, _MM_SHUFFLE(3, 1, 2, 0)); + // Table offset + int offset; + if (line == 64) { + offset = 0; } + else if (line == 32) { + offset = 16; + } + else if (line == 16) { + offset = 24; + } + else if (line == 8) { + offset = 28; + } + else if (line == 4) { + offset = 30; + } + else { // Do not care about line == 1 since it is illegal for this width + offset = 31; + } + + // Handle 4 lines at a time + #define UNROLL_LOOP(num) \ + for (int y = 0, s = offset, d = 0; y < (num); y += 4, ++s, d += 2) { \ + __m256i v_lo = _mm256_unpacklo_epi8(v_ref, v_last_ref); \ + __m256i v_hi = _mm256_unpackhi_epi8(v_ref, v_last_ref); \ + \ + __m256i v_madd_lo = _mm256_maddubs_epi16(v_lo, v_ys[s]); \ + __m256i v_madd_hi = _mm256_maddubs_epi16(v_hi, v_ys[s]); \ + __m256i v_tmp0 = _mm256_permute2x128_si256(v_madd_lo, v_madd_hi, 0x20); \ + __m256i v_tmp1 = _mm256_permute2x128_si256(v_madd_lo, v_madd_hi, 0x31); \ + \ + dst[d + 0] = _mm256_permute4x64_epi64(v_tmp0, _MM_SHUFFLE(3, 1, 2, 0)); \ + dst[d + 1] = _mm256_permute4x64_epi64(v_tmp1, _MM_SHUFFLE(3, 1, 2, 0)); \ + } + + switch (line) { + case 1: UNROLL_LOOP(1); break; + case 2: UNROLL_LOOP(2); break; + case 4: UNROLL_LOOP(4); break; + case 8: UNROLL_LOOP(8); break; + case 16: UNROLL_LOOP(16); break; + case 32: UNROLL_LOOP(32); break; + case 64: UNROLL_LOOP(64); break; + default: + assert(false && "Invalid dimension."); + break; + } + #undef UNROLL_LOOP } static void intra_pred_planar_ver_w16(const uvg_pixel* ref, const int line, const int shift, __m256i* dst) { @@ -763,26 +879,55 @@ static void intra_pred_planar_ver_w16(const uvg_pixel* ref, const int line, cons __m256i v_ref = _mm256_castsi128_si256(v_ref_raw); v_ref = _mm256_inserti128_si256(v_ref, v_ref_raw, 1); - // Handle 2 lines at a time - for (int y = 0; y < line; y += 2) { - const int a1 = line - 1 - (y + 0); - const int b1 = (y + 0) + 1; - const int a2 = line - 1 - (y + 1); - const int b2 = (y + 1) + 1; - __m256i v_ys = _mm256_setr_epi8(a1, b1, a1, b1, a1, b1, a1, b1, - a1, b1, a1, b1, a1, b1, a1, b1, - a2, b2, a2, b2, a2, b2, a2, b2, - a2, b2, a2, b2, a2, b2, a2, b2); // TODO: these could be loaded from a table - __m256i v_lo = _mm256_unpacklo_epi8(v_ref, v_last_ref); - __m256i v_hi = _mm256_unpackhi_epi8(v_ref, v_last_ref); + const __m256i* v_ys = (const __m256i*)planar_avx2_ver_w8ys; - __m256i v_madd_lo = _mm256_maddubs_epi16(v_lo, v_ys); - __m256i v_madd_hi = _mm256_maddubs_epi16(v_hi, v_ys); - v_madd_lo = _mm256_slli_epi16(v_madd_lo, shift); - v_madd_hi = _mm256_slli_epi16(v_madd_hi, shift); - dst[y + 0] = _mm256_permute2x128_si256(v_madd_lo, v_madd_hi, 0x20); - dst[y + 1] = _mm256_permute2x128_si256(v_madd_lo, v_madd_hi, 0x31); + // Table offset + int offset; + if (line == 64) { + offset = 0; } + else if (line == 32) { + offset = 16; + } + else if (line == 16) { + offset = 24; + } + else if (line == 8) { + offset = 28; + } + else if (line == 4) { + offset = 30; + } + else { // Do not care about line == 1 since it is illegal for this width + offset = 31; + } + + // These stay constant through the loop + const __m256i v_lo = _mm256_unpacklo_epi8(v_ref, v_last_ref); + const __m256i v_hi = _mm256_unpackhi_epi8(v_ref, v_last_ref); + + // Handle 2 lines at a time + #define UNROLL_LOOP(num) \ + for (int y = 0, s = offset; y < (num); y += 2, ++s) { \ + __m256i v_madd_lo = _mm256_maddubs_epi16(v_lo, v_ys[s]); \ + __m256i v_madd_hi = _mm256_maddubs_epi16(v_hi, v_ys[s]); \ + dst[y + 0] = _mm256_permute2x128_si256(v_madd_lo, v_madd_hi, 0x20); \ + dst[y + 1] = _mm256_permute2x128_si256(v_madd_lo, v_madd_hi, 0x31); \ + } + + switch (line) { + case 1: UNROLL_LOOP(1); break; + case 2: UNROLL_LOOP(2); break; + case 4: UNROLL_LOOP(4); break; + case 8: UNROLL_LOOP(8); break; + case 16: UNROLL_LOOP(16); break; + case 32: UNROLL_LOOP(32); break; + case 64: UNROLL_LOOP(64); break; + default: + assert(false && "Invalid dimension."); + break; + } + #undef UNROLL_LOOP } static void intra_pred_planar_ver_w32(const uvg_pixel* ref, const int line, const int shift, __m256i* dst) {} @@ -802,7 +947,7 @@ void uvg_intra_pred_planar_avx2(const cu_loc_t* const cu_loc, const int width = color == COLOR_Y ? cu_loc->width : cu_loc->chroma_width; const int height = color == COLOR_Y ? cu_loc->height : cu_loc->chroma_height; const int samples = width * height; - const __m256i v_samples = _mm256_set1_epi16(samples); + const __m256i v_samples = _mm256_set1_epi32(samples); const int log2_width = uvg_g_convert_to_log2[width]; const int log2_height = uvg_g_convert_to_log2[height]; @@ -821,11 +966,35 @@ void uvg_intra_pred_planar_avx2(const cu_loc_t* const cu_loc, int16_t* hor_res = (int16_t*)v_pred_hor; int16_t* ver_res = (int16_t*)v_pred_ver; + // Cast two 16-bit values to 32-bit and fill a 256-bit vector + int16_t tmp[2] = {height, width}; + int32_t* tmp2 = (int32_t*)tmp; + const __m256i v_madd_shift = _mm256_set1_epi32(*tmp2); + __m256i v_res[64]; - for (int i = 0, d = 0; i < samples; i += 16, ++d) { + // Old loop + /*for (int i = 0, d = 0; i < samples; i += 16, ++d) { v_res[d] = _mm256_add_epi16(v_pred_ver[d], v_pred_hor[d]); v_res[d] = _mm256_add_epi16(v_res[d], v_samples); v_res[d] = _mm256_srli_epi16(v_res[d], shift_r); + }*/ + + // New loop + for (int i = 0, d = 0; i < samples; i += 16, ++d) { + __m256i v_lo = _mm256_unpacklo_epi16(v_pred_hor[d], v_pred_ver[d]); + __m256i v_hi = _mm256_unpackhi_epi16(v_pred_hor[d], v_pred_ver[d]); + + // madd will extend the intermediate results to 32-bit to avoid overflows + __m256i v_madd_lo = _mm256_madd_epi16(v_lo, v_madd_shift); + __m256i v_madd_hi = _mm256_madd_epi16(v_hi, v_madd_shift); + + v_madd_lo = _mm256_add_epi32(v_madd_lo, v_samples); + v_madd_hi = _mm256_add_epi32(v_madd_hi, v_samples); + + v_madd_lo = _mm256_srli_epi32(v_madd_lo, shift_r); + v_madd_hi = _mm256_srli_epi32(v_madd_hi, shift_r); + + v_res[d] = _mm256_packs_epi32(v_madd_lo, v_madd_hi); } // debug