From c18adc5ee0294ff70d5e6b14241364d819f1b206 Mon Sep 17 00:00:00 2001 From: Pauli Oikkonen Date: Wed, 24 Jul 2019 15:30:00 +0300 Subject: [PATCH] Redo sao_band_ddistortion_avx2 Avoid branching and do the entire thing on 32 pixels at once in YMMs. Also make the sao_bands function parameter const. --- src/strategies/avx2/sao-avx2.c | 208 ++++++++++++------ src/strategies/generic/sao-generic.c | 30 +-- src/strategies/generic/sao_band_ddistortion.h | 48 ++++ src/strategies/strategies-sao.h | 2 +- 4 files changed, 195 insertions(+), 93 deletions(-) create mode 100644 src/strategies/generic/sao_band_ddistortion.h diff --git a/src/strategies/avx2/sao-avx2.c b/src/strategies/avx2/sao-avx2.c index c9969e8d..8ccd9ea8 100644 --- a/src/strategies/avx2/sao-avx2.c +++ b/src/strategies/avx2/sao-avx2.c @@ -24,6 +24,7 @@ #include #include +#include "strategies/generic/sao_band_ddistortion.h" #include "cu.h" #include "encoder.h" #include "encoderstate.h" @@ -436,8 +437,7 @@ static void sao_reconstruct_color_avx2(const encoder_control_t * const encoder, bool use_8_elements = (block_width - x) >= 8; - switch (use_8_elements) { - case true:; + if (use_8_elements) { const kvz_pixel *c_data = &rec_data[y * stride + x]; __m128i vector_a_epi8 = _mm_loadl_epi64((__m128i*)&c_data[a_ofs.y * stride + a_ofs.x]); @@ -465,9 +465,8 @@ static void sao_reconstruct_color_avx2(const encoder_control_t * const encoder, // Store 64-bits from vector to memory _mm_storel_epi64((__m128i*)&(new_rec_data[y * new_stride + x]), _mm256_castsi256_si128(temp_epi8)); - break; - default:; + } else { for (int i = x; i < (block_width); ++i) { const kvz_pixel *c_data = &rec_data[y * stride + i]; @@ -481,94 +480,177 @@ static void sao_reconstruct_color_avx2(const encoder_control_t * const encoder, int eo_cat = sao_calc_eo_cat(a, b, c); new_data[0] = (kvz_pixel)CLIP(0, (1 << KVZ_BIT_DEPTH) - 1, c_data[0] + sao->offsets[eo_cat + offset_v]); - } - break; } - - } - - - } } } -static int sao_band_ddistortion_avx2(const encoder_state_t * const state, - const kvz_pixel *orig_data, - const kvz_pixel *rec_data, - int block_width, - int block_height, - int band_pos, - int sao_bands[4]) +static INLINE __m256i srli_epi8(__m256i v, const uint32_t shift) { - int y, x; - int shift = state->encoder_control->bitdepth - 5; - int sum = 0; + const uint8_t hibit_mask = 0xff >> shift; + const __m256i hibit_mask_256 = _mm256_set1_epi8(hibit_mask); - __m256i sum_epi32 = { 0 }; + __m256i v_shifted = _mm256_srli_epi32(v, shift); + __m256i v_masked = _mm256_and_si256 (v_shifted, hibit_mask_256); - __m256i band_pos_epi32 = _mm256_set1_epi32(band_pos); - for (y = 0; y < block_height; ++y) { - for (x = 0; x < block_width; x += 8) { - bool use_8_elements = (block_width - x) >= 8; + return v_masked; +} - switch (use_8_elements) { - case true:; - //int band = (rec_data[y * block_width + x] >> shift) - band_pos; +static INLINE void cvt_epu8_epi16(const __m256i v, __m256i *res_lo, __m256i *res_hi) +{ + const __m256i zero = _mm256_setzero_si256(); + *res_lo = _mm256_unpacklo_epi8(v, zero); + *res_hi = _mm256_unpackhi_epi8(v, zero); +} - __m256i band_epi32 = _mm256_cvtepu8_epi32(_mm_loadl_epi64((__m128i*)&(rec_data[y * block_width + x]))); - band_epi32 = _mm256_srli_epi32(band_epi32, shift); - band_epi32 = _mm256_sub_epi32(band_epi32, band_pos_epi32); +static INLINE void cvt_epi8_epi16(const __m256i v, __m256i *res_lo, __m256i *res_hi) +{ + const __m256i zero = _mm256_setzero_si256(); + __m256i signs = _mm256_cmpgt_epi8 (zero, v); + *res_lo = _mm256_unpacklo_epi8(v, signs); + *res_hi = _mm256_unpackhi_epi8(v, signs); +} +static int32_t sao_band_ddistortion_avx2(const encoder_state_t *state, + const uint8_t *orig_data, + const uint8_t *rec_data, + int32_t block_width, + int32_t block_height, + int32_t band_pos, + const int32_t sao_bands[4]) +{ + const uint32_t bitdepth = 8; + const uint32_t shift = bitdepth - 5; - __m256i vector_mask = _mm256_cmpeq_epi32(_mm256_and_si256(_mm256_set1_epi32(~3), band_epi32), _mm256_setzero_si256()); + // Clamp band_pos to 32 from above. It'll be subtracted from the shifted + // rec_data values, which in 8-bit depth will always be clamped to [0, 31], + // so if it ever exceeds 32, all the band values will be negative and + // ignored. Ditto for less than -4. + __m128i bp_128 = _mm_cvtsi32_si128 (band_pos); + __m128i hilimit = _mm_cvtsi32_si128 (32); + __m128i lolimit = _mm_cvtsi32_si128 (-4); - __m256i offset_epi32 = _mm256_permutevar8x32_epi32(_mm256_castsi128_si256(_mm_loadu_si128((__m128i*)sao_bands)), band_epi32); + bp_128 = _mm_min_epi8 (bp_128, hilimit); + bp_128 = _mm_max_epi8 (bp_128, lolimit); - offset_epi32 = _mm256_and_si256(vector_mask, offset_epi32); + __m256i bp_256 = _mm256_broadcastb_epi8(bp_128); - __m256i orig_data_epi32 = _mm256_cvtepu8_epi32(_mm_loadl_epi64((__m128i*)&(orig_data[y * block_width + x]))); - __m256i rec_data_epi32 = _mm256_cvtepu8_epi32(_mm_loadl_epi64((__m128i*)&(rec_data[y * block_width + x]))); - __m256i diff_epi32 = _mm256_sub_epi32(orig_data_epi32, rec_data_epi32); + // LSBs of each SAO band dword, the band values must fit in 8 bits anyway + // (this will be checked later) + const __m128i sb_shufmask = _mm_set1_epi32(0x0c080400); - __m256i diff_minus_offset_epi32 = _mm256_sub_epi32(diff_epi32, offset_epi32); + __m128i sbs_32 = _mm_loadu_si128((const __m128i *)sao_bands); - __m256i temp_sum = _mm256_sub_epi32(_mm256_mullo_epi32(diff_minus_offset_epi32, diff_minus_offset_epi32), _mm256_mullo_epi32(diff_epi32, diff_epi32)); + __m128i sbs_8 = _mm_shuffle_epi8 (sbs_32, sb_shufmask); + __m256i sb_256 = _mm256_broadcastsi128_si256 (sbs_8); - sum_epi32 = _mm256_add_epi32(sum_epi32, temp_sum); + // Compare most significant 25 bits of SAO bands to the sign bit to assert + // that the band is between -128 and 127 (only comparing 24 would fail to + // detect values of 128...255) + __m128i sb_ms25b = _mm_srai_epi32 (sbs_32, 7); + __m128i sb_signs = _mm_srai_epi32 (sbs_32, 31); + __m128i sbs_ok_v = _mm_cmpeq_epi32 (sb_ms25b, sb_signs); + uint16_t sbs_ok = _mm_movemask_epi8 (sbs_ok_v); + // These should trigger like, never, at least the later condition of block + // not being a multiple of 32 wide. Rather safe than sorry though, huge SAO + // bands are more tricky of these two because the algorithm needs a complete + // reimplementation to work on 16-bit values. + if (sbs_ok != 0xffff) + goto use_generic; - break; + // If VVC or something will start using SAO on blocks with width a multiple + // of 16, feel free to implement a XMM variant of this algorithm + if ((block_width & 31) != 0) + goto use_generic; - default:; - for (x; x < block_width; ++x) { - int band = (rec_data[y * block_width + x] >> shift) - band_pos; - int offset = 0; - if (band >= 0 && band < 4) { - offset = sao_bands[band]; - } - if (offset != 0) { - int diff = orig_data[y * block_width + x] - rec_data[y * block_width + x]; - // Offset is applied to reconstruction, so it is subtracted from diff. - sum += (diff - offset) * (diff - offset) - diff * diff; - } - } - } + const __m256i zero = _mm256_setzero_si256(); + const __m256i threes = _mm256_set1_epi8 (3); + const __m256i negate_hiword = _mm256_set1_epi32(0xffff0001); + __m256i sum = _mm256_setzero_si256(); + for (uint32_t y = 0; y < block_height; y++) { + for (uint32_t x = 0; x < block_width; x += 32) { + const int32_t curr_pos = y * block_width + x; + __m256i rd = _mm256_loadu_si256((const __m256i *)( rec_data + curr_pos)); + __m256i orig = _mm256_loadu_si256((const __m256i *)(orig_data + curr_pos)); + + __m256i orig_lo, orig_hi, rd_lo, rd_hi; + cvt_epu8_epi16(orig, &orig_lo, &orig_hi); + cvt_epu8_epi16(rd, &rd_lo, &rd_hi); + + // The shift will clamp band to 0...31; band_pos on the other + // hand is always between 0...32, so band will be -1...31. Anything + // below zero is ignored, so we can clamp band_pos to 32. + __m256i rd_divd = srli_epi8 (rd, shift); + __m256i band = _mm256_sub_epi8 (rd_divd, bp_256); + + // Force all <0 or >3 bands to 0xff, which will zero the shuffle result + __m256i band_lt_0 = _mm256_cmpgt_epi8 (zero, band); + __m256i band_gt_3 = _mm256_cmpgt_epi8 (band, threes); + __m256i band_inv = _mm256_or_si256 (band_lt_0, band_gt_3); + + band = _mm256_or_si256 (band, band_inv); + + __m256i offsets = _mm256_shuffle_epi8 (sb_256, band); + + __m256i offsets_lo, offsets_hi; + cvt_epi8_epi16(offsets, &offsets_lo, &offsets_hi); + + __m256i offsets_0_lo = _mm256_cmpeq_epi16 (offsets_lo, zero); + __m256i offsets_0_hi = _mm256_cmpeq_epi16 (offsets_hi, zero); + + __m256i diff_lo = _mm256_sub_epi16 (orig_lo, rd_lo); + __m256i diff_hi = _mm256_sub_epi16 (orig_hi, rd_hi); + + __m256i delta_lo = _mm256_sub_epi16 (diff_lo, offsets_lo); + __m256i delta_hi = _mm256_sub_epi16 (diff_hi, offsets_hi); + + diff_lo = _mm256_andnot_si256 (offsets_0_lo, diff_lo); + diff_hi = _mm256_andnot_si256 (offsets_0_hi, diff_hi); + delta_lo = _mm256_andnot_si256 (offsets_0_lo, delta_lo); + delta_hi = _mm256_andnot_si256 (offsets_0_hi, delta_hi); + + __m256i dd0_lo = _mm256_unpacklo_epi16(delta_lo, diff_lo); + __m256i dd0_hi = _mm256_unpackhi_epi16(delta_lo, diff_lo); + __m256i dd1_lo = _mm256_unpacklo_epi16(delta_hi, diff_hi); + __m256i dd1_hi = _mm256_unpackhi_epi16(delta_hi, diff_hi); + + __m256i dd0_lo_n = _mm256_sign_epi16 (dd0_lo, negate_hiword); + __m256i dd0_hi_n = _mm256_sign_epi16 (dd0_hi, negate_hiword); + __m256i dd1_lo_n = _mm256_sign_epi16 (dd1_lo, negate_hiword); + __m256i dd1_hi_n = _mm256_sign_epi16 (dd1_hi, negate_hiword); + + __m256i sum0_lo = _mm256_madd_epi16 (dd0_lo, dd0_lo_n); + __m256i sum0_hi = _mm256_madd_epi16 (dd0_hi, dd0_hi_n); + __m256i sum1_lo = _mm256_madd_epi16 (dd1_lo, dd1_lo_n); + __m256i sum1_hi = _mm256_madd_epi16 (dd1_hi, dd1_hi_n); + + __m256i sum0 = _mm256_add_epi32 (sum0_lo, sum0_hi); + __m256i sum1 = _mm256_add_epi32 (sum1_lo, sum1_hi); + __m256i curr_sum = _mm256_add_epi32 (sum0, sum1); + + sum = _mm256_add_epi32 (sum, curr_sum); } } + // Horizontal sum of 8x32 YMM, nothing special here + __m256i sum2 = _mm256_permute4x64_epi64(sum, _MM_SHUFFLE(1, 0, 3, 2)); + __m256i sum3 = _mm256_add_epi32 (sum, sum2); + __m256i sum4 = _mm256_shuffle_epi32 (sum3, _MM_SHUFFLE(1, 0, 3, 2)); + __m256i sum5 = _mm256_add_epi32 (sum3, sum4); + __m256i sum6 = _mm256_shuffle_epi32 (sum5, _MM_SHUFFLE(2, 3, 0, 1)); + __m256i sum7 = _mm256_add_epi32 (sum5, sum6); - //Full horizontal sum - sum_epi32 = _mm256_add_epi32(sum_epi32, _mm256_castsi128_si256(_mm256_extracti128_si256(sum_epi32, 1))); - sum_epi32 = _mm256_add_epi32(sum_epi32, _mm256_shuffle_epi32(sum_epi32, _MM_SHUFFLE(1, 0, 3, 2))); - sum_epi32 = _mm256_add_epi32(sum_epi32, _mm256_shuffle_epi32(sum_epi32, _MM_SHUFFLE(0, 1, 0, 1))); - sum += _mm_cvtsi128_si32(_mm256_castsi256_si128(sum_epi32)); - - return sum; + __m128i sum8 = _mm256_castsi256_si128 (sum7); + int32_t sum9 = _mm_cvtsi128_si32 (sum8); + return sum9; +use_generic: + return sao_band_ddistortion_generic(state, orig_data, rec_data, block_width, + block_height, band_pos, sao_bands); } #endif //COMPILE_INTEL_AVX2 diff --git a/src/strategies/generic/sao-generic.c b/src/strategies/generic/sao-generic.c index 46a6caf0..ff1a53b4 100644 --- a/src/strategies/generic/sao-generic.c +++ b/src/strategies/generic/sao-generic.c @@ -19,6 +19,7 @@ ****************************************************************************/ #include "strategies/generic/sao-generic.h" +#include "strategies/generic/sao_band_ddistortion.h" #include "cu.h" #include "encoder.h" @@ -156,35 +157,6 @@ static void sao_reconstruct_color_generic(const encoder_control_t * const encode } -static int sao_band_ddistortion_generic(const encoder_state_t * const state, - const kvz_pixel *orig_data, - const kvz_pixel *rec_data, - int block_width, - int block_height, - int band_pos, - int sao_bands[4]) -{ - int y, x; - int shift = state->encoder_control->bitdepth-5; - int sum = 0; - for (y = 0; y < block_height; ++y) { - for (x = 0; x < block_width; ++x) { - int band = (rec_data[y * block_width + x] >> shift) - band_pos; - int offset = 0; - if (band >= 0 && band < 4) { - offset = sao_bands[band]; - } - if (offset != 0) { - int diff = orig_data[y * block_width + x] - rec_data[y * block_width + x]; - // Offset is applied to reconstruction, so it is subtracted from diff. - sum += (diff - offset) * (diff - offset) - diff * diff; - } - } - } - - return sum; -} - int kvz_strategy_register_sao_generic(void* opaque, uint8_t bitdepth) { diff --git a/src/strategies/generic/sao_band_ddistortion.h b/src/strategies/generic/sao_band_ddistortion.h new file mode 100644 index 00000000..f5a68166 --- /dev/null +++ b/src/strategies/generic/sao_band_ddistortion.h @@ -0,0 +1,48 @@ +#ifndef SAO_BAND_DDISTORTION_H_ +#define SAO_BAND_DDISTORTION_H_ + +// #include "encoder.h" +#include "encoderstate.h" +#include "kvazaar.h" +#include "sao.h" + +static int sao_band_ddistortion_generic(const encoder_state_t * const state, + const kvz_pixel *orig_data, + const kvz_pixel *rec_data, + int block_width, + int block_height, + int band_pos, + const int sao_bands[4]) +{ + int y, x; + int shift = state->encoder_control->bitdepth-5; + int sum = 0; + for (y = 0; y < block_height; ++y) { + for (x = 0; x < block_width; ++x) { + const int32_t curr_pos = y * block_width + x; + + kvz_pixel rec = rec_data[curr_pos]; + kvz_pixel orig = orig_data[curr_pos]; + + int32_t band = (rec >> shift) - band_pos; + int32_t offset = 0; + if (band >= 0 && band <= 3) { + offset = sao_bands[band]; + } + // Offset is applied to reconstruction, so it is subtracted from diff. + + int32_t diff = orig - rec; + int32_t delta = diff - offset; + + int32_t dmask = (offset == 0) ? -1 : 0; + diff &= ~dmask; + delta &= ~dmask; + + sum += delta * delta - diff * diff; + } + } + + return sum; +} + +#endif diff --git a/src/strategies/strategies-sao.h b/src/strategies/strategies-sao.h index e469810e..0c58b719 100644 --- a/src/strategies/strategies-sao.h +++ b/src/strategies/strategies-sao.h @@ -51,7 +51,7 @@ typedef void (sao_reconstruct_color_func)(const encoder_control_t * const encode typedef int (sao_band_ddistortion_func)(const encoder_state_t * const state, const kvz_pixel *orig_data, const kvz_pixel *rec_data, int block_width, int block_height, - int band_pos, int sao_bands[4]); + int band_pos, const int sao_bands[4]); // Declare function pointers. extern sao_edge_ddistortion_func * kvz_sao_edge_ddistortion;