From 146298a0dfa3b502e50ac621c1070aaed7edd221 Mon Sep 17 00:00:00 2001 From: Ari Lemmetti Date: Sun, 5 Apr 2020 22:42:47 +0300 Subject: [PATCH] New AVX2 block averaging *WIP* missing small chroma block and SMP/AMP --- src/strategies/avx2/picture-avx2.c | 309 ++++++++++++++++++++++++++++- 1 file changed, 308 insertions(+), 1 deletion(-) diff --git a/src/strategies/avx2/picture-avx2.c b/src/strategies/avx2/picture-avx2.c index 09b2b7c7..84a92f2c 100644 --- a/src/strategies/avx2/picture-avx2.c +++ b/src/strategies/avx2/picture-avx2.c @@ -769,6 +769,313 @@ static unsigned pixels_calc_ssd_avx2(const uint8_t *const ref, const uint8_t *co } } +static INLINE void scatter_ymm_4x8_8bit(kvz_pixel * dst, __m256i ymm, unsigned dst_stride) +{ + __m128i ymm_lo = _mm256_castsi256_si128(ymm); + __m128i ymm_hi = _mm256_extracti128_si256(ymm, 1); + *(uint32_t *)dst = _mm_cvtsi128_si32(ymm_lo); dst += dst_stride; + *(uint32_t *)dst = _mm_extract_epi32(ymm_lo, 1); dst += dst_stride; + *(uint32_t *)dst = _mm_extract_epi32(ymm_lo, 2); dst += dst_stride; + *(uint32_t *)dst = _mm_extract_epi32(ymm_lo, 3); dst += dst_stride; + *(uint32_t *)dst = _mm_cvtsi128_si32(ymm_hi); dst += dst_stride; + *(uint32_t *)dst = _mm_extract_epi32(ymm_hi, 1); dst += dst_stride; + *(uint32_t *)dst = _mm_extract_epi32(ymm_hi, 2); dst += dst_stride; + *(uint32_t *)dst = _mm_extract_epi32(ymm_hi, 3); +} + +static INLINE void scatter_ymm_8x4_8bit(kvz_pixel *dst, __m256i ymm, unsigned dst_stride) +{ + __m256d ymm_as_m256d = _mm256_castsi256_pd(ymm); + __m128d ymm_lo = _mm256_castpd256_pd128(ymm_as_m256d); + __m128d ymm_hi = _mm256_extractf128_pd(ymm_as_m256d, 1); + _mm_storel_pd((double*)dst, ymm_lo); dst += dst_stride; + _mm_storeh_pd((double*)dst, ymm_lo); dst += dst_stride; + _mm_storel_pd((double*)dst, ymm_hi); dst += dst_stride; + _mm_storeh_pd((double*)dst, ymm_hi); +} + +static INLINE void scatter_ymm_16x2_8bit(kvz_pixel *dst, __m256i ymm, unsigned dst_stride) +{ + __m128i ymm_lo = _mm256_castsi256_si128(ymm); + __m128i ymm_hi = _mm256_extracti128_si256(ymm, 1); + _mm_storeu_si128((__m128i *)dst, ymm_lo); dst += dst_stride; + _mm_storeu_si128((__m128i *)dst, ymm_hi); +} + +static INLINE void bipred_average_px_px_template_avx2(kvz_pixel *dst, + kvz_pixel *px_L0, + kvz_pixel *px_L1, + unsigned pu_w, + unsigned pu_h, + unsigned dst_stride) +{ + for (int i = 0; i < pu_w * pu_h; i += 32) { + int y = i / pu_w; + int x = i % pu_w; + __m256i sample_L0 = _mm256_loadu_si256((__m256i *)&px_L0[i]); + __m256i sample_L1 = _mm256_loadu_si256((__m256i *)&px_L1[i]); + __m256i avg = _mm256_avg_epu8(sample_L0, sample_L1); + + switch (pu_w) { + case 4: scatter_ymm_4x8_8bit(&dst[y * dst_stride + x], avg, dst_stride); break; + case 8: scatter_ymm_8x4_8bit(&dst[y * dst_stride + x], avg, dst_stride); break; + case 16: scatter_ymm_16x2_8bit(&dst[y * dst_stride + x], avg, dst_stride); break; + case 32: // Same as case 64 + case 64: _mm256_storeu_si256((__m256i*)&dst[y * dst_stride + x], avg); break; + default: + assert(0 && "Unexpected block width"); + break; + } + } +} + +static INLINE void bipred_average_px_px_avx2(kvz_pixel *dst, + kvz_pixel *px_L0, + kvz_pixel *px_L1, + unsigned pu_w, + unsigned pu_h, + unsigned dst_stride) +{ + unsigned size = pu_w * pu_h; + bool multiple_of_32 = !(size % 32); + + if (MIN(pu_w, pu_h) >= 4) { + switch (pu_w) { + case 4: bipred_average_px_px_template_avx2(dst, px_L0, px_L1, 4, pu_h, dst_stride); break; + case 8: bipred_average_px_px_template_avx2(dst, px_L0, px_L1, 8, pu_h, dst_stride); break; + case 16: bipred_average_px_px_template_avx2(dst, px_L0, px_L1, 16, pu_h, dst_stride); break; + case 32: bipred_average_px_px_template_avx2(dst, px_L0, px_L1, 32, pu_h, dst_stride); break; + case 64: bipred_average_px_px_template_avx2(dst, px_L0, px_L1, 64, pu_h, dst_stride); break; + default: + printf("W: %d\n", pu_w); + assert(0 && "Unexpected block width."); + break; + } + } +} + +static INLINE void bipred_average_ip_ip_template_avx2(kvz_pixel *dst, + kvz_pixel_ip *ip_L0, + kvz_pixel_ip *ip_L1, + unsigned pu_w, + unsigned pu_h, + unsigned dst_stride) +{ + int32_t shift = 15 - KVZ_BIT_DEPTH; // TODO: defines + int32_t scalar_offset = 1 << (shift - 1); + __m256i offset = _mm256_set1_epi32(scalar_offset); + + for (int i = 0; i < pu_w * pu_h; i += 32) { + int y = i / pu_w; + int x = i % pu_w; + + __m256i sample_L0_01_16bit = _mm256_loadu_si256((__m256i*)&ip_L0[i]); + __m256i sample_L0_23_16bit = _mm256_loadu_si256((__m256i*)&ip_L0[i + 16]); + __m256i sample_L1_01_16bit = _mm256_loadu_si256((__m256i*)&ip_L1[i]); + __m256i sample_L1_23_16bit = _mm256_loadu_si256((__m256i*)&ip_L1[i + 16]); + + __m256i sample_L0_L1_01_lo = _mm256_unpacklo_epi16(sample_L0_01_16bit, sample_L1_01_16bit); + __m256i sample_L0_L1_01_hi = _mm256_unpackhi_epi16(sample_L0_01_16bit, sample_L1_01_16bit); + __m256i sample_L0_L1_23_lo = _mm256_unpacklo_epi16(sample_L0_23_16bit, sample_L1_23_16bit); + __m256i sample_L0_L1_23_hi = _mm256_unpackhi_epi16(sample_L0_23_16bit, sample_L1_23_16bit); + + __m256i all_ones = _mm256_set1_epi16(1); + __m256i avg_01_lo = _mm256_madd_epi16(sample_L0_L1_01_lo, all_ones); + __m256i avg_01_hi = _mm256_madd_epi16(sample_L0_L1_01_hi, all_ones); + __m256i avg_23_lo = _mm256_madd_epi16(sample_L0_L1_23_lo, all_ones); + __m256i avg_23_hi = _mm256_madd_epi16(sample_L0_L1_23_hi, all_ones); + + avg_01_lo = _mm256_add_epi32(avg_01_lo, offset); + avg_01_hi = _mm256_add_epi32(avg_01_hi, offset); + avg_23_lo = _mm256_add_epi32(avg_23_lo, offset); + avg_23_hi = _mm256_add_epi32(avg_23_hi, offset); + + avg_01_lo = _mm256_srai_epi32(avg_01_lo, shift); + avg_01_hi = _mm256_srai_epi32(avg_01_hi, shift); + avg_23_lo = _mm256_srai_epi32(avg_23_lo, shift); + avg_23_hi = _mm256_srai_epi32(avg_23_hi, shift); + + __m256i avg_01 = _mm256_packus_epi32(avg_01_lo, avg_01_hi); + __m256i avg_23 = _mm256_packus_epi32(avg_23_lo, avg_23_hi); + __m256i avg0213 = _mm256_packus_epi16(avg_01, avg_23); + __m256i avg = _mm256_permute4x64_epi64(avg0213, _MM_SHUFFLE(3,1,2,0)); + + switch (pu_w) { + case 4: scatter_ymm_4x8_8bit(&dst[y * dst_stride + x], avg, dst_stride); break; + case 8: scatter_ymm_8x4_8bit(&dst[y * dst_stride + x], avg, dst_stride); break; + case 16: scatter_ymm_16x2_8bit(&dst[y * dst_stride + x], avg, dst_stride); break; + case 32: // Same as case 64 + case 64: _mm256_storeu_si256((__m256i *)&dst[y * dst_stride + x], avg); break; + default: + assert(0 && "Unexpected block width"); + break; + } + } +} + +static void bipred_average_ip_ip_avx2(kvz_pixel *dst, + kvz_pixel_ip *ip_L0, + kvz_pixel_ip *ip_L1, + unsigned pu_w, + unsigned pu_h, + unsigned dst_stride) +{ + unsigned size = pu_w * pu_h; + bool multiple_of_32 = !(size % 32); + + if (MIN(pu_w, pu_h) >= 4) { + switch (pu_w) { + case 4: bipred_average_ip_ip_template_avx2(dst, ip_L0, ip_L1, 4, pu_h, dst_stride); break; + case 8: bipred_average_ip_ip_template_avx2(dst, ip_L0, ip_L1, 8, pu_h, dst_stride); break; + case 16: bipred_average_ip_ip_template_avx2(dst, ip_L0, ip_L1, 16, pu_h, dst_stride); break; + case 32: bipred_average_ip_ip_template_avx2(dst, ip_L0, ip_L1, 32, pu_h, dst_stride); break; + case 64: bipred_average_ip_ip_template_avx2(dst, ip_L0, ip_L1, 64, pu_h, dst_stride); break; + default: + assert(0 && "Unexpected block width."); + break; + } + } +} + +static INLINE void bipred_average_px_ip_template_avx2(kvz_pixel *dst, + kvz_pixel *px, + kvz_pixel_ip *ip, + unsigned pu_w, + unsigned pu_h, + unsigned dst_stride) +{ + int32_t shift = 15 - KVZ_BIT_DEPTH; // TODO: defines + int32_t scalar_offset = 1 << (shift - 1); + __m256i offset = _mm256_set1_epi32(scalar_offset); + + for (int i = 0; i < pu_w * pu_h; i += 32) { + int y = i / pu_w; + int x = i % pu_w; + + __m256i sample_px_01_16bit = _mm256_cvtepu8_epi16(_mm_loadu_si128((__m128i *)&px[i])); + __m256i sample_px_23_16bit = _mm256_cvtepu8_epi16(_mm_loadu_si128((__m128i *)&px[i + 16])); + sample_px_01_16bit = _mm256_slli_epi16(sample_px_01_16bit, 14 - KVZ_BIT_DEPTH); + sample_px_23_16bit = _mm256_slli_epi16(sample_px_23_16bit, 14 - KVZ_BIT_DEPTH); + __m256i sample_ip_01_16bit = _mm256_loadu_si256((__m256i *)&ip[i]); + __m256i sample_ip_23_16bit = _mm256_loadu_si256((__m256i *)&ip[i + 16]); + + __m256i sample_px_ip_01_lo = _mm256_unpacklo_epi16(sample_px_01_16bit, sample_ip_01_16bit); + __m256i sample_px_ip_01_hi = _mm256_unpackhi_epi16(sample_px_01_16bit, sample_ip_01_16bit); + __m256i sample_px_ip_23_lo = _mm256_unpacklo_epi16(sample_px_23_16bit, sample_ip_23_16bit); + __m256i sample_px_ip_23_hi = _mm256_unpackhi_epi16(sample_px_23_16bit, sample_ip_23_16bit); + + __m256i all_ones = _mm256_set1_epi16(1); + __m256i avg_01_lo = _mm256_madd_epi16(sample_px_ip_01_lo, all_ones); + __m256i avg_01_hi = _mm256_madd_epi16(sample_px_ip_01_hi, all_ones); + __m256i avg_23_lo = _mm256_madd_epi16(sample_px_ip_23_lo, all_ones); + __m256i avg_23_hi = _mm256_madd_epi16(sample_px_ip_23_hi, all_ones); + + avg_01_lo = _mm256_add_epi32(avg_01_lo, offset); + avg_01_hi = _mm256_add_epi32(avg_01_hi, offset); + avg_23_lo = _mm256_add_epi32(avg_23_lo, offset); + avg_23_hi = _mm256_add_epi32(avg_23_hi, offset); + + avg_01_lo = _mm256_srai_epi32(avg_01_lo, shift); + avg_01_hi = _mm256_srai_epi32(avg_01_hi, shift); + avg_23_lo = _mm256_srai_epi32(avg_23_lo, shift); + avg_23_hi = _mm256_srai_epi32(avg_23_hi, shift); + + __m256i avg_01 = _mm256_packus_epi32(avg_01_lo, avg_01_hi); + __m256i avg_23 = _mm256_packus_epi32(avg_23_lo, avg_23_hi); + __m256i avg0213 = _mm256_packus_epi16(avg_01, avg_23); + __m256i avg = _mm256_permute4x64_epi64(avg0213, _MM_SHUFFLE(3, 1, 2, 0)); + + switch (pu_w) { + case 4: scatter_ymm_4x8_8bit(&dst[y * dst_stride + x], avg, dst_stride); break; + case 8: scatter_ymm_8x4_8bit(&dst[y * dst_stride + x], avg, dst_stride); break; + case 16: scatter_ymm_16x2_8bit(&dst[y * dst_stride + x], avg, dst_stride); break; + case 32: // Same as case 64 + case 64: _mm256_storeu_si256((__m256i *)&dst[y * dst_stride + x], avg); break; + default: + assert(0 && "Unexpected block width"); + break; + } + } +} + +static void bipred_average_px_ip_avx2(kvz_pixel *dst, + kvz_pixel *px, + kvz_pixel_ip *ip, + unsigned pu_w, + unsigned pu_h, + unsigned dst_stride) +{ + unsigned size = pu_w * pu_h; + bool multiple_of_32 = !(size % 32); + + if (MIN(pu_w, pu_h) >= 4) { + switch (pu_w) { + case 4: bipred_average_px_ip_template_avx2(dst, px, ip, 4, pu_h, dst_stride); break; + case 8: bipred_average_px_ip_template_avx2(dst, px, ip, 8, pu_h, dst_stride); break; + case 16: bipred_average_px_ip_template_avx2(dst, px, ip, 16, pu_h, dst_stride); break; + case 32: bipred_average_px_ip_template_avx2(dst, px, ip, 32, pu_h, dst_stride); break; + case 64: bipred_average_px_ip_template_avx2(dst, px, ip, 64, pu_h, dst_stride); break; + default: + assert(0 && "Unexpected block width."); + break; + } + } +} + +static void bipred_average_avx2(lcu_t *const lcu, + const yuv_t *const px_L0, + const yuv_t *const px_L1, + const yuv_ip_t *const ip_L0, + const yuv_ip_t *const ip_L1, + const unsigned pu_x, + const unsigned pu_y, + const unsigned pu_w, + const unsigned pu_h, + const unsigned ip_flags_L0, + const unsigned ip_flags_L1, + const bool predict_luma, + const bool predict_chroma) { + + //After reconstruction, merge the predictors by taking an average of each pixel + if (predict_luma) { + unsigned pb_offset = SUB_SCU(pu_y) * LCU_WIDTH + SUB_SCU(pu_x); + + if (!(ip_flags_L0 & 1) && !(ip_flags_L1 & 1)) { + bipred_average_px_px_avx2(lcu->rec.y + pb_offset, px_L0->y, px_L1->y, pu_w, pu_h, LCU_WIDTH); + + } else if ((ip_flags_L0 & 1) && (ip_flags_L1 & 1)) { + bipred_average_ip_ip_avx2(lcu->rec.y + pb_offset, ip_L0->y, ip_L1->y, pu_w, pu_h, LCU_WIDTH); + + } else { + kvz_pixel *src_px = (ip_flags_L0 & 1) ? px_L1->y : px_L0->y; + kvz_pixel_ip *src_ip = (ip_flags_L0 & 1) ? ip_L0->y : ip_L1->y; + bipred_average_px_ip_avx2(lcu->rec.y + pb_offset, src_px, src_ip, pu_w, pu_h, LCU_WIDTH); + } + } + if (predict_chroma) { + unsigned pb_offset = SUB_SCU(pu_y) / 2 * LCU_WIDTH_C + SUB_SCU(pu_x) / 2; + unsigned pb_w = pu_w / 2; + unsigned pb_h = pu_h / 2; + + if (!(ip_flags_L0 & 2) && !(ip_flags_L1 & 2)) { + bipred_average_px_px_avx2(lcu->rec.u + pb_offset, px_L0->u, px_L1->u, pb_w, pb_h, LCU_WIDTH_C); + bipred_average_px_px_avx2(lcu->rec.v + pb_offset, px_L0->v, px_L1->v, pb_w, pb_h, LCU_WIDTH_C); + + } else if ((ip_flags_L0 & 2) && (ip_flags_L1 & 2)) { + bipred_average_ip_ip_avx2(lcu->rec.u + pb_offset, ip_L0->u, ip_L1->u, pb_w, pb_h, LCU_WIDTH_C); + bipred_average_ip_ip_avx2(lcu->rec.v + pb_offset, ip_L0->v, ip_L1->v, pb_w, pb_h, LCU_WIDTH_C); + + } else { + kvz_pixel *src_px_u = (ip_flags_L0 & 2) ? px_L1->u : px_L0->u; + kvz_pixel_ip *src_ip_u = (ip_flags_L0 & 2) ? ip_L0->u : ip_L1->u; + kvz_pixel *src_px_v = (ip_flags_L0 & 2) ? px_L1->v : px_L0->v; + kvz_pixel_ip *src_ip_v = (ip_flags_L0 & 2) ? ip_L0->v : ip_L1->v; + bipred_average_px_ip_avx2(lcu->rec.u + pb_offset, src_px_u, src_ip_u, pb_w, pb_h, LCU_WIDTH_C); + bipred_average_px_ip_avx2(lcu->rec.v + pb_offset, src_px_v, src_ip_v, pb_w, pb_h, LCU_WIDTH_C); + } + } +} + static optimized_sad_func_ptr_t get_optimized_sad_avx2(int32_t width) { if (width == 0) @@ -1043,7 +1350,7 @@ int kvz_strategy_register_picture_avx2(void* opaque, uint8_t bitdepth) success &= kvz_strategyselector_register(opaque, "satd_any_size_quad", "avx2", 40, &satd_any_size_quad_avx2); success &= kvz_strategyselector_register(opaque, "pixels_calc_ssd", "avx2", 40, &pixels_calc_ssd_avx2); - //success &= kvz_strategyselector_register(opaque, "bipred_average", "avx2", 40, &bipred_average_avx2); + success &= kvz_strategyselector_register(opaque, "bipred_average", "avx2", 40, &bipred_average_avx2); success &= kvz_strategyselector_register(opaque, "get_optimized_sad", "avx2", 40, &get_optimized_sad_avx2); success &= kvz_strategyselector_register(opaque, "ver_sad", "avx2", 40, &ver_sad_avx2); success &= kvz_strategyselector_register(opaque, "hor_sad", "avx2", 40, &hor_sad_avx2);