diff --git a/src/strategies/avx2/encode_coding_tree-avx2.c b/src/strategies/avx2/encode_coding_tree-avx2.c index b0666e8c..204bbe7d 100644 --- a/src/strategies/avx2/encode_coding_tree-avx2.c +++ b/src/strategies/avx2/encode_coding_tree-avx2.c @@ -26,6 +26,22 @@ #include "kvz_math.h" #include +/* + * NOTE: Returns 10, not 11 like SSE and AVX comparisons do, as the bit pattern + * implying greaterness + */ +static INLINE uint32_t _mm32_cmpgt_epu2(uint32_t a, uint32_t b) +{ + const uint32_t himask = 0xaaaaaaaa; + + uint32_t a_gt_b = __andn_u32(b, a); + uint32_t a_ne_b = a ^ b; + uint32_t a_gt_b_sh = a_gt_b << 1; + uint32_t lobit_tiebrk_hi = __andn_u32(a_ne_b, a_gt_b_sh); + uint32_t res = (a_gt_b | lobit_tiebrk_hi) & himask; + return res; +} + /** * \brief Context derivation process of coeff_abs_significant_flag, * parallelized to handle 16 coeffs at once @@ -433,22 +449,24 @@ void kvz_encode_coeff_nxn_avx2(encoder_state_t * const state, assert(scan_cg_last >= 0); ALIGNED(64) int16_t coeff_reord[LCU_WIDTH * LCU_WIDTH]; - for (int32_t i = scan_cg_last; i >= 0; i--) { - int32_t subpos = i * 16; - __m256i coeffs_r = scanord_read_vector(coeff, scan, scan_mode, subpos, width); - _mm256_store_si256((__m256i *)(coeff_reord + subpos), coeffs_r); + uint32_t pos_last, scan_pos_last; + + { + __m256i coeffs_r; + for (int32_t i = 0; i <= scan_cg_last; i++) { + int32_t subpos = i * 16; + coeffs_r = scanord_read_vector(coeff, scan, scan_mode, subpos, width); + _mm256_store_si256((__m256i *)(coeff_reord + subpos), coeffs_r); + } + + // Find the last coeff by going backwards in scan order. + uint32_t baseaddr = scan_cg_last * 16; + __m256i cur_coeffs_zeros = _mm256_cmpeq_epi16(coeffs_r, zero); + uint32_t nz_bytes = ~(_mm256_movemask_epi8(cur_coeffs_zeros)); + scan_pos_last = baseaddr + ((31 - _lzcnt_u32(nz_bytes)) >> 1); + pos_last = scan[scan_pos_last]; } - // Find the last coeff by going backwards in scan order. - uint32_t scan_pos_last; - uint32_t baseaddr = scan_cg_last * 16; - __m256i cur_coeffs = _mm256_loadu_si256((__m256i *)(coeff_reord + baseaddr)); - __m256i cur_coeffs_zeros = _mm256_cmpeq_epi16(cur_coeffs, zero); - uint32_t nz_bytes = ~(_mm256_movemask_epi8(cur_coeffs_zeros)); - scan_pos_last = baseaddr + ((31 - _lzcnt_u32(nz_bytes)) >> 1); - - int pos_last = scan[scan_pos_last]; - // transform skip flag if(width == 4 && encoder->cfg.trskip_enable) { cabac->cur_ctx = (type == 0) ? &(cabac->ctx.transform_skip_model_luma) : &(cabac->ctx.transform_skip_model_chroma); @@ -558,7 +576,7 @@ void kvz_encode_coeff_nxn_avx2(encoder_state_t * const state, } if (curr_sig) { - abs_coeff[num_non_zero] = abs_coeff_buf_sb[id]; + abs_coeff[num_non_zero] = abs_coeff_buf_sb[id]; coeff_signs = 2 * coeff_signs + curr_coeff_sign; num_non_zero++; } @@ -572,9 +590,9 @@ void kvz_encode_coeff_nxn_avx2(encoder_state_t * const state, && !encoder->cfg.lossless; uint32_t ctx_set = (i > 0 && type == 0) ? 2 : 0; cabac_ctx_t *base_ctx_mod; - int32_t num_c1_flag, first_c2_flag_idx, idx, first_coeff2; + int32_t num_c1_flag, first_c2_flag_idx, idx; - __m256i abs_coeffs = _mm256_loadu_si256((__m256i *)abs_coeff); + __m256i abs_coeffs = _mm256_load_si256((__m256i *)abs_coeff); __m256i coeffs_gt1 = _mm256_cmpgt_epi16(abs_coeffs, ones); __m256i coeffs_gt2 = _mm256_cmpgt_epi16(abs_coeffs, twos); uint32_t coeffs_gt1_bits = _mm256_movemask_epi8(coeffs_gt1); @@ -600,8 +618,8 @@ void kvz_encode_coeff_nxn_avx2(encoder_state_t * const state, */ const uint32_t c1s_pattern = 0xfffffffe; uint32_t n_nongt1_bits = _tzcnt_u32(coeffs_gt1_bits); - uint32_t c1s_nextiter = _bzhi_u32(c1s_pattern, n_nongt1_bits); - first_c2_flag_idx = n_nongt1_bits >> 1; + uint32_t c1s_nextiter = _bzhi_u32(c1s_pattern, n_nongt1_bits); + first_c2_flag_idx = n_nongt1_bits >> 1; c1 = 1; for (idx = 0; idx < num_c1_flag; idx++) { @@ -637,26 +655,65 @@ void kvz_encode_coeff_nxn_avx2(encoder_state_t * const state, CABAC_BINS_EP(cabac, coeff_signs, nnz, "coeff_sign_flag"); if (c1 == 0 || num_non_zero > C1FLAG_NUMBER) { - first_coeff2 = 1; + + const __m256i ones = _mm256_set1_epi16(1); + const __m256i threes = _mm256_set1_epi16(3); + + __m256i abs_coeffs_gt1 = _mm256_cmpgt_epi16 (abs_coeffs, ones); + uint32_t acgt1_bits = _mm256_movemask_epi8(abs_coeffs_gt1); + uint32_t first_acgt1_bpos = _tzcnt_u32(acgt1_bits); + + /* + * Extract low two bits (X and Y) from each coeff clipped at 3: + * abs_coeffs_max3: 0000 0000 0000 00XY + * abs_coeffs_tmp1: 0000 000X Y000 0000 + * abs_coeffs_tmp2: XXXX XXXX YYYY YYYY inverted + * + * abs_coeffs can be clipped to [0, 3] for this because it will only + * be compared whether it's >= X, where X is between 0 and 3 + */ + __m256i abs_coeffs_max3 = _mm256_min_epu16 (abs_coeffs, threes); + __m256i abs_coeffs_tmp1 = _mm256_slli_epi16 (abs_coeffs_max3, 7); + __m256i abs_coeffs_tmp2 = _mm256_cmpeq_epi8 (abs_coeffs_tmp1, zero); + uint32_t abs_coeffs_base4 = ~(_mm256_movemask_epi8(abs_coeffs_tmp2)); + + const uint32_t ones_base4 = 0x55555555; + const uint32_t twos_base4 = 0xaaaaaaaa; + + const uint32_t c1flag_number_mask_inv = 0xffffffff << (C1FLAG_NUMBER << 1); + const uint32_t c1flag_number_mask = ~c1flag_number_mask_inv; + + // The addition will not overflow between 2-bit atoms because + // first_coeff2s will only be 1 or 0, and the other addend is 2 + uint32_t first_coeff2s = _bzhi_u32(ones_base4, first_acgt1_bpos + 2); + uint32_t base_levels = first_coeff2s + twos_base4; + + base_levels &= c1flag_number_mask; + base_levels |= (ones_base4 & c1flag_number_mask_inv); + + uint32_t encode_decisions = _mm32_cmpgt_epu2(base_levels, abs_coeffs_base4); for (idx = 0; idx < num_non_zero; idx++) { - int32_t base_level = (idx < C1FLAG_NUMBER) ? (2 + first_coeff2) : 1; - if (abs_coeff[idx] >= base_level) { + uint32_t shamt = idx << 1; + uint32_t dont_encode_curr = (encode_decisions >> shamt) & 3; + int16_t base_level = (base_levels >> shamt) & 3; + + uint16_t curr_abs_coeff = abs_coeff[idx]; + + if (!dont_encode_curr) { + uint16_t level_diff = curr_abs_coeff - base_level; if (!cabac->only_count && (encoder->cfg.crypto_features & KVZ_CRYPTO_TRANSF_COEFFS)) { - kvz_cabac_write_coeff_remain_encry(state, cabac, abs_coeff[idx] - base_level, go_rice_param, base_level); + kvz_cabac_write_coeff_remain_encry(state, cabac, level_diff, go_rice_param, base_level); } else { - kvz_cabac_write_coeff_remain(cabac, abs_coeff[idx] - base_level, go_rice_param); + kvz_cabac_write_coeff_remain(cabac, level_diff, go_rice_param); } - if (abs_coeff[idx] > 3 * (1 << go_rice_param)) { + if (curr_abs_coeff > 3 * (1 << go_rice_param)) { go_rice_param = MIN(go_rice_param + 1, 4); } } - if (abs_coeff[idx] >= 2) { - first_coeff2 = 0; - } } } }