New AVX2 block averaging *WIP* missing small chroma block and SMP/AMP

This commit is contained in:
Ari Lemmetti 2020-04-05 22:42:47 +03:00
parent ef69c65c58
commit 146298a0df

View file

@ -769,6 +769,313 @@ static unsigned pixels_calc_ssd_avx2(const uint8_t *const ref, const uint8_t *co
}
}
static INLINE void scatter_ymm_4x8_8bit(kvz_pixel * dst, __m256i ymm, unsigned dst_stride)
{
__m128i ymm_lo = _mm256_castsi256_si128(ymm);
__m128i ymm_hi = _mm256_extracti128_si256(ymm, 1);
*(uint32_t *)dst = _mm_cvtsi128_si32(ymm_lo); dst += dst_stride;
*(uint32_t *)dst = _mm_extract_epi32(ymm_lo, 1); dst += dst_stride;
*(uint32_t *)dst = _mm_extract_epi32(ymm_lo, 2); dst += dst_stride;
*(uint32_t *)dst = _mm_extract_epi32(ymm_lo, 3); dst += dst_stride;
*(uint32_t *)dst = _mm_cvtsi128_si32(ymm_hi); dst += dst_stride;
*(uint32_t *)dst = _mm_extract_epi32(ymm_hi, 1); dst += dst_stride;
*(uint32_t *)dst = _mm_extract_epi32(ymm_hi, 2); dst += dst_stride;
*(uint32_t *)dst = _mm_extract_epi32(ymm_hi, 3);
}
static INLINE void scatter_ymm_8x4_8bit(kvz_pixel *dst, __m256i ymm, unsigned dst_stride)
{
__m256d ymm_as_m256d = _mm256_castsi256_pd(ymm);
__m128d ymm_lo = _mm256_castpd256_pd128(ymm_as_m256d);
__m128d ymm_hi = _mm256_extractf128_pd(ymm_as_m256d, 1);
_mm_storel_pd((double*)dst, ymm_lo); dst += dst_stride;
_mm_storeh_pd((double*)dst, ymm_lo); dst += dst_stride;
_mm_storel_pd((double*)dst, ymm_hi); dst += dst_stride;
_mm_storeh_pd((double*)dst, ymm_hi);
}
static INLINE void scatter_ymm_16x2_8bit(kvz_pixel *dst, __m256i ymm, unsigned dst_stride)
{
__m128i ymm_lo = _mm256_castsi256_si128(ymm);
__m128i ymm_hi = _mm256_extracti128_si256(ymm, 1);
_mm_storeu_si128((__m128i *)dst, ymm_lo); dst += dst_stride;
_mm_storeu_si128((__m128i *)dst, ymm_hi);
}
static INLINE void bipred_average_px_px_template_avx2(kvz_pixel *dst,
kvz_pixel *px_L0,
kvz_pixel *px_L1,
unsigned pu_w,
unsigned pu_h,
unsigned dst_stride)
{
for (int i = 0; i < pu_w * pu_h; i += 32) {
int y = i / pu_w;
int x = i % pu_w;
__m256i sample_L0 = _mm256_loadu_si256((__m256i *)&px_L0[i]);
__m256i sample_L1 = _mm256_loadu_si256((__m256i *)&px_L1[i]);
__m256i avg = _mm256_avg_epu8(sample_L0, sample_L1);
switch (pu_w) {
case 4: scatter_ymm_4x8_8bit(&dst[y * dst_stride + x], avg, dst_stride); break;
case 8: scatter_ymm_8x4_8bit(&dst[y * dst_stride + x], avg, dst_stride); break;
case 16: scatter_ymm_16x2_8bit(&dst[y * dst_stride + x], avg, dst_stride); break;
case 32: // Same as case 64
case 64: _mm256_storeu_si256((__m256i*)&dst[y * dst_stride + x], avg); break;
default:
assert(0 && "Unexpected block width");
break;
}
}
}
static INLINE void bipred_average_px_px_avx2(kvz_pixel *dst,
kvz_pixel *px_L0,
kvz_pixel *px_L1,
unsigned pu_w,
unsigned pu_h,
unsigned dst_stride)
{
unsigned size = pu_w * pu_h;
bool multiple_of_32 = !(size % 32);
if (MIN(pu_w, pu_h) >= 4) {
switch (pu_w) {
case 4: bipred_average_px_px_template_avx2(dst, px_L0, px_L1, 4, pu_h, dst_stride); break;
case 8: bipred_average_px_px_template_avx2(dst, px_L0, px_L1, 8, pu_h, dst_stride); break;
case 16: bipred_average_px_px_template_avx2(dst, px_L0, px_L1, 16, pu_h, dst_stride); break;
case 32: bipred_average_px_px_template_avx2(dst, px_L0, px_L1, 32, pu_h, dst_stride); break;
case 64: bipred_average_px_px_template_avx2(dst, px_L0, px_L1, 64, pu_h, dst_stride); break;
default:
printf("W: %d\n", pu_w);
assert(0 && "Unexpected block width.");
break;
}
}
}
static INLINE void bipred_average_ip_ip_template_avx2(kvz_pixel *dst,
kvz_pixel_ip *ip_L0,
kvz_pixel_ip *ip_L1,
unsigned pu_w,
unsigned pu_h,
unsigned dst_stride)
{
int32_t shift = 15 - KVZ_BIT_DEPTH; // TODO: defines
int32_t scalar_offset = 1 << (shift - 1);
__m256i offset = _mm256_set1_epi32(scalar_offset);
for (int i = 0; i < pu_w * pu_h; i += 32) {
int y = i / pu_w;
int x = i % pu_w;
__m256i sample_L0_01_16bit = _mm256_loadu_si256((__m256i*)&ip_L0[i]);
__m256i sample_L0_23_16bit = _mm256_loadu_si256((__m256i*)&ip_L0[i + 16]);
__m256i sample_L1_01_16bit = _mm256_loadu_si256((__m256i*)&ip_L1[i]);
__m256i sample_L1_23_16bit = _mm256_loadu_si256((__m256i*)&ip_L1[i + 16]);
__m256i sample_L0_L1_01_lo = _mm256_unpacklo_epi16(sample_L0_01_16bit, sample_L1_01_16bit);
__m256i sample_L0_L1_01_hi = _mm256_unpackhi_epi16(sample_L0_01_16bit, sample_L1_01_16bit);
__m256i sample_L0_L1_23_lo = _mm256_unpacklo_epi16(sample_L0_23_16bit, sample_L1_23_16bit);
__m256i sample_L0_L1_23_hi = _mm256_unpackhi_epi16(sample_L0_23_16bit, sample_L1_23_16bit);
__m256i all_ones = _mm256_set1_epi16(1);
__m256i avg_01_lo = _mm256_madd_epi16(sample_L0_L1_01_lo, all_ones);
__m256i avg_01_hi = _mm256_madd_epi16(sample_L0_L1_01_hi, all_ones);
__m256i avg_23_lo = _mm256_madd_epi16(sample_L0_L1_23_lo, all_ones);
__m256i avg_23_hi = _mm256_madd_epi16(sample_L0_L1_23_hi, all_ones);
avg_01_lo = _mm256_add_epi32(avg_01_lo, offset);
avg_01_hi = _mm256_add_epi32(avg_01_hi, offset);
avg_23_lo = _mm256_add_epi32(avg_23_lo, offset);
avg_23_hi = _mm256_add_epi32(avg_23_hi, offset);
avg_01_lo = _mm256_srai_epi32(avg_01_lo, shift);
avg_01_hi = _mm256_srai_epi32(avg_01_hi, shift);
avg_23_lo = _mm256_srai_epi32(avg_23_lo, shift);
avg_23_hi = _mm256_srai_epi32(avg_23_hi, shift);
__m256i avg_01 = _mm256_packus_epi32(avg_01_lo, avg_01_hi);
__m256i avg_23 = _mm256_packus_epi32(avg_23_lo, avg_23_hi);
__m256i avg0213 = _mm256_packus_epi16(avg_01, avg_23);
__m256i avg = _mm256_permute4x64_epi64(avg0213, _MM_SHUFFLE(3,1,2,0));
switch (pu_w) {
case 4: scatter_ymm_4x8_8bit(&dst[y * dst_stride + x], avg, dst_stride); break;
case 8: scatter_ymm_8x4_8bit(&dst[y * dst_stride + x], avg, dst_stride); break;
case 16: scatter_ymm_16x2_8bit(&dst[y * dst_stride + x], avg, dst_stride); break;
case 32: // Same as case 64
case 64: _mm256_storeu_si256((__m256i *)&dst[y * dst_stride + x], avg); break;
default:
assert(0 && "Unexpected block width");
break;
}
}
}
static void bipred_average_ip_ip_avx2(kvz_pixel *dst,
kvz_pixel_ip *ip_L0,
kvz_pixel_ip *ip_L1,
unsigned pu_w,
unsigned pu_h,
unsigned dst_stride)
{
unsigned size = pu_w * pu_h;
bool multiple_of_32 = !(size % 32);
if (MIN(pu_w, pu_h) >= 4) {
switch (pu_w) {
case 4: bipred_average_ip_ip_template_avx2(dst, ip_L0, ip_L1, 4, pu_h, dst_stride); break;
case 8: bipred_average_ip_ip_template_avx2(dst, ip_L0, ip_L1, 8, pu_h, dst_stride); break;
case 16: bipred_average_ip_ip_template_avx2(dst, ip_L0, ip_L1, 16, pu_h, dst_stride); break;
case 32: bipred_average_ip_ip_template_avx2(dst, ip_L0, ip_L1, 32, pu_h, dst_stride); break;
case 64: bipred_average_ip_ip_template_avx2(dst, ip_L0, ip_L1, 64, pu_h, dst_stride); break;
default:
assert(0 && "Unexpected block width.");
break;
}
}
}
static INLINE void bipred_average_px_ip_template_avx2(kvz_pixel *dst,
kvz_pixel *px,
kvz_pixel_ip *ip,
unsigned pu_w,
unsigned pu_h,
unsigned dst_stride)
{
int32_t shift = 15 - KVZ_BIT_DEPTH; // TODO: defines
int32_t scalar_offset = 1 << (shift - 1);
__m256i offset = _mm256_set1_epi32(scalar_offset);
for (int i = 0; i < pu_w * pu_h; i += 32) {
int y = i / pu_w;
int x = i % pu_w;
__m256i sample_px_01_16bit = _mm256_cvtepu8_epi16(_mm_loadu_si128((__m128i *)&px[i]));
__m256i sample_px_23_16bit = _mm256_cvtepu8_epi16(_mm_loadu_si128((__m128i *)&px[i + 16]));
sample_px_01_16bit = _mm256_slli_epi16(sample_px_01_16bit, 14 - KVZ_BIT_DEPTH);
sample_px_23_16bit = _mm256_slli_epi16(sample_px_23_16bit, 14 - KVZ_BIT_DEPTH);
__m256i sample_ip_01_16bit = _mm256_loadu_si256((__m256i *)&ip[i]);
__m256i sample_ip_23_16bit = _mm256_loadu_si256((__m256i *)&ip[i + 16]);
__m256i sample_px_ip_01_lo = _mm256_unpacklo_epi16(sample_px_01_16bit, sample_ip_01_16bit);
__m256i sample_px_ip_01_hi = _mm256_unpackhi_epi16(sample_px_01_16bit, sample_ip_01_16bit);
__m256i sample_px_ip_23_lo = _mm256_unpacklo_epi16(sample_px_23_16bit, sample_ip_23_16bit);
__m256i sample_px_ip_23_hi = _mm256_unpackhi_epi16(sample_px_23_16bit, sample_ip_23_16bit);
__m256i all_ones = _mm256_set1_epi16(1);
__m256i avg_01_lo = _mm256_madd_epi16(sample_px_ip_01_lo, all_ones);
__m256i avg_01_hi = _mm256_madd_epi16(sample_px_ip_01_hi, all_ones);
__m256i avg_23_lo = _mm256_madd_epi16(sample_px_ip_23_lo, all_ones);
__m256i avg_23_hi = _mm256_madd_epi16(sample_px_ip_23_hi, all_ones);
avg_01_lo = _mm256_add_epi32(avg_01_lo, offset);
avg_01_hi = _mm256_add_epi32(avg_01_hi, offset);
avg_23_lo = _mm256_add_epi32(avg_23_lo, offset);
avg_23_hi = _mm256_add_epi32(avg_23_hi, offset);
avg_01_lo = _mm256_srai_epi32(avg_01_lo, shift);
avg_01_hi = _mm256_srai_epi32(avg_01_hi, shift);
avg_23_lo = _mm256_srai_epi32(avg_23_lo, shift);
avg_23_hi = _mm256_srai_epi32(avg_23_hi, shift);
__m256i avg_01 = _mm256_packus_epi32(avg_01_lo, avg_01_hi);
__m256i avg_23 = _mm256_packus_epi32(avg_23_lo, avg_23_hi);
__m256i avg0213 = _mm256_packus_epi16(avg_01, avg_23);
__m256i avg = _mm256_permute4x64_epi64(avg0213, _MM_SHUFFLE(3, 1, 2, 0));
switch (pu_w) {
case 4: scatter_ymm_4x8_8bit(&dst[y * dst_stride + x], avg, dst_stride); break;
case 8: scatter_ymm_8x4_8bit(&dst[y * dst_stride + x], avg, dst_stride); break;
case 16: scatter_ymm_16x2_8bit(&dst[y * dst_stride + x], avg, dst_stride); break;
case 32: // Same as case 64
case 64: _mm256_storeu_si256((__m256i *)&dst[y * dst_stride + x], avg); break;
default:
assert(0 && "Unexpected block width");
break;
}
}
}
static void bipred_average_px_ip_avx2(kvz_pixel *dst,
kvz_pixel *px,
kvz_pixel_ip *ip,
unsigned pu_w,
unsigned pu_h,
unsigned dst_stride)
{
unsigned size = pu_w * pu_h;
bool multiple_of_32 = !(size % 32);
if (MIN(pu_w, pu_h) >= 4) {
switch (pu_w) {
case 4: bipred_average_px_ip_template_avx2(dst, px, ip, 4, pu_h, dst_stride); break;
case 8: bipred_average_px_ip_template_avx2(dst, px, ip, 8, pu_h, dst_stride); break;
case 16: bipred_average_px_ip_template_avx2(dst, px, ip, 16, pu_h, dst_stride); break;
case 32: bipred_average_px_ip_template_avx2(dst, px, ip, 32, pu_h, dst_stride); break;
case 64: bipred_average_px_ip_template_avx2(dst, px, ip, 64, pu_h, dst_stride); break;
default:
assert(0 && "Unexpected block width.");
break;
}
}
}
static void bipred_average_avx2(lcu_t *const lcu,
const yuv_t *const px_L0,
const yuv_t *const px_L1,
const yuv_ip_t *const ip_L0,
const yuv_ip_t *const ip_L1,
const unsigned pu_x,
const unsigned pu_y,
const unsigned pu_w,
const unsigned pu_h,
const unsigned ip_flags_L0,
const unsigned ip_flags_L1,
const bool predict_luma,
const bool predict_chroma) {
//After reconstruction, merge the predictors by taking an average of each pixel
if (predict_luma) {
unsigned pb_offset = SUB_SCU(pu_y) * LCU_WIDTH + SUB_SCU(pu_x);
if (!(ip_flags_L0 & 1) && !(ip_flags_L1 & 1)) {
bipred_average_px_px_avx2(lcu->rec.y + pb_offset, px_L0->y, px_L1->y, pu_w, pu_h, LCU_WIDTH);
} else if ((ip_flags_L0 & 1) && (ip_flags_L1 & 1)) {
bipred_average_ip_ip_avx2(lcu->rec.y + pb_offset, ip_L0->y, ip_L1->y, pu_w, pu_h, LCU_WIDTH);
} else {
kvz_pixel *src_px = (ip_flags_L0 & 1) ? px_L1->y : px_L0->y;
kvz_pixel_ip *src_ip = (ip_flags_L0 & 1) ? ip_L0->y : ip_L1->y;
bipred_average_px_ip_avx2(lcu->rec.y + pb_offset, src_px, src_ip, pu_w, pu_h, LCU_WIDTH);
}
}
if (predict_chroma) {
unsigned pb_offset = SUB_SCU(pu_y) / 2 * LCU_WIDTH_C + SUB_SCU(pu_x) / 2;
unsigned pb_w = pu_w / 2;
unsigned pb_h = pu_h / 2;
if (!(ip_flags_L0 & 2) && !(ip_flags_L1 & 2)) {
bipred_average_px_px_avx2(lcu->rec.u + pb_offset, px_L0->u, px_L1->u, pb_w, pb_h, LCU_WIDTH_C);
bipred_average_px_px_avx2(lcu->rec.v + pb_offset, px_L0->v, px_L1->v, pb_w, pb_h, LCU_WIDTH_C);
} else if ((ip_flags_L0 & 2) && (ip_flags_L1 & 2)) {
bipred_average_ip_ip_avx2(lcu->rec.u + pb_offset, ip_L0->u, ip_L1->u, pb_w, pb_h, LCU_WIDTH_C);
bipred_average_ip_ip_avx2(lcu->rec.v + pb_offset, ip_L0->v, ip_L1->v, pb_w, pb_h, LCU_WIDTH_C);
} else {
kvz_pixel *src_px_u = (ip_flags_L0 & 2) ? px_L1->u : px_L0->u;
kvz_pixel_ip *src_ip_u = (ip_flags_L0 & 2) ? ip_L0->u : ip_L1->u;
kvz_pixel *src_px_v = (ip_flags_L0 & 2) ? px_L1->v : px_L0->v;
kvz_pixel_ip *src_ip_v = (ip_flags_L0 & 2) ? ip_L0->v : ip_L1->v;
bipred_average_px_ip_avx2(lcu->rec.u + pb_offset, src_px_u, src_ip_u, pb_w, pb_h, LCU_WIDTH_C);
bipred_average_px_ip_avx2(lcu->rec.v + pb_offset, src_px_v, src_ip_v, pb_w, pb_h, LCU_WIDTH_C);
}
}
}
static optimized_sad_func_ptr_t get_optimized_sad_avx2(int32_t width)
{
if (width == 0)
@ -1043,7 +1350,7 @@ int kvz_strategy_register_picture_avx2(void* opaque, uint8_t bitdepth)
success &= kvz_strategyselector_register(opaque, "satd_any_size_quad", "avx2", 40, &satd_any_size_quad_avx2);
success &= kvz_strategyselector_register(opaque, "pixels_calc_ssd", "avx2", 40, &pixels_calc_ssd_avx2);
//success &= kvz_strategyselector_register(opaque, "bipred_average", "avx2", 40, &bipred_average_avx2);
success &= kvz_strategyselector_register(opaque, "bipred_average", "avx2", 40, &bipred_average_avx2);
success &= kvz_strategyselector_register(opaque, "get_optimized_sad", "avx2", 40, &get_optimized_sad_avx2);
success &= kvz_strategyselector_register(opaque, "ver_sad", "avx2", 40, &ver_sad_avx2);
success &= kvz_strategyselector_register(opaque, "hor_sad", "avx2", 40, &hor_sad_avx2);