From 3e66a897d490da5c297197fd326c79a58d037be9 Mon Sep 17 00:00:00 2001 From: Joose Sainio Date: Tue, 10 Jan 2023 15:32:07 +0200 Subject: [PATCH] [DepQuant] WIP: compiles --- src/dep_quant.c | 668 +++++++++++++++++++++++++++++++++--------------- src/dep_quant.h | 16 ++ src/encoder.c | 7 + src/encoder.h | 4 + src/rdo.c | 2 +- src/uvg266.h | 2 +- 6 files changed, 484 insertions(+), 215 deletions(-) diff --git a/src/dep_quant.c b/src/dep_quant.c index 776d482b..ff9f62be 100644 --- a/src/dep_quant.c +++ b/src/dep_quant.c @@ -64,92 +64,258 @@ static const uint32_t g_goRiceParsCoeff[32] = { 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, enum ScanPosType { SCAN_ISCSBB = 0, SCAN_SOCSBB = 1, SCAN_EOCSBB = 2 }; -typedef struct { - int m_QShift; +typedef struct +{ + int m_QShift; int64_t m_QAdd; int64_t m_QScale; - coeff_t m_maxQIdx; + coeff_t m_maxQIdx; coeff_t m_thresLast; - coeff_t m_thresSSbb; + coeff_t m_thresSSbb; // distortion normalization - int m_DistShift; + int m_DistShift; int64_t m_DistAdd; int64_t m_DistStepAdd; int64_t m_DistOrgFact; } quant_block; -typedef struct { - uint8_t num; - uint8_t inPos[5]; -} NbInfoSbb; -typedef struct { - uint16_t maxDist; - uint16_t num; - uint16_t outPos[5]; -} NbInfoOut; - -typedef struct { +typedef struct +{ uint8_t* sbbFlags; uint8_t* levels; } SbbCtx; - -typedef struct +typedef struct { - coeff_t absLevel; + coeff_t absLevel; int64_t deltaDist; -}PQData; +} PQData; -typedef struct { +typedef struct +{ int64_t rdCost; - coeff_t absLevel; - int prevId; + coeff_t absLevel; + int prevId; } Decision; -typedef struct { +typedef struct +{ const NbInfoOut* m_nbInfo; - uint32_t m_sbbFlagBits[2][2]; - SbbCtx m_allSbbCtx[8]; - SbbCtx* m_currSbbCtx; - SbbCtx* m_prevSbbCtx; - uint8_t m_memory[8 * (TR_MAX_WIDTH * TR_MAX_WIDTH + 1024)]; + uint32_t m_sbbFlagBits[2][2]; + SbbCtx m_allSbbCtx[8]; + SbbCtx* m_currSbbCtx; + SbbCtx* m_prevSbbCtx; + uint8_t m_memory[8 * (TR_MAX_WIDTH * TR_MAX_WIDTH + 1024)]; } common_context; -typedef struct +typedef struct { - int32_t m_lastBitsX[TR_MAX_WIDTH]; - int32_t m_lastBitsY[TR_MAX_WIDTH]; - uint32_t m_sigSbbFracBits[sm_maxNumSigSbbCtx][2]; - uint32_t m_sigFracBits[sm_numCtxSetsSig][sm_maxNumSigCtx][2]; - int32_t m_gtxFracBits[sm_maxNumGtxCtx][6]; - + int32_t m_lastBitsX[TR_MAX_WIDTH]; + int32_t m_lastBitsY[TR_MAX_WIDTH]; + uint32_t m_sigSbbFracBits[sm_maxNumSigSbbCtx][2]; + uint32_t m_sigFracBits[sm_numCtxSetsSig][sm_maxNumSigCtx][2]; + int32_t m_gtxFracBits[sm_maxNumGtxCtx][6]; } rate_estimator; -typedef struct { - int64_t m_rdCost; - uint16_t m_absLevelsAndCtxInit[24]; // 16x8bit for abs levels + 16x16bit for ctx init id - int8_t m_numSigSbb; - int m_remRegBins; - int8_t m_refSbbCtxId; - uint32_t m_sbbFracBits[2]; - uint32_t m_sigFracBits[2]; - int32_t m_coeffFracBits[6]; - int8_t m_goRicePar; - int8_t m_goRiceZero; - int8_t m_stateId; - const uint32_t* m_sigFracBitsArray; - const uint32_t* m_gtxFracBitsArray; - struct common_context* m_commonCtx; - +typedef struct +{ + int64_t m_rdCost; + uint16_t m_absLevelsAndCtxInit[24]; // 16x8bit for abs levels + 16x16bit for ctx init id + int8_t m_numSigSbb; + int m_remRegBins; + int8_t m_refSbbCtxId; + uint32_t m_sbbFracBits[2]; + uint32_t m_sigFracBits[2]; + int32_t m_coeffFracBits[6]; + int8_t m_goRicePar; + int8_t m_goRiceZero; + int8_t m_stateId; + const uint32_t* m_sigFracBitsArray[2]; + const uint32_t* m_gtxFracBitsArray[6]; + struct common_context* m_commonCtx; + unsigned effWidth; unsigned effHeight; } depquant_state; +typedef struct +{ + common_context m_common_context; + depquant_state m_allStates[12]; + depquant_state* m_currStates; + depquant_state* m_prevStates; + depquant_state* m_skipStates; + depquant_state m_startState; + quant_block m_quant; + Decision m_trellis[TR_MAX_WIDTH * TR_MAX_WIDTH][8]; +} context_store; + + +int uvg_init_nb_info(encoder_control_t * encoder) { + memset(encoder->m_scanId2NbInfoSbbArray, 0, sizeof(encoder->m_scanId2NbInfoSbbArray)); + memset(encoder->m_scanId2NbInfoOutArray, 0, sizeof(encoder->m_scanId2NbInfoOutArray)); + for (int hd = 0; hd <= 7; hd++) + { + + uint32_t raster2id[64 * 64] = {0}; + + for (int vd = 0; vd <= 7; vd++) + { + if ((hd == 0 && vd <= 1) || (hd <= 1 && vd == 0)) + { + continue; + } + const uint32_t blockWidth = (1 << hd); + const uint32_t blockHeight = (1 << vd); + const uint32_t log2CGWidth = g_log2_sbb_size[hd][vd][0]; + const uint32_t log2CGHeight = g_log2_sbb_size[hd][vd][1]; + const uint32_t groupWidth = 1 << log2CGWidth; + const uint32_t groupHeight = 1 << log2CGHeight; + const uint32_t groupSize = groupWidth * groupHeight; + const int scanType = SCAN_DIAG; + const uint32_t blkWidthIdx = hd; + const uint32_t blkHeightIdx = vd; + const uint32_t* scanId2RP = uvg_get_scan_order_table(SCAN_GROUP_4X4, scanType, blkWidthIdx, blkHeightIdx); + NbInfoSbb** sId2NbSbb = &encoder->m_scanId2NbInfoSbbArray[hd][vd]; + NbInfoOut** sId2NbOut = &encoder->m_scanId2NbInfoOutArray[hd][vd]; + // consider only non-zero-out region + const uint32_t blkWidthNZOut = MIN(32, blockWidth); + const uint32_t blkHeightNZOut = MIN(32, blockHeight); + const uint32_t totalValues = blkWidthNZOut * blkHeightNZOut; + + *sId2NbSbb = MALLOC(NbInfoSbb, totalValues); + if (*sId2NbSbb == NULL) { + return 0; + } + *sId2NbOut = MALLOC(NbInfoOut, totalValues); + if (*sId2NbOut == NULL) { + return 0; + } + + for (uint32_t scanId = 0; scanId < totalValues; scanId++) + { + raster2id[scanId2RP[scanId]] = scanId; + } + + for (unsigned scanId = 0; scanId < totalValues; scanId++) + { + const int rpos = scanId2RP[scanId]; + uint32_t pos_y = rpos >> hd; + uint32_t pos_x = rpos - (pos_y << hd); // TODO: height + { + //===== inside subband neighbours ===== + NbInfoSbb *nbSbb = &(*sId2NbSbb)[scanId]; + const int begSbb = scanId - (scanId & (groupSize - 1)); // first pos in current subblock + int cpos[5]; + + cpos[0] = (pos_x + 1 < blkWidthNZOut ? (raster2id[rpos + 1] < groupSize + begSbb ? raster2id[rpos + 1] - begSbb : 0) : 0); + cpos[1] = (pos_x + 2 < blkWidthNZOut ? (raster2id[rpos + 2] < groupSize + begSbb ? raster2id[rpos + 2] - begSbb : 0) : 0); + cpos[2] = (pos_x + 1 < blkWidthNZOut && pos_y + 1 < blkHeightNZOut ? (raster2id[rpos + 1 + blockWidth] < groupSize + begSbb ? raster2id[rpos + 1 + blockWidth] - begSbb : 0) : 0); + cpos[3] = (pos_y + 1 < blkHeightNZOut ? (raster2id[rpos + blockWidth] < groupSize + begSbb ? raster2id[rpos + blockWidth] - begSbb : 0) : 0); + cpos[4] = (pos_y + 2 < blkHeightNZOut ? (raster2id[rpos + 2 * blockWidth] < groupSize + begSbb ? raster2id[rpos + 2 * blockWidth] - begSbb : 0) : 0); + + for (nbSbb->num = 0; true; ) + { + int nk = -1; + for (int k = 0; k < 5; k++) + { + if (cpos[k] != 0 && (nk < 0 || cpos[k] < cpos[nk])) + { + nk = k; + } + } + if (nk < 0) + { + break; + } + nbSbb->inPos[nbSbb->num++] = (uint8_t)(cpos[nk]); + cpos[nk] = 0; + } + for (int k = nbSbb->num; k < 5; k++) + { + nbSbb->inPos[k] = 0; + } + } + { + //===== outside subband neighbours ===== + NbInfoOut *nbOut = &(*sId2NbOut)[scanId]; + const int begSbb = scanId - (scanId & (groupSize - 1)); // first pos in current subblock + int cpos[5]; + + cpos[0] = (pos_x + 1 < blkWidthNZOut ? (raster2id[rpos + 1] >= groupSize + begSbb ? raster2id[rpos + 1] : 0) : 0); + cpos[1] = (pos_x + 2 < blkWidthNZOut ? (raster2id[rpos + 2] >= groupSize + begSbb ? raster2id[rpos + 2] : 0) : 0); + cpos[2] = (pos_x + 1 < blkWidthNZOut && pos_y + 1 < blkHeightNZOut ? (raster2id[rpos + 1 + blockWidth] >= groupSize + begSbb ? raster2id[rpos + 1 + blockWidth] : 0) : 0); + cpos[3] = (pos_y + 1 < blkHeightNZOut ? (raster2id[rpos + blockWidth] >= groupSize + begSbb ? raster2id[rpos + blockWidth] : 0) : 0); + cpos[4] = (pos_y + 2 < blkHeightNZOut ? (raster2id[rpos + 2 * blockWidth] >= groupSize + begSbb ? raster2id[rpos + 2 * blockWidth] : 0) : 0); + + for (nbOut->num = 0; true; ) + { + int nk = -1; + for (int k = 0; k < 5; k++) + { + if (cpos[k] != 0 && (nk < 0 || cpos[k] < cpos[nk])) + { + nk = k; + } + } + if (nk < 0) + { + break; + } + nbOut->outPos[nbOut->num++] = (uint16_t)(cpos[nk]); + cpos[nk] = 0; + } + for (int k = nbOut->num; k < 5; k++) + { + nbOut->outPos[k] = 0; + } + nbOut->maxDist = (scanId == 0 ? 0 : (*sId2NbOut)[scanId - 1].maxDist); + for (int k = 0; k < nbOut->num; k++) + { + if (nbOut->outPos[k] > nbOut->maxDist) + { + nbOut->maxDist = nbOut->outPos[k]; + } + } + } + } + + // make it relative + for (unsigned scanId = 0; scanId < totalValues; scanId++) + { + NbInfoOut *nbOut = &(*sId2NbOut)[scanId]; + const int begSbb = scanId - (scanId & (groupSize - 1)); // first pos in current subblock + for (int k = 0; k < nbOut->num; k++) + { + nbOut->outPos[k] -= begSbb; + } + nbOut->maxDist -= scanId; + } + } + } + return 1; +} + +void uvg_dealloc_nb_info(encoder_control_t* encoder) { + + for (int hd = 0; hd <= 7; hd++) { + for (int vd = 0; vd <= 7; vd++) + { + if ((hd == 0 && vd <= 1) || (hd <= 1 && vd == 0)) + { + continue; + } + if(encoder->m_scanId2NbInfoOutArray[hd][vd]) FREE_POINTER(encoder->m_scanId2NbInfoOutArray[hd][vd]); + if(encoder->m_scanId2NbInfoOutArray[hd][vd]) FREE_POINTER(encoder->m_scanId2NbInfoSbbArray[hd][vd]); + } + } +} + static void init_quant_block( const encoder_state_t* state, @@ -207,7 +373,7 @@ static void init_quant_block( static void reset_common_context(common_context* ctx, const rate_estimator * rate_estimator, int numSbb, int num_coeff) { - memset(&ctx->m_nbInfo, 0, sizeof(ctx->m_nbInfo)); + //memset(&ctx->m_nbInfo, 0, sizeof(ctx->m_nbInfo)); memcpy(&ctx->m_sbbFlagBits, &rate_estimator->m_sigSbbFracBits, sizeof(rate_estimator->m_sigSbbFracBits)); const int chunkSize = numSbb + num_coeff; uint8_t* nextMem = ctx->m_memory; @@ -215,6 +381,8 @@ static void reset_common_context(common_context* ctx, const rate_estimator * rat ctx->m_allSbbCtx[k].sbbFlags = nextMem; ctx->m_allSbbCtx[k].levels = nextMem + numSbb; } + ctx->m_currSbbCtx = &ctx->m_allSbbCtx[0]; + ctx->m_prevSbbCtx = &ctx->m_allSbbCtx[4]; } static void init_rate_esimator(rate_estimator * rate_estimator, const cabac_data_t * const ctx, color_t color) @@ -569,119 +737,136 @@ unsigned templateAbsCompare(coeff_t sum) return g_riceShift[rangeIdx]; } -static INLINE void update_common_context(common_context * cc, const ScanInfo *scanInfo, const depquant_state* prevState, depquant_state *currState) +static INLINE void update_common_context( + common_context * cc, + const uint32_t scan_pos, + const uint32_t width_in_sbb, + const uint32_t height_in_sbb, + const int sigNSbb, + const depquant_state* prevState, + depquant_state *currState) { - uint8_t* sbbFlags = cc->m_currSbbCtx[currState->m_stateId].sbbFlags; - uint8_t* levels = cc->m_currSbbCtx[currState->m_stateId].levels; - size_t setCpSize = cc->m_nbInfo[scanInfo.scanIdx - 1].maxDist * sizeof(uint8_t); - if (prevState && prevState->m_refSbbCtxId >= 0) - { - memcpy(sbbFlags, cc->m_prevSbbCtx[prevState->m_refSbbCtxId].sbbFlags, scanInfo.numSbb * sizeof(uint8_t)); - memcpy(levels + scanInfo.scanIdx, cc->m_prevSbbCtx[prevState->m_refSbbCtxId].levels + scanInfo.scanIdx, setCpSize); - } - else - { - memset(sbbFlags, 0, scanInfo.numSbb * sizeof(uint8_t)); - memset(levels + scanInfo.scanIdx, 0, setCpSize); - } - sbbFlags[scanInfo.sbbPos] = !!currState->m_numSigSbb; - memcpy(levels + scanInfo.scanIdx, currState->m_absLevelsAndCtxInit, scanInfo.sbbSize * sizeof(uint8_t)); + const uint32_t numSbb = width_in_sbb * height_in_sbb; + uint8_t* sbbFlags = cc->m_currSbbCtx[currState->m_stateId].sbbFlags; + uint8_t* levels = cc->m_currSbbCtx[currState->m_stateId].levels; + size_t setCpSize = cc->m_nbInfo[scan_pos - 1].maxDist * sizeof(uint8_t); + if (prevState && prevState->m_refSbbCtxId >= 0) { + memcpy(sbbFlags, cc->m_prevSbbCtx[prevState->m_refSbbCtxId].sbbFlags, numSbb * sizeof(uint8_t)); + memcpy(levels + scan_pos, cc->m_prevSbbCtx[prevState->m_refSbbCtxId].levels + scan_pos, setCpSize); + } + else { + memset(sbbFlags, 0, numSbb * sizeof(uint8_t)); + memset(levels + scan_pos, 0, setCpSize); + } + sbbFlags[scan_pos >> 4] = !!currState->m_numSigSbb; + memcpy(levels + scan_pos, currState->m_absLevelsAndCtxInit, 16 * sizeof(uint8_t)); - const int sigNSbb = ((scanInfo.nextSbbRight ? sbbFlags[scanInfo.nextSbbRight] : false) || (scanInfo.nextSbbBelow ? sbbFlags[scanInfo.nextSbbBelow] : false) ? 1 : 0); - currState->m_numSigSbb = 0; - if (prevState) - { - currState->m_remRegBins = prevState->m_remRegBins; - } - else - { - int ctxBinSampleRatio = 28; // (scanInfo.chType == COLOR_Y) ? MAX_TU_LEVEL_CTX_CODED_BIN_CONSTRAINT_LUMA : MAX_TU_LEVEL_CTX_CODED_BIN_CONSTRAINT_CHROMA; - currState->m_remRegBins = (currState->effWidth * currState->effHeight * ctxBinSampleRatio) / 16; - } - currState->m_goRicePar = 0; - currState->m_refSbbCtxId = currState->m_stateId; - currState->m_sbbFracBits[0] = cc->m_sbbFlagBits[sigNSbb][0]; - currState->m_sbbFracBits[1] = cc->m_sbbFlagBits[sigNSbb][1]; + currState->m_numSigSbb = 0; + if (prevState) { + currState->m_remRegBins = prevState->m_remRegBins; + } + else { + int ctxBinSampleRatio = 28; + // (scanInfo.chType == COLOR_Y) ? MAX_TU_LEVEL_CTX_CODED_BIN_CONSTRAINT_LUMA : MAX_TU_LEVEL_CTX_CODED_BIN_CONSTRAINT_CHROMA; + currState->m_remRegBins = (currState->effWidth * currState->effHeight * ctxBinSampleRatio) / 16; + } + currState->m_goRicePar = 0; + currState->m_refSbbCtxId = currState->m_stateId; + currState->m_sbbFracBits[0] = cc->m_sbbFlagBits[sigNSbb][0]; + currState->m_sbbFracBits[1] = cc->m_sbbFlagBits[sigNSbb][1]; - uint16_t templateCtxInit[16]; - const int scanBeg = scanInfo.scanIdx - scanInfo.sbbSize; - const NbInfoOut* nbOut = cc->m_nbInfo + scanBeg; - const uint8_t* absLevels = levels + scanBeg; - for (int id = 0; id < scanInfo.sbbSize; id++, nbOut++) - { - if (nbOut->num) - { - coeff_t sumAbs = 0, sumAbs1 = 0, sumNum = 0; + uint16_t templateCtxInit[16]; + const int scanBeg = scan_pos - 16; + const NbInfoOut* nbOut = cc->m_nbInfo + scanBeg; + const uint8_t* absLevels = levels + scanBeg; + for (int id = 0; id < 16; id++, nbOut++) { + if (nbOut->num) { + coeff_t sumAbs = 0, sumAbs1 = 0, sumNum = 0; #define UPDATE(k) {coeff_t t=absLevels[nbOut->outPos[k]]; sumAbs+=t; sumAbs1+=MIN(4+(t&1),t); sumNum+=!!t; } - UPDATE(0); - if (nbOut->num > 1) - { - UPDATE(1); - if (nbOut->num > 2) - { - UPDATE(2); - if (nbOut->num > 3) - { - UPDATE(3); - if (nbOut->num > 4) - { - UPDATE(4); - } - } - } + UPDATE(0); + if (nbOut->num > 1) { + UPDATE(1); + if (nbOut->num > 2) { + UPDATE(2); + if (nbOut->num > 3) { + UPDATE(3); + if (nbOut->num > 4) { + UPDATE(4); } + } + } + } #undef UPDATE - templateCtxInit[id] = (uint16_t)(sumNum) + ((uint16_t)(sumAbs1) << 3) + ((uint16_t)MIN(127, sumAbs) << 8); - } - else - { - templateCtxInit[id] = 0; - } + templateCtxInit[id] = (uint16_t)(sumNum) + ((uint16_t)(sumAbs1) << 3) + ((uint16_t)MIN(127, sumAbs) << 8); } - memset(currState->m_absLevelsAndCtxInit, 0, 16 * sizeof(uint8_t)); - memcpy(currState->m_absLevelsAndCtxInit + 8, templateCtxInit, 16 * sizeof(uint16_t)); + else { + templateCtxInit[id] = 0; + } + } + memset(currState->m_absLevelsAndCtxInit, 0, 16 * sizeof(uint8_t)); + memcpy(currState->m_absLevelsAndCtxInit + 8, templateCtxInit, 16 * sizeof(uint16_t)); } -static INLINE void updateStateEOS(depquant_state * state, const ScanInfo *scanInfo, const depquant_state* prevStates, const depquant_state* skipStates, - const Decision *decision) +static INLINE void updateStateEOS( + depquant_state * state, + const uint32_t scan_pos, + const uint32_t sigCtxOffsetNext, + const uint32_t gtxCtxOffsetNext, + const uint32_t width_in_sbb, + const uint32_t height_in_sbb, + const uint32_t sigNSbb, + const depquant_state* prevStates, + const depquant_state* skipStates, + const Decision *decision) { state->m_rdCost = decision->rdCost; if (decision->prevId > -2) { - const depquant_state* prvState = 0; - if (decision->prevId >= 4) - { - prvState = skipStates + (decision->prevId - 4); - state->m_numSigSbb = 0; - memset(state->m_absLevelsAndCtxInit, 0, 16 * sizeof(uint8_t)); - } - else if (decision->prevId >= 0) - { - prvState = prevStates + decision->prevId; - state->m_numSigSbb = prvState->m_numSigSbb + !!decision->absLevel; - memcpy(state->m_absLevelsAndCtxInit, prvState->m_absLevelsAndCtxInit, 16 * sizeof(uint8_t)); - } - else - { - state->m_numSigSbb = 1; - memset(state->m_absLevelsAndCtxInit, 0, 16 * sizeof(uint8_t)); - } - reinterpret_cast(m_absLevelsAndCtxInit)[scanInfo.insidePos] = (uint8_t)MIN(255, decision->absLevel); + const depquant_state* prvState = 0; + if (decision->prevId >= 4) + { + prvState = skipStates + (decision->prevId - 4); + state->m_numSigSbb = 0; + memset(state->m_absLevelsAndCtxInit, 0, 16 * sizeof(uint8_t)); + } + else if (decision->prevId >= 0) + { + prvState = prevStates + decision->prevId; + state->m_numSigSbb = prvState->m_numSigSbb + !!decision->absLevel; + memcpy(state->m_absLevelsAndCtxInit, prvState->m_absLevelsAndCtxInit, 16 * sizeof(uint8_t)); + } + else + { + state->m_numSigSbb = 1; + memset(state->m_absLevelsAndCtxInit, 0, 16 * sizeof(uint8_t)); + } + uint8_t* temp = (uint8_t*)(state->m_absLevelsAndCtxInit[scan_pos & 15]); + *temp = (uint8_t)MIN(255, decision->absLevel); - update_common_context(state->m_commonCtx, scanInfo, prvState, state); - - coeff_t tinit = state->m_absLevelsAndCtxInit[8 + scanInfo.nextInsidePos]; - coeff_t sumNum = tinit & 7; - coeff_t sumAbs1 = (tinit >> 3) & 31; - coeff_t sumGt1 = sumAbs1 - sumNum; - state->m_sigFracBits = state->m_sigFracBitsArray[scanInfo.sigCtxOffsetNext + MIN((sumAbs1 + 1) >> 1, 3)]; - state->m_coeffFracBits = state->m_gtxFracBitsArray[scanInfo.gtxCtxOffsetNext + (sumGt1 < 4 ? sumGt1 : 4)]; + update_common_context(state->m_commonCtx, scan_pos, width_in_sbb, height_in_sbb, sigNSbb, prvState, state); + + coeff_t tinit = state->m_absLevelsAndCtxInit[8 + ((scan_pos - 1) & 15)]; + coeff_t sumNum = tinit & 7; + coeff_t sumAbs1 = (tinit >> 3) & 31; + coeff_t sumGt1 = sumAbs1 - sumNum; + state->m_sigFracBits[0] = state->m_sigFracBitsArray[sigCtxOffsetNext + MIN((sumAbs1 + 1) >> 1, 3)][0]; + state->m_sigFracBits[1] = state->m_sigFracBitsArray[sigCtxOffsetNext + MIN((sumAbs1 + 1) >> 1, 3)][1]; + + memcpy(state->m_coeffFracBits, state->m_gtxFracBitsArray[gtxCtxOffsetNext + (sumGt1 < 4 ? sumGt1 : 4)], sizeof(state->m_coeffFracBits)); } } -static INLINE void updateState(depquant_state* state, int numIPos, const ScanInfo scanInfo, const depquant_state *prevStates, const Decision *decision, const int baseLevel, const bool extRiceFlag) -{ +static INLINE void updateState( + depquant_state* state, + int numIPos, const uint32_t scan_pos, + const depquant_state* prevStates, + const Decision* decision, + const uint32_t sigCtxOffsetNext, + const uint32_t gtxCtxOffsetNext, + const NbInfoSbb next_nb_info_ssb, + const int baseLevel, + const bool extRiceFlag) { state->m_rdCost = decision->rdCost; if (decision->prevId > -2) { @@ -710,14 +895,14 @@ static INLINE void updateState(depquant_state* state, int numIPos, const ScanInf } uint8_t* levels = (uint8_t*)(state->m_absLevelsAndCtxInit); - levels[scanInfo.insidePos] = (uint8_t)MIN(255, decision->absLevel); + levels[scan_pos & 15] = (uint8_t)MIN(255, decision->absLevel); if (state->m_remRegBins >= 4) { - coeff_t tinit = state->m_absLevelsAndCtxInit[8 + scanInfo.nextInsidePos]; + coeff_t tinit = state->m_absLevelsAndCtxInit[8 + ((scan_pos - 1) & 15)]; coeff_t sumAbs1 = (tinit >> 3) & 31; coeff_t sumNum = tinit & 7; -#define UPDATE(k) {coeff_t t=levels[scanInfo.nextNbInfoSbb.inPos[k]]; sumAbs1+=MIN(4+(t&1),t); sumNum+=!!t; } +#define UPDATE(k) {coeff_t t=levels[next_nb_info_ssb.inPos[k]]; sumAbs1+=MIN(4+(t&1),t); sumNum+=!!t; } if (numIPos == 1) { UPDATE(0); @@ -750,13 +935,13 @@ static INLINE void updateState(depquant_state* state, int numIPos, const ScanInf } #undef UPDATE coeff_t sumGt1 = sumAbs1 - sumNum; - state->m_sigFracBits[0] = state->m_sigFracBitsArray[scanInfo.sigCtxOffsetNext + MIN((sumAbs1 + 1) >> 1, 3)][0]; - state->m_sigFracBits[1] = state->m_sigFracBitsArray[scanInfo.sigCtxOffsetNext + MIN((sumAbs1 + 1) >> 1, 3)][1]; - memcpy(state->m_coeffFracBits, &state->m_gtxFracBitsArray[scanInfo.gtxCtxOffsetNext + (sumGt1 < 4 ? sumGt1 : 4)], sizeof(state->m_coeffFracBits)); + state->m_sigFracBits[0] = state->m_sigFracBitsArray[sigCtxOffsetNext + MIN((sumAbs1 + 1) >> 1, 3)][0]; + state->m_sigFracBits[1] = state->m_sigFracBitsArray[sigCtxOffsetNext + MIN((sumAbs1 + 1) >> 1, 3)][1]; + memcpy(state->m_coeffFracBits, state->m_gtxFracBitsArray[gtxCtxOffsetNext + (sumGt1 < 4 ? sumGt1 : 4)], sizeof(state->m_coeffFracBits)); - coeff_t sumAbs = state->m_absLevelsAndCtxInit[8 + scanInfo.nextInsidePos] >> 8; -#define UPDATE(k) {coeff_t t=levels[scanInfo.nextNbInfoSbb.inPos[k]]; sumAbs+=t; } + coeff_t sumAbs = state->m_absLevelsAndCtxInit[8 + ((scan_pos - 1) & 15)] >> 8; +#define UPDATE(k) {coeff_t t=levels[next_nb_info_ssb.inPos[k]]; sumAbs+=t; } if (numIPos == 1) { UPDATE(0); @@ -804,8 +989,8 @@ static INLINE void updateState(depquant_state* state, int numIPos, const ScanInf } else { - coeff_t sumAbs = state->m_absLevelsAndCtxInit[8 + scanInfo.nextInsidePos] >> 8; -#define UPDATE(k) {coeff_t t=levels[scanInfo.nextNbInfoSbb.inPos[k]]; sumAbs+=t; } + coeff_t sumAbs = state->m_absLevelsAndCtxInit[8 + ((scan_pos - 1) & 15)] >> 8; +#define UPDATE(k) {coeff_t t=levels[next_nb_info_ssb.inPos[k]]; sumAbs+=t; } if (numIPos == 1) { UPDATE(0); @@ -856,37 +1041,56 @@ static INLINE void updateState(depquant_state* state, int numIPos, const ScanInf } static void xDecideAndUpdate( - const coeff_t absCoeff, - const ScanInfo scanInfo, - bool zeroOut, - coeff_t quantCoeff, - int effWidth, - int effHeight, - bool reverseLast, - Decision* decisions) + rate_estimator* re, + context_store* ctxs, + const coeff_t absCoeff, + const uint32_t scan_pos, + const uint32_t pos_x, + const uint32_t pos_y, + const uint32_t sigCtxOffsetNext, + const uint32_t gtxCtxOffsetNext, + const uint32_t width_in_sbb, + const uint32_t height_in_sbb, + const uint32_t sigNSbb, + const NbInfoSbb next_nb_info_ssb, + bool zeroOut, + coeff_t quantCoeff, + int effWidth, + int effHeight) { - std::swap(m_prevStates, m_currStates); + Decision* decisions = ctxs->m_trellis[scan_pos]; + SWAP(ctxs->m_currStates, ctxs->m_prevStates, depquant_state*); - xDecide(scanInfo.spt, absCoeff, lastOffset(scanInfo.scanIdx, effWidth, effHeight, reverseLast), decisions, zeroOut, quantCoeff); + enum ScanPosType spt = 0; + if ((scan_pos & 15) == 15 && scan_pos > 16 && scan_pos < effHeight * effWidth - 1) + { + spt = SCAN_SOCSBB; + } + else if ((scan_pos & 15) == 0 && scan_pos > 0 && scan_pos < effHeight * effWidth - 16) + { + spt = SCAN_EOCSBB; + } - if (scanInfo.scanIdx) { - if (scanInfo.eosbb) { - m_commonCtx.swap(); - updateStateEOS(&m_currStates[0], scanInfo, m_prevStates, m_skipStates, &decisions[0]); - updateStateEOS(&m_currStates[1], scanInfo, m_prevStates, m_skipStates, &decisions[1]); - updateStateEOS(&m_currStates[2], scanInfo, m_prevStates, m_skipStates, &decisions[2]); - updateStateEOS(&m_currStates[3], scanInfo, m_prevStates, m_skipStates, &decisions[3]); + xDecide(ctxs->m_skipStates, ctxs->m_prevStates, &ctxs->m_startState, &ctxs->m_quant, spt, absCoeff, re->m_lastBitsX[pos_x] + re->m_lastBitsY[pos_y], decisions, zeroOut, quantCoeff); + + if (scan_pos) { + if (!(scan_pos & 15)) { + SWAP(ctxs->m_common_context.m_currSbbCtx, ctxs->m_common_context.m_prevSbbCtx, SbbCtx*); + updateStateEOS(&ctxs->m_currStates[0], scan_pos, sigCtxOffsetNext, gtxCtxOffsetNext, width_in_sbb, height_in_sbb, sigNSbb, ctxs->m_prevStates, ctxs->m_skipStates, &decisions[0]); + updateStateEOS(&ctxs->m_currStates[1], scan_pos, sigCtxOffsetNext, gtxCtxOffsetNext, width_in_sbb, height_in_sbb, sigNSbb, ctxs->m_prevStates, ctxs->m_skipStates, &decisions[1]); + updateStateEOS(&ctxs->m_currStates[2], scan_pos, sigCtxOffsetNext, gtxCtxOffsetNext, width_in_sbb, height_in_sbb, sigNSbb, ctxs->m_prevStates, ctxs->m_skipStates, &decisions[2]); + updateStateEOS(&ctxs->m_currStates[3], scan_pos, sigCtxOffsetNext, gtxCtxOffsetNext, width_in_sbb, height_in_sbb, sigNSbb, ctxs->m_prevStates, ctxs->m_skipStates, &decisions[3]); memcpy(decisions + 4, decisions, 4 * sizeof(Decision)); } else if (!zeroOut) { - updateState(&m_currStates[0], scanInfo.nextNbInfoSbb.num, scanInfo, m_prevStates, decisions[0], m_baseLevel, m_extRiceRRCFlag); - updateState(&m_currStates[1], scanInfo.nextNbInfoSbb.num, scanInfo, m_prevStates, decisions[1], m_baseLevel, m_extRiceRRCFlag); - updateState(&m_currStates[2], scanInfo.nextNbInfoSbb.num, scanInfo, m_prevStates, decisions[2], m_baseLevel, m_extRiceRRCFlag); - updateState(&m_currStates[3], scanInfo.nextNbInfoSbb.num, scanInfo, m_prevStates, decisions[3], m_baseLevel, m_extRiceRRCFlag); + updateState(&ctxs->m_currStates[0], next_nb_info_ssb.num, scan_pos, ctxs->m_prevStates, &decisions[0], sigCtxOffsetNext, gtxCtxOffsetNext, next_nb_info_ssb, 4, false); + updateState(&ctxs->m_currStates[1], next_nb_info_ssb.num, scan_pos, ctxs->m_prevStates, &decisions[1], sigCtxOffsetNext, gtxCtxOffsetNext, next_nb_info_ssb, 4, false); + updateState(&ctxs->m_currStates[2], next_nb_info_ssb.num, scan_pos, ctxs->m_prevStates, &decisions[2], sigCtxOffsetNext, gtxCtxOffsetNext, next_nb_info_ssb, 4, false); + updateState(&ctxs->m_currStates[3], next_nb_info_ssb.num, scan_pos, ctxs->m_prevStates, &decisions[3], sigCtxOffsetNext, gtxCtxOffsetNext, next_nb_info_ssb, 4, false); } - if (scanInfo.spt == SCAN_SOCSBB) { - std::swap(m_prevStates, m_skipStates); + if (spt == SCAN_SOCSBB) { + SWAP(ctxs->m_skipStates, ctxs->m_prevStates, depquant_state*); } } } @@ -907,6 +1111,10 @@ uint8_t uvg_dep_quant( const encoder_control_t* const encoder = state->encoder_control; //===== reset / pre-init ===== const int baseLevel = 4; + context_store dep_quant_context; + dep_quant_context.m_currStates = &dep_quant_context.m_allStates[0]; + dep_quant_context.m_prevStates = &dep_quant_context.m_allStates[4]; + dep_quant_context.m_skipStates = &dep_quant_context.m_allStates[8]; const uint32_t width = compID == COLOR_Y ? cu_loc->width : cu_loc->chroma_width; const uint32_t height = compID == COLOR_Y ? cu_loc->height : cu_loc->chroma_height; @@ -925,6 +1133,7 @@ uint8_t uvg_dep_quant( const uint32_t log2_tr_width = uvg_g_convert_to_log2[width]; const uint32_t log2_tr_height = uvg_g_convert_to_log2[height]; const uint32_t* const scan = uvg_get_scan_order_table(SCAN_GROUP_4X4,0,log2_tr_width,log2_tr_height); + const uint32_t* const cg_scan = uvg_get_scan_order_table(SCAN_GROUP_UNGROUPED,0,log2_tr_width,log2_tr_height); int32_t qp_scaled = uvg_get_scaled_qp(compID, state->qp, (encoder->bitdepth - 8) * 6, encoder->qp_map[0]); qp_scaled = is_ts ? MAX(qp_scaled, 4 + 6 * MIN_QP_PRIME_TS) : qp_scaled; @@ -936,11 +1145,9 @@ uint8_t uvg_dep_quant( const int32_t transform_shift = MAX_TR_DYNAMIC_RANGE - encoder->bitdepth - ((log2_tr_height + log2_tr_width) >> 1) - needs_block_size_trafo_scale; //!< Represents scaling through forward transform const int64_t q_bits = QUANT_SHIFT + qp_scaled / 6 + (is_ts ? 0 : transform_shift ); const int32_t add = ((state->frame->slicetype == UVG_SLICE_I) ? 171 : 85) << (q_bits - 9); - - quant_block quant_block; - init_quant_block(state, &quant_block, cur_tu, log2_tr_width, log2_tr_height, compID, needs_block_size_trafo_scale, -1); - - Decision trellis[TR_MAX_WIDTH * TR_MAX_WIDTH][8]; + + init_quant_block(state, &dep_quant_context.m_quant, cur_tu, log2_tr_width, log2_tr_height, compID, needs_block_size_trafo_scale, -1); + //===== scaling matrix ==== //const int qpDQ = cQP.Qp + 1; //const int qpPer = qpDQ / 6; @@ -970,7 +1177,7 @@ uint8_t uvg_dep_quant( const int32_t default_quant_coeff = uvg_g_quant_scales[needs_block_size_trafo_scale][qp_scaled % 6]; const coeff_t thres = 4 << q_bits; for (; firstTestPos >= 0; firstTestPos--) { - coeff_t thresTmp = (enableScalingLists) ? (thres / (4 * q_coeff[scan[firstTestPos]])) :(thres / (4 * default_quant_coeff)); + coeff_t thresTmp = (enableScalingLists) ? (thres / (4 * q_coeff[scan[firstTestPos]])) : (thres / (4 * default_quant_coeff)); if (abs(srcCoeff[scan[firstTestPos]]) > thresTmp) { break; } @@ -983,46 +1190,81 @@ uint8_t uvg_dep_quant( rate_estimator rate_estimator; init_rate_esimator(&rate_estimator, &state->search_cabac, compID); xSetLastCoeffOffset(state, cur_tu, cu_loc, &rate_estimator, cbf_is_set(cur_tu->cbf, COLOR_U), compID); - common_context common_context; - reset_common_context(&common_context, &rate_estimator, (width * height) >> 4, numCoeff); - depquant_state all_state[12]; - depquant_state start_state; + reset_common_context(&dep_quant_context.m_common_context, &rate_estimator, (width * height) >> 4, numCoeff); + dep_quant_context.m_common_context.m_nbInfo = encoder->m_scanId2NbInfoOutArray[log2_tr_width][log2_tr_height]; + int effectHeight = MIN(32, effHeight); int effectWidth = MIN(32, effWidth); for (int k = 0; k < 12; k++) { - depquant_state_init(&all_state[k], rate_estimator.m_sigFracBits[0][0], rate_estimator.m_gtxFracBits[0]); - all_state[k].effHeight = effectHeight; - all_state[k].effWidth = effectWidth; + depquant_state_init(&dep_quant_context.m_allStates[k], rate_estimator.m_sigFracBits[0][0], rate_estimator.m_gtxFracBits[0]); + dep_quant_context.m_allStates[k].effHeight = effectHeight; + dep_quant_context.m_allStates[k].effWidth = effectWidth; } - depquant_state_init(&start_state, rate_estimator.m_sigFracBits[0][0], rate_estimator.m_gtxFracBits[0]); - start_state.effHeight = effectHeight; - start_state.effWidth = effectWidth; - + depquant_state_init(&dep_quant_context.m_startState, rate_estimator.m_sigFracBits[0][0], rate_estimator.m_gtxFracBits[0]); + dep_quant_context.m_startState.effHeight = effectHeight; + dep_quant_context.m_startState.effWidth = effectWidth; + + + const uint32_t height_in_sbb = MAX(height >> 2, 1); + const uint32_t width_in_sbb = MAX(width >> 2, 1); //===== populate trellis ===== for (int scanIdx = firstTestPos; scanIdx >= 0; scanIdx--) { - uint32_t scan_pos = scan[scanIdx]; + uint32_t blkpos = scan[scanIdx]; + uint32_t pos_y = blkpos >> log2_tr_width; + uint32_t pos_x = blkpos - (pos_y << log2_tr_width); + + uint32_t cg_blockpos = scanIdx ? cg_scan[(scanIdx -1) >> 4] : 0; + uint32_t cg_pos_y = cg_blockpos / height_in_sbb; + uint32_t cg_pos_x = cg_blockpos - (cg_pos_y * height_in_sbb); + uint32_t diag = cg_pos_y + cg_pos_x; + + uint32_t sig_ctx_offset = compID == COLOR_Y ? (diag < 2 ? 8 : diag < 5 ? 4 : 0) : (diag < 2 ? 4 : 0); + uint32_t gtx_ctx_offset = compID == COLOR_Y ? (diag < 1 ? 16 : diag < 3 ? 11 : diag < 10 ? 6 : 1) : (diag < 1 ? 6 : 1); + + uint32_t nextSbbRight = (cg_pos_x < width_in_sbb - 1 ? cg_blockpos + 1 : 0); + uint32_t nextSbbBelow = (cg_pos_y < height_in_sbb - 1 ? cg_blockpos + width_in_sbb : 0); + if (enableScalingLists) { - init_quant_block(state, &quant_block, cur_tu, log2_tr_width, log2_tr_height, compID, needs_block_size_trafo_scale, q_coeff[scan_pos]); + init_quant_block(state, &dep_quant_context.m_quant, cur_tu, log2_tr_width, log2_tr_height, compID, needs_block_size_trafo_scale, q_coeff[blkpos]); xDecideAndUpdate( - abs(srcCoeff[scan_pos]), - scanInfo, - (zeroOut && (scanInfo.posX >= effWidth || scanInfo.posY >= effHeight)), - q_coeff[scan_pos], + &rate_estimator, + &dep_quant_context, + abs(srcCoeff[blkpos]), + scanIdx, + pos_x, + pos_y, + sig_ctx_offset, + gtx_ctx_offset, + width_in_sbb, + height_in_sbb, + nextSbbRight || nextSbbBelow, + encoder->m_scanId2NbInfoSbbArray[log2_tr_width][log2_tr_height][scanIdx ? scanIdx - 1 : 0], + (zeroOut && (pos_x >= effWidth || pos_y >= effHeight)), + q_coeff[blkpos], effectWidth, - effectHeight, - false); //tu.cu->slice->getReverseLastSigCoeffFlag()); + effectHeight + ); //tu.cu->slice->getReverseLastSigCoeffFlag()); } else { xDecideAndUpdate( - abs(srcCoeff[scan_pos]), - scanInfo, - (zeroOut && (scanInfo.posX >= effWidth || scanInfo.posY >= effHeight)), + &rate_estimator, + &dep_quant_context, + abs(srcCoeff[blkpos]), + scanIdx, + pos_x, + pos_y, + sig_ctx_offset, + gtx_ctx_offset, + width_in_sbb, + height_in_sbb, + nextSbbRight || nextSbbBelow, + encoder->m_scanId2NbInfoSbbArray[log2_tr_width][log2_tr_height][scanIdx ? scanIdx - 1 : 0], + (zeroOut && (pos_x >= effWidth || pos_y >= effHeight)), default_quant_coeff, effectWidth, - effectHeight, - false); //tu.cu->slice->getReverseLastSigCoeffFlag()); + effectHeight); //tu.cu->slice->getReverseLastSigCoeffFlag()); } } @@ -1030,7 +1272,7 @@ uint8_t uvg_dep_quant( Decision decision = {INT64_MAX, -1, -2}; int64_t minPathCost = 0; for (int8_t stateId = 0; stateId < 4; stateId++) { - int64_t pathCost = trellis[0][stateId].rdCost; + int64_t pathCost = dep_quant_context.m_trellis[0][stateId].rdCost; if (pathCost < minPathCost) { decision.prevId = stateId; minPathCost = pathCost; @@ -1040,7 +1282,7 @@ uint8_t uvg_dep_quant( //===== backward scanning ===== int scanIdx = 0; for (; decision.prevId >= 0; scanIdx++) { - decision = trellis[scanIdx][decision.prevId]; + decision = dep_quant_context.m_trellis[scanIdx][decision.prevId]; int32_t blkpos = scan[scanIdx]; coeff_out[blkpos] = (srcCoeff[blkpos] < 0 ? -decision.absLevel : decision.absLevel); absSum += decision.absLevel; diff --git a/src/dep_quant.h b/src/dep_quant.h index 35fec0b5..0e1d20ca 100644 --- a/src/dep_quant.h +++ b/src/dep_quant.h @@ -35,6 +35,22 @@ #include "global.h" +typedef struct encoder_control_t encoder_control_t; +typedef struct +{ + uint8_t num; + uint8_t inPos[5]; +} NbInfoSbb; + +typedef struct +{ + uint16_t maxDist; + uint16_t num; + uint16_t outPos[5]; +} NbInfoOut; + +int uvg_init_nb_info(encoder_control_t* encoder); +void uvg_dealloc_nb_info(encoder_control_t* encoder); #endif diff --git a/src/encoder.c b/src/encoder.c index f3d7653a..56d03305 100644 --- a/src/encoder.c +++ b/src/encoder.c @@ -320,6 +320,13 @@ encoder_control_t* uvg_encoder_control_init(const uvg_config *const cfg) encoder->scaling_list.use_default_list = 1; } + if(cfg->dep_quant) { + if(!uvg_init_nb_info(encoder)) { + fprintf(stderr, "Could not initialize nb info.\n"); + goto init_failed; + } + } + // ROI / delta QP if (cfg->roi.file_path) { const char *mode[2] = { "r", "rb" }; diff --git a/src/encoder.h b/src/encoder.h index be835890..81b091b3 100644 --- a/src/encoder.h +++ b/src/encoder.h @@ -38,6 +38,7 @@ * Initialization of encoder_control_t. */ +#include "dep_quant.h" #include "global.h" // IWYU pragma: keep #include "uvg266.h" #include "scalinglist.h" @@ -98,6 +99,9 @@ typedef struct encoder_control_t //scaling list scaling_list_t scaling_list; + NbInfoSbb* m_scanId2NbInfoSbbArray[7 + 1][7 + 1]; + NbInfoOut* m_scanId2NbInfoOutArray[7 + 1][7 + 1]; + //spec: references to variables defined in Rec. ITU-T H.265 (04/2013) int8_t tiles_enable; /*!> log2_block_width; - uint32_t pos_x = blkpos - (pos_y << log2_block_width); // TODO: height + uint32_t pos_x = blkpos - (pos_y << log2_block_width); //===== quantization ===== // set coeff diff --git a/src/uvg266.h b/src/uvg266.h index fe6e2b0f..c71a835a 100644 --- a/src/uvg266.h +++ b/src/uvg266.h @@ -552,7 +552,7 @@ typedef struct uvg_config uint8_t intra_rough_search_levels; uint8_t ibc; /* \brief Intra Block Copy parameter */ - + uint8_t dep_quant; } uvg_config; /**