Implement AVX2 8x8 filtered DC algorithm

This commit is contained in:
Pauli Oikkonen 2019-11-26 17:10:44 +02:00
parent 5d9b7019ca
commit da370ea36d

View file

@ -641,6 +641,104 @@ static INLINE void pred_filtered_dc_4x4(const uint8_t *ref_top,
_mm_storeu_si128((__m128i *)out_block, final); _mm_storeu_si128((__m128i *)out_block, final);
} }
static INLINE void pred_filtered_dc_8x8(const uint8_t *ref_top,
const uint8_t *ref_left,
uint8_t *out_block)
{
const uint64_t rt_u64 = *(const uint64_t *)(ref_top + 1);
const uint64_t rl_u64 = *(const uint64_t *)(ref_left + 1);
const __m128i zero128 = _mm_setzero_si128();
const __m256i twos = _mm256_set1_epi8(2);
// DC multiplier is 2 at (0, 0), 3 at (*, 0) and (0, *), and 4 at (*, *).
// There is a constant addend of 2 on each pixel, use values from the twos
// register and multipliers of 1 for that, to use maddubs for an (a*b)+c
// operation.
const __m256i mult_up_lo = _mm256_setr_epi32(0x01030102, 0x01030103,
0x01030103, 0x01030103,
0x01040103, 0x01040104,
0x01040104, 0x01040104);
// The 6 lowest rows have same multipliers, also the DC values and addends
// are the same so this works for all of those
const __m256i mult_rest = _mm256_permute4x64_epi64(mult_up_lo, _MM_SHUFFLE(3, 2, 3, 2));
// Every 8-pixel row starts with the next pixel of ref_left. Along with
// doing the shuffling, also expand u8->u16, ie. move bytes 0 and 1 from
// ref_left to bit positions 0 and 128 in rl_up_lo, 2 and 3 to rl_up_hi,
// etc. The places to be zeroed out are 0x80 instead of the usual 0xff,
// because this allows us to form new masks on the fly by adding 0x02-bytes
// to this mask and still retain the highest bits as 1 where things should
// be zeroed out.
const __m256i rl_shuf_up_lo = _mm256_setr_epi32(0x80808000, 0x80808080,
0x80808080, 0x80808080,
0x80808001, 0x80808080,
0x80808080, 0x80808080);
// And don't waste memory or architectural regs, hope these instructions
// will be placed in between the shuffles by the compiler to only use one
// register for the shufmasks, and executed way ahead of time because their
// regs can be renamed.
const __m256i rl_shuf_up_hi = _mm256_add_epi8 (rl_shuf_up_lo, twos);
const __m256i rl_shuf_dn_lo = _mm256_add_epi8 (rl_shuf_up_hi, twos);
const __m256i rl_shuf_dn_hi = _mm256_add_epi8 (rl_shuf_dn_lo, twos);
__m128i eight = _mm_cvtsi32_si128 (8);
__m128i rt = _mm_cvtsi64_si128 (rt_u64);
__m128i rl = _mm_cvtsi64_si128 (rl_u64);
__m128i rtrl = _mm_unpacklo_epi64 (rt, rl);
__m128i sad0 = _mm_sad_epu8 (rtrl, zero128);
__m128i sad1 = _mm_shuffle_epi32 (sad0, _MM_SHUFFLE(1, 0, 3, 2));
__m128i sad2 = _mm_add_epi64 (sad0, sad1);
__m128i sad3 = _mm_add_epi64 (sad2, eight);
__m128i dc_64 = _mm_srli_epi64 (sad3, 4);
__m256i dc_8 = _mm256_broadcastb_epi8(dc_64);
__m256i dc_addend = _mm256_unpacklo_epi8 (dc_8, twos);
__m256i dc_up_lo = _mm256_maddubs_epi16 (dc_addend, mult_up_lo);
__m256i dc_rest = _mm256_maddubs_epi16 (dc_addend, mult_rest);
// rt_dn is all zeros, as is rt_up_hi. This'll get us the rl and rt parts
// in A|B, C|D order instead of A|C, B|D that could be packed into abcd
// order, so these need to be permuted before adding to the weighed DC
// values.
__m256i rt_up_lo = _mm256_cvtepu8_epi16 (rt);
__m256i rlrlrlrl = _mm256_broadcastq_epi64(rl);
__m256i rl_up_lo = _mm256_shuffle_epi8 (rlrlrlrl, rl_shuf_up_lo);
// Everything ref_top is zero except on the very first row
__m256i rt_rl_up_hi = _mm256_shuffle_epi8 (rlrlrlrl, rl_shuf_up_hi);
__m256i rt_rl_dn_lo = _mm256_shuffle_epi8 (rlrlrlrl, rl_shuf_dn_lo);
__m256i rt_rl_dn_hi = _mm256_shuffle_epi8 (rlrlrlrl, rl_shuf_dn_hi);
__m256i rt_rl_up_lo = _mm256_add_epi16 (rt_up_lo, rl_up_lo);
__m256i rt_rl_up_lo_2 = _mm256_permute2x128_si256(rt_rl_up_lo, rt_rl_up_hi, 0x20);
__m256i rt_rl_up_hi_2 = _mm256_permute2x128_si256(rt_rl_up_lo, rt_rl_up_hi, 0x31);
__m256i rt_rl_dn_lo_2 = _mm256_permute2x128_si256(rt_rl_dn_lo, rt_rl_dn_hi, 0x20);
__m256i rt_rl_dn_hi_2 = _mm256_permute2x128_si256(rt_rl_dn_lo, rt_rl_dn_hi, 0x31);
__m256i up_lo = _mm256_add_epi16(rt_rl_up_lo_2, dc_up_lo);
__m256i up_hi = _mm256_add_epi16(rt_rl_up_hi_2, dc_rest);
__m256i dn_lo = _mm256_add_epi16(rt_rl_dn_lo_2, dc_rest);
__m256i dn_hi = _mm256_add_epi16(rt_rl_dn_hi_2, dc_rest);
up_lo = _mm256_srli_epi16(up_lo, 2);
up_hi = _mm256_srli_epi16(up_hi, 2);
dn_lo = _mm256_srli_epi16(dn_lo, 2);
dn_hi = _mm256_srli_epi16(dn_hi, 2);
__m256i res_up = _mm256_packus_epi16(up_lo, up_hi);
__m256i res_dn = _mm256_packus_epi16(dn_lo, dn_hi);
_mm256_storeu_si256(((__m256i *)out_block) + 0, res_up);
_mm256_storeu_si256(((__m256i *)out_block) + 1, res_dn);
}
/** /**
* \brief Generage intra DC prediction with post filtering applied. * \brief Generage intra DC prediction with post filtering applied.
* \param log2_width Log2 of width, range 2..5. * \param log2_width Log2 of width, range 2..5.
@ -655,6 +753,15 @@ static void kvz_intra_pred_filtered_dc_avx2(
kvz_pixel *const out_block) kvz_pixel *const out_block)
{ {
assert(log2_width >= 2 && log2_width <= 5); assert(log2_width >= 2 && log2_width <= 5);
if (log2_width == 2) {
pred_filtered_dc_4x4(ref_top, ref_left, out_block);
return;
} else if (log2_width == 3) {
pred_filtered_dc_8x8(ref_top, ref_left, out_block);
return;
}
const int_fast8_t width = 1 << log2_width; const int_fast8_t width = 1 << log2_width;
const __m256i zero = _mm256_setzero_si256(); const __m256i zero = _mm256_setzero_si256();
@ -794,24 +901,41 @@ static void kvz_intra_pred_filtered_dc_avx2(
out_block[y * width + x] = res >> 2; out_block[y * width + x] = res >> 2;
} }
} }
/*
if (width == 4) { if (width == 4) {
uint8_t tampio[16]; uint8_t tmp[16];
pred_filtered_dc_4x4(ref_top, ref_left, tampio); pred_filtered_dc_4x4(ref_top, ref_left, tmp);
for (int i = 0; i < 16; i++) { for (int i = 0; i < 16; i++) {
if (tampio[i] != out_block[i]) { if (tmp[i] != out_block[i]) {
int j; int j;
printf("mults c: "); print_128_s(mults); printf("mults c: "); print_128_s(mults);
printf("dv c: "); print_128_s(dv); printf("dv c: "); print_128_s(dv);
printf("rits c: "); print_128_s(rits); printf("rits c: "); print_128_s(rits);
printf("rils c: "); print_128_s(rits); printf("rils c: "); print_128_s(rits);
asm("int $3"); asm("int $3");
pred_filtered_dc_4x4(ref_top, ref_left, tampio); pred_filtered_dc_4x4(ref_top, ref_left, tmp);
break; break;
} }
} }
} }
// asm("int $3"); if (width == 8) {
return; uint8_t tmp[64];
pred_filtered_dc_8x8(ref_top, ref_left, tmp);
pred_filtered_dc_8x8(ref_top, ref_left, out_block);
for (int i = 0; i < 64; i++) {
if (tmp[i] != out_block[i]) {
int j;
printf("mults c: "); print_128_s(mults);
printf("dv c: "); print_128_s(dv);
printf("rits c: "); print_128_s(rits);
printf("rils c: "); print_128_s(rits);
asm("int $3");
pred_filtered_dc_8x8(ref_top, ref_left, tmp);
break;
}
}
}
*/
} }
#endif //COMPILE_INTEL_AVX2 && defined X86_64 #endif //COMPILE_INTEL_AVX2 && defined X86_64