[depquant] AoS -> SoA for Decision

This commit is contained in:
Joose Sainio 2023-04-05 11:17:16 +03:00
parent 26ef1dda09
commit 73442f1bba

View file

@ -95,9 +95,9 @@ typedef struct
typedef struct
{
int64_t rdCost;
coeff_t absLevel;
int prevId;
int64_t rdCost[8];
coeff_t absLevel[8];
int prevId[8];
} Decision;
@ -152,7 +152,7 @@ typedef struct
depquant_state* m_skipStates;
depquant_state m_startState;
quant_block m_quant;
Decision m_trellis[TR_MAX_WIDTH * TR_MAX_WIDTH][8];
Decision m_trellis[TR_MAX_WIDTH * TR_MAX_WIDTH];
} context_store;
@ -526,15 +526,22 @@ static void depquant_state_init(depquant_state* state, uint32_t sig_frac_bits[2]
state->m_sbbFracBits[1] = 0;
}
static INLINE void checkRdCostSkipSbbZeroOut(Decision *decision, const depquant_state * const state)
static INLINE void checkRdCostSkipSbbZeroOut(Decision * decision, const depquant_state * const state, int decision_id)
{
int64_t rdCost = state->m_rdCost + state->m_sbbFracBits[0];
decision->rdCost = rdCost;
decision->absLevel = 0;
decision->prevId = 4 + state->m_stateId;
decision->rdCost[decision_id] = rdCost;
decision->absLevel[decision_id] = 0;
decision->prevId[decision_id] = 4 + state->m_stateId;
}
static void checkRdCosts(const depquant_state * const state, const enum ScanPosType spt, const PQData *pqDataA, const PQData *pqDataB, Decision *decisionA, Decision *decisionB)
static void checkRdCosts(
const depquant_state * const state,
const enum ScanPosType spt,
const PQData *pqDataA,
const PQData *pqDataB,
Decision *decisions,
int decisionA,
int decisionB)
{
const int32_t* goRiceTab = g_goRiceBits[state->m_goRicePar];
int64_t rdCostA = state->m_rdCost + pqDataA->deltaDist;
@ -582,7 +589,7 @@ static void checkRdCosts(const depquant_state * const state, const enum ScanPosT
}
else
{
rdCostZ = decisionA->rdCost;
rdCostZ = decisions->rdCost[decisionA];
}
}
else
@ -597,38 +604,39 @@ static void checkRdCosts(const depquant_state * const state, const enum ScanPosT
: (pqDataB->absLevel < RICEMAX ? pqDataB->absLevel : RICEMAX - 1)];
rdCostZ += goRiceTab[state->m_goRiceZero];
}
if (rdCostA < decisionA->rdCost)
if (rdCostA < decisions->rdCost[decisionA])
{
decisionA->rdCost = rdCostA;
decisionA->absLevel = pqDataA->absLevel;
decisionA->prevId = state->m_stateId;
decisions->rdCost[decisionA] = rdCostA;
decisions->absLevel[decisionA] = pqDataA->absLevel;
decisions->prevId[decisionA] = state->m_stateId;
}
if (rdCostZ < decisionA->rdCost)
if (rdCostZ < decisions->rdCost[decisionA])
{
decisionA->rdCost = rdCostZ;
decisionA->absLevel = 0;
decisionA->prevId = state->m_stateId;
decisions->rdCost[decisionA] = rdCostZ;
decisions->absLevel[decisionA] = 0;
decisions->prevId[decisionA] = state->m_stateId;
}
if (rdCostB < decisionB->rdCost)
if (rdCostB < decisions->rdCost[decisionB])
{
decisionB->rdCost = rdCostB;
decisionB->absLevel = pqDataB->absLevel;
decisionB->prevId = state->m_stateId;
decisions->rdCost[decisionB] = rdCostB;
decisions->absLevel[decisionB] = pqDataB->absLevel;
decisions->prevId[decisionB] = state->m_stateId;
}
}
static INLINE void checkRdCostSkipSbb(const depquant_state* const state, Decision *decision)
static INLINE void checkRdCostSkipSbb(const depquant_state* const state, Decision * decisions, int decision_id)
{
int64_t rdCost = state->m_rdCost + state->m_sbbFracBits[0];
if (rdCost < decision->rdCost)
if (rdCost < decisions->rdCost[decision_id])
{
decision->rdCost = rdCost;
decision->absLevel = 0;
decision->prevId = 4 + state->m_stateId;
decisions->rdCost[decision_id] = rdCost;
decisions->absLevel[decision_id] = 0;
decisions->prevId[decision_id] = 4 + state->m_stateId;
}
}
static INLINE void checkRdCostStart(const depquant_state* const state, int32_t lastOffset, const PQData *pqData, Decision *decision)
static INLINE void checkRdCostStart(const depquant_state* const state, int32_t lastOffset, const PQData *pqData, Decision *decisions, int
decision_id)
{
int64_t rdCost = pqData->deltaDist + lastOffset;
if (pqData->absLevel < 4)
@ -640,11 +648,11 @@ static INLINE void checkRdCostStart(const depquant_state* const state, int32_t l
const coeff_t value = (pqData->absLevel - 4) >> 1;
rdCost += state->m_coeffFracBits[pqData->absLevel - (value << 1)] + g_goRiceBits[state->m_goRicePar][value < RICEMAX ? value : RICEMAX - 1];
}
if (rdCost < decision->rdCost)
if (rdCost < decisions->rdCost[decision_id])
{
decision->rdCost = rdCost;
decision->absLevel = pqData->absLevel;
decision->prevId = -1;
decisions->rdCost[decision_id] = rdCost;
decisions->absLevel[decision_id] = pqData->absLevel;
decisions->prevId[decision_id] = -1;
}
}
@ -672,9 +680,8 @@ static INLINE void preQuantCoeff(const quant_block * const qp, const coeff_t abs
}
#define DINIT(l,p) {INT64_MAX>>2,(l),(p)}
static const Decision startDec[8] = { DINIT(-1,-2),DINIT(-1,-2),DINIT(-1,-2),DINIT(-1,-2),DINIT(0,4),DINIT(0,5),DINIT(0,6),DINIT(0,7) };
#undef DINIT
static const Decision startDec = { .rdCost = {INT64_MAX >> 2, INT64_MAX >> 2, INT64_MAX >> 2, INT64_MAX >> 2, INT64_MAX >> 2, INT64_MAX >> 2, INT64_MAX >> 2, INT64_MAX >> 2},
.absLevel = {-1, -1, -1, -1, 0, 0, 0, 0}, .prevId = {-2, -2, -2, -2, 4, 5, 6, 7} };
static void xDecide(
@ -689,36 +696,36 @@ static void xDecide(
bool zeroOut,
coeff_t quanCoeff)
{
memcpy(decisions, startDec, 8 * sizeof(Decision));
memcpy(decisions, &startDec, sizeof(Decision));
if (zeroOut)
{
if (spt == SCAN_EOCSBB)
{
checkRdCostSkipSbbZeroOut(&decisions[0], &m_skipStates[0]);
checkRdCostSkipSbbZeroOut(&decisions[1], &m_skipStates[1]);
checkRdCostSkipSbbZeroOut(&decisions[2], &m_skipStates[2]);
checkRdCostSkipSbbZeroOut(&decisions[3], &m_skipStates[3]);
checkRdCostSkipSbbZeroOut(decisions, &m_skipStates[0], 0);
checkRdCostSkipSbbZeroOut(decisions, &m_skipStates[1], 1);
checkRdCostSkipSbbZeroOut(decisions, &m_skipStates[2],2);
checkRdCostSkipSbbZeroOut(decisions, &m_skipStates[3],3);
}
return;
}
PQData pqData[4];
preQuantCoeff(qp, absCoeff, pqData, quanCoeff);
checkRdCosts(&m_prevStates[0], spt, &pqData[0], &pqData[2], &decisions[0], &decisions[2]);
checkRdCosts(&m_prevStates[1], spt, &pqData[0], &pqData[2], &decisions[2], &decisions[0]);
checkRdCosts(&m_prevStates[2], spt, &pqData[3], &pqData[1], &decisions[1], &decisions[3]);
checkRdCosts(&m_prevStates[3], spt, &pqData[3], &pqData[1], &decisions[3], &decisions[1]);
checkRdCosts(&m_prevStates[0], spt, &pqData[0], &pqData[2], decisions, 0, 2);
checkRdCosts(&m_prevStates[1], spt, &pqData[0], &pqData[2], decisions,2, 0);
checkRdCosts(&m_prevStates[2], spt, &pqData[3], &pqData[1], decisions, 1,3);
checkRdCosts(&m_prevStates[3], spt, &pqData[3], &pqData[1], decisions, 3,1);
if (spt == SCAN_EOCSBB)
{
checkRdCostSkipSbb(&m_skipStates[0], &decisions[0]);
checkRdCostSkipSbb(&m_skipStates[1], &decisions[1]);
checkRdCostSkipSbb(&m_skipStates[2], &decisions[2]);
checkRdCostSkipSbb(&m_skipStates[3], &decisions[3]);
checkRdCostSkipSbb(&m_skipStates[0], decisions, 0);
checkRdCostSkipSbb(&m_skipStates[1], decisions, 1);
checkRdCostSkipSbb(&m_skipStates[2], decisions,2);
checkRdCostSkipSbb(&m_skipStates[3], decisions,3);
}
checkRdCostStart(m_startState, lastOffset, &pqData[0], &decisions[0]);
checkRdCostStart(m_startState, lastOffset, &pqData[2], &decisions[2]);
checkRdCostStart(m_startState, lastOffset, &pqData[0], decisions, 0);
checkRdCostStart(m_startState, lastOffset, &pqData[2], decisions, 2);
}
@ -834,22 +841,23 @@ static INLINE void updateStateEOS(
const uint32_t next_sbb_below,
const depquant_state* prevStates,
const depquant_state* skipStates,
const Decision *decision)
const Decision * decisions,
int decision_id)
{
state->m_rdCost = decision->rdCost;
if (decision->prevId > -2)
state->m_rdCost = decisions->rdCost[decision_id];
if (decisions->prevId[decision_id] > -2)
{
const depquant_state* prvState = 0;
if (decision->prevId >= 4)
if (decisions->prevId[decision_id] >= 4)
{
prvState = skipStates + (decision->prevId - 4);
prvState = skipStates + (decisions->prevId[decision_id] - 4);
state->m_numSigSbb = 0;
memset(state->m_absLevelsAndCtxInit, 0, 16 * sizeof(uint8_t));
}
else if (decision->prevId >= 0)
else if (decisions->prevId[decision_id] >= 0)
{
prvState = prevStates + decision->prevId;
state->m_numSigSbb = prvState->m_numSigSbb + !!decision->absLevel;
prvState = prevStates + decisions->prevId[decision_id];
state->m_numSigSbb = prvState->m_numSigSbb + !!decisions->absLevel[decision_id];
memcpy(state->m_absLevelsAndCtxInit, prvState->m_absLevelsAndCtxInit, 16 * sizeof(uint8_t));
}
else
@ -858,7 +866,7 @@ static INLINE void updateStateEOS(
memset(state->m_absLevelsAndCtxInit, 0, 16 * sizeof(uint8_t));
}
uint8_t* temp = (uint8_t*)(&state->m_absLevelsAndCtxInit[scan_pos & 15]);
*temp = (uint8_t)MIN(255, decision->absLevel);
*temp = (uint8_t)MIN(255, decisions->absLevel[decision_id]);
update_common_context(state->m_commonCtx, scan_pos, cg_pos, width_in_sbb, height_in_sbb, next_sbb_right, next_sbb_below,prvState, state);
@ -875,21 +883,23 @@ static INLINE void updateStateEOS(
static INLINE void updateState(
depquant_state* state,
int numIPos, const uint32_t scan_pos,
int numIPos,
const uint32_t scan_pos,
const depquant_state* prevStates,
const Decision* decision,
const Decision* decisions,
const uint32_t sigCtxOffsetNext,
const uint32_t gtxCtxOffsetNext,
const NbInfoSbb next_nb_info_ssb,
const int baseLevel,
const bool extRiceFlag) {
state->m_rdCost = decision->rdCost;
if (decision->prevId > -2)
const bool extRiceFlag,
int decision_id) {
state->m_rdCost = decisions->rdCost[decision_id];
if (decisions->prevId[decision_id] > -2)
{
if (decision->prevId >= 0)
if (decisions->prevId[decision_id] >= 0)
{
const depquant_state* prvState = prevStates + decision->prevId;
state->m_numSigSbb = prvState->m_numSigSbb + !!decision->absLevel;
const depquant_state* prvState = prevStates + decisions->prevId[decision_id];
state->m_numSigSbb = prvState->m_numSigSbb + !!decisions->absLevel[decision_id];
state->m_refSbbCtxId = prvState->m_refSbbCtxId;
state->m_sbbFracBits[0] = prvState->m_sbbFracBits[0];
state->m_sbbFracBits[1] = prvState->m_sbbFracBits[1];
@ -897,7 +907,7 @@ static INLINE void updateState(
state->m_goRicePar = prvState->m_goRicePar;
if (state->m_remRegBins >= 4)
{
state->m_remRegBins -= (decision->absLevel < 2 ? (unsigned)decision->absLevel : 3);
state->m_remRegBins -= (decisions->absLevel[decision_id] < 2 ? (unsigned)decisions->absLevel[decision_id] : 3);
}
memcpy(state->m_absLevelsAndCtxInit, prvState->m_absLevelsAndCtxInit, 48 * sizeof(uint8_t));
}
@ -906,12 +916,12 @@ static INLINE void updateState(
state->m_numSigSbb = 1;
state->m_refSbbCtxId = -1;
int ctxBinSampleRatio = 28; //(scanInfo.chType == CHANNEL_TYPE_LUMA) ? MAX_TU_LEVEL_CTX_CODED_BIN_CONSTRAINT_LUMA : MAX_TU_LEVEL_CTX_CODED_BIN_CONSTRAINT_CHROMA;
state->m_remRegBins = (state->effWidth * state->effHeight * ctxBinSampleRatio) / 16 - (decision->absLevel < 2 ? (unsigned)decision->absLevel : 3);
state->m_remRegBins = (state->effWidth * state->effHeight * ctxBinSampleRatio) / 16 - (decisions->absLevel[decision_id] < 2 ? (unsigned)decisions->absLevel[decision_id] : 3);
memset(state->m_absLevelsAndCtxInit, 0, 48 * sizeof(uint8_t));
}
uint8_t* levels = (uint8_t*)(state->m_absLevelsAndCtxInit);
levels[scan_pos & 15] = (uint8_t)MIN(255, decision->absLevel);
levels[scan_pos & 15] = (uint8_t)MIN(255, decisions->absLevel[decision_id]);
if (state->m_remRegBins >= 4)
{
@ -1076,7 +1086,7 @@ static void xDecideAndUpdate(
int effWidth,
int effHeight)
{
Decision* decisions = ctxs->m_trellis[scan_pos];
Decision* decisions = &ctxs->m_trellis[scan_pos];
SWAP(ctxs->m_currStates, ctxs->m_prevStates, depquant_state*);
enum ScanPosType spt = 0;
@ -1094,17 +1104,19 @@ static void xDecideAndUpdate(
if (scan_pos) {
if (!(scan_pos & 15)) {
SWAP(ctxs->m_common_context.m_currSbbCtx, ctxs->m_common_context.m_prevSbbCtx, SbbCtx*);
updateStateEOS(&ctxs->m_currStates[0], scan_pos, cg_pos, sigCtxOffsetNext, gtxCtxOffsetNext, width_in_sbb, height_in_sbb, next_sbb_right, next_sbb_below, ctxs->m_prevStates, ctxs->m_skipStates, &decisions[0]);
updateStateEOS(&ctxs->m_currStates[1], scan_pos, cg_pos, sigCtxOffsetNext, gtxCtxOffsetNext, width_in_sbb, height_in_sbb, next_sbb_right, next_sbb_below, ctxs->m_prevStates, ctxs->m_skipStates, &decisions[1]);
updateStateEOS(&ctxs->m_currStates[2], scan_pos, cg_pos, sigCtxOffsetNext, gtxCtxOffsetNext, width_in_sbb, height_in_sbb, next_sbb_right, next_sbb_below, ctxs->m_prevStates, ctxs->m_skipStates, &decisions[2]);
updateStateEOS(&ctxs->m_currStates[3], scan_pos, cg_pos, sigCtxOffsetNext, gtxCtxOffsetNext, width_in_sbb, height_in_sbb, next_sbb_right, next_sbb_below, ctxs->m_prevStates, ctxs->m_skipStates, &decisions[3]);
memcpy(decisions + 4, decisions, 4 * sizeof(Decision));
updateStateEOS(&ctxs->m_currStates[0], scan_pos, cg_pos, sigCtxOffsetNext, gtxCtxOffsetNext, width_in_sbb, height_in_sbb, next_sbb_right, next_sbb_below, ctxs->m_prevStates, ctxs->m_skipStates, decisions,0);
updateStateEOS(&ctxs->m_currStates[1], scan_pos, cg_pos, sigCtxOffsetNext, gtxCtxOffsetNext, width_in_sbb, height_in_sbb, next_sbb_right, next_sbb_below, ctxs->m_prevStates, ctxs->m_skipStates, decisions,1);
updateStateEOS(&ctxs->m_currStates[2], scan_pos, cg_pos, sigCtxOffsetNext, gtxCtxOffsetNext, width_in_sbb, height_in_sbb, next_sbb_right, next_sbb_below, ctxs->m_prevStates, ctxs->m_skipStates, decisions,2);
updateStateEOS(&ctxs->m_currStates[3], scan_pos, cg_pos, sigCtxOffsetNext, gtxCtxOffsetNext, width_in_sbb, height_in_sbb, next_sbb_right, next_sbb_below, ctxs->m_prevStates, ctxs->m_skipStates, decisions,3);
memcpy(decisions->prevId + 4, decisions->prevId, 4 * sizeof(int));
memcpy(decisions->absLevel + 4, decisions->absLevel, 4 * sizeof(coeff_t));
memcpy(decisions->rdCost + 4, decisions->rdCost, 4 * sizeof(int64_t));
} else if (!zeroOut) {
updateState(&ctxs->m_currStates[0], next_nb_info_ssb.num, scan_pos, ctxs->m_prevStates, &decisions[0], sigCtxOffsetNext, gtxCtxOffsetNext, next_nb_info_ssb, 4, false);
updateState(&ctxs->m_currStates[1], next_nb_info_ssb.num, scan_pos, ctxs->m_prevStates, &decisions[1], sigCtxOffsetNext, gtxCtxOffsetNext, next_nb_info_ssb, 4, false);
updateState(&ctxs->m_currStates[2], next_nb_info_ssb.num, scan_pos, ctxs->m_prevStates, &decisions[2], sigCtxOffsetNext, gtxCtxOffsetNext, next_nb_info_ssb, 4, false);
updateState(&ctxs->m_currStates[3], next_nb_info_ssb.num, scan_pos, ctxs->m_prevStates, &decisions[3], sigCtxOffsetNext, gtxCtxOffsetNext, next_nb_info_ssb, 4, false);
updateState(&ctxs->m_currStates[0], next_nb_info_ssb.num, scan_pos, ctxs->m_prevStates, decisions, sigCtxOffsetNext, gtxCtxOffsetNext, next_nb_info_ssb, 4, false, 0);
updateState(&ctxs->m_currStates[1], next_nb_info_ssb.num, scan_pos, ctxs->m_prevStates, decisions, sigCtxOffsetNext, gtxCtxOffsetNext, next_nb_info_ssb, 4, false, 1);
updateState(&ctxs->m_currStates[2], next_nb_info_ssb.num, scan_pos, ctxs->m_prevStates, decisions, sigCtxOffsetNext, gtxCtxOffsetNext, next_nb_info_ssb, 4, false, 2);
updateState(&ctxs->m_currStates[3], next_nb_info_ssb.num, scan_pos, ctxs->m_prevStates, decisions, sigCtxOffsetNext, gtxCtxOffsetNext, next_nb_info_ssb, 4, false, 3);
}
if (spt == SCAN_SOCSBB) {
@ -1308,27 +1320,32 @@ int uvg_dep_quant(
width,
height); //tu.cu->slice->getReverseLastSigCoeffFlag());
}
//printf("%d\n", scanIdx);
//for(int i = 0; i < 4; i++) {
// printf("%lld %hu %d\n", ctxs->m_trellis[scanIdx].rdCost[i], ctxs->m_trellis[scanIdx].absLevel[i], ctxs->m_trellis[scanIdx].prevId[i]);
//}
//printf("\n");
}
//===== find best path =====
Decision decision = {INT64_MAX, -1, -2};
int prev_id = -1;
int64_t minPathCost = 0;
for (int8_t stateId = 0; stateId < 4; stateId++) {
int64_t pathCost = dep_quant_context.m_trellis[0][stateId].rdCost;
int64_t pathCost = dep_quant_context.m_trellis[0].rdCost[stateId];
if (pathCost < minPathCost) {
decision.prevId = stateId;
prev_id = stateId;
minPathCost = pathCost;
}
}
//===== backward scanning =====
int scanIdx = 0;
for (; decision.prevId >= 0; scanIdx++) {
decision = dep_quant_context.m_trellis[scanIdx][decision.prevId];
for (; prev_id >= 0; scanIdx++) {
Decision temp = dep_quant_context.m_trellis[scanIdx];
int32_t blkpos = scan[scanIdx];
coeff_out[blkpos] = (srcCoeff[blkpos] < 0 ? -decision.absLevel : decision.absLevel);
*absSum += decision.absLevel;
coeff_out[blkpos] = (srcCoeff[blkpos] < 0 ? -temp.absLevel[prev_id] : temp.absLevel[prev_id]);
*absSum += temp.absLevel[prev_id];
prev_id = temp.prevId[prev_id];
}
return *absSum;
}