Redo sao_band_ddistortion_avx2

Avoid branching and do the entire thing on 32 pixels at once in YMMs.
Also make the sao_bands function parameter const.
This commit is contained in:
Pauli Oikkonen 2019-07-24 15:30:00 +03:00
parent 2827c3e3ab
commit c18adc5ee0
4 changed files with 195 additions and 93 deletions

View file

@ -24,6 +24,7 @@
#include <immintrin.h>
#include <nmmintrin.h>
#include "strategies/generic/sao_band_ddistortion.h"
#include "cu.h"
#include "encoder.h"
#include "encoderstate.h"
@ -436,8 +437,7 @@ static void sao_reconstruct_color_avx2(const encoder_control_t * const encoder,
bool use_8_elements = (block_width - x) >= 8;
switch (use_8_elements) {
case true:;
if (use_8_elements) {
const kvz_pixel *c_data = &rec_data[y * stride + x];
__m128i vector_a_epi8 = _mm_loadl_epi64((__m128i*)&c_data[a_ofs.y * stride + a_ofs.x]);
@ -465,9 +465,8 @@ static void sao_reconstruct_color_avx2(const encoder_control_t * const encoder,
// Store 64-bits from vector to memory
_mm_storel_epi64((__m128i*)&(new_rec_data[y * new_stride + x]), _mm256_castsi256_si128(temp_epi8));
break;
default:;
} else {
for (int i = x; i < (block_width); ++i) {
const kvz_pixel *c_data = &rec_data[y * stride + i];
@ -481,94 +480,177 @@ static void sao_reconstruct_color_avx2(const encoder_control_t * const encoder,
int eo_cat = sao_calc_eo_cat(a, b, c);
new_data[0] = (kvz_pixel)CLIP(0, (1 << KVZ_BIT_DEPTH) - 1, c_data[0] + sao->offsets[eo_cat + offset_v]);
}
break;
}
}
}
}
}
static int sao_band_ddistortion_avx2(const encoder_state_t * const state,
const kvz_pixel *orig_data,
const kvz_pixel *rec_data,
int block_width,
int block_height,
int band_pos,
int sao_bands[4])
static INLINE __m256i srli_epi8(__m256i v, const uint32_t shift)
{
int y, x;
int shift = state->encoder_control->bitdepth - 5;
int sum = 0;
const uint8_t hibit_mask = 0xff >> shift;
const __m256i hibit_mask_256 = _mm256_set1_epi8(hibit_mask);
__m256i sum_epi32 = { 0 };
__m256i v_shifted = _mm256_srli_epi32(v, shift);
__m256i v_masked = _mm256_and_si256 (v_shifted, hibit_mask_256);
__m256i band_pos_epi32 = _mm256_set1_epi32(band_pos);
for (y = 0; y < block_height; ++y) {
for (x = 0; x < block_width; x += 8) {
bool use_8_elements = (block_width - x) >= 8;
return v_masked;
}
switch (use_8_elements) {
case true:;
//int band = (rec_data[y * block_width + x] >> shift) - band_pos;
static INLINE void cvt_epu8_epi16(const __m256i v, __m256i *res_lo, __m256i *res_hi)
{
const __m256i zero = _mm256_setzero_si256();
*res_lo = _mm256_unpacklo_epi8(v, zero);
*res_hi = _mm256_unpackhi_epi8(v, zero);
}
__m256i band_epi32 = _mm256_cvtepu8_epi32(_mm_loadl_epi64((__m128i*)&(rec_data[y * block_width + x])));
band_epi32 = _mm256_srli_epi32(band_epi32, shift);
band_epi32 = _mm256_sub_epi32(band_epi32, band_pos_epi32);
static INLINE void cvt_epi8_epi16(const __m256i v, __m256i *res_lo, __m256i *res_hi)
{
const __m256i zero = _mm256_setzero_si256();
__m256i signs = _mm256_cmpgt_epi8 (zero, v);
*res_lo = _mm256_unpacklo_epi8(v, signs);
*res_hi = _mm256_unpackhi_epi8(v, signs);
}
static int32_t sao_band_ddistortion_avx2(const encoder_state_t *state,
const uint8_t *orig_data,
const uint8_t *rec_data,
int32_t block_width,
int32_t block_height,
int32_t band_pos,
const int32_t sao_bands[4])
{
const uint32_t bitdepth = 8;
const uint32_t shift = bitdepth - 5;
__m256i vector_mask = _mm256_cmpeq_epi32(_mm256_and_si256(_mm256_set1_epi32(~3), band_epi32), _mm256_setzero_si256());
// Clamp band_pos to 32 from above. It'll be subtracted from the shifted
// rec_data values, which in 8-bit depth will always be clamped to [0, 31],
// so if it ever exceeds 32, all the band values will be negative and
// ignored. Ditto for less than -4.
__m128i bp_128 = _mm_cvtsi32_si128 (band_pos);
__m128i hilimit = _mm_cvtsi32_si128 (32);
__m128i lolimit = _mm_cvtsi32_si128 (-4);
__m256i offset_epi32 = _mm256_permutevar8x32_epi32(_mm256_castsi128_si256(_mm_loadu_si128((__m128i*)sao_bands)), band_epi32);
bp_128 = _mm_min_epi8 (bp_128, hilimit);
bp_128 = _mm_max_epi8 (bp_128, lolimit);
offset_epi32 = _mm256_and_si256(vector_mask, offset_epi32);
__m256i bp_256 = _mm256_broadcastb_epi8(bp_128);
__m256i orig_data_epi32 = _mm256_cvtepu8_epi32(_mm_loadl_epi64((__m128i*)&(orig_data[y * block_width + x])));
__m256i rec_data_epi32 = _mm256_cvtepu8_epi32(_mm_loadl_epi64((__m128i*)&(rec_data[y * block_width + x])));
__m256i diff_epi32 = _mm256_sub_epi32(orig_data_epi32, rec_data_epi32);
// LSBs of each SAO band dword, the band values must fit in 8 bits anyway
// (this will be checked later)
const __m128i sb_shufmask = _mm_set1_epi32(0x0c080400);
__m256i diff_minus_offset_epi32 = _mm256_sub_epi32(diff_epi32, offset_epi32);
__m128i sbs_32 = _mm_loadu_si128((const __m128i *)sao_bands);
__m256i temp_sum = _mm256_sub_epi32(_mm256_mullo_epi32(diff_minus_offset_epi32, diff_minus_offset_epi32), _mm256_mullo_epi32(diff_epi32, diff_epi32));
__m128i sbs_8 = _mm_shuffle_epi8 (sbs_32, sb_shufmask);
__m256i sb_256 = _mm256_broadcastsi128_si256 (sbs_8);
sum_epi32 = _mm256_add_epi32(sum_epi32, temp_sum);
// Compare most significant 25 bits of SAO bands to the sign bit to assert
// that the band is between -128 and 127 (only comparing 24 would fail to
// detect values of 128...255)
__m128i sb_ms25b = _mm_srai_epi32 (sbs_32, 7);
__m128i sb_signs = _mm_srai_epi32 (sbs_32, 31);
__m128i sbs_ok_v = _mm_cmpeq_epi32 (sb_ms25b, sb_signs);
uint16_t sbs_ok = _mm_movemask_epi8 (sbs_ok_v);
// These should trigger like, never, at least the later condition of block
// not being a multiple of 32 wide. Rather safe than sorry though, huge SAO
// bands are more tricky of these two because the algorithm needs a complete
// reimplementation to work on 16-bit values.
if (sbs_ok != 0xffff)
goto use_generic;
break;
// If VVC or something will start using SAO on blocks with width a multiple
// of 16, feel free to implement a XMM variant of this algorithm
if ((block_width & 31) != 0)
goto use_generic;
default:;
for (x; x < block_width; ++x) {
int band = (rec_data[y * block_width + x] >> shift) - band_pos;
int offset = 0;
if (band >= 0 && band < 4) {
offset = sao_bands[band];
}
if (offset != 0) {
int diff = orig_data[y * block_width + x] - rec_data[y * block_width + x];
// Offset is applied to reconstruction, so it is subtracted from diff.
sum += (diff - offset) * (diff - offset) - diff * diff;
}
const __m256i zero = _mm256_setzero_si256();
const __m256i threes = _mm256_set1_epi8 (3);
const __m256i negate_hiword = _mm256_set1_epi32(0xffff0001);
__m256i sum = _mm256_setzero_si256();
for (uint32_t y = 0; y < block_height; y++) {
for (uint32_t x = 0; x < block_width; x += 32) {
const int32_t curr_pos = y * block_width + x;
__m256i rd = _mm256_loadu_si256((const __m256i *)( rec_data + curr_pos));
__m256i orig = _mm256_loadu_si256((const __m256i *)(orig_data + curr_pos));
__m256i orig_lo, orig_hi, rd_lo, rd_hi;
cvt_epu8_epi16(orig, &orig_lo, &orig_hi);
cvt_epu8_epi16(rd, &rd_lo, &rd_hi);
// The shift will clamp band to 0...31; band_pos on the other
// hand is always between 0...32, so band will be -1...31. Anything
// below zero is ignored, so we can clamp band_pos to 32.
__m256i rd_divd = srli_epi8 (rd, shift);
__m256i band = _mm256_sub_epi8 (rd_divd, bp_256);
// Force all <0 or >3 bands to 0xff, which will zero the shuffle result
__m256i band_lt_0 = _mm256_cmpgt_epi8 (zero, band);
__m256i band_gt_3 = _mm256_cmpgt_epi8 (band, threes);
__m256i band_inv = _mm256_or_si256 (band_lt_0, band_gt_3);
band = _mm256_or_si256 (band, band_inv);
__m256i offsets = _mm256_shuffle_epi8 (sb_256, band);
__m256i offsets_lo, offsets_hi;
cvt_epi8_epi16(offsets, &offsets_lo, &offsets_hi);
__m256i offsets_0_lo = _mm256_cmpeq_epi16 (offsets_lo, zero);
__m256i offsets_0_hi = _mm256_cmpeq_epi16 (offsets_hi, zero);
__m256i diff_lo = _mm256_sub_epi16 (orig_lo, rd_lo);
__m256i diff_hi = _mm256_sub_epi16 (orig_hi, rd_hi);
__m256i delta_lo = _mm256_sub_epi16 (diff_lo, offsets_lo);
__m256i delta_hi = _mm256_sub_epi16 (diff_hi, offsets_hi);
diff_lo = _mm256_andnot_si256 (offsets_0_lo, diff_lo);
diff_hi = _mm256_andnot_si256 (offsets_0_hi, diff_hi);
delta_lo = _mm256_andnot_si256 (offsets_0_lo, delta_lo);
delta_hi = _mm256_andnot_si256 (offsets_0_hi, delta_hi);
__m256i dd0_lo = _mm256_unpacklo_epi16(delta_lo, diff_lo);
__m256i dd0_hi = _mm256_unpackhi_epi16(delta_lo, diff_lo);
__m256i dd1_lo = _mm256_unpacklo_epi16(delta_hi, diff_hi);
__m256i dd1_hi = _mm256_unpackhi_epi16(delta_hi, diff_hi);
__m256i dd0_lo_n = _mm256_sign_epi16 (dd0_lo, negate_hiword);
__m256i dd0_hi_n = _mm256_sign_epi16 (dd0_hi, negate_hiword);
__m256i dd1_lo_n = _mm256_sign_epi16 (dd1_lo, negate_hiword);
__m256i dd1_hi_n = _mm256_sign_epi16 (dd1_hi, negate_hiword);
__m256i sum0_lo = _mm256_madd_epi16 (dd0_lo, dd0_lo_n);
__m256i sum0_hi = _mm256_madd_epi16 (dd0_hi, dd0_hi_n);
__m256i sum1_lo = _mm256_madd_epi16 (dd1_lo, dd1_lo_n);
__m256i sum1_hi = _mm256_madd_epi16 (dd1_hi, dd1_hi_n);
__m256i sum0 = _mm256_add_epi32 (sum0_lo, sum0_hi);
__m256i sum1 = _mm256_add_epi32 (sum1_lo, sum1_hi);
__m256i curr_sum = _mm256_add_epi32 (sum0, sum1);
sum = _mm256_add_epi32 (sum, curr_sum);
}
}
// Horizontal sum of 8x32 YMM, nothing special here
__m256i sum2 = _mm256_permute4x64_epi64(sum, _MM_SHUFFLE(1, 0, 3, 2));
__m256i sum3 = _mm256_add_epi32 (sum, sum2);
__m256i sum4 = _mm256_shuffle_epi32 (sum3, _MM_SHUFFLE(1, 0, 3, 2));
__m256i sum5 = _mm256_add_epi32 (sum3, sum4);
__m256i sum6 = _mm256_shuffle_epi32 (sum5, _MM_SHUFFLE(2, 3, 0, 1));
__m256i sum7 = _mm256_add_epi32 (sum5, sum6);
__m128i sum8 = _mm256_castsi256_si128 (sum7);
int32_t sum9 = _mm_cvtsi128_si32 (sum8);
return sum9;
}
}
//Full horizontal sum
sum_epi32 = _mm256_add_epi32(sum_epi32, _mm256_castsi128_si256(_mm256_extracti128_si256(sum_epi32, 1)));
sum_epi32 = _mm256_add_epi32(sum_epi32, _mm256_shuffle_epi32(sum_epi32, _MM_SHUFFLE(1, 0, 3, 2)));
sum_epi32 = _mm256_add_epi32(sum_epi32, _mm256_shuffle_epi32(sum_epi32, _MM_SHUFFLE(0, 1, 0, 1)));
sum += _mm_cvtsi128_si32(_mm256_castsi256_si128(sum_epi32));
return sum;
use_generic:
return sao_band_ddistortion_generic(state, orig_data, rec_data, block_width,
block_height, band_pos, sao_bands);
}
#endif //COMPILE_INTEL_AVX2

View file

@ -19,6 +19,7 @@
****************************************************************************/
#include "strategies/generic/sao-generic.h"
#include "strategies/generic/sao_band_ddistortion.h"
#include "cu.h"
#include "encoder.h"
@ -156,35 +157,6 @@ static void sao_reconstruct_color_generic(const encoder_control_t * const encode
}
static int sao_band_ddistortion_generic(const encoder_state_t * const state,
const kvz_pixel *orig_data,
const kvz_pixel *rec_data,
int block_width,
int block_height,
int band_pos,
int sao_bands[4])
{
int y, x;
int shift = state->encoder_control->bitdepth-5;
int sum = 0;
for (y = 0; y < block_height; ++y) {
for (x = 0; x < block_width; ++x) {
int band = (rec_data[y * block_width + x] >> shift) - band_pos;
int offset = 0;
if (band >= 0 && band < 4) {
offset = sao_bands[band];
}
if (offset != 0) {
int diff = orig_data[y * block_width + x] - rec_data[y * block_width + x];
// Offset is applied to reconstruction, so it is subtracted from diff.
sum += (diff - offset) * (diff - offset) - diff * diff;
}
}
}
return sum;
}
int kvz_strategy_register_sao_generic(void* opaque, uint8_t bitdepth)
{

View file

@ -0,0 +1,48 @@
#ifndef SAO_BAND_DDISTORTION_H_
#define SAO_BAND_DDISTORTION_H_
// #include "encoder.h"
#include "encoderstate.h"
#include "kvazaar.h"
#include "sao.h"
static int sao_band_ddistortion_generic(const encoder_state_t * const state,
const kvz_pixel *orig_data,
const kvz_pixel *rec_data,
int block_width,
int block_height,
int band_pos,
const int sao_bands[4])
{
int y, x;
int shift = state->encoder_control->bitdepth-5;
int sum = 0;
for (y = 0; y < block_height; ++y) {
for (x = 0; x < block_width; ++x) {
const int32_t curr_pos = y * block_width + x;
kvz_pixel rec = rec_data[curr_pos];
kvz_pixel orig = orig_data[curr_pos];
int32_t band = (rec >> shift) - band_pos;
int32_t offset = 0;
if (band >= 0 && band <= 3) {
offset = sao_bands[band];
}
// Offset is applied to reconstruction, so it is subtracted from diff.
int32_t diff = orig - rec;
int32_t delta = diff - offset;
int32_t dmask = (offset == 0) ? -1 : 0;
diff &= ~dmask;
delta &= ~dmask;
sum += delta * delta - diff * diff;
}
}
return sum;
}
#endif

View file

@ -51,7 +51,7 @@ typedef void (sao_reconstruct_color_func)(const encoder_control_t * const encode
typedef int (sao_band_ddistortion_func)(const encoder_state_t * const state, const kvz_pixel *orig_data, const kvz_pixel *rec_data,
int block_width, int block_height,
int band_pos, int sao_bands[4]);
int band_pos, const int sao_bands[4]);
// Declare function pointers.
extern sao_edge_ddistortion_func * kvz_sao_edge_ddistortion;