From d71eb77d1a2fc1e89d2995fce193f5e68690a2f5 Mon Sep 17 00:00:00 2001 From: Joose Sainio Date: Wed, 12 Apr 2023 10:41:37 +0300 Subject: [PATCH] [avx2] update_states_avx2 working --- src/dep_quant.c | 159 +++++++++++++++++++++++++++++------------------- 1 file changed, 95 insertions(+), 64 deletions(-) diff --git a/src/dep_quant.c b/src/dep_quant.c index 78a039bb..270e0639 100644 --- a/src/dep_quant.c +++ b/src/dep_quant.c @@ -158,11 +158,14 @@ typedef struct int8_t m_goRiceZero[12]; int8_t m_stateId[12]; uint32_t m_sigFracBitsArray[12][12][2]; - int32_t *m_gtxFracBitsArray[21]; + int32_t m_gtxFracBitsArray[21][6]; common_context* m_commonCtx; unsigned effWidth; unsigned effHeight; + + bool all_gte_four; + bool all_lt_four; } all_depquant_states; typedef struct @@ -577,14 +580,8 @@ static void check_rd_costs_avx2(const all_depquant_states* const state, const en rd_cost_a = _mm256_add_epi64(rd_cost_a, pq_a_delta_dist); rd_cost_b = _mm256_add_epi64(rd_cost_b, pq_b_delta_dist); - bool all_over_or_four = true; - bool all_under_four = true; - for (int i = 0; i < 4; i++) { - all_over_or_four &= state->m_remRegBins[start + i] >= 4; - all_under_four &= state->m_remRegBins[start + i] < 4; - } - if (all_over_or_four) { + if (state->all_gte_four) { if (pqDataA->absLevel[0] < 4 && pqDataA->absLevel[3] < 4) { __m128i offsets = _mm_set_epi32(18 + pqDataA->absLevel[3], 12 + pqDataA->absLevel[3], 6 + pqDataA->absLevel[0], 0 + pqDataA->absLevel[0]); __m128i coeff_frac_bits = _mm_i32gather_epi32(&state->m_coeffFracBits[start][0], offsets, 4); @@ -737,7 +734,7 @@ static void check_rd_costs_avx2(const all_depquant_states* const state, const en _mm256_storeu_epi64(temp_rd_cost_a, rd_cost_a); _mm256_storeu_epi64(temp_rd_cost_b, rd_cost_b); _mm256_storeu_epi64(temp_rd_cost_z, rd_cost_z); - } else if (all_under_four) { + } else if (state->all_lt_four) { __m128i scale_bits = _mm_set1_epi32(1 << SCALE_BITS); __m128i max_rice = _mm_set1_epi32(31); __m128i go_rice_zero = _mm_cvtepi8_epi32(_mm_loadu_epi8(&state->m_goRiceZero[start])); @@ -1274,6 +1271,8 @@ static INLINE void update_states_avx2( all_minus_one &= decisions->prevId[i] == -1; } int state_offset = ctxs->m_curr_state_offset; + __m256i rd_cost = _mm256_loadu_epi64(decisions->rdCost); + _mm256_storeu_epi64(&ctxs->m_allStates.m_rdCost[state_offset], rd_cost); if (all_above_minus_two) { bool rem_reg_all_gte_4 = true; @@ -1312,7 +1311,7 @@ static INLINE void update_states_avx2( memcpy(&state->m_goRicePar[state_offset], &go_rice_par_i, 4); - __m256i sbb_frac_bits = _mm256_i32gather_epi64(state->m_sbbFracBits, prv_states, 4); + __m256i sbb_frac_bits = _mm256_i32gather_epi64(state->m_sbbFracBits, prv_states, 8); _mm256_storeu_epi64(&state->m_sbbFracBits[state_offset][0], sbb_frac_bits); __m128i rem_reg_bins = _mm_i32gather_epi32(state->m_remRegBins, prv_states, 4); @@ -1321,7 +1320,7 @@ static INLINE void update_states_avx2( __m128i reg_bins_sub = _mm_set1_epi32(0); __m128i abs_level_smaller_than_two = _mm_cmplt_epi32(abs_level, _mm_set1_epi32(2)); - __m128i secondary = _mm_blendv_epi8(abs_level, _mm_set1_epi32(3), abs_level_smaller_than_two); + __m128i secondary = _mm_blendv_epi8(_mm_set1_epi32(3), abs_level, abs_level_smaller_than_two); __m128i rem_reg_bins_smaller_than_four = _mm_cmplt_epi32(rem_reg_bins, _mm_set1_epi32(4)); reg_bins_sub = _mm_blendv_epi8(secondary, reg_bins_sub, rem_reg_bins_smaller_than_four); @@ -1336,7 +1335,7 @@ static INLINE void update_states_avx2( rem_reg_all_lt4 = (bit_mask == 0xFFFF); for (int i = 0; i < 4; ++i) { - memcpy(state->m_absLevelsAndCtxInit[i], state->m_absLevelsAndCtxInit[prv_states_scalar[i]], 48 * sizeof(uint8_t)); + memcpy(state->m_absLevelsAndCtxInit[state_offset + i], state->m_absLevelsAndCtxInit[prv_states_scalar[i]], 48 * sizeof(uint8_t)); } } else if (all_minus_one) { @@ -1347,8 +1346,8 @@ static INLINE void update_states_avx2( __m128i rem_reg_bins = _mm_set1_epi32(a); __m128i sub = _mm_blendv_epi8( - abs_level, _mm_set1_epi32(3), + abs_level, _mm_cmplt_epi32(abs_level, _mm_set1_epi32(2)) ); rem_reg_bins = _mm_sub_epi32(rem_reg_bins, sub); @@ -1400,18 +1399,20 @@ static INLINE void update_states_avx2( uint8_t* levels = (uint8_t*)state->m_absLevelsAndCtxInit[state_offset + i]; levels[level_offset] = max_abs_s[i]; } - + state->all_gte_four = rem_reg_all_gte_4; + state->all_lt_four = rem_reg_all_lt4; if (rem_reg_all_gte_4) { - const __m128i last_two_bytes = _mm_set1_epi32(0xffff); - const __m128i last_byte = _mm_set1_epi32(0xff); + const __m128i first_two_bytes = _mm_set1_epi32(0xffff); + const __m128i first_byte = _mm_set1_epi32(0xff); const __m128i ones = _mm_set1_epi32(1); const uint32_t tinit_offset = MIN(level_offset - 1u, 15u) + 8; const __m128i levels_start_offsets = _mm_set_epi32(48 * 3, 48 * 2, 48 * 1, 48 * 0); + const __m128i ctx_start_offsets = _mm_srli_epi32(levels_start_offsets, 1); __m128i tinit = _mm_i32gather_epi32( state->m_absLevelsAndCtxInit[state_offset], - _mm_add_epi32(levels_start_offsets, _mm_set1_epi32(tinit_offset)), - 1); - tinit = _mm_and_epi32(tinit, last_two_bytes); + _mm_add_epi32(ctx_start_offsets, _mm_set1_epi32(tinit_offset)), + 2); + tinit = _mm_and_epi32(tinit, first_two_bytes); __m128i sum_abs1 = _mm_and_epi32(_mm_srli_epi32(tinit, 3), _mm_set1_epi32(31)); __m128i sum_num = _mm_and_epi32(tinit, _mm_set1_epi32(7)); @@ -1423,12 +1424,18 @@ static INLINE void update_states_avx2( levels, _mm_add_epi32(levels_start_offsets, _mm_set1_epi32(next_nb_info_ssb.inPos[4])), 1); + t = _mm_and_epi32(t, first_byte); + __m128i min_arg = _mm_min_epi32( + _mm_add_epi32(_mm_set1_epi32(4), _mm_and_epi32(t, ones)), + t + ); sum_abs1 = _mm_add_epi32( sum_abs1, - _mm_and_epi32(t, ones)); + min_arg + ); sum_num = _mm_add_epi32( sum_num, - _mm_min_epi32(_mm_and_epi32(t, last_byte), ones)); + _mm_min_epi32(_mm_and_epi32(t, first_byte), ones)); } case 4: { @@ -1436,12 +1443,18 @@ static INLINE void update_states_avx2( levels, _mm_add_epi32(levels_start_offsets, _mm_set1_epi32(next_nb_info_ssb.inPos[3])), 1); + t = _mm_and_epi32(t, first_byte); + __m128i min_arg = _mm_min_epi32( + _mm_add_epi32(_mm_set1_epi32(4), _mm_and_epi32(t, ones)), + t + ); sum_abs1 = _mm_add_epi32( sum_abs1, - _mm_and_epi32(t, ones)); + min_arg + ); sum_num = _mm_add_epi32( sum_num, - _mm_min_epi32(_mm_and_epi32(t, last_byte), ones)); + _mm_min_epi32(_mm_and_epi32(t, first_byte), ones)); } case 3: { @@ -1449,12 +1462,18 @@ static INLINE void update_states_avx2( levels, _mm_add_epi32(levels_start_offsets, _mm_set1_epi32(next_nb_info_ssb.inPos[2])), 1); + t = _mm_and_epi32(t, first_byte); + __m128i min_arg = _mm_min_epi32( + _mm_add_epi32(_mm_set1_epi32(4), _mm_and_epi32(t, ones)), + t + ); sum_abs1 = _mm_add_epi32( sum_abs1, - _mm_and_epi32(t, ones)); + min_arg + ); sum_num = _mm_add_epi32( sum_num, - _mm_min_epi32(_mm_and_epi32(t, last_byte), ones)); + _mm_min_epi32(_mm_and_epi32(t, first_byte), ones)); } case 2: { @@ -1462,39 +1481,52 @@ static INLINE void update_states_avx2( levels, _mm_add_epi32(levels_start_offsets, _mm_set1_epi32(next_nb_info_ssb.inPos[1])), 1); + t = _mm_and_epi32(t, first_byte); + __m128i min_arg = _mm_min_epi32( + _mm_add_epi32(_mm_set1_epi32(4), _mm_and_epi32(t, ones)), + t + ); sum_abs1 = _mm_add_epi32( sum_abs1, - _mm_and_epi32(t, ones)); + min_arg + ); sum_num = _mm_add_epi32( sum_num, - _mm_min_epi32(_mm_and_epi32(t, last_byte), ones)); + _mm_min_epi32(_mm_and_epi32(t, first_byte), ones)); } case 1: { __m128i t = _mm_i32gather_epi32( levels, _mm_add_epi32(levels_start_offsets, _mm_set1_epi32(next_nb_info_ssb.inPos[0])), 1); + t = _mm_and_epi32(t, first_byte); + __m128i min_arg = _mm_min_epi32( + _mm_add_epi32(_mm_set1_epi32(4), _mm_and_epi32(t, ones)), + t + ); sum_abs1 = _mm_add_epi32( sum_abs1, - _mm_and_epi32(t, ones)); + min_arg + ); sum_num = _mm_add_epi32( sum_num, - _mm_min_epi32(_mm_and_epi32(t, last_byte), ones)); + _mm_min_epi32(_mm_and_epi32(t, first_byte), ones)); } break; default: assert(0); } __m128i sum_gt1 = _mm_sub_epi32(sum_abs1, sum_num); - __m128i offsets = _mm_set_epi32(24 * 3, 24 * 2, 24 * 1, 24 * 0); + __m128i offsets = _mm_set_epi32(12 * 3, 12 * 2, 12 * 1, 12 * 0); offsets = _mm_add_epi32(offsets, _mm_set1_epi32(sigCtxOffsetNext)); __m128i temp = _mm_min_epi32( _mm_srli_epi32(_mm_add_epi32(sum_abs1, ones), 1), _mm_set1_epi32(3)); offsets = _mm_add_epi32(offsets, temp); - __m256i sig_frac_bits = _mm256_i32gather_epi64(state->m_sigFracBitsArray[state_offset][0], offsets, 4); + __m256i sig_frac_bits = _mm256_i32gather_epi64(state->m_sigFracBitsArray[state_offset][0], offsets, 8); _mm256_storeu_epi64(&state->m_sigFracBits[state_offset][0], sig_frac_bits); sum_gt1 = _mm_min_epi32(sum_gt1, _mm_set1_epi32(4)); + sum_gt1 = _mm_add_epi32(sum_gt1, _mm_set1_epi32(gtxCtxOffsetNext)); uint32_t sum_gt1_s[4]; _mm_storeu_epi32(sum_gt1_s, sum_gt1); for (int i = 0; i < 4; ++i) { @@ -1509,7 +1541,7 @@ static INLINE void update_states_avx2( levels, _mm_add_epi32(levels_start_offsets, _mm_set1_epi32(next_nb_info_ssb.inPos[4])), 1); - t = _mm_and_epi32(t, last_byte); + t = _mm_and_epi32(t, first_byte); sum_abs = _mm_add_epi32(sum_abs, t); } case 4: @@ -1518,7 +1550,7 @@ static INLINE void update_states_avx2( levels, _mm_add_epi32(levels_start_offsets, _mm_set1_epi32(next_nb_info_ssb.inPos[3])), 1); - t = _mm_and_epi32(t, last_byte); + t = _mm_and_epi32(t, first_byte); sum_abs = _mm_add_epi32(sum_abs, t); } case 3: @@ -1527,7 +1559,7 @@ static INLINE void update_states_avx2( levels, _mm_add_epi32(levels_start_offsets, _mm_set1_epi32(next_nb_info_ssb.inPos[2])), 1); - t = _mm_and_epi32(t, last_byte); + t = _mm_and_epi32(t, first_byte); sum_abs = _mm_add_epi32(sum_abs, t); } case 2: @@ -1536,7 +1568,7 @@ static INLINE void update_states_avx2( levels, _mm_add_epi32(levels_start_offsets, _mm_set1_epi32(next_nb_info_ssb.inPos[1])), 1); - t = _mm_and_epi32(t, last_byte); + t = _mm_and_epi32(t, first_byte); sum_abs = _mm_add_epi32(sum_abs, t); } case 1: @@ -1545,7 +1577,7 @@ static INLINE void update_states_avx2( levels, _mm_add_epi32(levels_start_offsets, _mm_set1_epi32(next_nb_info_ssb.inPos[0])), 1); - t = _mm_and_epi32(t, last_byte); + t = _mm_and_epi32(t, first_byte); sum_abs = _mm_add_epi32(sum_abs, t); } break; default: @@ -1560,7 +1592,10 @@ static INLINE void update_states_avx2( _mm_sub_epi32(sum_abs, _mm_set1_epi32(20))), _mm_set1_epi32(0)); __m128i temp = _mm_i32gather_epi32(g_goRiceParsCoeff, sum_all, 4); - _mm_storeu_epi32(&state->m_goRicePar[state_offset], temp); + __m128i control = _mm_setr_epi8(0, 4, 8, 12, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1); + __m128i go_rice_par = _mm_shuffle_epi8(temp, control); + int go_rice_par_i = _mm_extract_epi32(go_rice_par, 0); + memcpy(&state->m_goRicePar[state_offset], &go_rice_par_i, 4); } } @@ -1571,10 +1606,11 @@ static INLINE void update_states_avx2( const __m128i ones = _mm_set1_epi32(1); const uint32_t tinit_offset = MIN(level_offset - 1u, 15u) + 8; const __m128i levels_start_offsets = _mm_set_epi32(48 * 3, 48 * 2, 48 * 1, 48 * 0); + const __m128i ctx_start_offsets = _mm_srli_epi32(levels_start_offsets, 1); __m128i tinit = _mm_i32gather_epi32( state->m_absLevelsAndCtxInit[state_offset], - _mm_add_epi32(levels_start_offsets, _mm_set1_epi32(tinit_offset)), - 1); + _mm_add_epi32(ctx_start_offsets, _mm_set1_epi32(tinit_offset)), + 2); tinit = _mm_and_epi32(tinit, last_two_bytes); __m128i sum_abs = _mm_srli_epi32(tinit, 8); switch (numIPos) { @@ -1624,22 +1660,19 @@ static INLINE void update_states_avx2( if (extRiceFlag) { assert(0 && "Not implemented for avx2"); } else { - __m128i sum_all = _mm_max_epi32( - _mm_min_epi32( - _mm_set1_epi32(31), - _mm_sub_epi32(sum_abs, _mm_set1_epi32(20))), - _mm_set1_epi32(0)); + __m128i sum_all = _mm_min_epi32(_mm_set1_epi32(31), sum_abs); __m128i temp = _mm_i32gather_epi32(g_goRiceParsCoeff, sum_all, 4); __m128i control = _mm_setr_epi8(0, 4, 8, 12, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1); __m128i go_rice_par = _mm_shuffle_epi8(temp, control); int go_rice_par_i = _mm_extract_epi32(go_rice_par, 0); memcpy(&state->m_goRicePar[state_offset], &go_rice_par_i, 4); - __m128i go_rice_zero = _mm_set_epi32(2, 2, 1, 1); - go_rice_zero = _mm_sll_epi32(go_rice_zero, temp); - go_rice_zero = _mm_shuffle_epi8(go_rice_zero, control); - int go_rice_zero_i = _mm_extract_epi32(go_rice_par, 0); - memcpy(&state->m_goRiceZero[state_offset], &go_rice_zero_i, 4); + + for (int i = 0; i < 4; ++i) { + state->m_goRiceZero[state_offset + i] = (i < 2 ? 1 : 2) << state->m_goRicePar[state_offset + i]; + + } + } } @@ -1729,6 +1762,8 @@ static INLINE void update_states_avx2( } } else { for (int i = 0; i < 4; ++i) { + state->all_gte_four = true; + state->all_lt_four = true; updateState( ctxs, numIPos, @@ -1758,7 +1793,7 @@ static INLINE void updateState( int decision_id) { all_depquant_states* state = &ctxs->m_allStates; int state_id = ctxs->m_curr_state_offset + decision_id; - state->m_rdCost[state_id] = decisions->rdCost[decision_id]; + // state->m_rdCost[state_id] = decisions->rdCost[decision_id]; if (decisions->prevId[decision_id] > -2) { if (decisions->prevId[decision_id] >= 0) { const int prvState = ctxs->m_prev_state_offset + decisions->prevId[decision_id]; @@ -1784,7 +1819,8 @@ static INLINE void updateState( decisions->absLevel[decision_id] < 2 ? (unsigned)decisions->absLevel[decision_id] : 3); memset(state->m_absLevelsAndCtxInit[state_id], 0, 48 * sizeof(uint8_t)); } - + state->all_gte_four &= state->m_remRegBins[state_id] >= 4; + state->all_lt_four &= state->m_remRegBins[state_id] < 4; uint8_t* levels = (uint8_t*)(state->m_absLevelsAndCtxInit[state_id]); levels[scan_pos & 15] = (uint8_t)MIN(255, decisions->absLevel[decision_id]); @@ -1860,6 +1896,10 @@ static INLINE void updateState( state->m_goRiceZero[state_id] = ((state_id & 3) < 2 ? 1 : 2) << state->m_goRicePar[state_id]; } } + else { + state->all_gte_four &= state->m_remRegBins[state_id] >= 4; + state->all_lt_four &= state->m_remRegBins[state_id] < 4; + } } static bool same[13]; @@ -1946,18 +1986,7 @@ int uvg_dep_quant( const uint32_t lfnstIdx = tree_type != UVG_CHROMA_T || compID == COLOR_Y ? cur_tu->lfnst_idx : cur_tu->cr_lfnst_idx; - - int8_t t[4] = {2, 2, 2, 2}; - __m128i pq_abs_a = _mm_set_epi32(16, 0, 16, 0); - __m128i go_rice_zero = _mm_cvtepi8_epi32(_mm_loadu_epi8(t)); - __m128i cmp = _mm_cmplt_epi32(go_rice_zero, pq_abs_a); - - __m128i max_rice = _mm_set1_epi32(15); - __m128i go_rice_smaller = _mm_min_epi32(pq_abs_a, max_rice); - - __m128i other = _mm_sub_epi32(pq_abs_a, _mm_set1_epi32(1)); - __m128i selected = _mm_blendv_epi8(go_rice_zero, other, cmp); - + const int numCoeff = width * height; memset(coeff_out, 0x00, width * height * sizeof(coeff_t)); @@ -2055,9 +2084,11 @@ int uvg_dep_quant( dep_quant_context.m_allStates.effHeight = effectHeight; dep_quant_context.m_allStates.effWidth = effectWidth; + dep_quant_context.m_allStates.all_gte_four = true; + dep_quant_context.m_allStates.all_lt_four = false; dep_quant_context.m_allStates.m_commonCtx = &dep_quant_context.m_common_context; for (int i = 0; i < (compID == COLOR_Y ? 21 : 11); ++i) { - dep_quant_context.m_allStates.m_gtxFracBitsArray[i] = rate_estimator.m_gtxFracBits[i]; + memcpy(dep_quant_context.m_allStates.m_gtxFracBitsArray[i], rate_estimator.m_gtxFracBits[i], sizeof(int32_t) * 6); } depquant_state_init(&dep_quant_context.m_startState, rate_estimator.m_sigFracBits[0][0], rate_estimator.m_gtxFracBits[0]);