From 58a66c06545c4c1a9812a8ab77165b0a61570a4d Mon Sep 17 00:00:00 2001 From: Joose Sainio Date: Mon, 10 Apr 2023 15:31:05 +0300 Subject: [PATCH] [avx2] WIP update_states_avx2 --- src/dep_quant.c | 609 +++++++++++++++++++++++++++++++++++++++++------- 1 file changed, 531 insertions(+), 78 deletions(-) diff --git a/src/dep_quant.c b/src/dep_quant.c index ffd352b7..78a039bb 100644 --- a/src/dep_quant.c +++ b/src/dep_quant.c @@ -157,7 +157,7 @@ typedef struct int8_t m_goRicePar[12]; int8_t m_goRiceZero[12]; int8_t m_stateId[12]; - uint32_t *m_sigFracBitsArray[12][12]; + uint32_t m_sigFracBitsArray[12][12][2]; int32_t *m_gtxFracBitsArray[21]; common_context* m_commonCtx; @@ -1240,6 +1240,510 @@ static INLINE void updateStateEOS( state->m_gtxFracBitsArray[gtxCtxOffsetNext + (sumGt1 < 4 ? sumGt1 : 4)], sizeof(state->m_coeffFracBits[0])); } } +static INLINE void updateState( + context_store* ctxs, + int numIPos, + const uint32_t scan_pos, + const Decision* decisions, + const uint32_t sigCtxOffsetNext, + const uint32_t gtxCtxOffsetNext, + const NbInfoSbb next_nb_info_ssb, + const int baseLevel, + const bool extRiceFlag, + int decision_id); + +static INLINE void update_states_avx2( + context_store* ctxs, + int numIPos, + const uint32_t scan_pos, + const Decision* decisions, + const uint32_t sigCtxOffsetNext, + const uint32_t gtxCtxOffsetNext, + const NbInfoSbb next_nb_info_ssb, + const int baseLevel, + const bool extRiceFlag) +{ + all_depquant_states* state = &ctxs->m_allStates; + + bool all_non_negative = true; + bool all_above_minus_two = true; + bool all_minus_one = true; + for (int i = 0; i < 4; ++i) { + all_non_negative &= decisions->prevId[i] >= 0; + all_above_minus_two &= decisions->prevId[i] > -2; + all_minus_one &= decisions->prevId[i] == -1; + } + int state_offset = ctxs->m_curr_state_offset; + if (all_above_minus_two) { + + bool rem_reg_all_gte_4 = true; + bool rem_reg_all_lt4 = true; + + __m128i abs_level = _mm_loadu_epi16(decisions->absLevel); + abs_level = _mm_cvtepi16_epi32(abs_level); + if (all_non_negative) { + __m128i prv_states = _mm_loadu_epi32(decisions->prevId); + __m128i prev_offset = _mm_set1_epi32(ctxs->m_prev_state_offset); + prv_states = _mm_add_epi32(prv_states, prev_offset); + + + //__m128i num_sig_sbb = _mm_i32gather_epi32(state->m_numSigSbb, prv_states, 1); + //__m128 mask = _mm_set_epi32(0xff, 0xff, 0xff, 0xff); + //num_sig_sbb + + + int32_t prv_states_scalar[4]; + _mm_storeu_epi32(prv_states_scalar, prv_states); + int8_t sig_sbb[4] = {state->m_numSigSbb[prv_states_scalar[0]], state->m_numSigSbb[prv_states_scalar[1]], state->m_numSigSbb[prv_states_scalar[2]], state->m_numSigSbb[prv_states_scalar[3]]}; + for (int i = 0; i < 4; ++i) { + sig_sbb[i] = sig_sbb[i] || decisions->absLevel[i]; + } + memcpy(&state->m_numSigSbb[state_offset], sig_sbb, 4); + + __m128i ref_sbb_ctx_idx = _mm_i32gather_epi32(state->m_refSbbCtxId, prv_states, 1); + __m128i control = _mm_setr_epi8(0, 4, 8, 12, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1); + ref_sbb_ctx_idx = _mm_shuffle_epi8(ref_sbb_ctx_idx, control); + int ref_sbb_ctx = _mm_extract_epi32(ref_sbb_ctx_idx, 0); + memcpy(&state->m_refSbbCtxId[state_offset], &ref_sbb_ctx, 4); + + __m128i go_rice_par = _mm_i32gather_epi32(state->m_goRicePar, prv_states, 1); + go_rice_par = _mm_shuffle_epi8(go_rice_par, control); + int go_rice_par_i = _mm_extract_epi32(go_rice_par, 0); + memcpy(&state->m_goRicePar[state_offset], &go_rice_par_i, 4); + + + __m256i sbb_frac_bits = _mm256_i32gather_epi64(state->m_sbbFracBits, prv_states, 4); + _mm256_storeu_epi64(&state->m_sbbFracBits[state_offset][0], sbb_frac_bits); + + __m128i rem_reg_bins = _mm_i32gather_epi32(state->m_remRegBins, prv_states, 4); + __m128i ones = _mm_set1_epi32(1); + rem_reg_bins = _mm_sub_epi32(rem_reg_bins, ones); + + __m128i reg_bins_sub = _mm_set1_epi32(0); + __m128i abs_level_smaller_than_two = _mm_cmplt_epi32(abs_level, _mm_set1_epi32(2)); + __m128i secondary = _mm_blendv_epi8(abs_level, _mm_set1_epi32(3), abs_level_smaller_than_two); + + __m128i rem_reg_bins_smaller_than_four = _mm_cmplt_epi32(rem_reg_bins, _mm_set1_epi32(4)); + reg_bins_sub = _mm_blendv_epi8(secondary, reg_bins_sub, rem_reg_bins_smaller_than_four); + rem_reg_bins = _mm_sub_epi32(rem_reg_bins, reg_bins_sub); + _mm_storeu_epi32(&state->m_remRegBins[state_offset], rem_reg_bins); + + __m128i mask = _mm_cmpgt_epi32(rem_reg_bins, _mm_set1_epi32(3)); + int bit_mask = _mm_movemask_epi8(mask); + rem_reg_all_gte_4 = (bit_mask == 0xFFFF); + mask = _mm_cmplt_epi32(rem_reg_bins, _mm_set1_epi32(4)); + bit_mask = _mm_movemask_epi8(mask); + rem_reg_all_lt4 = (bit_mask == 0xFFFF); + + for (int i = 0; i < 4; ++i) { + memcpy(state->m_absLevelsAndCtxInit[i], state->m_absLevelsAndCtxInit[prv_states_scalar[i]], 48 * sizeof(uint8_t)); + } + } + else if (all_minus_one) { + memset(&state->m_numSigSbb[state_offset], 1, 4); + memset(&state->m_refSbbCtxId[state_offset], -1, 4); + + const int a = (state->effWidth * state->effHeight * 28) / 16; + + __m128i rem_reg_bins = _mm_set1_epi32(a); + __m128i sub = _mm_blendv_epi8( + abs_level, + _mm_set1_epi32(3), + _mm_cmplt_epi32(abs_level, _mm_set1_epi32(2)) + ); + rem_reg_bins = _mm_sub_epi32(rem_reg_bins, sub); + _mm_storeu_epi32(&state->m_remRegBins[state_offset], rem_reg_bins); + + __m128i mask = _mm_cmpgt_epi32(rem_reg_bins, _mm_set1_epi32(3)); + int bit_mask = _mm_movemask_epi8(mask); + rem_reg_all_gte_4 = (bit_mask == 0xFFFF); + mask = _mm_cmplt_epi32(rem_reg_bins, _mm_set1_epi32(4)); + bit_mask = _mm_movemask_epi8(mask); + rem_reg_all_lt4 = (bit_mask == 0xFFFF); + + memset(state->m_absLevelsAndCtxInit[state_offset], 0, 48 * sizeof(uint8_t) * 4); + + } + else { + for (int i = 0; i< 4; ++i) { + const int decision_id = i; + const int state_id = state_offset + i; + if (decisions->prevId[decision_id] >= 0) { + const int prvState = ctxs->m_prev_state_offset + decisions->prevId[decision_id]; + state->m_numSigSbb[state_id] = (state->m_numSigSbb[prvState]) || !!decisions->absLevel[decision_id]; + state->m_refSbbCtxId[state_id] = state->m_refSbbCtxId[prvState]; + state->m_sbbFracBits[state_id][0] = state->m_sbbFracBits[prvState][0]; + state->m_sbbFracBits[state_id][1] = state->m_sbbFracBits[prvState][1]; + state->m_remRegBins[state_id] = state->m_remRegBins[prvState] - 1; + state->m_goRicePar[state_id] = state->m_goRicePar[prvState]; + if (state->m_remRegBins[state_id] >= 4) { + state->m_remRegBins[state_id] -= (decisions->absLevel[decision_id] < 2 ? (unsigned)decisions->absLevel[decision_id] : 3); + } + memcpy(state->m_absLevelsAndCtxInit[state_id], state->m_absLevelsAndCtxInit[prvState], 48 * sizeof(uint8_t)); + } else { + state->m_numSigSbb[state_id] = 1; + state->m_refSbbCtxId[state_id] = -1; + int ctxBinSampleRatio = 28; + //(scanInfo.chType == CHANNEL_TYPE_LUMA) ? MAX_TU_LEVEL_CTX_CODED_BIN_CONSTRAINT_LUMA : MAX_TU_LEVEL_CTX_CODED_BIN_CONSTRAINT_CHROMA; + state->m_remRegBins[state_id] = (state->effWidth * state->effHeight * ctxBinSampleRatio) / 16 - (decisions->absLevel[decision_id] < 2 ? (unsigned)decisions->absLevel[decision_id] : 3); + memset(state->m_absLevelsAndCtxInit[state_id], 0, 48 * sizeof(uint8_t)); + } + rem_reg_all_gte_4 &= state->m_remRegBins[state_id] >= 4; + rem_reg_all_lt4 &= state->m_remRegBins[state_id] < 4; + } + } + uint32_t level_offset = scan_pos & 15; + __m128i max_abs = _mm_min_epi32(abs_level, _mm_set1_epi32(255)); + uint32_t max_abs_s[4]; + _mm_storeu_epi32(max_abs_s, max_abs); + for (int i = 0; i < 4; ++i) { + uint8_t* levels = (uint8_t*)state->m_absLevelsAndCtxInit[state_offset + i]; + levels[level_offset] = max_abs_s[i]; + } + + if (rem_reg_all_gte_4) { + const __m128i last_two_bytes = _mm_set1_epi32(0xffff); + const __m128i last_byte = _mm_set1_epi32(0xff); + const __m128i ones = _mm_set1_epi32(1); + const uint32_t tinit_offset = MIN(level_offset - 1u, 15u) + 8; + const __m128i levels_start_offsets = _mm_set_epi32(48 * 3, 48 * 2, 48 * 1, 48 * 0); + __m128i tinit = _mm_i32gather_epi32( + state->m_absLevelsAndCtxInit[state_offset], + _mm_add_epi32(levels_start_offsets, _mm_set1_epi32(tinit_offset)), + 1); + tinit = _mm_and_epi32(tinit, last_two_bytes); + __m128i sum_abs1 = _mm_and_epi32(_mm_srli_epi32(tinit, 3), _mm_set1_epi32(31)); + __m128i sum_num = _mm_and_epi32(tinit, _mm_set1_epi32(7)); + + uint8_t* levels = state->m_absLevelsAndCtxInit[state_offset]; + switch (numIPos) { + case 5: + { + __m128i t = _mm_i32gather_epi32( + levels, + _mm_add_epi32(levels_start_offsets, _mm_set1_epi32(next_nb_info_ssb.inPos[4])), + 1); + sum_abs1 = _mm_add_epi32( + sum_abs1, + _mm_and_epi32(t, ones)); + sum_num = _mm_add_epi32( + sum_num, + _mm_min_epi32(_mm_and_epi32(t, last_byte), ones)); + } + case 4: + { + __m128i t = _mm_i32gather_epi32( + levels, + _mm_add_epi32(levels_start_offsets, _mm_set1_epi32(next_nb_info_ssb.inPos[3])), + 1); + sum_abs1 = _mm_add_epi32( + sum_abs1, + _mm_and_epi32(t, ones)); + sum_num = _mm_add_epi32( + sum_num, + _mm_min_epi32(_mm_and_epi32(t, last_byte), ones)); + } + case 3: + { + __m128i t = _mm_i32gather_epi32( + levels, + _mm_add_epi32(levels_start_offsets, _mm_set1_epi32(next_nb_info_ssb.inPos[2])), + 1); + sum_abs1 = _mm_add_epi32( + sum_abs1, + _mm_and_epi32(t, ones)); + sum_num = _mm_add_epi32( + sum_num, + _mm_min_epi32(_mm_and_epi32(t, last_byte), ones)); + } + case 2: + { + __m128i t = _mm_i32gather_epi32( + levels, + _mm_add_epi32(levels_start_offsets, _mm_set1_epi32(next_nb_info_ssb.inPos[1])), + 1); + sum_abs1 = _mm_add_epi32( + sum_abs1, + _mm_and_epi32(t, ones)); + sum_num = _mm_add_epi32( + sum_num, + _mm_min_epi32(_mm_and_epi32(t, last_byte), ones)); + } + case 1: { + __m128i t = _mm_i32gather_epi32( + levels, + _mm_add_epi32(levels_start_offsets, _mm_set1_epi32(next_nb_info_ssb.inPos[0])), + 1); + sum_abs1 = _mm_add_epi32( + sum_abs1, + _mm_and_epi32(t, ones)); + sum_num = _mm_add_epi32( + sum_num, + _mm_min_epi32(_mm_and_epi32(t, last_byte), ones)); + } break; + default: + assert(0); + } + __m128i sum_gt1 = _mm_sub_epi32(sum_abs1, sum_num); + __m128i offsets = _mm_set_epi32(24 * 3, 24 * 2, 24 * 1, 24 * 0); + offsets = _mm_add_epi32(offsets, _mm_set1_epi32(sigCtxOffsetNext)); + __m128i temp = _mm_min_epi32( + _mm_srli_epi32(_mm_add_epi32(sum_abs1, ones), 1), + _mm_set1_epi32(3)); + offsets = _mm_add_epi32(offsets, temp); + __m256i sig_frac_bits = _mm256_i32gather_epi64(state->m_sigFracBitsArray[state_offset][0], offsets, 4); + _mm256_storeu_epi64(&state->m_sigFracBits[state_offset][0], sig_frac_bits); + + sum_gt1 = _mm_min_epi32(sum_gt1, _mm_set1_epi32(4)); + uint32_t sum_gt1_s[4]; + _mm_storeu_epi32(sum_gt1_s, sum_gt1); + for (int i = 0; i < 4; ++i) { + memcpy(state->m_coeffFracBits[state_offset + i], state->m_gtxFracBitsArray[sum_gt1_s[i]], sizeof(state->m_coeffFracBits[0])); + } + + __m128i sum_abs = _mm_srli_epi32(tinit, 8); + switch (numIPos) { + case 5: + { + __m128i t = _mm_i32gather_epi32( + levels, + _mm_add_epi32(levels_start_offsets, _mm_set1_epi32(next_nb_info_ssb.inPos[4])), + 1); + t = _mm_and_epi32(t, last_byte); + sum_abs = _mm_add_epi32(sum_abs, t); + } + case 4: + { + __m128i t = _mm_i32gather_epi32( + levels, + _mm_add_epi32(levels_start_offsets, _mm_set1_epi32(next_nb_info_ssb.inPos[3])), + 1); + t = _mm_and_epi32(t, last_byte); + sum_abs = _mm_add_epi32(sum_abs, t); + } + case 3: + { + __m128i t = _mm_i32gather_epi32( + levels, + _mm_add_epi32(levels_start_offsets, _mm_set1_epi32(next_nb_info_ssb.inPos[2])), + 1); + t = _mm_and_epi32(t, last_byte); + sum_abs = _mm_add_epi32(sum_abs, t); + } + case 2: + { + __m128i t = _mm_i32gather_epi32( + levels, + _mm_add_epi32(levels_start_offsets, _mm_set1_epi32(next_nb_info_ssb.inPos[1])), + 1); + t = _mm_and_epi32(t, last_byte); + sum_abs = _mm_add_epi32(sum_abs, t); + } + case 1: + { + __m128i t = _mm_i32gather_epi32( + levels, + _mm_add_epi32(levels_start_offsets, _mm_set1_epi32(next_nb_info_ssb.inPos[0])), + 1); + t = _mm_and_epi32(t, last_byte); + sum_abs = _mm_add_epi32(sum_abs, t); + } break; + default: + assert(0); + } + if (extRiceFlag) { + assert(0 && "Not implemented for avx2"); + } else { + __m128i sum_all = _mm_max_epi32( + _mm_min_epi32( + _mm_set1_epi32(31), + _mm_sub_epi32(sum_abs, _mm_set1_epi32(20))), + _mm_set1_epi32(0)); + __m128i temp = _mm_i32gather_epi32(g_goRiceParsCoeff, sum_all, 4); + _mm_storeu_epi32(&state->m_goRicePar[state_offset], temp); + } + } + + else if (rem_reg_all_lt4) { + uint8_t* levels = state->m_absLevelsAndCtxInit[state_offset]; + const __m128i last_two_bytes = _mm_set1_epi32(0xffff); + const __m128i last_byte = _mm_set1_epi32(0xff); + const __m128i ones = _mm_set1_epi32(1); + const uint32_t tinit_offset = MIN(level_offset - 1u, 15u) + 8; + const __m128i levels_start_offsets = _mm_set_epi32(48 * 3, 48 * 2, 48 * 1, 48 * 0); + __m128i tinit = _mm_i32gather_epi32( + state->m_absLevelsAndCtxInit[state_offset], + _mm_add_epi32(levels_start_offsets, _mm_set1_epi32(tinit_offset)), + 1); + tinit = _mm_and_epi32(tinit, last_two_bytes); + __m128i sum_abs = _mm_srli_epi32(tinit, 8); + switch (numIPos) { + case 5: { + __m128i t = _mm_i32gather_epi32( + levels, + _mm_add_epi32(levels_start_offsets, _mm_set1_epi32(next_nb_info_ssb.inPos[4])), + 1); + t = _mm_and_epi32(t, last_byte); + sum_abs = _mm_add_epi32(sum_abs, t); + } + case 4: { + __m128i t = _mm_i32gather_epi32( + levels, + _mm_add_epi32(levels_start_offsets, _mm_set1_epi32(next_nb_info_ssb.inPos[3])), + 1); + t = _mm_and_epi32(t, last_byte); + sum_abs = _mm_add_epi32(sum_abs, t); + } + case 3: { + __m128i t = _mm_i32gather_epi32( + levels, + _mm_add_epi32(levels_start_offsets, _mm_set1_epi32(next_nb_info_ssb.inPos[2])), + 1); + t = _mm_and_epi32(t, last_byte); + sum_abs = _mm_add_epi32(sum_abs, t); + } + case 2: { + __m128i t = _mm_i32gather_epi32( + levels, + _mm_add_epi32(levels_start_offsets, _mm_set1_epi32(next_nb_info_ssb.inPos[1])), + 1); + t = _mm_and_epi32(t, last_byte); + sum_abs = _mm_add_epi32(sum_abs, t); + } + case 1: { + __m128i t = _mm_i32gather_epi32( + levels, + _mm_add_epi32(levels_start_offsets, _mm_set1_epi32(next_nb_info_ssb.inPos[0])), + 1); + t = _mm_and_epi32(t, last_byte); + sum_abs = _mm_add_epi32(sum_abs, t); + } break; + default: + assert(0); + } + if (extRiceFlag) { + assert(0 && "Not implemented for avx2"); + } else { + __m128i sum_all = _mm_max_epi32( + _mm_min_epi32( + _mm_set1_epi32(31), + _mm_sub_epi32(sum_abs, _mm_set1_epi32(20))), + _mm_set1_epi32(0)); + __m128i temp = _mm_i32gather_epi32(g_goRiceParsCoeff, sum_all, 4); + __m128i control = _mm_setr_epi8(0, 4, 8, 12, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1); + __m128i go_rice_par = _mm_shuffle_epi8(temp, control); + int go_rice_par_i = _mm_extract_epi32(go_rice_par, 0); + memcpy(&state->m_goRicePar[state_offset], &go_rice_par_i, 4); + + __m128i go_rice_zero = _mm_set_epi32(2, 2, 1, 1); + go_rice_zero = _mm_sll_epi32(go_rice_zero, temp); + go_rice_zero = _mm_shuffle_epi8(go_rice_zero, control); + int go_rice_zero_i = _mm_extract_epi32(go_rice_par, 0); + memcpy(&state->m_goRiceZero[state_offset], &go_rice_zero_i, 4); + } + + } + else { + for (int i = 0; i < 4; ++i) { + const int state_id = state_offset + i; + uint8_t* levels = (uint8_t*)(state->m_absLevelsAndCtxInit[state_id]); + if (state->m_remRegBins[state_id] >= 4) { + coeff_t tinit = state->m_absLevelsAndCtxInit[state_id][8 + ((scan_pos - 1) & 15)]; + coeff_t sumAbs1 = (tinit >> 3) & 31; + coeff_t sumNum = tinit & 7; +#define UPDATE(k) \ + { \ + coeff_t t = levels[next_nb_info_ssb.inPos[k]]; \ + sumAbs1 += MIN(4 + (t & 1), t); \ + sumNum += !!t; \ + } + switch (numIPos) { + case 5: UPDATE(4); + case 4: UPDATE(3); + case 3: UPDATE(2); + case 2: UPDATE(1); + case 1: UPDATE(0); break; + default: assert(0); + } +#undef UPDATE + coeff_t sumGt1 = sumAbs1 - sumNum; + state->m_sigFracBits[state_id][0] = state->m_sigFracBitsArray[state_id][sigCtxOffsetNext + MIN((sumAbs1 + 1) >> 1, 3)][0]; + state->m_sigFracBits[state_id][1] = state->m_sigFracBitsArray[state_id][sigCtxOffsetNext + MIN((sumAbs1 + 1) >> 1, 3)][1]; + memcpy(state->m_coeffFracBits[state_id], state->m_gtxFracBitsArray[gtxCtxOffsetNext + (sumGt1 < 4 ? sumGt1 : 4)], sizeof(state->m_coeffFracBits[0])); + + + coeff_t sumAbs = state->m_absLevelsAndCtxInit[state_id][8 + ((scan_pos - 1) & 15)] >> 8; +#define UPDATE(k) \ + { \ + coeff_t t = levels[next_nb_info_ssb.inPos[k]]; \ + sumAbs += t; \ + } + switch (numIPos) { + case 5: UPDATE(4); + case 4: UPDATE(3); + case 3: UPDATE(2); + case 2: UPDATE(1); + case 1: UPDATE(0); break; + default: assert(0); + } +#undef UPDATE + if (extRiceFlag) { + unsigned currentShift = templateAbsCompare(sumAbs); + sumAbs = sumAbs >> currentShift; + int sumAll = MAX(MIN(31, (int)sumAbs - (int)baseLevel), 0); + state->m_goRicePar[state_id] = g_goRiceParsCoeff[sumAll]; + state->m_goRicePar[state_id] += currentShift; + } else { + int sumAll = MAX(MIN(31, (int)sumAbs - 4 * 5), 0); + state->m_goRicePar[state_id] = g_goRiceParsCoeff[sumAll]; + } + } else { + coeff_t sumAbs = (state->m_absLevelsAndCtxInit[state_id][8 + ((scan_pos - 1) & 15)]) >> 8; +#define UPDATE(k) \ + { \ + coeff_t t = levels[next_nb_info_ssb.inPos[k]]; \ + sumAbs += t; \ + } + switch (numIPos) { + case 5: UPDATE(4); + case 4: UPDATE(3); + case 3: UPDATE(2); + case 2: UPDATE(1); + case 1: UPDATE(0); break; + default: assert(0); + } +#undef UPDATE + if (extRiceFlag) { + unsigned currentShift = templateAbsCompare(sumAbs); + sumAbs = sumAbs >> currentShift; + sumAbs = MIN(31, sumAbs); + state->m_goRicePar[state_id] = g_goRiceParsCoeff[sumAbs]; + state->m_goRicePar[state_id] += currentShift; + } else { + sumAbs = MIN(31, sumAbs); + state->m_goRicePar[state_id] = g_goRiceParsCoeff[sumAbs]; + } + state->m_goRiceZero[state_id] = ((state_id & 3) < 2 ? 1 : 2) << state->m_goRicePar[state_id]; + } + } + } + } else { + for (int i = 0; i < 4; ++i) { + updateState( + ctxs, + numIPos, + scan_pos, + decisions, + sigCtxOffsetNext, + gtxCtxOffsetNext, + next_nb_info_ssb, + baseLevel, + extRiceFlag, + i); + } + } +} + static INLINE void updateState( context_store * ctxs, @@ -1258,7 +1762,7 @@ static INLINE void updateState( if (decisions->prevId[decision_id] > -2) { if (decisions->prevId[decision_id] >= 0) { const int prvState = ctxs->m_prev_state_offset + decisions->prevId[decision_id]; - state->m_numSigSbb[state_id] = (state->m_numSigSbb[prvState]) + !!decisions->absLevel[decision_id]; + state->m_numSigSbb[state_id] = (state->m_numSigSbb[prvState]) || !!decisions->absLevel[decision_id]; state->m_refSbbCtxId[state_id] = state->m_refSbbCtxId[prvState]; state->m_sbbFracBits[state_id][0] = state->m_sbbFracBits[prvState][0]; state->m_sbbFracBits[state_id][1] = state->m_sbbFracBits[prvState][1]; @@ -1289,30 +1793,13 @@ static INLINE void updateState( coeff_t sumAbs1 = (tinit >> 3) & 31; coeff_t sumNum = tinit & 7; #define UPDATE(k) {coeff_t t=levels[next_nb_info_ssb.inPos[k]]; sumAbs1+=MIN(4+(t&1),t); sumNum+=!!t; } - if (numIPos == 1) { - UPDATE(0); - } - else if (numIPos == 2) { - UPDATE(0); - UPDATE(1); - } - else if (numIPos == 3) { - UPDATE(0); - UPDATE(1); - UPDATE(2); - } - else if (numIPos == 4) { - UPDATE(0); - UPDATE(1); - UPDATE(2); - UPDATE(3); - } - else if (numIPos == 5) { - UPDATE(0); - UPDATE(1); - UPDATE(2); - UPDATE(3); - UPDATE(4); + switch (numIPos) { + case 5: UPDATE(4); + case 4: UPDATE(3); + case 3: UPDATE(2); + case 2: UPDATE(1); + case 1: UPDATE(0); break; + default: assert(0); } #undef UPDATE coeff_t sumGt1 = sumAbs1 - sumNum; @@ -1326,30 +1813,13 @@ static INLINE void updateState( coeff_t sumAbs = state->m_absLevelsAndCtxInit[state_id][8 + ((scan_pos - 1) & 15)] >> 8; #define UPDATE(k) {coeff_t t=levels[next_nb_info_ssb.inPos[k]]; sumAbs+=t; } - if (numIPos == 1) { - UPDATE(0); - } - else if (numIPos == 2) { - UPDATE(0); - UPDATE(1); - } - else if (numIPos == 3) { - UPDATE(0); - UPDATE(1); - UPDATE(2); - } - else if (numIPos == 4) { - UPDATE(0); - UPDATE(1); - UPDATE(2); - UPDATE(3); - } - else if (numIPos == 5) { - UPDATE(0); - UPDATE(1); - UPDATE(2); - UPDATE(3); - UPDATE(4); + switch (numIPos) { + case 5: UPDATE(4); + case 4: UPDATE(3); + case 3: UPDATE(2); + case 2: UPDATE(1); + case 1: UPDATE(0); break; + default: assert(0); } #undef UPDATE if (extRiceFlag) { @@ -1367,30 +1837,13 @@ static INLINE void updateState( else { coeff_t sumAbs = (state->m_absLevelsAndCtxInit[state_id][8 + ((scan_pos - 1) & 15)]) >> 8; #define UPDATE(k) {coeff_t t=levels[next_nb_info_ssb.inPos[k]]; sumAbs+=t; } - if (numIPos == 1) { - UPDATE(0); - } - else if (numIPos == 2) { - UPDATE(0); - UPDATE(1); - } - else if (numIPos == 3) { - UPDATE(0); - UPDATE(1); - UPDATE(2); - } - else if (numIPos == 4) { - UPDATE(0); - UPDATE(1); - UPDATE(2); - UPDATE(3); - } - else if (numIPos == 5) { - UPDATE(0); - UPDATE(1); - UPDATE(2); - UPDATE(3); - UPDATE(4); + switch (numIPos) { + case 5: UPDATE(4); + case 4: UPDATE(3); + case 3: UPDATE(2); + case 2: UPDATE(1); + case 1: UPDATE(0); break; + default: assert(0); } #undef UPDATE if (extRiceFlag) { @@ -1456,11 +1909,11 @@ static void xDecideAndUpdate( memcpy(decisions->absLevel + 4, decisions->absLevel, 4 * sizeof(coeff_t)); memcpy(decisions->rdCost + 4, decisions->rdCost, 4 * sizeof(int64_t)); } else if (!zeroOut) { - - updateState(ctxs, next_nb_info_ssb.num, scan_pos, decisions, sigCtxOffsetNext, gtxCtxOffsetNext, next_nb_info_ssb, 4, false, 0); + update_states_avx2(ctxs, next_nb_info_ssb.num, scan_pos, decisions, sigCtxOffsetNext, gtxCtxOffsetNext, next_nb_info_ssb, 4, false); + /* updateState(ctxs, next_nb_info_ssb.num, scan_pos, decisions, sigCtxOffsetNext, gtxCtxOffsetNext, next_nb_info_ssb, 4, false, 0); updateState(ctxs, next_nb_info_ssb.num, scan_pos, decisions, sigCtxOffsetNext, gtxCtxOffsetNext, next_nb_info_ssb, 4, false, 1); updateState(ctxs, next_nb_info_ssb.num, scan_pos, decisions, sigCtxOffsetNext, gtxCtxOffsetNext, next_nb_info_ssb, 4, false, 2); - updateState(ctxs, next_nb_info_ssb.num, scan_pos, decisions, sigCtxOffsetNext, gtxCtxOffsetNext, next_nb_info_ssb, 4, false, 3); + updateState(ctxs, next_nb_info_ssb.num, scan_pos, decisions, sigCtxOffsetNext, gtxCtxOffsetNext, next_nb_info_ssb, 4, false, 3);*/ } if (spt == SCAN_SOCSBB) { @@ -1596,7 +2049,7 @@ int uvg_dep_quant( dep_quant_context.m_allStates.m_stateId[k] = k & 3; for (int i = 0; i < (compID == COLOR_Y ? 12 : 8); ++i) { - dep_quant_context.m_allStates.m_sigFracBitsArray[k][i] = rate_estimator.m_sigFracBits[(k & 3 ? (k & 3) - 1 : 0)][i]; + memcpy(dep_quant_context.m_allStates.m_sigFracBitsArray[k][i], rate_estimator.m_sigFracBits[(k & 3 ? (k & 3) - 1 : 0)][i], sizeof(uint32_t) * 2); } }