Remove left shift from planar half functions. Implement the left shift with madd. Planar preds of width 4, 8 and 16 should work now without overflows. Add loop unroll macros to vertical half functions. Will be added to hor half functions later.

This commit is contained in:
siivonek 2023-09-10 19:46:42 +03:00 committed by Joose Sainio
parent 0eb0f110c2
commit b02fb1b1af

View file

@ -48,6 +48,77 @@
#include "strategyselector.h" #include "strategyselector.h"
#include "strategies/missing-intel-intrinsics.h" #include "strategies/missing-intel-intrinsics.h"
// Y coord tables
ALIGNED(32) static const int8_t planar_avx2_ver_w4ys[1024] = {
63, 1, 63, 1, 63, 1, 63, 1, 62, 2, 62, 2, 62, 2, 62, 2, 61, 3, 61, 3, 61, 3, 61, 3, 60, 4, 60, 4, 60, 4, 60, 4, // offset 0, line == 64
59, 5, 59, 5, 59, 5, 59, 5, 58, 6, 58, 6, 58, 6, 58, 6, 57, 7, 57, 7, 57, 7, 57, 7, 56, 8, 56, 8, 56, 8, 56, 8,
55, 9, 55, 9, 55, 9, 55, 9, 54, 10, 54, 10, 54, 10, 54, 10, 53, 11, 53, 11, 53, 11, 53, 11, 52, 12, 52, 12, 52, 12, 52, 12,
51, 13, 51, 13, 51, 13, 51, 13, 50, 14, 50, 14, 50, 14, 50, 14, 49, 15, 49, 15, 49, 15, 49, 15, 48, 16, 48, 16, 48, 16, 48, 16,
47, 17, 47, 17, 47, 17, 47, 17, 46, 18, 46, 18, 46, 18, 46, 18, 45, 19, 45, 19, 45, 19, 45, 19, 44, 20, 44, 20, 44, 20, 44, 20,
43, 21, 43, 21, 43, 21, 43, 21, 42, 22, 42, 22, 42, 22, 42, 22, 41, 23, 41, 23, 41, 23, 41, 23, 40, 24, 40, 24, 40, 24, 40, 24,
39, 25, 39, 25, 39, 25, 39, 25, 38, 26, 38, 26, 38, 26, 38, 26, 37, 27, 37, 27, 37, 27, 37, 27, 36, 28, 36, 28, 36, 28, 36, 28,
35, 29, 35, 29, 35, 29, 35, 29, 34, 30, 34, 30, 34, 30, 34, 30, 33, 31, 33, 31, 33, 31, 33, 31, 32, 32, 32, 32, 32, 32, 32, 32,
31, 33, 31, 33, 31, 33, 31, 33, 30, 34, 30, 34, 30, 34, 30, 34, 29, 35, 29, 35, 29, 35, 29, 35, 28, 36, 28, 36, 28, 36, 28, 36,
27, 37, 27, 37, 27, 37, 27, 37, 26, 38, 26, 38, 26, 38, 26, 38, 25, 39, 25, 39, 25, 39, 25, 39, 24, 40, 24, 40, 24, 40, 24, 40,
23, 41, 23, 41, 23, 41, 23, 41, 22, 42, 22, 42, 22, 42, 22, 42, 21, 43, 21, 43, 21, 43, 21, 43, 20, 44, 20, 44, 20, 44, 20, 44,
19, 45, 19, 45, 19, 45, 19, 45, 18, 46, 18, 46, 18, 46, 18, 46, 17, 47, 17, 47, 17, 47, 17, 47, 16, 48, 16, 48, 16, 48, 16, 48,
15, 49, 15, 49, 15, 49, 15, 49, 14, 50, 14, 50, 14, 50, 14, 50, 13, 51, 13, 51, 13, 51, 13, 51, 12, 52, 12, 52, 12, 52, 12, 52,
11, 53, 11, 53, 11, 53, 11, 53, 10, 54, 10, 54, 10, 54, 10, 54, 9, 55, 9, 55, 9, 55, 9, 55, 8, 56, 8, 56, 8, 56, 8, 56,
7, 57, 7, 57, 7, 57, 7, 57, 6, 58, 6, 58, 6, 58, 6, 58, 5, 59, 5, 59, 5, 59, 5, 59, 4, 60, 4, 60, 4, 60, 4, 60,
3, 61, 3, 61, 3, 61, 3, 61, 2, 62, 2, 62, 2, 62, 2, 62, 1, 63, 1, 63, 1, 63, 1, 63, 0, 64, 0, 64, 0, 64, 0, 64,
31, 1, 31, 1, 31, 1, 31, 1, 30, 2, 30, 2, 30, 2, 30, 2, 29, 3, 29, 3, 29, 3, 29, 3, 28, 4, 28, 4, 28, 4, 28, 4, // offset 16, line == 32
27, 5, 27, 5, 27, 5, 27, 5, 26, 6, 26, 6, 26, 6, 26, 6, 25, 7, 25, 7, 25, 7, 25, 7, 24, 8, 24, 8, 24, 8, 24, 8,
23, 9, 23, 9, 23, 9, 23, 9, 22, 10, 22, 10, 22, 10, 22, 10, 21, 11, 21, 11, 21, 11, 21, 11, 20, 12, 20, 12, 20, 12, 20, 12,
19, 13, 19, 13, 19, 13, 19, 13, 18, 14, 18, 14, 18, 14, 18, 14, 17, 15, 17, 15, 17, 15, 17, 15, 16, 16, 16, 16, 16, 16, 16, 16,
15, 17, 15, 17, 15, 17, 15, 17, 14, 18, 14, 18, 14, 18, 14, 18, 13, 19, 13, 19, 13, 19, 13, 19, 12, 20, 12, 20, 12, 20, 12, 20,
11, 21, 11, 21, 11, 21, 11, 21, 10, 22, 10, 22, 10, 22, 10, 22, 9, 23, 9, 23, 9, 23, 9, 23, 8, 24, 8, 24, 8, 24, 8, 24,
7, 25, 7, 25, 7, 25, 7, 25, 6, 26, 6, 26, 6, 26, 6, 26, 5, 27, 5, 27, 5, 27, 5, 27, 4, 28, 4, 28, 4, 28, 4, 28,
3, 29, 3, 29, 3, 29, 3, 29, 2, 30, 2, 30, 2, 30, 2, 30, 1, 31, 1, 31, 1, 31, 1, 31, 0, 32, 0, 32, 0, 32, 0, 32,
15, 1, 15, 1, 15, 1, 15, 1, 14, 2, 14, 2, 14, 2, 14, 2, 13, 3, 13, 3, 13, 3, 13, 3, 12, 4, 12, 4, 12, 4, 12, 4, // offset 24, line == 16
11, 5, 11, 5, 11, 5, 11, 5, 10, 6, 10, 6, 10, 6, 10, 6, 9, 7, 9, 7, 9, 7, 9, 7, 8, 8, 8, 8, 8, 8, 8, 8,
7, 9, 7, 9, 7, 9, 7, 9, 6, 10, 6, 10, 6, 10, 6, 10, 5, 11, 5, 11, 5, 11, 5, 11, 4, 12, 4, 12, 4, 12, 4, 12,
3, 13, 3, 13, 3, 13, 3, 13, 2, 14, 2, 14, 2, 14, 2, 14, 1, 15, 1, 15, 1, 15, 1, 15, 0, 16, 0, 16, 0, 16, 0, 16,
7, 1, 7, 1, 7, 1, 7, 1, 6, 2, 6, 2, 6, 2, 6, 2, 5, 3, 5, 3, 5, 3, 5, 3, 4, 4, 4, 4, 4, 4, 4, 4, // offset 28, line == 8
3, 5, 3, 5, 3, 5, 3, 5, 2, 6, 2, 6, 2, 6, 2, 6, 1, 7, 1, 7, 1, 7, 1, 7, 0, 8, 0, 8, 0, 8, 0, 8,
3, 1, 3, 1, 3, 1, 3, 1, 2, 2, 2, 2, 2, 2, 2, 2, 1, 3, 1, 3, 1, 3, 1, 3, 0, 4, 0, 4, 0, 4, 0, 4, // offset 30, line == 4
1, 1, 1, 1, 1, 1, 1, 1, 0, 2, 0, 2, 0, 2, 0, 2, 1, 1, 1, 1, 1, 1, 1, 1, 0, 2, 0, 2, 0, 2, 0, 2, // offset 31. line == 2
};
ALIGNED(32) static const int8_t planar_avx2_ver_w8ys[2048] = {
63, 1, 63, 1, 63, 1, 63, 1, 63, 1, 63, 1, 63, 1, 63, 1, 62, 2, 62, 2, 62, 2, 62, 2, 62, 2, 62, 2, 62, 2, 62, 2,
61, 3, 61, 3, 61, 3, 61, 3, 61, 3, 61, 3, 61, 3, 61, 3, 60, 4, 60, 4, 60, 4, 60, 4, 60, 4, 60, 4, 60, 4, 60, 4,
59, 5, 59, 5, 59, 5, 59, 5, 59, 5, 59, 5, 59, 5, 59, 5, 58, 6, 58, 6, 58, 6, 58, 6, 58, 6, 58, 6, 58, 6, 58, 6,
57, 7, 57, 7, 57, 7, 57, 7, 57, 7, 57, 7, 57, 7, 57, 7, 56, 8, 56, 8, 56, 8, 56, 8, 56, 8, 56, 8, 56, 8, 56, 8,
55, 9, 55, 9, 55, 9, 55, 9, 55, 9, 55, 9, 55, 9, 55, 9, 54, 10, 54, 10, 54, 10, 54, 10, 54, 10, 54, 10, 54, 10, 54, 10,
53, 11, 53, 11, 53, 11, 53, 11, 53, 11, 53, 11, 53, 11, 53, 11, 52, 12, 52, 12, 52, 12, 52, 12, 52, 12, 52, 12, 52, 12, 52, 12,
51, 13, 51, 13, 51, 13, 51, 13, 51, 13, 51, 13, 51, 13, 51, 13, 50, 14, 50, 14, 50, 14, 50, 14, 50, 14, 50, 14, 50, 14, 50, 14,
49, 15, 49, 15, 49, 15, 49, 15, 49, 15, 49, 15, 49, 15, 49, 15, 48, 16, 48, 16, 48, 16, 48, 16, 48, 16, 48, 16, 48, 16, 48, 16,
47, 17, 47, 17, 47, 17, 47, 17, 47, 17, 47, 17, 47, 17, 47, 17, 46, 18, 46, 18, 46, 18, 46, 18, 46, 18, 46, 18, 46, 18, 46, 18,
45, 19, 45, 19, 45, 19, 45, 19, 45, 19, 45, 19, 45, 19, 45, 19, 44, 20, 44, 20, 44, 20, 44, 20, 44, 20, 44, 20, 44, 20, 44, 20,
43, 21, 43, 21, 43, 21, 43, 21, 43, 21, 43, 21, 43, 21, 43, 21, 42, 22, 42, 22, 42, 22, 42, 22, 42, 22, 42, 22, 42, 22, 42, 22,
41, 23, 41, 23, 41, 23, 41, 23, 41, 23, 41, 23, 41, 23, 41, 23, 40, 24, 40, 24, 40, 24, 40, 24, 40, 24, 40, 24, 40, 24, 40, 24,
39, 25, 39, 25, 39, 25, 39, 25, 39, 25, 39, 25, 39, 25, 39, 25, 38, 26, 38, 26, 38, 26, 38, 26, 38, 26, 38, 26, 38, 26, 38, 26,
37, 27, 37, 27, 37, 27, 37, 27, 37, 27, 37, 27, 37, 27, 37, 27, 36, 28, 36, 28, 36, 28, 36, 28, 36, 28, 36, 28, 36, 28, 36, 28,
35, 29, 35, 29, 35, 29, 35, 29, 35, 29, 35, 29, 35, 29, 35, 29, 34, 30, 34, 30, 34, 30, 34, 30, 34, 30, 34, 30, 34, 30, 34, 30,
33, 31, 33, 31, 33, 31, 33, 31, 33, 31, 33, 31, 33, 31, 33, 31, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32,
31, 33, 31, 33, 31, 33, 31, 33, 31, 33, 31, 33, 31, 33, 31, 33, 30, 34, 30, 34, 30, 34, 30, 34, 30, 34, 30, 34, 30, 34, 30, 34,
29, 35, 29, 35, 29, 35, 29, 35, 29, 35, 29, 35, 29, 35, 29, 35, 28, 36, 28, 36, 28, 36, 28, 36, 28, 36, 28, 36, 28, 36, 28, 36,
27, 37, 27, 37, 27, 37, 27, 37, 27, 37, 27, 37, 27, 37, 27, 37, 26, 38, 26, 38, 26, 38, 26, 38, 26, 38, 26, 38, 26, 38, 26, 38,
25, 39, 25, 39, 25, 39, 25, 39, 25, 39, 25, 39, 25, 39, 25, 39, 24, 40, 24, 40, 24, 40, 24, 40, 24, 40, 24, 40, 24, 40, 24, 40,
23, 41, 23, 41, 23, 41, 23, 41, 23, 41, 23, 41, 23, 41, 23, 41, 22, 42, 22, 42, 22, 42, 22, 42, 22, 42, 22, 42, 22, 42, 22, 42,
21, 43, 21, 43, 21, 43, 21, 43, 21, 43, 21, 43, 21, 43, 21, 43, 20, 44, 20, 44, 20, 44, 20, 44, 20, 44, 20, 44, 20, 44, 20, 44,
19, 45, 19, 45, 19, 45, 19, 45, 19, 45, 19, 45, 19, 45, 19, 45, 18, 46, 18, 46, 18, 46, 18, 46, 18, 46, 18, 46, 18, 46, 18, 46,
17, 47, 17, 47, 17, 47, 17, 47, 17, 47, 17, 47, 17, 47, 17, 47, 16, 48, 16, 48, 16, 48, 16, 48, 16, 48, 16, 48, 16, 48, 16, 48,
15, 49, 15, 49, 15, 49, 15, 49, 15, 49, 15, 49, 15, 49, 15, 49, 14, 50, 14, 50, 14, 50, 14, 50, 14, 50, 14, 50, 14, 50, 14, 50,
13, 51, 13, 51, 13, 51, 13, 51, 13, 51, 13, 51, 13, 51, 13, 51, 12, 52, 12, 52, 12, 52, 12, 52, 12, 52, 12, 52, 12, 52, 12, 52,
11, 53, 11, 53, 11, 53, 11, 53, 11, 53, 11, 53, 11, 53, 11, 53, 10, 54, 10, 54, 10, 54, 10, 54, 10, 54, 10, 54, 10, 54, 10, 54,
9, 55, 9, 55, 9, 55, 9, 55, 9, 55, 9, 55, 9, 55, 9, 55, 8, 56, 8, 56, 8, 56, 8, 56, 8, 56, 8, 56, 8, 56, 8, 56,
7, 57, 7, 57, 7, 57, 7, 57, 7, 57, 7, 57, 7, 57, 7, 57, 6, 58, 6, 58, 6, 58, 6, 58, 6, 58, 6, 58, 6, 58, 6, 58,
5, 59, 5, 59, 5, 59, 5, 59, 5, 59, 5, 59, 5, 59, 5, 59, 4, 60, 4, 60, 4, 60, 4, 60, 4, 60, 4, 60, 4, 60, 4, 60,
3, 61, 3, 61, 3, 61, 3, 61, 3, 61, 3, 61, 3, 61, 3, 61, 2, 62, 2, 62, 2, 62, 2, 62, 2, 62, 2, 62, 2, 62, 2, 62,
1, 63, 1, 63, 1, 63, 1, 63, 1, 63, 1, 63, 1, 63, 1, 63, 0, 64, 0, 64, 0, 64, 0, 64, 0, 64, 0, 64, 0, 64, 0, 64,
};
/** /**
* \brief Generate angular predictions. * \brief Generate angular predictions.
* \param cu_loc CU locationand size data. * \param cu_loc CU locationand size data.
@ -633,13 +704,13 @@ static void intra_pred_planar_hor_w4(const uvg_pixel* ref, const int line, const
for (int i = 0, d = 0; i < line; i += 4, ++d) { for (int i = 0, d = 0; i < line; i += 4, ++d) {
// Handle 4 lines at a time // Handle 4 lines at a time
// TODO: setr is VERY SLOW, replace this
__m256i v_ref = _mm256_setr_epi16(ref[i + 1], ref[i + 1], ref[i + 1], ref[i + 1], ref[i + 2], ref[i + 2], ref[i + 2], ref[i + 2], __m256i v_ref = _mm256_setr_epi16(ref[i + 1], ref[i + 1], ref[i + 1], ref[i + 1], ref[i + 2], ref[i + 2], ref[i + 2], ref[i + 2],
ref[i + 3], ref[i + 3], ref[i + 3], ref[i + 3], ref[i + 4], ref[i + 4], ref[i + 4], ref[i + 4]); ref[i + 3], ref[i + 3], ref[i + 3], ref[i + 3], ref[i + 4], ref[i + 4], ref[i + 4], ref[i + 4]);
__m256i v_tmp = _mm256_mullo_epi16(v_ref, v_ref_coeff); __m256i v_tmp = _mm256_mullo_epi16(v_ref, v_ref_coeff);
v_tmp = _mm256_add_epi16(v_last_ref_mul, v_tmp); dst[d] = _mm256_add_epi16(v_last_ref_mul, v_tmp);
dst[d] = _mm256_slli_epi16(v_tmp, shift);
} }
} }
static void intra_pred_planar_hor_w8(const uvg_pixel* ref, const int line, const int shift, __m256i* dst) static void intra_pred_planar_hor_w8(const uvg_pixel* ref, const int line, const int shift, __m256i* dst)
@ -661,8 +732,7 @@ static void intra_pred_planar_hor_w8(const uvg_pixel* ref, const int line, const
__m256i v_tmp = _mm256_mullo_epi16(v_ref, v_ref_coeff); __m256i v_tmp = _mm256_mullo_epi16(v_ref, v_ref_coeff);
v_tmp = _mm256_add_epi16(v_last_ref_mul, v_tmp); dst[d] = _mm256_add_epi16(v_last_ref_mul, v_tmp);
dst[d] = _mm256_slli_epi16(v_tmp, shift);
} }
} }
static void intra_pred_planar_hor_w16(const uvg_pixel* ref, const int line, const int shift, __m256i* dst) static void intra_pred_planar_hor_w16(const uvg_pixel* ref, const int line, const int shift, __m256i* dst)
@ -679,8 +749,7 @@ static void intra_pred_planar_hor_w16(const uvg_pixel* ref, const int line, cons
__m256i v_tmp = _mm256_mullo_epi16(v_ref, v_ref_coeff); // TODO: the result is needed immediately after this. This leads to NOPs, consider doing multiple lines at a time __m256i v_tmp = _mm256_mullo_epi16(v_ref, v_ref_coeff); // TODO: the result is needed immediately after this. This leads to NOPs, consider doing multiple lines at a time
v_tmp = _mm256_add_epi16(v_last_ref_mul, v_tmp); dst[d] = _mm256_add_epi16(v_last_ref_mul, v_tmp);
dst[d] = _mm256_slli_epi16(v_tmp, shift);
} }
} }
static void intra_pred_planar_hor_w32(const uvg_pixel* ref, const int line, const int shift, __m256i* dst) {} static void intra_pred_planar_hor_w32(const uvg_pixel* ref, const int line, const int shift, __m256i* dst) {}
@ -691,30 +760,53 @@ static void intra_pred_planar_ver_w4(const uvg_pixel* ref, const int line, const
{ {
const __m256i v_last_ref = _mm256_set1_epi8(ref[line + 1]); const __m256i v_last_ref = _mm256_set1_epi8(ref[line + 1]);
// Overflow possible for this width if line > 32
const bool overflow = line > 32;
// Got four 8-bit references, or 32 bits of data. Duplicate to fill a whole 256-bit vector. // Got four 8-bit references, or 32 bits of data. Duplicate to fill a whole 256-bit vector.
const uint32_t* tmp = (const uint32_t*)&ref[1]; // Cast to 32 bit int to load 4 refs at the same time const uint32_t* tmp = (const uint32_t*)&ref[1]; // Cast to 32 bit int to load 4 refs at the same time
const __m256i v_ref = _mm256_set1_epi32(*tmp); const __m256i v_ref = _mm256_set1_epi32(*tmp);
// Handle 4 lines at a time const __m256i* v_ys = (const __m256i*)planar_avx2_ver_w4ys;
for (int y = 0, d = 0; y < line; y += 4, ++d) {
const int a1 = line - 1 - (y + 0);
const int a2 = line - 1 - (y + 1);
const int a3 = line - 1 - (y + 2);
const int a4 = line - 1 - (y + 3);
const int b1 = (y + 0) + 1;
const int b2 = (y + 1) + 1;
const int b3 = (y + 2) + 1;
const int b4 = (y + 3) + 1;
__m256i v_ys = _mm256_setr_epi8(a1, b1, a1, b1, a1, b1, a1, b1, // Table offset
a2, b2, a2, b2, a2, b2, a2, b2, int offset;
a3, b3, a3, b3, a3, b3, a3, b3, if (line == 64) {
a4, b4, a4, b4, a4, b4, a4, b4); // TODO: these could be loaded from a table offset = 0;
__m256i v_lo = _mm256_unpacklo_epi8(v_ref, v_last_ref);
__m256i v_madd_lo = _mm256_maddubs_epi16(v_lo, v_ys);
dst[d] = _mm256_slli_epi16(v_madd_lo, shift);
} }
else if (line == 32) {
offset = 16;
}
else if (line == 16) {
offset = 24;
}
else if (line == 8) {
offset = 28;
}
else { // Do not care about lines < 4 since they are illegal
offset = 30;
}
// Handle 4 lines at a time
#define UNROLL_LOOP(num) \
for (int y = 0, s = offset, d = 0; y < (num); y += 4, ++s, ++d) { \
__m256i v_lo = _mm256_unpacklo_epi8(v_ref, v_last_ref); \
dst[d] = _mm256_maddubs_epi16(v_lo, v_ys[s]); \
}
switch (line) {
case 1: UNROLL_LOOP(1); break;
case 2: UNROLL_LOOP(2); break;
case 4: UNROLL_LOOP(4); break;
case 8: UNROLL_LOOP(8); break;
case 16: UNROLL_LOOP(16); break;
case 32: UNROLL_LOOP(32); break;
case 64: UNROLL_LOOP(64); break;
default:
assert(false && "Invalid dimension.");
break;
}
#undef UNROLL_LOOP
} }
static void intra_pred_planar_ver_w8(const uvg_pixel* ref, const int line, const int shift, __m256i* dst) static void intra_pred_planar_ver_w8(const uvg_pixel* ref, const int line, const int shift, __m256i* dst)
{ {
@ -726,33 +818,57 @@ static void intra_pred_planar_ver_w8(const uvg_pixel* ref, const int line, const
v_ref = _mm256_inserti128_si256(v_ref, v_ref_raw, 1); v_ref = _mm256_inserti128_si256(v_ref, v_ref_raw, 1);
v_ref = _mm256_shuffle_epi32(v_ref, _MM_SHUFFLE(1, 1, 0, 0)); v_ref = _mm256_shuffle_epi32(v_ref, _MM_SHUFFLE(1, 1, 0, 0));
// Handle 4 lines at a time, unless line == 2 const __m256i* v_ys = (const __m256i*)planar_avx2_ver_w4ys;
for (int y = 0, d = 0; y < line; y += 4, d += 2) {
const int a1 = line - 1 - (y + 0);
const int b1 = (y + 0) + 1;
const int a2 = line - 1 - (y + 1);
const int b2 = (y + 1) + 1;
const int a3 = line - 1 - (y + 2);
const int b3 = (y + 2) + 1;
const int a4 = line - 1 - (y + 3);
const int b4 = (y + 3) + 1;
__m256i v_ys = _mm256_setr_epi8(a1, b1, a1, b1, a1, b1, a1, b1,
a2, b2, a2, b2, a2, b2, a2, b2,
a3, b3, a3, b3, a3, b3, a3, b3,
a4, b4, a4, b4, a4, b4, a4, b4); // TODO: these could be loaded from a table
__m256i v_lo = _mm256_unpacklo_epi8(v_ref, v_last_ref);
__m256i v_hi = _mm256_unpackhi_epi8(v_ref, v_last_ref);
__m256i v_madd_lo = _mm256_maddubs_epi16(v_lo, v_ys); // Table offset
__m256i v_madd_hi = _mm256_maddubs_epi16(v_hi, v_ys); int offset;
v_madd_lo = _mm256_slli_epi16(v_madd_lo, shift); if (line == 64) {
v_madd_hi = _mm256_slli_epi16(v_madd_hi, shift); offset = 0;
__m256i v_tmp0 = _mm256_permute2x128_si256(v_madd_lo, v_madd_hi, 0x20);
__m256i v_tmp1 = _mm256_permute2x128_si256(v_madd_lo, v_madd_hi, 0x31);
dst[d + 0] = _mm256_permute4x64_epi64(v_tmp0, _MM_SHUFFLE(3, 1, 2, 0));
dst[d + 1] = _mm256_permute4x64_epi64(v_tmp1, _MM_SHUFFLE(3, 1, 2, 0));
} }
else if (line == 32) {
offset = 16;
}
else if (line == 16) {
offset = 24;
}
else if (line == 8) {
offset = 28;
}
else if (line == 4) {
offset = 30;
}
else { // Do not care about line == 1 since it is illegal for this width
offset = 31;
}
// Handle 4 lines at a time
#define UNROLL_LOOP(num) \
for (int y = 0, s = offset, d = 0; y < (num); y += 4, ++s, d += 2) { \
__m256i v_lo = _mm256_unpacklo_epi8(v_ref, v_last_ref); \
__m256i v_hi = _mm256_unpackhi_epi8(v_ref, v_last_ref); \
\
__m256i v_madd_lo = _mm256_maddubs_epi16(v_lo, v_ys[s]); \
__m256i v_madd_hi = _mm256_maddubs_epi16(v_hi, v_ys[s]); \
__m256i v_tmp0 = _mm256_permute2x128_si256(v_madd_lo, v_madd_hi, 0x20); \
__m256i v_tmp1 = _mm256_permute2x128_si256(v_madd_lo, v_madd_hi, 0x31); \
\
dst[d + 0] = _mm256_permute4x64_epi64(v_tmp0, _MM_SHUFFLE(3, 1, 2, 0)); \
dst[d + 1] = _mm256_permute4x64_epi64(v_tmp1, _MM_SHUFFLE(3, 1, 2, 0)); \
}
switch (line) {
case 1: UNROLL_LOOP(1); break;
case 2: UNROLL_LOOP(2); break;
case 4: UNROLL_LOOP(4); break;
case 8: UNROLL_LOOP(8); break;
case 16: UNROLL_LOOP(16); break;
case 32: UNROLL_LOOP(32); break;
case 64: UNROLL_LOOP(64); break;
default:
assert(false && "Invalid dimension.");
break;
}
#undef UNROLL_LOOP
} }
static void intra_pred_planar_ver_w16(const uvg_pixel* ref, const int line, const int shift, __m256i* dst) static void intra_pred_planar_ver_w16(const uvg_pixel* ref, const int line, const int shift, __m256i* dst)
{ {
@ -763,26 +879,55 @@ static void intra_pred_planar_ver_w16(const uvg_pixel* ref, const int line, cons
__m256i v_ref = _mm256_castsi128_si256(v_ref_raw); __m256i v_ref = _mm256_castsi128_si256(v_ref_raw);
v_ref = _mm256_inserti128_si256(v_ref, v_ref_raw, 1); v_ref = _mm256_inserti128_si256(v_ref, v_ref_raw, 1);
// Handle 2 lines at a time const __m256i* v_ys = (const __m256i*)planar_avx2_ver_w8ys;
for (int y = 0; y < line; y += 2) {
const int a1 = line - 1 - (y + 0);
const int b1 = (y + 0) + 1;
const int a2 = line - 1 - (y + 1);
const int b2 = (y + 1) + 1;
__m256i v_ys = _mm256_setr_epi8(a1, b1, a1, b1, a1, b1, a1, b1,
a1, b1, a1, b1, a1, b1, a1, b1,
a2, b2, a2, b2, a2, b2, a2, b2,
a2, b2, a2, b2, a2, b2, a2, b2); // TODO: these could be loaded from a table
__m256i v_lo = _mm256_unpacklo_epi8(v_ref, v_last_ref);
__m256i v_hi = _mm256_unpackhi_epi8(v_ref, v_last_ref);
__m256i v_madd_lo = _mm256_maddubs_epi16(v_lo, v_ys); // Table offset
__m256i v_madd_hi = _mm256_maddubs_epi16(v_hi, v_ys); int offset;
v_madd_lo = _mm256_slli_epi16(v_madd_lo, shift); if (line == 64) {
v_madd_hi = _mm256_slli_epi16(v_madd_hi, shift); offset = 0;
dst[y + 0] = _mm256_permute2x128_si256(v_madd_lo, v_madd_hi, 0x20);
dst[y + 1] = _mm256_permute2x128_si256(v_madd_lo, v_madd_hi, 0x31);
} }
else if (line == 32) {
offset = 16;
}
else if (line == 16) {
offset = 24;
}
else if (line == 8) {
offset = 28;
}
else if (line == 4) {
offset = 30;
}
else { // Do not care about line == 1 since it is illegal for this width
offset = 31;
}
// These stay constant through the loop
const __m256i v_lo = _mm256_unpacklo_epi8(v_ref, v_last_ref);
const __m256i v_hi = _mm256_unpackhi_epi8(v_ref, v_last_ref);
// Handle 2 lines at a time
#define UNROLL_LOOP(num) \
for (int y = 0, s = offset; y < (num); y += 2, ++s) { \
__m256i v_madd_lo = _mm256_maddubs_epi16(v_lo, v_ys[s]); \
__m256i v_madd_hi = _mm256_maddubs_epi16(v_hi, v_ys[s]); \
dst[y + 0] = _mm256_permute2x128_si256(v_madd_lo, v_madd_hi, 0x20); \
dst[y + 1] = _mm256_permute2x128_si256(v_madd_lo, v_madd_hi, 0x31); \
}
switch (line) {
case 1: UNROLL_LOOP(1); break;
case 2: UNROLL_LOOP(2); break;
case 4: UNROLL_LOOP(4); break;
case 8: UNROLL_LOOP(8); break;
case 16: UNROLL_LOOP(16); break;
case 32: UNROLL_LOOP(32); break;
case 64: UNROLL_LOOP(64); break;
default:
assert(false && "Invalid dimension.");
break;
}
#undef UNROLL_LOOP
} }
static void intra_pred_planar_ver_w32(const uvg_pixel* ref, const int line, const int shift, __m256i* dst) {} static void intra_pred_planar_ver_w32(const uvg_pixel* ref, const int line, const int shift, __m256i* dst) {}
@ -802,7 +947,7 @@ void uvg_intra_pred_planar_avx2(const cu_loc_t* const cu_loc,
const int width = color == COLOR_Y ? cu_loc->width : cu_loc->chroma_width; const int width = color == COLOR_Y ? cu_loc->width : cu_loc->chroma_width;
const int height = color == COLOR_Y ? cu_loc->height : cu_loc->chroma_height; const int height = color == COLOR_Y ? cu_loc->height : cu_loc->chroma_height;
const int samples = width * height; const int samples = width * height;
const __m256i v_samples = _mm256_set1_epi16(samples); const __m256i v_samples = _mm256_set1_epi32(samples);
const int log2_width = uvg_g_convert_to_log2[width]; const int log2_width = uvg_g_convert_to_log2[width];
const int log2_height = uvg_g_convert_to_log2[height]; const int log2_height = uvg_g_convert_to_log2[height];
@ -821,11 +966,35 @@ void uvg_intra_pred_planar_avx2(const cu_loc_t* const cu_loc,
int16_t* hor_res = (int16_t*)v_pred_hor; int16_t* hor_res = (int16_t*)v_pred_hor;
int16_t* ver_res = (int16_t*)v_pred_ver; int16_t* ver_res = (int16_t*)v_pred_ver;
// Cast two 16-bit values to 32-bit and fill a 256-bit vector
int16_t tmp[2] = {height, width};
int32_t* tmp2 = (int32_t*)tmp;
const __m256i v_madd_shift = _mm256_set1_epi32(*tmp2);
__m256i v_res[64]; __m256i v_res[64];
for (int i = 0, d = 0; i < samples; i += 16, ++d) { // Old loop
/*for (int i = 0, d = 0; i < samples; i += 16, ++d) {
v_res[d] = _mm256_add_epi16(v_pred_ver[d], v_pred_hor[d]); v_res[d] = _mm256_add_epi16(v_pred_ver[d], v_pred_hor[d]);
v_res[d] = _mm256_add_epi16(v_res[d], v_samples); v_res[d] = _mm256_add_epi16(v_res[d], v_samples);
v_res[d] = _mm256_srli_epi16(v_res[d], shift_r); v_res[d] = _mm256_srli_epi16(v_res[d], shift_r);
}*/
// New loop
for (int i = 0, d = 0; i < samples; i += 16, ++d) {
__m256i v_lo = _mm256_unpacklo_epi16(v_pred_hor[d], v_pred_ver[d]);
__m256i v_hi = _mm256_unpackhi_epi16(v_pred_hor[d], v_pred_ver[d]);
// madd will extend the intermediate results to 32-bit to avoid overflows
__m256i v_madd_lo = _mm256_madd_epi16(v_lo, v_madd_shift);
__m256i v_madd_hi = _mm256_madd_epi16(v_hi, v_madd_shift);
v_madd_lo = _mm256_add_epi32(v_madd_lo, v_samples);
v_madd_hi = _mm256_add_epi32(v_madd_hi, v_samples);
v_madd_lo = _mm256_srli_epi32(v_madd_lo, shift_r);
v_madd_hi = _mm256_srli_epi32(v_madd_hi, shift_r);
v_res[d] = _mm256_packs_epi32(v_madd_lo, v_madd_hi);
} }
// debug // debug