diff --git a/src/strategies/avx2/intra-avx2.c b/src/strategies/avx2/intra-avx2.c index 5679715c..4b6f4d4e 100644 --- a/src/strategies/avx2/intra-avx2.c +++ b/src/strategies/avx2/intra-avx2.c @@ -641,6 +641,104 @@ static INLINE void pred_filtered_dc_4x4(const uint8_t *ref_top, _mm_storeu_si128((__m128i *)out_block, final); } +static INLINE void pred_filtered_dc_8x8(const uint8_t *ref_top, + const uint8_t *ref_left, + uint8_t *out_block) +{ + const uint64_t rt_u64 = *(const uint64_t *)(ref_top + 1); + const uint64_t rl_u64 = *(const uint64_t *)(ref_left + 1); + + const __m128i zero128 = _mm_setzero_si128(); + const __m256i twos = _mm256_set1_epi8(2); + + // DC multiplier is 2 at (0, 0), 3 at (*, 0) and (0, *), and 4 at (*, *). + // There is a constant addend of 2 on each pixel, use values from the twos + // register and multipliers of 1 for that, to use maddubs for an (a*b)+c + // operation. + const __m256i mult_up_lo = _mm256_setr_epi32(0x01030102, 0x01030103, + 0x01030103, 0x01030103, + 0x01040103, 0x01040104, + 0x01040104, 0x01040104); + + // The 6 lowest rows have same multipliers, also the DC values and addends + // are the same so this works for all of those + const __m256i mult_rest = _mm256_permute4x64_epi64(mult_up_lo, _MM_SHUFFLE(3, 2, 3, 2)); + + // Every 8-pixel row starts with the next pixel of ref_left. Along with + // doing the shuffling, also expand u8->u16, ie. move bytes 0 and 1 from + // ref_left to bit positions 0 and 128 in rl_up_lo, 2 and 3 to rl_up_hi, + // etc. The places to be zeroed out are 0x80 instead of the usual 0xff, + // because this allows us to form new masks on the fly by adding 0x02-bytes + // to this mask and still retain the highest bits as 1 where things should + // be zeroed out. + const __m256i rl_shuf_up_lo = _mm256_setr_epi32(0x80808000, 0x80808080, + 0x80808080, 0x80808080, + 0x80808001, 0x80808080, + 0x80808080, 0x80808080); + // And don't waste memory or architectural regs, hope these instructions + // will be placed in between the shuffles by the compiler to only use one + // register for the shufmasks, and executed way ahead of time because their + // regs can be renamed. + const __m256i rl_shuf_up_hi = _mm256_add_epi8 (rl_shuf_up_lo, twos); + const __m256i rl_shuf_dn_lo = _mm256_add_epi8 (rl_shuf_up_hi, twos); + const __m256i rl_shuf_dn_hi = _mm256_add_epi8 (rl_shuf_dn_lo, twos); + + __m128i eight = _mm_cvtsi32_si128 (8); + __m128i rt = _mm_cvtsi64_si128 (rt_u64); + __m128i rl = _mm_cvtsi64_si128 (rl_u64); + __m128i rtrl = _mm_unpacklo_epi64 (rt, rl); + + __m128i sad0 = _mm_sad_epu8 (rtrl, zero128); + __m128i sad1 = _mm_shuffle_epi32 (sad0, _MM_SHUFFLE(1, 0, 3, 2)); + __m128i sad2 = _mm_add_epi64 (sad0, sad1); + __m128i sad3 = _mm_add_epi64 (sad2, eight); + + __m128i dc_64 = _mm_srli_epi64 (sad3, 4); + __m256i dc_8 = _mm256_broadcastb_epi8(dc_64); + + __m256i dc_addend = _mm256_unpacklo_epi8 (dc_8, twos); + + __m256i dc_up_lo = _mm256_maddubs_epi16 (dc_addend, mult_up_lo); + __m256i dc_rest = _mm256_maddubs_epi16 (dc_addend, mult_rest); + + // rt_dn is all zeros, as is rt_up_hi. This'll get us the rl and rt parts + // in A|B, C|D order instead of A|C, B|D that could be packed into abcd + // order, so these need to be permuted before adding to the weighed DC + // values. + __m256i rt_up_lo = _mm256_cvtepu8_epi16 (rt); + + __m256i rlrlrlrl = _mm256_broadcastq_epi64(rl); + __m256i rl_up_lo = _mm256_shuffle_epi8 (rlrlrlrl, rl_shuf_up_lo); + + // Everything ref_top is zero except on the very first row + __m256i rt_rl_up_hi = _mm256_shuffle_epi8 (rlrlrlrl, rl_shuf_up_hi); + __m256i rt_rl_dn_lo = _mm256_shuffle_epi8 (rlrlrlrl, rl_shuf_dn_lo); + __m256i rt_rl_dn_hi = _mm256_shuffle_epi8 (rlrlrlrl, rl_shuf_dn_hi); + + __m256i rt_rl_up_lo = _mm256_add_epi16 (rt_up_lo, rl_up_lo); + + __m256i rt_rl_up_lo_2 = _mm256_permute2x128_si256(rt_rl_up_lo, rt_rl_up_hi, 0x20); + __m256i rt_rl_up_hi_2 = _mm256_permute2x128_si256(rt_rl_up_lo, rt_rl_up_hi, 0x31); + __m256i rt_rl_dn_lo_2 = _mm256_permute2x128_si256(rt_rl_dn_lo, rt_rl_dn_hi, 0x20); + __m256i rt_rl_dn_hi_2 = _mm256_permute2x128_si256(rt_rl_dn_lo, rt_rl_dn_hi, 0x31); + + __m256i up_lo = _mm256_add_epi16(rt_rl_up_lo_2, dc_up_lo); + __m256i up_hi = _mm256_add_epi16(rt_rl_up_hi_2, dc_rest); + __m256i dn_lo = _mm256_add_epi16(rt_rl_dn_lo_2, dc_rest); + __m256i dn_hi = _mm256_add_epi16(rt_rl_dn_hi_2, dc_rest); + + up_lo = _mm256_srli_epi16(up_lo, 2); + up_hi = _mm256_srli_epi16(up_hi, 2); + dn_lo = _mm256_srli_epi16(dn_lo, 2); + dn_hi = _mm256_srli_epi16(dn_hi, 2); + + __m256i res_up = _mm256_packus_epi16(up_lo, up_hi); + __m256i res_dn = _mm256_packus_epi16(dn_lo, dn_hi); + + _mm256_storeu_si256(((__m256i *)out_block) + 0, res_up); + _mm256_storeu_si256(((__m256i *)out_block) + 1, res_dn); +} + /** * \brief Generage intra DC prediction with post filtering applied. * \param log2_width Log2 of width, range 2..5. @@ -655,6 +753,15 @@ static void kvz_intra_pred_filtered_dc_avx2( kvz_pixel *const out_block) { assert(log2_width >= 2 && log2_width <= 5); + + if (log2_width == 2) { + pred_filtered_dc_4x4(ref_top, ref_left, out_block); + return; + } else if (log2_width == 3) { + pred_filtered_dc_8x8(ref_top, ref_left, out_block); + return; + } + const int_fast8_t width = 1 << log2_width; const __m256i zero = _mm256_setzero_si256(); @@ -794,24 +901,41 @@ static void kvz_intra_pred_filtered_dc_avx2( out_block[y * width + x] = res >> 2; } } + /* if (width == 4) { - uint8_t tampio[16]; - pred_filtered_dc_4x4(ref_top, ref_left, tampio); + uint8_t tmp[16]; + pred_filtered_dc_4x4(ref_top, ref_left, tmp); for (int i = 0; i < 16; i++) { - if (tampio[i] != out_block[i]) { + if (tmp[i] != out_block[i]) { int j; printf("mults c: "); print_128_s(mults); printf("dv c: "); print_128_s(dv); printf("rits c: "); print_128_s(rits); printf("rils c: "); print_128_s(rits); asm("int $3"); - pred_filtered_dc_4x4(ref_top, ref_left, tampio); + pred_filtered_dc_4x4(ref_top, ref_left, tmp); break; } } } - // asm("int $3"); - return; + if (width == 8) { + uint8_t tmp[64]; + pred_filtered_dc_8x8(ref_top, ref_left, tmp); + pred_filtered_dc_8x8(ref_top, ref_left, out_block); + for (int i = 0; i < 64; i++) { + if (tmp[i] != out_block[i]) { + int j; + printf("mults c: "); print_128_s(mults); + printf("dv c: "); print_128_s(dv); + printf("rits c: "); print_128_s(rits); + printf("rils c: "); print_128_s(rits); + asm("int $3"); + pred_filtered_dc_8x8(ref_top, ref_left, tmp); + break; + } + } + } + */ } #endif //COMPILE_INTEL_AVX2 && defined X86_64