[depquant] Only initialize rate_estimator when necessary

This commit is contained in:
Joose Sainio 2023-04-17 14:45:55 +03:00
parent 00f838306f
commit 00cc58bc55
6 changed files with 53 additions and 34 deletions

View file

@ -42,13 +42,6 @@
#include <immintrin.h>
#define sm_numCtxSetsSig 3
#define sm_numCtxSetsGtx 2
#define sm_maxNumSigSbbCtx 2
#define sm_maxNumSigCtx 12
#define sm_maxNumGtxCtx 21
#define SCALE_BITS 15
#define RICEMAX 32
static const int32_t g_goRiceBits[4][RICEMAX] = {
{ 32768, 65536, 98304, 131072, 163840, 196608, 262144, 262144, 327680, 327680, 327680, 327680, 393216, 393216, 393216, 393216, 393216, 393216, 393216, 393216, 458752, 458752, 458752, 458752, 458752, 458752, 458752, 458752, 458752, 458752, 458752, 458752},
@ -102,16 +95,6 @@ typedef struct
} common_context;
typedef struct
{
int32_t m_lastBitsX[TR_MAX_WIDTH];
int32_t m_lastBitsY[TR_MAX_WIDTH];
uint32_t m_sigSbbFracBits[sm_maxNumSigSbbCtx][2];
uint32_t m_sigFracBits[sm_numCtxSetsSig][sm_maxNumSigCtx][2];
int32_t m_gtxFracBits[sm_maxNumGtxCtx][6];
} rate_estimator;
typedef struct
{
int64_t m_rdCost;
@ -451,12 +434,12 @@ static void reset_common_context(common_context* ctx, const rate_estimator * rat
static void init_rate_esimator(rate_estimator * rate_estimator, const cabac_data_t * const ctx, color_t color)
{
const cabac_ctx_t * base_ctx = color == COLOR_Y ? ctx->ctx.sig_coeff_group_model : (ctx->ctx.sig_coeff_group_model + 2);
for (unsigned ctxId = 0; ctxId < sm_maxNumSigSbbCtx; ctxId++) {
for (unsigned ctxId = 0; ctxId < SM_MAX_NUM_SIG_SBB_CTX; ctxId++) {
rate_estimator->m_sigSbbFracBits[ctxId][0] = CTX_ENTROPY_BITS(&base_ctx[ctxId], 0);
rate_estimator->m_sigSbbFracBits[ctxId][1] = CTX_ENTROPY_BITS(&base_ctx[ctxId], 1);
}
unsigned numCtx = (color == COLOR_Y ? 12 : 8);
for (unsigned ctxSetId = 0; ctxSetId < sm_numCtxSetsSig; ctxSetId++) {
for (unsigned ctxSetId = 0; ctxSetId < SM_NUM_CTX_SETS_SIG; ctxSetId++) {
base_ctx = color == COLOR_Y ? ctx->ctx.cu_sig_model_luma[ctxSetId] : ctx->ctx.cu_sig_model_chroma[ctxSetId];
for (unsigned ctxId = 0; ctxId < numCtx; ctxId++) {
rate_estimator->m_sigFracBits[ctxSetId][ctxId][0] = CTX_ENTROPY_BITS(&base_ctx[ctxId], 0);
@ -2309,7 +2292,8 @@ int uvg_dep_quant(
} else {
dep_quant_context.m_quant = (quant_block*)&state->quant_blocks[0];
}
if (dep_quant_context.m_quant->needs_init) {
//TODO: no idea when it is safe not to reinit for inter
if (dep_quant_context.m_quant->needs_init || cur_tu->type == CU_INTER) {
init_quant_block(state, dep_quant_context.m_quant, cur_tu, log2_tr_width, log2_tr_height, compID, needs_block_size_trafo_scale, -1);
}
@ -2352,11 +2336,15 @@ int uvg_dep_quant(
}
//===== real init =====
rate_estimator rate_estimator;
init_rate_esimator(&rate_estimator, &state->search_cabac, compID);
xSetLastCoeffOffset(state, cur_tu, width, height, &rate_estimator, compID);
rate_estimator* rate_estimator = compID == COLOR_Y && cur_tu->type == CU_INTRA && cur_tu->intra.isp_mode != ISP_MODE_NO_ISP ?
&state->rate_estimator[3] : &state->rate_estimator[compID];
if(rate_estimator->needs_init || cur_tu->type == CU_INTER) {
init_rate_esimator(rate_estimator, &state->search_cabac, compID);
xSetLastCoeffOffset(state, cur_tu, width, height, rate_estimator, compID);
rate_estimator->needs_init = false;
}
reset_common_context(&dep_quant_context.m_common_context, &rate_estimator, (width * height) >> 4, numCoeff);
reset_common_context(&dep_quant_context.m_common_context, rate_estimator, (width * height) >> 4, numCoeff);
dep_quant_context.m_common_context.m_nbInfo = encoder->m_scanId2NbInfoOutArray[log2_tr_width][log2_tr_height];
@ -2367,9 +2355,9 @@ int uvg_dep_quant(
dep_quant_context.m_allStates.m_numSigSbb[k] = 0;
dep_quant_context.m_allStates.m_remRegBins[k] = 4; // just large enough for last scan pos
dep_quant_context.m_allStates.m_refSbbCtxId[k] = -1;
dep_quant_context.m_allStates.m_sigFracBits[k][0] = rate_estimator.m_sigFracBits[0][0][0];
dep_quant_context.m_allStates.m_sigFracBits[k][1] = rate_estimator.m_sigFracBits[0][0][1];
memcpy(dep_quant_context.m_allStates.m_coeffFracBits[k], rate_estimator.m_gtxFracBits[0], sizeof(dep_quant_context.m_allStates.m_coeffFracBits[k]));
dep_quant_context.m_allStates.m_sigFracBits[k][0] = rate_estimator->m_sigFracBits[0][0][0];
dep_quant_context.m_allStates.m_sigFracBits[k][1] = rate_estimator->m_sigFracBits[0][0][1];
memcpy(dep_quant_context.m_allStates.m_coeffFracBits[k], rate_estimator->m_gtxFracBits[0], sizeof(dep_quant_context.m_allStates.m_coeffFracBits[k]));
dep_quant_context.m_allStates.m_goRicePar[k] = 0;
dep_quant_context.m_allStates.m_goRiceZero[k] = 0;
@ -2378,7 +2366,7 @@ int uvg_dep_quant(
dep_quant_context.m_allStates.m_stateId[k] = k & 3;
for (int i = 0; i < (compID == COLOR_Y ? 12 : 8); ++i) {
memcpy(dep_quant_context.m_allStates.m_sigFracBitsArray[k][i], rate_estimator.m_sigFracBits[(k & 3 ? (k & 3) - 1 : 0)][i], sizeof(uint32_t) * 2);
memcpy(dep_quant_context.m_allStates.m_sigFracBitsArray[k][i], rate_estimator->m_sigFracBits[(k & 3 ? (k & 3) - 1 : 0)][i], sizeof(uint32_t) * 2);
}
}
@ -2388,19 +2376,19 @@ int uvg_dep_quant(
dep_quant_context.m_allStates.all_lt_four = false;
dep_quant_context.m_allStates.m_commonCtx = &dep_quant_context.m_common_context;
for (int i = 0; i < (compID == COLOR_Y ? 21 : 11); ++i) {
memcpy(dep_quant_context.m_allStates.m_gtxFracBitsArray[i], rate_estimator.m_gtxFracBits[i], sizeof(int32_t) * 6);
memcpy(dep_quant_context.m_allStates.m_gtxFracBitsArray[i], rate_estimator->m_gtxFracBits[i], sizeof(int32_t) * 6);
}
depquant_state_init(&dep_quant_context.m_startState, rate_estimator.m_sigFracBits[0][0], rate_estimator.m_gtxFracBits[0]);
depquant_state_init(&dep_quant_context.m_startState, rate_estimator->m_sigFracBits[0][0], rate_estimator->m_gtxFracBits[0]);
dep_quant_context.m_startState.effHeight = effectHeight;
dep_quant_context.m_startState.effWidth = effectWidth;
dep_quant_context.m_startState.m_stateId = 0;
dep_quant_context.m_startState.m_commonCtx = &dep_quant_context.m_common_context;
for (int i = 0; i < (compID == COLOR_Y ? 12 : 8); ++i) {
dep_quant_context.m_startState.m_sigFracBitsArray[i] = rate_estimator.m_sigFracBits[0][i];
dep_quant_context.m_startState.m_sigFracBitsArray[i] = rate_estimator->m_sigFracBits[0][i];
}
for (int i = 0; i < (compID == COLOR_Y ? 21 : 11); ++i) {
dep_quant_context.m_startState.m_gtxFracBitsArray[i] = rate_estimator.m_gtxFracBits[i];
dep_quant_context.m_startState.m_gtxFracBitsArray[i] = rate_estimator->m_gtxFracBits[i];
}
const uint32_t height_in_sbb = MAX(height >> 2, 1);
@ -2416,7 +2404,7 @@ int uvg_dep_quant(
init_quant_block(state, dep_quant_context.m_quant, cur_tu, log2_tr_width, log2_tr_height, compID, needs_block_size_trafo_scale, q_coeff[blkpos]);
xDecideAndUpdate(
&rate_estimator,
rate_estimator,
ctxs,
scan_info,
abs(srcCoeff[blkpos]),
@ -2433,7 +2421,7 @@ int uvg_dep_quant(
}
else {
xDecideAndUpdate(
&rate_estimator,
rate_estimator,
ctxs,
scan_info,
abs(srcCoeff[blkpos]),

View file

@ -36,6 +36,14 @@
#include "cu.h"
#include "global.h"
#define SM_NUM_CTX_SETS_SIG 3
#define SM_NUM_CTX_SETS_GTX 2
#define SM_MAX_NUM_SIG_SBB_CTX 2
#define SM_MAX_NUM_SIG_CTX 12
#define SM_MAX_NUM_GTX_CTX 21
#define SCALE_BITS 15
#define RICEMAX 32
typedef struct encoder_control_t encoder_control_t;
struct dep_quant_scan_info
@ -65,6 +73,17 @@ typedef struct
bool needs_init;
} quant_block;
typedef struct
{
int32_t m_lastBitsX[TR_MAX_WIDTH];
int32_t m_lastBitsY[TR_MAX_WIDTH];
uint32_t m_sigSbbFracBits[SM_MAX_NUM_SIG_SBB_CTX][2];
uint32_t m_sigFracBits[SM_NUM_CTX_SETS_SIG][SM_MAX_NUM_SIG_CTX][2];
int32_t m_gtxFracBits[SM_MAX_NUM_GTX_CTX][6];
bool needs_init;
} rate_estimator;
typedef struct
{
uint8_t num;

View file

@ -368,6 +368,7 @@ typedef struct encoder_state_t {
int8_t collocated_luma_mode;
quant_block quant_blocks[3]; // luma, ISP, chroma
rate_estimator rate_estimator[4]; // luma, cb, cr, isp
} encoder_state_t;
void uvg_encode_one_frame(encoder_state_t * const state, uvg_picture* frame);

View file

@ -1908,6 +1908,8 @@ void uvg_intra_recon_cu(
int split_type = search_data->pred_cu.intra.isp_mode;
int split_limit = uvg_get_isp_split_num(width, height, split_type, true);
state->quant_blocks[1].needs_init = true;
for (int i = 0; i < split_limit; ++i) {
cu_loc_t tu_loc;
uvg_get_isp_split_loc(&tu_loc, cu_loc->x, cu_loc->y, width, height, i, split_type, true);
@ -1917,6 +1919,7 @@ void uvg_intra_recon_cu(
if(tu_loc.x % 4 == 0) {
intra_recon_tb_leaf(state, &pu_loc, cu_loc, lcu, COLOR_Y, search_data);
}
state->rate_estimator[3].needs_init = true;
uvg_quantize_lcu_residual(state, true, false, false,
&tu_loc, cur_cu, lcu,
false, tree_type);
@ -2030,6 +2033,8 @@ double uvg_recon_and_estimate_cost_isp(encoder_state_t* const state,
if (tu_loc.x % 4 == 0) {
intra_recon_tb_leaf(state, &pu_loc, cu_loc, lcu, COLOR_Y, search_data);
}
state->rate_estimator[3].needs_init = true;
uvg_quantize_lcu_residual(state, true, false, false,
&tu_loc, &search_data->pred_cu, lcu,
false, UVG_LUMA_T);

View file

@ -1492,6 +1492,7 @@ int8_t uvg_search_intra_chroma_rdo(
double original_c_lambda = state->c_lambda;
state->quant_blocks[2].needs_init = true;
state->rate_estimator[1].needs_init = true;
for (int8_t mode_i = 0; mode_i < num_modes; ++mode_i) {
const uint8_t mode = chroma_data[mode_i].pred_cu.intra.mode_chroma;

View file

@ -468,6 +468,7 @@ static void quantize_chroma(
if (transform == DCT7_CHROMA) {
abs_sum = 0;
state->rate_estimator[2].needs_init = true;
uvg_dep_quant(
state,
cur_tu,
@ -1538,6 +1539,7 @@ void uvg_quantize_lcu_residual(
cu_loc_t split_cu_loc[4];
uint16_t child_cbfs[3];
const int split_count = uvg_get_split_locs(cu_loc, split, split_cu_loc,NULL);
for (int i = 0; i < split_count; ++i) {
uvg_quantize_lcu_residual(state, luma, chroma, 0, &split_cu_loc[i], NULL, lcu, early_skip, tree_type);
if(i != 0) {
@ -1558,11 +1560,14 @@ void uvg_quantize_lcu_residual(
uvg_cu_loc_ctor(&loc, x, y, width, height);
if (luma) {
state->quant_blocks[0].needs_init = true;
state->rate_estimator[0].needs_init = true;
quantize_tr_residual(state, COLOR_Y, &loc, cur_pu, lcu, early_skip, tree_type);
}
double c_lambda = state->c_lambda;
state->c_lambda = uvg_calculate_chroma_lambda(state, state->encoder_control->cfg.jccr, cur_pu->joint_cb_cr);
if (chroma) {
state->rate_estimator[2].needs_init = true;
if(state->encoder_control->cfg.dep_quant) {
cabac_data_t temp_cabac;
memcpy(&temp_cabac, &state->search_cabac, sizeof(cabac_data_t));