[avx2] WIP update_states_avx2

This commit is contained in:
Joose Sainio 2023-04-10 15:31:05 +03:00
parent 04be92a8ec
commit 58a66c0654

View file

@ -157,7 +157,7 @@ typedef struct
int8_t m_goRicePar[12]; int8_t m_goRicePar[12];
int8_t m_goRiceZero[12]; int8_t m_goRiceZero[12];
int8_t m_stateId[12]; int8_t m_stateId[12];
uint32_t *m_sigFracBitsArray[12][12]; uint32_t m_sigFracBitsArray[12][12][2];
int32_t *m_gtxFracBitsArray[21]; int32_t *m_gtxFracBitsArray[21];
common_context* m_commonCtx; common_context* m_commonCtx;
@ -1240,6 +1240,510 @@ static INLINE void updateStateEOS(
state->m_gtxFracBitsArray[gtxCtxOffsetNext + (sumGt1 < 4 ? sumGt1 : 4)], sizeof(state->m_coeffFracBits[0])); state->m_gtxFracBitsArray[gtxCtxOffsetNext + (sumGt1 < 4 ? sumGt1 : 4)], sizeof(state->m_coeffFracBits[0]));
} }
} }
static INLINE void updateState(
context_store* ctxs,
int numIPos,
const uint32_t scan_pos,
const Decision* decisions,
const uint32_t sigCtxOffsetNext,
const uint32_t gtxCtxOffsetNext,
const NbInfoSbb next_nb_info_ssb,
const int baseLevel,
const bool extRiceFlag,
int decision_id);
static INLINE void update_states_avx2(
context_store* ctxs,
int numIPos,
const uint32_t scan_pos,
const Decision* decisions,
const uint32_t sigCtxOffsetNext,
const uint32_t gtxCtxOffsetNext,
const NbInfoSbb next_nb_info_ssb,
const int baseLevel,
const bool extRiceFlag)
{
all_depquant_states* state = &ctxs->m_allStates;
bool all_non_negative = true;
bool all_above_minus_two = true;
bool all_minus_one = true;
for (int i = 0; i < 4; ++i) {
all_non_negative &= decisions->prevId[i] >= 0;
all_above_minus_two &= decisions->prevId[i] > -2;
all_minus_one &= decisions->prevId[i] == -1;
}
int state_offset = ctxs->m_curr_state_offset;
if (all_above_minus_two) {
bool rem_reg_all_gte_4 = true;
bool rem_reg_all_lt4 = true;
__m128i abs_level = _mm_loadu_epi16(decisions->absLevel);
abs_level = _mm_cvtepi16_epi32(abs_level);
if (all_non_negative) {
__m128i prv_states = _mm_loadu_epi32(decisions->prevId);
__m128i prev_offset = _mm_set1_epi32(ctxs->m_prev_state_offset);
prv_states = _mm_add_epi32(prv_states, prev_offset);
//__m128i num_sig_sbb = _mm_i32gather_epi32(state->m_numSigSbb, prv_states, 1);
//__m128 mask = _mm_set_epi32(0xff, 0xff, 0xff, 0xff);
//num_sig_sbb
int32_t prv_states_scalar[4];
_mm_storeu_epi32(prv_states_scalar, prv_states);
int8_t sig_sbb[4] = {state->m_numSigSbb[prv_states_scalar[0]], state->m_numSigSbb[prv_states_scalar[1]], state->m_numSigSbb[prv_states_scalar[2]], state->m_numSigSbb[prv_states_scalar[3]]};
for (int i = 0; i < 4; ++i) {
sig_sbb[i] = sig_sbb[i] || decisions->absLevel[i];
}
memcpy(&state->m_numSigSbb[state_offset], sig_sbb, 4);
__m128i ref_sbb_ctx_idx = _mm_i32gather_epi32(state->m_refSbbCtxId, prv_states, 1);
__m128i control = _mm_setr_epi8(0, 4, 8, 12, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1);
ref_sbb_ctx_idx = _mm_shuffle_epi8(ref_sbb_ctx_idx, control);
int ref_sbb_ctx = _mm_extract_epi32(ref_sbb_ctx_idx, 0);
memcpy(&state->m_refSbbCtxId[state_offset], &ref_sbb_ctx, 4);
__m128i go_rice_par = _mm_i32gather_epi32(state->m_goRicePar, prv_states, 1);
go_rice_par = _mm_shuffle_epi8(go_rice_par, control);
int go_rice_par_i = _mm_extract_epi32(go_rice_par, 0);
memcpy(&state->m_goRicePar[state_offset], &go_rice_par_i, 4);
__m256i sbb_frac_bits = _mm256_i32gather_epi64(state->m_sbbFracBits, prv_states, 4);
_mm256_storeu_epi64(&state->m_sbbFracBits[state_offset][0], sbb_frac_bits);
__m128i rem_reg_bins = _mm_i32gather_epi32(state->m_remRegBins, prv_states, 4);
__m128i ones = _mm_set1_epi32(1);
rem_reg_bins = _mm_sub_epi32(rem_reg_bins, ones);
__m128i reg_bins_sub = _mm_set1_epi32(0);
__m128i abs_level_smaller_than_two = _mm_cmplt_epi32(abs_level, _mm_set1_epi32(2));
__m128i secondary = _mm_blendv_epi8(abs_level, _mm_set1_epi32(3), abs_level_smaller_than_two);
__m128i rem_reg_bins_smaller_than_four = _mm_cmplt_epi32(rem_reg_bins, _mm_set1_epi32(4));
reg_bins_sub = _mm_blendv_epi8(secondary, reg_bins_sub, rem_reg_bins_smaller_than_four);
rem_reg_bins = _mm_sub_epi32(rem_reg_bins, reg_bins_sub);
_mm_storeu_epi32(&state->m_remRegBins[state_offset], rem_reg_bins);
__m128i mask = _mm_cmpgt_epi32(rem_reg_bins, _mm_set1_epi32(3));
int bit_mask = _mm_movemask_epi8(mask);
rem_reg_all_gte_4 = (bit_mask == 0xFFFF);
mask = _mm_cmplt_epi32(rem_reg_bins, _mm_set1_epi32(4));
bit_mask = _mm_movemask_epi8(mask);
rem_reg_all_lt4 = (bit_mask == 0xFFFF);
for (int i = 0; i < 4; ++i) {
memcpy(state->m_absLevelsAndCtxInit[i], state->m_absLevelsAndCtxInit[prv_states_scalar[i]], 48 * sizeof(uint8_t));
}
}
else if (all_minus_one) {
memset(&state->m_numSigSbb[state_offset], 1, 4);
memset(&state->m_refSbbCtxId[state_offset], -1, 4);
const int a = (state->effWidth * state->effHeight * 28) / 16;
__m128i rem_reg_bins = _mm_set1_epi32(a);
__m128i sub = _mm_blendv_epi8(
abs_level,
_mm_set1_epi32(3),
_mm_cmplt_epi32(abs_level, _mm_set1_epi32(2))
);
rem_reg_bins = _mm_sub_epi32(rem_reg_bins, sub);
_mm_storeu_epi32(&state->m_remRegBins[state_offset], rem_reg_bins);
__m128i mask = _mm_cmpgt_epi32(rem_reg_bins, _mm_set1_epi32(3));
int bit_mask = _mm_movemask_epi8(mask);
rem_reg_all_gte_4 = (bit_mask == 0xFFFF);
mask = _mm_cmplt_epi32(rem_reg_bins, _mm_set1_epi32(4));
bit_mask = _mm_movemask_epi8(mask);
rem_reg_all_lt4 = (bit_mask == 0xFFFF);
memset(state->m_absLevelsAndCtxInit[state_offset], 0, 48 * sizeof(uint8_t) * 4);
}
else {
for (int i = 0; i< 4; ++i) {
const int decision_id = i;
const int state_id = state_offset + i;
if (decisions->prevId[decision_id] >= 0) {
const int prvState = ctxs->m_prev_state_offset + decisions->prevId[decision_id];
state->m_numSigSbb[state_id] = (state->m_numSigSbb[prvState]) || !!decisions->absLevel[decision_id];
state->m_refSbbCtxId[state_id] = state->m_refSbbCtxId[prvState];
state->m_sbbFracBits[state_id][0] = state->m_sbbFracBits[prvState][0];
state->m_sbbFracBits[state_id][1] = state->m_sbbFracBits[prvState][1];
state->m_remRegBins[state_id] = state->m_remRegBins[prvState] - 1;
state->m_goRicePar[state_id] = state->m_goRicePar[prvState];
if (state->m_remRegBins[state_id] >= 4) {
state->m_remRegBins[state_id] -= (decisions->absLevel[decision_id] < 2 ? (unsigned)decisions->absLevel[decision_id] : 3);
}
memcpy(state->m_absLevelsAndCtxInit[state_id], state->m_absLevelsAndCtxInit[prvState], 48 * sizeof(uint8_t));
} else {
state->m_numSigSbb[state_id] = 1;
state->m_refSbbCtxId[state_id] = -1;
int ctxBinSampleRatio = 28;
//(scanInfo.chType == CHANNEL_TYPE_LUMA) ? MAX_TU_LEVEL_CTX_CODED_BIN_CONSTRAINT_LUMA : MAX_TU_LEVEL_CTX_CODED_BIN_CONSTRAINT_CHROMA;
state->m_remRegBins[state_id] = (state->effWidth * state->effHeight * ctxBinSampleRatio) / 16 - (decisions->absLevel[decision_id] < 2 ? (unsigned)decisions->absLevel[decision_id] : 3);
memset(state->m_absLevelsAndCtxInit[state_id], 0, 48 * sizeof(uint8_t));
}
rem_reg_all_gte_4 &= state->m_remRegBins[state_id] >= 4;
rem_reg_all_lt4 &= state->m_remRegBins[state_id] < 4;
}
}
uint32_t level_offset = scan_pos & 15;
__m128i max_abs = _mm_min_epi32(abs_level, _mm_set1_epi32(255));
uint32_t max_abs_s[4];
_mm_storeu_epi32(max_abs_s, max_abs);
for (int i = 0; i < 4; ++i) {
uint8_t* levels = (uint8_t*)state->m_absLevelsAndCtxInit[state_offset + i];
levels[level_offset] = max_abs_s[i];
}
if (rem_reg_all_gte_4) {
const __m128i last_two_bytes = _mm_set1_epi32(0xffff);
const __m128i last_byte = _mm_set1_epi32(0xff);
const __m128i ones = _mm_set1_epi32(1);
const uint32_t tinit_offset = MIN(level_offset - 1u, 15u) + 8;
const __m128i levels_start_offsets = _mm_set_epi32(48 * 3, 48 * 2, 48 * 1, 48 * 0);
__m128i tinit = _mm_i32gather_epi32(
state->m_absLevelsAndCtxInit[state_offset],
_mm_add_epi32(levels_start_offsets, _mm_set1_epi32(tinit_offset)),
1);
tinit = _mm_and_epi32(tinit, last_two_bytes);
__m128i sum_abs1 = _mm_and_epi32(_mm_srli_epi32(tinit, 3), _mm_set1_epi32(31));
__m128i sum_num = _mm_and_epi32(tinit, _mm_set1_epi32(7));
uint8_t* levels = state->m_absLevelsAndCtxInit[state_offset];
switch (numIPos) {
case 5:
{
__m128i t = _mm_i32gather_epi32(
levels,
_mm_add_epi32(levels_start_offsets, _mm_set1_epi32(next_nb_info_ssb.inPos[4])),
1);
sum_abs1 = _mm_add_epi32(
sum_abs1,
_mm_and_epi32(t, ones));
sum_num = _mm_add_epi32(
sum_num,
_mm_min_epi32(_mm_and_epi32(t, last_byte), ones));
}
case 4:
{
__m128i t = _mm_i32gather_epi32(
levels,
_mm_add_epi32(levels_start_offsets, _mm_set1_epi32(next_nb_info_ssb.inPos[3])),
1);
sum_abs1 = _mm_add_epi32(
sum_abs1,
_mm_and_epi32(t, ones));
sum_num = _mm_add_epi32(
sum_num,
_mm_min_epi32(_mm_and_epi32(t, last_byte), ones));
}
case 3:
{
__m128i t = _mm_i32gather_epi32(
levels,
_mm_add_epi32(levels_start_offsets, _mm_set1_epi32(next_nb_info_ssb.inPos[2])),
1);
sum_abs1 = _mm_add_epi32(
sum_abs1,
_mm_and_epi32(t, ones));
sum_num = _mm_add_epi32(
sum_num,
_mm_min_epi32(_mm_and_epi32(t, last_byte), ones));
}
case 2:
{
__m128i t = _mm_i32gather_epi32(
levels,
_mm_add_epi32(levels_start_offsets, _mm_set1_epi32(next_nb_info_ssb.inPos[1])),
1);
sum_abs1 = _mm_add_epi32(
sum_abs1,
_mm_and_epi32(t, ones));
sum_num = _mm_add_epi32(
sum_num,
_mm_min_epi32(_mm_and_epi32(t, last_byte), ones));
}
case 1: {
__m128i t = _mm_i32gather_epi32(
levels,
_mm_add_epi32(levels_start_offsets, _mm_set1_epi32(next_nb_info_ssb.inPos[0])),
1);
sum_abs1 = _mm_add_epi32(
sum_abs1,
_mm_and_epi32(t, ones));
sum_num = _mm_add_epi32(
sum_num,
_mm_min_epi32(_mm_and_epi32(t, last_byte), ones));
} break;
default:
assert(0);
}
__m128i sum_gt1 = _mm_sub_epi32(sum_abs1, sum_num);
__m128i offsets = _mm_set_epi32(24 * 3, 24 * 2, 24 * 1, 24 * 0);
offsets = _mm_add_epi32(offsets, _mm_set1_epi32(sigCtxOffsetNext));
__m128i temp = _mm_min_epi32(
_mm_srli_epi32(_mm_add_epi32(sum_abs1, ones), 1),
_mm_set1_epi32(3));
offsets = _mm_add_epi32(offsets, temp);
__m256i sig_frac_bits = _mm256_i32gather_epi64(state->m_sigFracBitsArray[state_offset][0], offsets, 4);
_mm256_storeu_epi64(&state->m_sigFracBits[state_offset][0], sig_frac_bits);
sum_gt1 = _mm_min_epi32(sum_gt1, _mm_set1_epi32(4));
uint32_t sum_gt1_s[4];
_mm_storeu_epi32(sum_gt1_s, sum_gt1);
for (int i = 0; i < 4; ++i) {
memcpy(state->m_coeffFracBits[state_offset + i], state->m_gtxFracBitsArray[sum_gt1_s[i]], sizeof(state->m_coeffFracBits[0]));
}
__m128i sum_abs = _mm_srli_epi32(tinit, 8);
switch (numIPos) {
case 5:
{
__m128i t = _mm_i32gather_epi32(
levels,
_mm_add_epi32(levels_start_offsets, _mm_set1_epi32(next_nb_info_ssb.inPos[4])),
1);
t = _mm_and_epi32(t, last_byte);
sum_abs = _mm_add_epi32(sum_abs, t);
}
case 4:
{
__m128i t = _mm_i32gather_epi32(
levels,
_mm_add_epi32(levels_start_offsets, _mm_set1_epi32(next_nb_info_ssb.inPos[3])),
1);
t = _mm_and_epi32(t, last_byte);
sum_abs = _mm_add_epi32(sum_abs, t);
}
case 3:
{
__m128i t = _mm_i32gather_epi32(
levels,
_mm_add_epi32(levels_start_offsets, _mm_set1_epi32(next_nb_info_ssb.inPos[2])),
1);
t = _mm_and_epi32(t, last_byte);
sum_abs = _mm_add_epi32(sum_abs, t);
}
case 2:
{
__m128i t = _mm_i32gather_epi32(
levels,
_mm_add_epi32(levels_start_offsets, _mm_set1_epi32(next_nb_info_ssb.inPos[1])),
1);
t = _mm_and_epi32(t, last_byte);
sum_abs = _mm_add_epi32(sum_abs, t);
}
case 1:
{
__m128i t = _mm_i32gather_epi32(
levels,
_mm_add_epi32(levels_start_offsets, _mm_set1_epi32(next_nb_info_ssb.inPos[0])),
1);
t = _mm_and_epi32(t, last_byte);
sum_abs = _mm_add_epi32(sum_abs, t);
} break;
default:
assert(0);
}
if (extRiceFlag) {
assert(0 && "Not implemented for avx2");
} else {
__m128i sum_all = _mm_max_epi32(
_mm_min_epi32(
_mm_set1_epi32(31),
_mm_sub_epi32(sum_abs, _mm_set1_epi32(20))),
_mm_set1_epi32(0));
__m128i temp = _mm_i32gather_epi32(g_goRiceParsCoeff, sum_all, 4);
_mm_storeu_epi32(&state->m_goRicePar[state_offset], temp);
}
}
else if (rem_reg_all_lt4) {
uint8_t* levels = state->m_absLevelsAndCtxInit[state_offset];
const __m128i last_two_bytes = _mm_set1_epi32(0xffff);
const __m128i last_byte = _mm_set1_epi32(0xff);
const __m128i ones = _mm_set1_epi32(1);
const uint32_t tinit_offset = MIN(level_offset - 1u, 15u) + 8;
const __m128i levels_start_offsets = _mm_set_epi32(48 * 3, 48 * 2, 48 * 1, 48 * 0);
__m128i tinit = _mm_i32gather_epi32(
state->m_absLevelsAndCtxInit[state_offset],
_mm_add_epi32(levels_start_offsets, _mm_set1_epi32(tinit_offset)),
1);
tinit = _mm_and_epi32(tinit, last_two_bytes);
__m128i sum_abs = _mm_srli_epi32(tinit, 8);
switch (numIPos) {
case 5: {
__m128i t = _mm_i32gather_epi32(
levels,
_mm_add_epi32(levels_start_offsets, _mm_set1_epi32(next_nb_info_ssb.inPos[4])),
1);
t = _mm_and_epi32(t, last_byte);
sum_abs = _mm_add_epi32(sum_abs, t);
}
case 4: {
__m128i t = _mm_i32gather_epi32(
levels,
_mm_add_epi32(levels_start_offsets, _mm_set1_epi32(next_nb_info_ssb.inPos[3])),
1);
t = _mm_and_epi32(t, last_byte);
sum_abs = _mm_add_epi32(sum_abs, t);
}
case 3: {
__m128i t = _mm_i32gather_epi32(
levels,
_mm_add_epi32(levels_start_offsets, _mm_set1_epi32(next_nb_info_ssb.inPos[2])),
1);
t = _mm_and_epi32(t, last_byte);
sum_abs = _mm_add_epi32(sum_abs, t);
}
case 2: {
__m128i t = _mm_i32gather_epi32(
levels,
_mm_add_epi32(levels_start_offsets, _mm_set1_epi32(next_nb_info_ssb.inPos[1])),
1);
t = _mm_and_epi32(t, last_byte);
sum_abs = _mm_add_epi32(sum_abs, t);
}
case 1: {
__m128i t = _mm_i32gather_epi32(
levels,
_mm_add_epi32(levels_start_offsets, _mm_set1_epi32(next_nb_info_ssb.inPos[0])),
1);
t = _mm_and_epi32(t, last_byte);
sum_abs = _mm_add_epi32(sum_abs, t);
} break;
default:
assert(0);
}
if (extRiceFlag) {
assert(0 && "Not implemented for avx2");
} else {
__m128i sum_all = _mm_max_epi32(
_mm_min_epi32(
_mm_set1_epi32(31),
_mm_sub_epi32(sum_abs, _mm_set1_epi32(20))),
_mm_set1_epi32(0));
__m128i temp = _mm_i32gather_epi32(g_goRiceParsCoeff, sum_all, 4);
__m128i control = _mm_setr_epi8(0, 4, 8, 12, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1);
__m128i go_rice_par = _mm_shuffle_epi8(temp, control);
int go_rice_par_i = _mm_extract_epi32(go_rice_par, 0);
memcpy(&state->m_goRicePar[state_offset], &go_rice_par_i, 4);
__m128i go_rice_zero = _mm_set_epi32(2, 2, 1, 1);
go_rice_zero = _mm_sll_epi32(go_rice_zero, temp);
go_rice_zero = _mm_shuffle_epi8(go_rice_zero, control);
int go_rice_zero_i = _mm_extract_epi32(go_rice_par, 0);
memcpy(&state->m_goRiceZero[state_offset], &go_rice_zero_i, 4);
}
}
else {
for (int i = 0; i < 4; ++i) {
const int state_id = state_offset + i;
uint8_t* levels = (uint8_t*)(state->m_absLevelsAndCtxInit[state_id]);
if (state->m_remRegBins[state_id] >= 4) {
coeff_t tinit = state->m_absLevelsAndCtxInit[state_id][8 + ((scan_pos - 1) & 15)];
coeff_t sumAbs1 = (tinit >> 3) & 31;
coeff_t sumNum = tinit & 7;
#define UPDATE(k) \
{ \
coeff_t t = levels[next_nb_info_ssb.inPos[k]]; \
sumAbs1 += MIN(4 + (t & 1), t); \
sumNum += !!t; \
}
switch (numIPos) {
case 5: UPDATE(4);
case 4: UPDATE(3);
case 3: UPDATE(2);
case 2: UPDATE(1);
case 1: UPDATE(0); break;
default: assert(0);
}
#undef UPDATE
coeff_t sumGt1 = sumAbs1 - sumNum;
state->m_sigFracBits[state_id][0] = state->m_sigFracBitsArray[state_id][sigCtxOffsetNext + MIN((sumAbs1 + 1) >> 1, 3)][0];
state->m_sigFracBits[state_id][1] = state->m_sigFracBitsArray[state_id][sigCtxOffsetNext + MIN((sumAbs1 + 1) >> 1, 3)][1];
memcpy(state->m_coeffFracBits[state_id], state->m_gtxFracBitsArray[gtxCtxOffsetNext + (sumGt1 < 4 ? sumGt1 : 4)], sizeof(state->m_coeffFracBits[0]));
coeff_t sumAbs = state->m_absLevelsAndCtxInit[state_id][8 + ((scan_pos - 1) & 15)] >> 8;
#define UPDATE(k) \
{ \
coeff_t t = levels[next_nb_info_ssb.inPos[k]]; \
sumAbs += t; \
}
switch (numIPos) {
case 5: UPDATE(4);
case 4: UPDATE(3);
case 3: UPDATE(2);
case 2: UPDATE(1);
case 1: UPDATE(0); break;
default: assert(0);
}
#undef UPDATE
if (extRiceFlag) {
unsigned currentShift = templateAbsCompare(sumAbs);
sumAbs = sumAbs >> currentShift;
int sumAll = MAX(MIN(31, (int)sumAbs - (int)baseLevel), 0);
state->m_goRicePar[state_id] = g_goRiceParsCoeff[sumAll];
state->m_goRicePar[state_id] += currentShift;
} else {
int sumAll = MAX(MIN(31, (int)sumAbs - 4 * 5), 0);
state->m_goRicePar[state_id] = g_goRiceParsCoeff[sumAll];
}
} else {
coeff_t sumAbs = (state->m_absLevelsAndCtxInit[state_id][8 + ((scan_pos - 1) & 15)]) >> 8;
#define UPDATE(k) \
{ \
coeff_t t = levels[next_nb_info_ssb.inPos[k]]; \
sumAbs += t; \
}
switch (numIPos) {
case 5: UPDATE(4);
case 4: UPDATE(3);
case 3: UPDATE(2);
case 2: UPDATE(1);
case 1: UPDATE(0); break;
default: assert(0);
}
#undef UPDATE
if (extRiceFlag) {
unsigned currentShift = templateAbsCompare(sumAbs);
sumAbs = sumAbs >> currentShift;
sumAbs = MIN(31, sumAbs);
state->m_goRicePar[state_id] = g_goRiceParsCoeff[sumAbs];
state->m_goRicePar[state_id] += currentShift;
} else {
sumAbs = MIN(31, sumAbs);
state->m_goRicePar[state_id] = g_goRiceParsCoeff[sumAbs];
}
state->m_goRiceZero[state_id] = ((state_id & 3) < 2 ? 1 : 2) << state->m_goRicePar[state_id];
}
}
}
} else {
for (int i = 0; i < 4; ++i) {
updateState(
ctxs,
numIPos,
scan_pos,
decisions,
sigCtxOffsetNext,
gtxCtxOffsetNext,
next_nb_info_ssb,
baseLevel,
extRiceFlag,
i);
}
}
}
static INLINE void updateState( static INLINE void updateState(
context_store * ctxs, context_store * ctxs,
@ -1258,7 +1762,7 @@ static INLINE void updateState(
if (decisions->prevId[decision_id] > -2) { if (decisions->prevId[decision_id] > -2) {
if (decisions->prevId[decision_id] >= 0) { if (decisions->prevId[decision_id] >= 0) {
const int prvState = ctxs->m_prev_state_offset + decisions->prevId[decision_id]; const int prvState = ctxs->m_prev_state_offset + decisions->prevId[decision_id];
state->m_numSigSbb[state_id] = (state->m_numSigSbb[prvState]) + !!decisions->absLevel[decision_id]; state->m_numSigSbb[state_id] = (state->m_numSigSbb[prvState]) || !!decisions->absLevel[decision_id];
state->m_refSbbCtxId[state_id] = state->m_refSbbCtxId[prvState]; state->m_refSbbCtxId[state_id] = state->m_refSbbCtxId[prvState];
state->m_sbbFracBits[state_id][0] = state->m_sbbFracBits[prvState][0]; state->m_sbbFracBits[state_id][0] = state->m_sbbFracBits[prvState][0];
state->m_sbbFracBits[state_id][1] = state->m_sbbFracBits[prvState][1]; state->m_sbbFracBits[state_id][1] = state->m_sbbFracBits[prvState][1];
@ -1289,30 +1793,13 @@ static INLINE void updateState(
coeff_t sumAbs1 = (tinit >> 3) & 31; coeff_t sumAbs1 = (tinit >> 3) & 31;
coeff_t sumNum = tinit & 7; coeff_t sumNum = tinit & 7;
#define UPDATE(k) {coeff_t t=levels[next_nb_info_ssb.inPos[k]]; sumAbs1+=MIN(4+(t&1),t); sumNum+=!!t; } #define UPDATE(k) {coeff_t t=levels[next_nb_info_ssb.inPos[k]]; sumAbs1+=MIN(4+(t&1),t); sumNum+=!!t; }
if (numIPos == 1) { switch (numIPos) {
UPDATE(0); case 5: UPDATE(4);
} case 4: UPDATE(3);
else if (numIPos == 2) { case 3: UPDATE(2);
UPDATE(0); case 2: UPDATE(1);
UPDATE(1); case 1: UPDATE(0); break;
} default: assert(0);
else if (numIPos == 3) {
UPDATE(0);
UPDATE(1);
UPDATE(2);
}
else if (numIPos == 4) {
UPDATE(0);
UPDATE(1);
UPDATE(2);
UPDATE(3);
}
else if (numIPos == 5) {
UPDATE(0);
UPDATE(1);
UPDATE(2);
UPDATE(3);
UPDATE(4);
} }
#undef UPDATE #undef UPDATE
coeff_t sumGt1 = sumAbs1 - sumNum; coeff_t sumGt1 = sumAbs1 - sumNum;
@ -1326,30 +1813,13 @@ static INLINE void updateState(
coeff_t sumAbs = state->m_absLevelsAndCtxInit[state_id][8 + ((scan_pos - 1) & 15)] >> 8; coeff_t sumAbs = state->m_absLevelsAndCtxInit[state_id][8 + ((scan_pos - 1) & 15)] >> 8;
#define UPDATE(k) {coeff_t t=levels[next_nb_info_ssb.inPos[k]]; sumAbs+=t; } #define UPDATE(k) {coeff_t t=levels[next_nb_info_ssb.inPos[k]]; sumAbs+=t; }
if (numIPos == 1) { switch (numIPos) {
UPDATE(0); case 5: UPDATE(4);
} case 4: UPDATE(3);
else if (numIPos == 2) { case 3: UPDATE(2);
UPDATE(0); case 2: UPDATE(1);
UPDATE(1); case 1: UPDATE(0); break;
} default: assert(0);
else if (numIPos == 3) {
UPDATE(0);
UPDATE(1);
UPDATE(2);
}
else if (numIPos == 4) {
UPDATE(0);
UPDATE(1);
UPDATE(2);
UPDATE(3);
}
else if (numIPos == 5) {
UPDATE(0);
UPDATE(1);
UPDATE(2);
UPDATE(3);
UPDATE(4);
} }
#undef UPDATE #undef UPDATE
if (extRiceFlag) { if (extRiceFlag) {
@ -1367,30 +1837,13 @@ static INLINE void updateState(
else { else {
coeff_t sumAbs = (state->m_absLevelsAndCtxInit[state_id][8 + ((scan_pos - 1) & 15)]) >> 8; coeff_t sumAbs = (state->m_absLevelsAndCtxInit[state_id][8 + ((scan_pos - 1) & 15)]) >> 8;
#define UPDATE(k) {coeff_t t=levels[next_nb_info_ssb.inPos[k]]; sumAbs+=t; } #define UPDATE(k) {coeff_t t=levels[next_nb_info_ssb.inPos[k]]; sumAbs+=t; }
if (numIPos == 1) { switch (numIPos) {
UPDATE(0); case 5: UPDATE(4);
} case 4: UPDATE(3);
else if (numIPos == 2) { case 3: UPDATE(2);
UPDATE(0); case 2: UPDATE(1);
UPDATE(1); case 1: UPDATE(0); break;
} default: assert(0);
else if (numIPos == 3) {
UPDATE(0);
UPDATE(1);
UPDATE(2);
}
else if (numIPos == 4) {
UPDATE(0);
UPDATE(1);
UPDATE(2);
UPDATE(3);
}
else if (numIPos == 5) {
UPDATE(0);
UPDATE(1);
UPDATE(2);
UPDATE(3);
UPDATE(4);
} }
#undef UPDATE #undef UPDATE
if (extRiceFlag) { if (extRiceFlag) {
@ -1456,11 +1909,11 @@ static void xDecideAndUpdate(
memcpy(decisions->absLevel + 4, decisions->absLevel, 4 * sizeof(coeff_t)); memcpy(decisions->absLevel + 4, decisions->absLevel, 4 * sizeof(coeff_t));
memcpy(decisions->rdCost + 4, decisions->rdCost, 4 * sizeof(int64_t)); memcpy(decisions->rdCost + 4, decisions->rdCost, 4 * sizeof(int64_t));
} else if (!zeroOut) { } else if (!zeroOut) {
update_states_avx2(ctxs, next_nb_info_ssb.num, scan_pos, decisions, sigCtxOffsetNext, gtxCtxOffsetNext, next_nb_info_ssb, 4, false);
updateState(ctxs, next_nb_info_ssb.num, scan_pos, decisions, sigCtxOffsetNext, gtxCtxOffsetNext, next_nb_info_ssb, 4, false, 0); /* updateState(ctxs, next_nb_info_ssb.num, scan_pos, decisions, sigCtxOffsetNext, gtxCtxOffsetNext, next_nb_info_ssb, 4, false, 0);
updateState(ctxs, next_nb_info_ssb.num, scan_pos, decisions, sigCtxOffsetNext, gtxCtxOffsetNext, next_nb_info_ssb, 4, false, 1); updateState(ctxs, next_nb_info_ssb.num, scan_pos, decisions, sigCtxOffsetNext, gtxCtxOffsetNext, next_nb_info_ssb, 4, false, 1);
updateState(ctxs, next_nb_info_ssb.num, scan_pos, decisions, sigCtxOffsetNext, gtxCtxOffsetNext, next_nb_info_ssb, 4, false, 2); updateState(ctxs, next_nb_info_ssb.num, scan_pos, decisions, sigCtxOffsetNext, gtxCtxOffsetNext, next_nb_info_ssb, 4, false, 2);
updateState(ctxs, next_nb_info_ssb.num, scan_pos, decisions, sigCtxOffsetNext, gtxCtxOffsetNext, next_nb_info_ssb, 4, false, 3); updateState(ctxs, next_nb_info_ssb.num, scan_pos, decisions, sigCtxOffsetNext, gtxCtxOffsetNext, next_nb_info_ssb, 4, false, 3);*/
} }
if (spt == SCAN_SOCSBB) { if (spt == SCAN_SOCSBB) {
@ -1596,7 +2049,7 @@ int uvg_dep_quant(
dep_quant_context.m_allStates.m_stateId[k] = k & 3; dep_quant_context.m_allStates.m_stateId[k] = k & 3;
for (int i = 0; i < (compID == COLOR_Y ? 12 : 8); ++i) { for (int i = 0; i < (compID == COLOR_Y ? 12 : 8); ++i) {
dep_quant_context.m_allStates.m_sigFracBitsArray[k][i] = rate_estimator.m_sigFracBits[(k & 3 ? (k & 3) - 1 : 0)][i]; memcpy(dep_quant_context.m_allStates.m_sigFracBitsArray[k][i], rate_estimator.m_sigFracBits[(k & 3 ? (k & 3) - 1 : 0)][i], sizeof(uint32_t) * 2);
} }
} }