diff --git a/src/strategies/avx2/encode_coding_tree-avx2.c b/src/strategies/avx2/encode_coding_tree-avx2.c index 9d81b5f1..bd84d05b 100644 --- a/src/strategies/avx2/encode_coding_tree-avx2.c +++ b/src/strategies/avx2/encode_coding_tree-avx2.c @@ -38,18 +38,25 @@ #endif // _andn_u32 /* - * NOTE: Returns 10, not 11 like SSE and AVX comparisons do, as the bit pattern - * implying greaterness + * NOTE: Unlike SSE/AVX comparisons that would return 11 or 00 for gt/lte, + * this'll use 1x and 0x as bit patterns (x: garbage). A couple extra + * instructions will get you 11 and 00 if you need to use this as a mask + * somewhere at some point, but we don't need this right now. + * + * I'd love to draw a logic circuit here to describe this, but I can't. Two + * 2-bit uints can be compared for greaterness by first comparing their high + * bits using AND-NOT; (x AND (NOT y)) == 1 if x > y. If A_hi > B_hi, A > B. + * If A_hi == B_hi AND A_lo > B_lo, A > B. Otherwise, A <= B. It's really + * simple when drawn on paper, but quite messy on a general-purpose ALU. But + * look, just five instructions! */ -static INLINE uint32_t _mm32_cmpgt_epu2(uint32_t a, uint32_t b) +static INLINE uint32_t u32vec_cmpgt_epu2(uint32_t a, uint32_t b) { - const uint32_t himask = 0xaaaaaaaa; - uint32_t a_gt_b = _andn_u32(b, a); uint32_t a_ne_b = a ^ b; uint32_t a_gt_b_sh = a_gt_b << 1; uint32_t lobit_tiebrk_hi = _andn_u32(a_ne_b, a_gt_b_sh); - uint32_t res = (a_gt_b | lobit_tiebrk_hi) & himask; + uint32_t res = a_gt_b | lobit_tiebrk_hi; return res; } @@ -619,17 +626,17 @@ void kvz_encode_coeff_nxn_avx2(encoder_state_t * const state, base_levels &= c1flag_number_mask; base_levels |= (ones_base4 & c1flag_number_mask_inv); - uint32_t encode_decisions = _mm32_cmpgt_epu2(base_levels, abs_coeffs_base4); + uint32_t encode_decisions = u32vec_cmpgt_epu2(base_levels, abs_coeffs_base4); for (idx = 0; idx < num_non_zero; idx++) { uint32_t shamt = idx << 1; - uint32_t dont_encode_curr = (encode_decisions >> shamt) & 3; + uint32_t dont_encode_curr = (encode_decisions >> shamt); int16_t base_level = (base_levels >> shamt) & 3; uint16_t curr_abs_coeff = abs_coeff[idx]; - if (!dont_encode_curr) { + if (!(dont_encode_curr & 2)) { uint16_t level_diff = curr_abs_coeff - base_level; if (!cabac->only_count && (encoder->cfg.crypto_features & KVZ_CRYPTO_TRANSF_COEFFS)) { kvz_cabac_write_coeff_remain_encry(state, cabac, level_diff, go_rice_param, base_level);