[dep_quant] Change order of ctxInit

This commit is contained in:
Joose Sainio 2023-05-05 16:21:31 +03:00
parent a624988c91
commit d850c346d6
3 changed files with 75 additions and 71 deletions

View file

@ -570,10 +570,11 @@ static INLINE void update_common_context(
const int prev_state,
const int curr_state)
{
const uint32_t numSbb = width_in_sbb * height_in_sbb;
uint8_t* sbbFlags = cc->m_allSbbCtx[cc->m_curr_sbb_ctx_offset + (curr_state & 3)].sbbFlags;
uint8_t* levels = cc->m_allSbbCtx[cc->m_curr_sbb_ctx_offset + (curr_state & 3)].levels;
size_t setCpSize = cc->m_nbInfo[scan_pos - 1].maxDist * sizeof(uint8_t);
const uint32_t numSbb = width_in_sbb * height_in_sbb;
const int curr_state_without_offset = curr_state & 3;
uint8_t* sbbFlags = cc->m_allSbbCtx[cc->m_curr_sbb_ctx_offset + curr_state_without_offset].sbbFlags;
uint8_t* levels = cc->m_allSbbCtx[cc->m_curr_sbb_ctx_offset + curr_state_without_offset].levels;
size_t setCpSize = cc->m_nbInfo[scan_pos - 1].maxDist * sizeof(uint8_t);
if (prev_state != -1 && ctxs->m_allStates.m_refSbbCtxId[prev_state] >= 0) {
memcpy(sbbFlags, cc->m_allSbbCtx[cc->m_prev_sbb_ctx_offset + ctxs->m_allStates.m_refSbbCtxId[prev_state]].sbbFlags, numSbb * sizeof(uint8_t));
memcpy(levels + scan_pos, cc->m_allSbbCtx[cc->m_prev_sbb_ctx_offset + ctxs->m_allStates.m_refSbbCtxId[prev_state]].levels + scan_pos, setCpSize);
@ -596,11 +597,11 @@ static INLINE void update_common_context(
ctxs->m_allStates.m_remRegBins[curr_state] = (ctxs->m_allStates.effWidth * ctxs->m_allStates.effHeight * ctxBinSampleRatio) / 16;
}
ctxs->m_allStates.m_goRicePar[curr_state] = 0;
ctxs->m_allStates.m_refSbbCtxId[curr_state] = curr_state & 3;
ctxs->m_allStates.m_refSbbCtxId[curr_state] = curr_state_without_offset;
ctxs->m_allStates.m_sbbFracBits[curr_state][0] = cc->m_sbbFlagBits[sigNSbb][0];
ctxs->m_allStates.m_sbbFracBits[curr_state][1] = cc->m_sbbFlagBits[sigNSbb][1];
uint16_t *templateCtxInit = ctxs->m_allStates.m_ctxInit[curr_state];
uint16_t *templateCtxInit = ctxs->m_allStates.m_ctxInit[ctxs->m_curr_state_offset >> 2];
const int scanBeg = scan_pos - 16;
const NbInfoOut* nbOut = cc->m_nbInfo + scanBeg;
const uint8_t* absLevels = levels + scanBeg;
@ -622,10 +623,10 @@ static INLINE void update_common_context(
}
}
#undef UPDATE
templateCtxInit[id] = (uint16_t)(sumNum) + ((uint16_t)(sumAbs1) << 3) + ((uint16_t)MIN(127, sumAbs) << 8);
templateCtxInit[curr_state_without_offset + id * 4] = (uint16_t)(sumNum) + ((uint16_t)(sumAbs1) << 3) + ((uint16_t)MIN(127, sumAbs) << 8);
}
else {
templateCtxInit[id] = 0;
templateCtxInit[curr_state_without_offset + id * 4] = 0;
}
}
memset(ctxs->m_allStates.m_absLevels[curr_state], 0, 16 * sizeof(uint8_t));
@ -671,7 +672,7 @@ void uvg_dep_quant_update_state_eos(
update_common_context(ctxs, state->m_commonCtx, scan_pos, cg_pos, width_in_sbb, height_in_sbb, next_sbb_right,
next_sbb_below, prvState, ctxs->m_curr_state_offset + decision_id);
coeff_t tinit = state->m_ctxInit[curr_state_offset][((scan_pos - 1) & 15)];
coeff_t tinit = state->m_ctxInit[ctxs->m_curr_state_offset >> 2][((scan_pos - 1) & 15) * 4 + decision_id];
coeff_t sumNum = tinit & 7;
coeff_t sumAbs1 = (tinit >> 3) & 31;
coeff_t sumGt1 = sumAbs1 - sumNum;
@ -695,12 +696,13 @@ void uvg_dep_quant_update_state(
const int baseLevel,
const bool extRiceFlag,
int decision_id) {
all_depquant_states* state = &ctxs->m_allStates;
int state_id = ctxs->m_curr_state_offset + decision_id;
state->m_rdCost[state_id] = decisions->rdCost[decision_id];
if (decisions->prevId[decision_id] > -2) {
if (decisions->prevId[decision_id] >= 0) {
const int prvState = ctxs->m_prev_state_offset + decisions->prevId[decision_id];
all_depquant_states* state = &ctxs->m_allStates;
int state_id = ctxs->m_curr_state_offset + decision_id;
state->m_rdCost[state_id] = decisions->rdCost[decision_id];
int32_t prev_id_no_offset = decisions->prevId[decision_id];
if (prev_id_no_offset > -2) {
if (prev_id_no_offset >= 0) {
const int prvState = ctxs->m_prev_state_offset + prev_id_no_offset;
state->m_numSigSbb[state_id] = (state->m_numSigSbb[prvState]) || !!decisions->absLevel[decision_id];
state->m_refSbbCtxId[state_id] = state->m_refSbbCtxId[prvState];
state->m_sbbFracBits[state_id][0] = state->m_sbbFracBits[prvState][0];
@ -713,7 +715,9 @@ void uvg_dep_quant_update_state(
: 3);
}
memcpy(state->m_absLevels[state_id], state->m_absLevels[prvState], 16 * sizeof(uint8_t));
memcpy(state->m_ctxInit[state_id], state->m_ctxInit[prvState], 16 * sizeof(uint16_t));
for (int i = 0; i < 64; i += 4) {
state->m_ctxInit[ctxs->m_curr_state_offset >> 2][decision_id + i] = state->m_ctxInit[ctxs->m_prev_state_offset >> 2][prev_id_no_offset + i];
}
}
else {
state->m_numSigSbb[state_id] = 1;
@ -723,7 +727,9 @@ void uvg_dep_quant_update_state(
state->m_remRegBins[state_id] = (state->effWidth * state->effHeight * ctxBinSampleRatio) / 16 - (
decisions->absLevel[decision_id] < 2 ? (unsigned)decisions->absLevel[decision_id] : 3);
memset(state->m_absLevels[state_id], 0, 16 * sizeof(uint8_t));
memset(state->m_ctxInit[state_id], 0, 16 * sizeof(uint16_t));
for (int i = 0; i < 64; i += 4) {
state->m_ctxInit[ctxs->m_curr_state_offset >> 2][decision_id + i] = 0;
}
}
state->all_gte_four &= state->m_remRegBins[state_id] >= 4;
state->all_lt_four &= state->m_remRegBins[state_id] < 4;
@ -731,7 +737,7 @@ void uvg_dep_quant_update_state(
levels[scan_pos & 15] = (uint8_t)MIN(32, decisions->absLevel[decision_id]);
if (state->m_remRegBins[state_id] >= 4) {
coeff_t tinit = state->m_ctxInit[state_id][((scan_pos - 1) & 15)];
coeff_t tinit = state->m_ctxInit[ctxs->m_curr_state_offset >> 2][((scan_pos - 1) & 15) * 4 + decision_id];
coeff_t sumAbs1 = (tinit >> 3) & 31;
coeff_t sumNum = tinit & 7;
#define UPDATE(k) {coeff_t t=levels[next_nb_info_ssb.inPos[k]]; sumAbs1+=MIN(4+(t&1),t); sumNum+=!!t; }
@ -753,7 +759,7 @@ void uvg_dep_quant_update_state(
sizeof(state->m_coeffFracBits[0]));
coeff_t sumAbs = state->m_ctxInit[state_id][(scan_pos - 1) & 15] >> 8;
coeff_t sumAbs = state->m_ctxInit[ctxs->m_curr_state_offset >> 2][((scan_pos - 1) & 15) * 4 + decision_id] >> 8;
#define UPDATE(k) {coeff_t t=levels[next_nb_info_ssb.inPos[k]]; sumAbs+=t; }
switch (numIPos) {
case 5: UPDATE(4);
@ -777,7 +783,7 @@ void uvg_dep_quant_update_state(
}
}
else {
coeff_t sumAbs = (state->m_ctxInit[state_id][(scan_pos - 1) & 15]) >> 8;
coeff_t sumAbs = state->m_ctxInit[ctxs->m_curr_state_offset >> 2][((scan_pos - 1) & 15) * 4 + decision_id] >> 8;
#define UPDATE(k) {coeff_t t=levels[next_nb_info_ssb.inPos[k]]; sumAbs+=t; }
switch (numIPos) {
case 5: UPDATE(4);
@ -1055,7 +1061,10 @@ int uvg_dep_quant(
height,
compID != 0); //tu.cu->slice->getReverseLastSigCoeffFlag());
}
if(0){
for (int i = 0; i < 8; ++i) {
assert(ctxs->m_allStates.m_refSbbCtxId[i] < 5);
}
if(1){
printf("%d\n", scanIdx);
for (int i = 0; i < 4; i++) {
printf("%lld %hu %d\n", ctxs->m_trellis[scanIdx].rdCost[i], ctxs->m_trellis[scanIdx].absLevel[i], ctxs->m_trellis[scanIdx].prevId[i]);

View file

@ -150,7 +150,7 @@ typedef struct {
typedef struct {
int64_t ALIGNED(32) m_rdCost[12];
uint8_t ALIGNED(32) m_absLevels[12][16];
uint16_t ALIGNED(32) m_ctxInit[12][16];
uint16_t ALIGNED(32) m_ctxInit[3][16 * 4];
int8_t ALIGNED(16) m_numSigSbb[12];
int ALIGNED(32) m_remRegBins[12];
int8_t ALIGNED(16) m_refSbbCtxId[12];

View file

@ -637,7 +637,6 @@ static void update_state_eos_avx2(context_store* ctxs, const uint32_t scan_pos,
temp[id % 4] = 0;
if (id % 4 == 3) {
all[id / 4] = _mm256_loadu_si256((__m256i const*)temp);
all[id / 4] = _mm256_shuffle_epi8(all[id / 4], v_shuffle);
}
continue;
}
@ -726,35 +725,14 @@ static void update_state_eos_avx2(context_store* ctxs, const uint32_t scan_pos,
temp[id % 4] = _mm_extract_epi64(shuffled_template_ctx_init, 0);
if (id % 4 == 3) {
all[id / 4] = _mm256_loadu_si256((__m256i const*)temp);
all[id / 4] = _mm256_shuffle_epi8(all[id / 4], v_shuffle);
last = template_ctx_init;
}
}
__m256i* v_src_tmp = all;
__m256i v_tmp[4];
v_tmp[0] = _mm256_permute2x128_si256(v_src_tmp[0], v_src_tmp[1], 0x20);
v_tmp[1] = _mm256_permute2x128_si256(v_src_tmp[0], v_src_tmp[1], 0x31);
v_tmp[2] = _mm256_permute2x128_si256(v_src_tmp[2], v_src_tmp[3], 0x20);
v_tmp[3] = _mm256_permute2x128_si256(v_src_tmp[2], v_src_tmp[3], 0x31);
__m256i v_tmp16_lo[2];
__m256i v_tmp16_hi[2];
v_tmp16_lo[0] = _mm256_unpacklo_epi32(v_tmp[0], v_tmp[1]);
v_tmp16_lo[1] = _mm256_unpacklo_epi32(v_tmp[2], v_tmp[3]);
v_tmp16_hi[0] = _mm256_unpackhi_epi32(v_tmp[0], v_tmp[1]);
v_tmp16_hi[1] = _mm256_unpackhi_epi32(v_tmp[2], v_tmp[3]);
v_tmp[0] = _mm256_permute4x64_epi64(v_tmp16_lo[0], _MM_SHUFFLE(3, 1, 2, 0));
v_tmp[1] = _mm256_permute4x64_epi64(v_tmp16_lo[1], _MM_SHUFFLE(3, 1, 2, 0));
v_tmp[2] = _mm256_permute4x64_epi64(v_tmp16_hi[0], _MM_SHUFFLE(3, 1, 2, 0));
v_tmp[3] = _mm256_permute4x64_epi64(v_tmp16_hi[1], _MM_SHUFFLE(3, 1, 2, 0));
_mm256_storeu_si256((__m256i*)(state->m_ctxInit[state_offset]), _mm256_permute2x128_si256(v_tmp[0], v_tmp[1], 0x20));
_mm256_storeu_si256((__m256i*)(state->m_ctxInit[state_offset + 1]), _mm256_permute2x128_si256(v_tmp[0], v_tmp[1], 0x31));
_mm256_storeu_si256((__m256i*)(state->m_ctxInit[state_offset + 2]), _mm256_permute2x128_si256(v_tmp[2], v_tmp[3], 0x20));
_mm256_storeu_si256((__m256i*)(state->m_ctxInit[state_offset + 3]), _mm256_permute2x128_si256(v_tmp[2], v_tmp[3], 0x31));
_mm256_storeu_si256((__m256i*)(&state->m_ctxInit[state_offset >> 2][0]), all[0]);
_mm256_storeu_si256((__m256i*)(&state->m_ctxInit[state_offset >> 2][16]), all[1]);
_mm256_storeu_si256((__m256i*)(&state->m_ctxInit[state_offset >> 2][32]), all[2]);
_mm256_storeu_si256((__m256i*)(&state->m_ctxInit[state_offset >> 2][48]), all[3]);
for (int i = 0; i < 4; ++i) {
memset(state->m_absLevels[state_offset + i], 0, 16);
@ -836,9 +814,9 @@ static INLINE void update_states_avx2(
__m128i abs_level = _mm_load_si128((__m128i const*)decisions->absLevel);
if (all_non_negative) {
__m128i prv_states = _mm_load_si128((__m128i const*)decisions->prevId);
__m128i prv_states_o = _mm_load_si128((__m128i const*)decisions->prevId);
__m128i prev_offset = _mm_set1_epi32(ctxs->m_prev_state_offset);
prv_states = _mm_add_epi32(prv_states, prev_offset);
__m128i prv_states = _mm_add_epi32(prv_states_o, prev_offset);
__m128i control = _mm_setr_epi8(0, 4, 8, 12, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1);
__m128i shuffled_prev_states = _mm_shuffle_epi8(prv_states, control);
@ -887,8 +865,20 @@ static INLINE void update_states_avx2(
int32_t prv_states_scalar[4];
_mm_storeu_si128((__m128i*)prv_states_scalar, prv_states);
for (int i = 0; i < 4; ++i) {
memcpy(state->m_absLevels[state_offset + i], state->m_absLevels[prv_states_scalar[i]], 16 * sizeof(uint8_t));
memcpy(state->m_ctxInit[state_offset + i], state->m_ctxInit[prv_states_scalar[i]], 16 * sizeof(uint16_t));
memcpy(state->m_absLevels[state_offset + i], state->m_absLevels[prv_states_scalar[i]], 16 * sizeof(uint8_t));
}
__m256i prev_state_full = _mm256_load_si256((__m256i const*)decisions->prevId);
__m256i shuffle_mask = _mm256_setr_epi8(0, 0, 4, 4,8, 8, 12, 12, 0, 0, 4, 4, 8, 8, 12, 12,0, 0, 0, 0,0, 0, 0, 0,16, 16, 16, 16, 16, 16, 16, 16);
prev_state_full = _mm256_shuffle_epi8(prev_state_full, shuffle_mask);
prev_state_full = _mm256_permute4x64_epi64(prev_state_full, 0);
prev_state_full = _mm256_slli_epi16(prev_state_full, 1);
__m256i temp_add = _mm256_setr_epi8(0, 1, 0, 1, 0, 1, 0, 1, 8, 9, 8, 9, 8, 9, 8, 9, 16, 17, 16, 17,16, 17,16, 17, 24, 25,24,25,24,25,24,25);
prev_state_full = _mm256_add_epi8(prev_state_full, temp_add);
for (int i = 0; i < 64; i += (256 / 8 / sizeof(uint16_t))) {
__m256i data = _mm256_load_si256((__m256i*)(&state->m_ctxInit[(ctxs->m_prev_state_offset >> 2)][i]));
data = _mm256_shuffle_epi8(data, prev_state_full);
_mm256_store_si256((__m256i*)(&state->m_ctxInit[(state_offset >> 2)][i]), data);
}
}
else if (all_minus_one) {
@ -914,7 +904,7 @@ static INLINE void update_states_avx2(
rem_reg_all_lt4 = (bit_mask == 0xFFFF);
memset(state->m_absLevels[state_offset], 0, 16 * sizeof(uint8_t) * 4);
memset(state->m_ctxInit[state_offset], 0, 16 * sizeof(uint16_t) * 4);
memset(state->m_ctxInit[state_offset >> 2], 0, 16 * sizeof(uint16_t) * 4);
}
else {
@ -933,7 +923,9 @@ static INLINE void update_states_avx2(
state->m_remRegBins[state_id] -= (decisions->absLevel[decision_id] < 2 ? (unsigned)decisions->absLevel[decision_id] : 3);
}
memcpy(state->m_absLevels[state_id], state->m_absLevels[prvState], 16 * sizeof(uint8_t));
memcpy(state->m_ctxInit[state_id], state->m_ctxInit[prvState], 16 * sizeof(uint16_t));
for (int k = 0; k < 16; ++k) {
state->m_ctxInit[state_offset >> 2][k * 4 + i] = state->m_ctxInit[ctxs->m_prev_state_offset >> 2][k * 4 + decisions->prevId[decision_id]];
}
} else {
state->m_numSigSbb[state_id] = 1;
state->m_refSbbCtxId[state_id] = -1;
@ -941,7 +933,9 @@ static INLINE void update_states_avx2(
//(scanInfo.chType == CHANNEL_TYPE_LUMA) ? MAX_TU_LEVEL_CTX_CODED_BIN_CONSTRAINT_LUMA : MAX_TU_LEVEL_CTX_CODED_BIN_CONSTRAINT_CHROMA;
state->m_remRegBins[state_id] = (state->effWidth * state->effHeight * ctxBinSampleRatio) / 16 - (decisions->absLevel[decision_id] < 2 ? (unsigned)decisions->absLevel[decision_id] : 3);
memset(state->m_absLevels[state_id], 0, 16 * sizeof(uint8_t));
memset(state->m_ctxInit[state_id], 0, 16 * sizeof(uint16_t));
for (int k = 0; k < 16; ++k) {
state->m_ctxInit[state_offset >> 2][k * 4 + i] = 0;
}
}
rem_reg_all_gte_4 &= state->m_remRegBins[state_id] >= 4;
rem_reg_all_lt4 &= state->m_remRegBins[state_id] < 4;
@ -958,16 +952,12 @@ static INLINE void update_states_avx2(
state->all_gte_four = rem_reg_all_gte_4;
state->all_lt_four = rem_reg_all_lt4;
if (rem_reg_all_gte_4) {
const __m128i first_two_bytes = _mm_set1_epi32(0xffff);
const __m128i first_byte = _mm_set1_epi32(0xff);
const __m128i ones = _mm_set1_epi32(1);
const uint32_t tinit_offset = MIN(level_offset - 1u, 15u);
const __m128i levels_start_offsets = _mm_set_epi32(16 * 3, 16 * 2, 16 * 1, 16 * 0);
__m128i tinit = _mm_i32gather_epi32(
(int *)state->m_ctxInit[state_offset],
_mm_add_epi32(levels_start_offsets, _mm_set1_epi32(tinit_offset)),
2);
tinit = _mm_and_si128(tinit, first_two_bytes);
__m128i tinit = _mm_loadu_si128((__m128i*)(&state->m_ctxInit[state_offset >> 2][tinit_offset * 4]));
tinit = _mm_cvtepi16_epi32(tinit);
__m128i sum_abs1 = _mm_and_si128(_mm_srli_epi32(tinit, 3), _mm_set1_epi32(31));
__m128i sum_num = _mm_and_si128(tinit, _mm_set1_epi32(7));
@ -1149,15 +1139,11 @@ static INLINE void update_states_avx2(
else if (rem_reg_all_lt4) {
const __m128i first_byte = _mm_set1_epi32(0xff);
uint8_t* levels = (uint8_t*)state->m_absLevels[state_offset];
const __m128i last_two_bytes = _mm_set1_epi32(0xffff);
const __m128i last_byte = _mm_set1_epi32(0xff);
const uint32_t tinit_offset = MIN(level_offset - 1u, 15u);
const __m128i levels_start_offsets = _mm_set_epi32(16 * 3, 16 * 2, 16 * 1, 16 * 0);
__m128i tinit = _mm_i32gather_epi32(
(int*)state->m_ctxInit[state_offset],
_mm_add_epi32(levels_start_offsets, _mm_set1_epi32(tinit_offset)),
2);
tinit = _mm_and_si128(tinit, last_two_bytes);
__m128i tinit = _mm_loadu_si128((__m128i*)(&state->m_ctxInit[state_offset >> 2][tinit_offset * 4]));
tinit = _mm_cvtepi16_epi32(tinit);
__m128i sum_abs = _mm_srli_epi32(tinit, 8);
sum_abs = _mm_min_epi32(sum_abs, _mm_set1_epi32(51));
switch (numIPos) {
@ -1225,7 +1211,7 @@ static INLINE void update_states_avx2(
const int state_id = state_offset + i;
uint8_t* levels = (uint8_t*)(state->m_absLevels[state_id]);
if (state->m_remRegBins[state_id] >= 4) {
coeff_t tinit = state->m_ctxInit[state_id][((scan_pos - 1) & 15)];
coeff_t tinit = state->m_ctxInit[state_offset >> 2][((scan_pos - 1) & 15) * 4 + i];
coeff_t sumAbs1 = (tinit >> 3) & 31;
coeff_t sumNum = tinit & 7;
#define UPDATE(k) \
@ -1249,7 +1235,7 @@ static INLINE void update_states_avx2(
memcpy(state->m_coeffFracBits[state_id], state->m_gtxFracBitsArray[gtxCtxOffsetNext + (sumGt1 < 4 ? sumGt1 : 4)], sizeof(state->m_coeffFracBits[0]));
coeff_t sumAbs = state->m_ctxInit[state_id][((scan_pos - 1) & 15)] >> 8;
coeff_t sumAbs = state->m_ctxInit[state_offset >> 2][((scan_pos - 1) & 15) * 4 + i] >> 8;
#define UPDATE(k) \
{ \
coeff_t t = levels[next_nb_info_ssb.inPos[k]]; \
@ -1271,7 +1257,7 @@ static INLINE void update_states_avx2(
state->m_goRicePar[state_id] = g_goRiceParsCoeff[sumAll];
}
} else {
coeff_t sumAbs = (state->m_ctxInit[state_id][((scan_pos - 1) & 15)]) >> 8;
coeff_t sumAbs = (state->m_ctxInit[state_offset >> 2][((scan_pos - 1) & 15) * 4 + i]) >> 8;
#define UPDATE(k) \
{ \
coeff_t t = levels[next_nb_info_ssb.inPos[k]]; \
@ -1355,6 +1341,15 @@ void uvg_dep_quant_decide_and_update_avx2(
} else if (!zeroOut) {
update_states_avx2(ctxs, next_nb_info_ssb.num, scan_pos, decisions, scan_info->sig_ctx_offset[is_chroma], scan_info->gtx_ctx_offset[is_chroma], next_nb_info_ssb, 4, false);
}
//for (int i = 0; i<4; i++) {
// for (int k = 0; k < 16; ++k) {
// printf(
// "%3d ",
// ctxs->m_allStates.m_ctxInit[ctxs->m_curr_state_offset / 4][k * 4 + i]);
// }
// printf("\n");
//}
//printf("\n");
if (spt == SCAN_SOCSBB) {
SWAP(ctxs->m_skip_state_offset, ctxs->m_prev_state_offset, int);