32x32 filtered DC prediction in AVX2

This commit is contained in:
Pauli Oikkonen 2019-12-10 18:59:00 +02:00
parent fb2481b7e4
commit 169314de4f

View file

@ -821,6 +821,84 @@ static INLINE void pred_filtered_dc_16x16(const uint8_t *ref_top,
}
}
static INLINE void pred_filtered_dc_32x32(const uint8_t *ref_top,
const uint8_t *ref_left,
uint8_t *out_block)
{
const __m256i rt = _mm256_loadu_si256((const __m256i *)(ref_top + 1));
const __m256i rl = _mm256_loadu_si256((const __m256i *)(ref_left + 1));
const __m256i zero = _mm256_setzero_si256();
const __m256i twos = _mm256_set1_epi8(2);
const __m256i mult_r0lo = _mm256_setr_epi32(0x01030102, 0x01030103,
0x01030103, 0x01030103,
0x01030103, 0x01030103,
0x01030103, 0x01030103);
const __m256i mult_left = _mm256_set1_epi16(0x0103);
const __m256i lm8_bmask = cvt_u32_si256 (0xff);
const __m256i bshif_msk = _mm256_setr_epi32(0x04030201, 0x08070605,
0x0c0b0a09, 0x800f0e0d,
0x03020100, 0x07060504,
0x0b0a0908, 0x0f0e0d0c);
__m256i debias = cvt_u32_si256(32);
__m256i sad0_t = _mm256_sad_epu8 (rt, zero);
__m256i sad0_l = _mm256_sad_epu8 (rl, zero);
__m256i sad0 = _mm256_add_epi64 (sad0_t, sad0_l);
__m256i sad1 = _mm256_permute4x64_epi64(sad0, _MM_SHUFFLE(1, 0, 3, 2));
__m256i sad2 = _mm256_add_epi64 (sad0, sad1);
__m256i sad3 = _mm256_shuffle_epi32 (sad2, _MM_SHUFFLE(1, 0, 3, 2));
__m256i sad4 = _mm256_add_epi64 (sad2, sad3);
__m256i sad5 = _mm256_add_epi64 (sad4, debias);
__m256i dc_64 = _mm256_srli_epi64 (sad5, 6);
__m128i dc_64_ = _mm256_castsi256_si128 (dc_64);
__m256i dc_8 = _mm256_broadcastb_epi8 (dc_64_);
__m256i rtlo = _mm256_unpacklo_epi8 (rt, zero);
__m256i rllo = _mm256_unpacklo_epi8 (rl, zero);
__m256i rthi = _mm256_unpackhi_epi8 (rt, zero);
__m256i rlhi = _mm256_unpackhi_epi8 (rl, zero);
__m256i dc_addend = _mm256_unpacklo_epi8 (dc_8, twos);
__m256i r0lo = _mm256_maddubs_epi16 (dc_addend, mult_r0lo);
__m256i r0hi = _mm256_maddubs_epi16 (dc_addend, mult_left);
__m256i c0dc = r0hi;
r0lo = _mm256_add_epi16 (r0lo, rtlo);
r0hi = _mm256_add_epi16 (r0hi, rthi);
__m256i rlr0 = _mm256_blendv_epi8 (zero, rl, lm8_bmask);
r0lo = _mm256_add_epi16 (r0lo, rlr0);
r0lo = _mm256_srli_epi16 (r0lo, 2);
r0hi = _mm256_srli_epi16 (r0hi, 2);
__m256i r0 = _mm256_packus_epi16 (r0lo, r0hi);
_mm256_storeu_si256((__m256i *)out_block, r0);
__m256i c0lo = _mm256_add_epi16 (c0dc, rllo);
__m256i c0hi = _mm256_add_epi16 (c0dc, rlhi);
c0lo = _mm256_srli_epi16 (c0lo, 2);
c0hi = _mm256_srli_epi16 (c0hi, 2);
__m256i c0 = _mm256_packus_epi16 (c0lo, c0hi);
// r0 already handled!
for (uint32_t y = 1; y < 32; y++) {
if (y == 16) {
c0 = _mm256_permute4x64_epi64(c0, _MM_SHUFFLE(1, 0, 3, 2));
} else {
c0 = _mm256_shuffle_epi8 (c0, bshif_msk);
}
__m256i curr_row = _mm256_blendv_epi8 (dc_8, c0, lm8_bmask);
_mm256_storeu_si256(((__m256i *)out_block) + y, curr_row);
}
}
/**
* \brief Generage intra DC prediction with post filtering applied.
* \param log2_width Log2 of width, range 2..5.
@ -830,213 +908,22 @@ static INLINE void pred_filtered_dc_16x16(const uint8_t *ref_top,
*/
static void kvz_intra_pred_filtered_dc_avx2(
const int_fast8_t log2_width,
const kvz_pixel *const ref_top,
const kvz_pixel *const ref_left,
kvz_pixel *const out_block)
const kvz_pixel *ref_top,
const kvz_pixel *ref_left,
kvz_pixel *out_block)
{
assert(log2_width >= 2 && log2_width <= 5);
assert(sizeof(kvz_pixel) == sizeof(uint8_t));
if (log2_width == 2) {
pred_filtered_dc_4x4(ref_top, ref_left, out_block);
return;
} else if (log2_width == 3) {
pred_filtered_dc_8x8(ref_top, ref_left, out_block);
return;
} else if (log2_width == 4) {
pred_filtered_dc_16x16(ref_top, ref_left, out_block);
return;
} else if (log2_width == 5) {
pred_filtered_dc_32x32(ref_top, ref_left, out_block);
}
const int_fast8_t width = 1 << log2_width;
const __m256i zero = _mm256_setzero_si256();
const __m128i wid_v = _mm_cvtsi32_si128(width);
// Generate masks to load <width> first pixels using these. If log2_width
// is 5, start from offset 0.. if 4, offset 4, 3 -> offset 6, 2 -> 7
static const int32_t ldmasks[] = {
-1, -1, -1, -1, -1, -1, -1, -1,
0, 0, 0, 0, 0, 0, 0,
};
uint32_t l2w_dwords = log2_width - 2;
uint32_t ldm_id = (7 >> l2w_dwords) << l2w_dwords;
__m256i ldst_mask = _mm256_loadu_si256((const __m256i *)(ldmasks + ldm_id));
__m256i rt = _mm256_maskload_epi32((const int32_t *)(ref_top + 1), ldst_mask);
__m256i rl = _mm256_maskload_epi32((const int32_t *)(ref_left + 1), ldst_mask);
__m256i rts = _mm256_sad_epu8 (rt, zero);
__m256i rls = _mm256_sad_epu8 (rl, zero);
__m256i sum0 = _mm256_add_epi64 (rts, rls);
__m256i sum1 = _mm256_permute4x64_epi64(sum0, _MM_SHUFFLE(1, 0, 3, 2));
__m256i sum2 = _mm256_add_epi64 (sum0, sum1);
__m256i sum3 = _mm256_shuffle_epi32 (sum2, _MM_SHUFFLE(1, 0, 3, 2));
__m256i sum4 = _mm256_add_epi64 (sum2, sum3);
__m128i sum5 = _mm256_castsi256_si128 (sum4);
__m128i sum6 = _mm_add_epi64 (sum5, wid_v);
__m128i l2wp1 = _mm_cvtsi32_si128 (log2_width + 1);
__m128i dc_32 = _mm_srl_epi32 (sum6, l2wp1);
__m256i dc_16 = _mm256_broadcastw_epi16 (dc_32);
__m256i dc_8 = _mm256_broadcastb_epi8 (dc_32);
////////////////////////////////////////////////////////////////////
int_fast16_t sum = 0;
for (int_fast8_t i = 0; i < width; ++i) {
sum += ref_top[i + 1];
sum += ref_left[i + 1];
}
const kvz_pixel dc_val = (sum + width) >> (log2_width + 1);
// int32_t sum_s32 = _mm_cvtsi128_si32(_mm256_castsi256_si128(sum16));
// int16_t sum_s = (int16_t)sum_s32;
int32_t dc_s32 = _mm_cvtsi128_si32(_mm256_castsi256_si128(dc_8));
uint8_t dc_s = (uint8_t)dc_s32;
// assert(sum_s == sum);
assert(dc_val == dc_s);
////////////////////////////////////////////////////////////////////
const __m256i ones = _mm256_set1_epi8( 1);
const __m256i twos = _mm256_set1_epi8( 2);
const __m256i ff = _mm256_set1_epi8(-1);
__m256i ref_lefts = _mm256_maskload_epi32((const int32_t *)(ref_left + 1), ldst_mask);
__m256i ref_tops = _mm256_maskload_epi32((const int32_t *)(ref_top + 1), ldst_mask);
uint8_t mults[16];
uint8_t dv[16];
uint8_t rits[16];
uint8_t rils[16];
// Filter top-left with ([1 2 1] / 4), rest of the boundary with ([1 3] / 4)
for (int_fast8_t y = 0; y < width; ++y) {
__m256i rt_lo, rt_hi, rl_lo, rl_hi;
if (y == 0) {
__m256i rt_radd_l = _mm256_unpacklo_epi8(ref_tops, twos);
__m256i rt_radd_h = _mm256_unpackhi_epi8(ref_tops, twos);
rt_lo = _mm256_maddubs_epi16(rt_radd_l, ones);
rt_hi = _mm256_maddubs_epi16(rt_radd_h, ones);
} else {
rt_lo = zero;
rt_hi = zero;
}
uint32_t which_rl_u32 = y >> 2;
uint32_t which_rl_u8 = y & 3;
__m256i rl_u32_mask = _mm256_insert_epi32 (zero, which_rl_u32, 1);
__m256i rl_u8_mask = _mm256_insert_epi8 (ff, which_rl_u8, 1);
__m256i curr_rl_u32 = _mm256_permutevar8x32_epi32(ref_lefts, rl_u32_mask);
__m256i curr_rl = _mm256_shuffle_epi8 (curr_rl_u32, rl_u8_mask);
// print_256(curr_rl);
for (int_fast8_t x = 0; x < width; ++x) {
uint32_t rl_s;
uint32_t rt_s;
uint8_t mult_s;
int daa = x + y * width;
// DONE
if (x == 0)
rl_s = ref_left[y + 1];
else
rl_s = 0;
// /DONE
if (y == 0) {
// DONE
rt_s = ref_top[x + 1];
// rt_add_s = 2;
// /DONE
if (x == 0) {
mult_s = 2;
} else {
mult_s = 3;
}
} else {
// DONE
rt_s = 0;
// rt_add_s = 0;
// /DONE
if (x == 0) {
mult_s = 3;
} else {
mult_s = 4;
}
}
if (width == 4) {
mults[daa] = mult_s;
dv[daa] = dc_val;
rits[daa] = rt_s;
rils[daa] = rl_s;
}
uint16_t dc_multd = mult_s * dc_val;
uint16_t res = rt_s + rl_s + dc_multd + 2;
out_block[y * width + x] = res >> 2;
}
}
/*
if (width == 4) {
uint8_t tmp[16];
pred_filtered_dc_4x4(ref_top, ref_left, tmp);
for (int i = 0; i < 16; i++) {
if (tmp[i] != out_block[i]) {
int j;
printf("mults c: "); print_128_s(mults);
printf("dv c: "); print_128_s(dv);
printf("rits c: "); print_128_s(rits);
printf("rils c: "); print_128_s(rits);
asm("int $3");
pred_filtered_dc_4x4(ref_top, ref_left, tmp);
break;
}
}
}
if (width == 8) {
uint8_t tmp[64];
pred_filtered_dc_8x8(ref_top, ref_left, tmp);
pred_filtered_dc_8x8(ref_top, ref_left, out_block);
for (int i = 0; i < 64; i++) {
if (tmp[i] != out_block[i]) {
int j;
printf("mults c: "); print_128_s(mults);
printf("dv c: "); print_128_s(dv);
printf("rits c: "); print_128_s(rits);
printf("rils c: "); print_128_s(rits);
asm("int $3");
pred_filtered_dc_8x8(ref_top, ref_left, tmp);
break;
}
}
}
if (width == 16) {
uint8_t tmp[256];
pred_filtered_dc_16x16(ref_top, ref_left, tmp);
for (int i = 0; i < 256; i++) {
if (tmp[i] != out_block[i]) {
int j;
printf("mults c: "); print_128_s(mults);
printf("dv c: "); print_128_s(dv);
printf("rits c: "); print_128_s(rits);
printf("rils c: "); print_128_s(rits);
asm("int $3");
pred_filtered_dc_16x16(ref_top, ref_left, tmp);
break;
}
}
}
*/
}
#endif //COMPILE_INTEL_AVX2 && defined X86_64