[DepQuant] WIP: doesn't crash but bitstream is illegal and quality a lot worse

This commit is contained in:
Joose Sainio 2023-01-11 14:17:18 +02:00 committed by Marko Viitanen
parent bfa699fac6
commit 5236bc93be
5 changed files with 155 additions and 48 deletions

View file

@ -69,9 +69,9 @@ typedef struct
int m_QShift;
int64_t m_QAdd;
int64_t m_QScale;
coeff_t m_maxQIdx;
coeff_t m_thresLast;
coeff_t m_thresSSbb;
int64_t m_maxQIdx;
int64_t m_thresLast;
int64_t m_thresSSbb;
// distortion normalization
int m_DistShift;
int64_t m_DistAdd;
@ -135,9 +135,9 @@ typedef struct
int8_t m_goRicePar;
int8_t m_goRiceZero;
int8_t m_stateId;
const uint32_t* m_sigFracBitsArray[2];
const uint32_t* m_gtxFracBitsArray[6];
struct common_context* m_commonCtx;
uint32_t *m_sigFracBitsArray[12];
int32_t *m_gtxFracBitsArray[21];
common_context* m_commonCtx;
unsigned effWidth;
unsigned effHeight;
@ -159,12 +159,12 @@ typedef struct
int uvg_init_nb_info(encoder_control_t * encoder) {
memset(encoder->m_scanId2NbInfoSbbArray, 0, sizeof(encoder->m_scanId2NbInfoSbbArray));
memset(encoder->m_scanId2NbInfoOutArray, 0, sizeof(encoder->m_scanId2NbInfoOutArray));
for (int hd = 0; hd <= 7; hd++)
for (int hd = 0; hd <= 6; hd++)
{
uint32_t raster2id[64 * 64] = {0};
for (int vd = 0; vd <= 7; vd++)
for (int vd = 0; vd <= 6; vd++)
{
if ((hd == 0 && vd <= 1) || (hd <= 1 && vd == 0))
{
@ -317,6 +317,21 @@ void uvg_dealloc_nb_info(encoder_control_t* encoder) {
}
static INLINE int ceil_log2(uint64_t x)
{
static const uint64_t t[6] = { 0xFFFFFFFF00000000ull, 0x00000000FFFF0000ull, 0x000000000000FF00ull, 0x00000000000000F0ull, 0x000000000000000Cull, 0x0000000000000002ull };
int y = (((x & (x - 1)) == 0) ? 0 : 1);
int j = 32;
for (int i = 0; i < 6; i++)
{
int k = (((x & t[i]) == 0) ? 0 : j);
y += k;
x >>= k;
j >>= 1;
}
return y;
}
static void init_quant_block(
const encoder_state_t* state,
quant_block* qp,
@ -349,8 +364,8 @@ static void init_quant_block(
maxLog2TrDynamicRange + 1,
8 * sizeof(int) + invShift - IQUANT_SHIFT - 1);
qp->m_maxQIdx = (1 << (qIdxBD - 1)) - 4;
qp->m_thresLast = (coeff_t)(((int64_t)(4) << qp->m_QShift));
qp->m_thresSSbb = (coeff_t)(((int64_t)(3) << qp->m_QShift));
qp->m_thresLast = (((int64_t)(4) << (int64_t)qp->m_QShift));
qp->m_thresSSbb = (((int64_t)(3) << (int64_t)qp->m_QShift));
// distortion calculation parameters
const int64_t qScale = (gValue == -1) ? qp->m_QScale : gValue;
const int nomDShift =
@ -363,8 +378,7 @@ static void init_quant_block(
1.0 / ((double)((int64_t)(1) << (-nomDShift)) * qScale2 * lambda) :
(double)((int64_t)(1) << nomDShift) / (qScale2 * lambda));
const int64_t pow2dfShift = (int64_t)(nomDistFactor * qScale2) + 1;
assert(pow2dfShift > 0xfffffffll);
const int dfShift = uvg_math_ceil_log2(pow2dfShift);
const int dfShift = ceil_log2(pow2dfShift);
qp->m_DistShift = 62 + qp->m_QShift - 2 * maxLog2TrDynamicRange - dfShift;
qp->m_DistAdd = ((int64_t)(1) << qp->m_DistShift) >> 1;
qp->m_DistStepAdd = (int64_t)(nomDistFactor * (double)((int64_t)(1) << (qp->m_DistShift + qp->m_QShift)) + .5);
@ -404,8 +418,8 @@ static void init_rate_esimator(rate_estimator * rate_estimator, const cabac_data
numCtx = (color == COLOR_Y? 21 : 11);
for (unsigned ctxId = 0; ctxId < numCtx; ctxId++) {
const cabac_ctx_t * par_ctx = color == COLOR_Y ? &ctx->ctx.cu_parity_flag_model_luma[ctxId] : &ctx->ctx.cu_parity_flag_model_chroma[ctxId];
const cabac_ctx_t * gt1_ctx = color == COLOR_Y ? &ctx->ctx.cu_gtx_flag_model_luma[0][ctxId] : &ctx->ctx.cu_gtx_flag_model_chroma[0][ctxId];
const cabac_ctx_t * gt2_ctx = color == COLOR_Y ? &ctx->ctx.cu_gtx_flag_model_luma[1][ctxId] : &ctx->ctx.cu_gtx_flag_model_chroma[1][ctxId];
const cabac_ctx_t * gt2_ctx = color == COLOR_Y ? &ctx->ctx.cu_gtx_flag_model_luma[0][ctxId] : &ctx->ctx.cu_gtx_flag_model_chroma[0][ctxId];
const cabac_ctx_t * gt1_ctx = color == COLOR_Y ? &ctx->ctx.cu_gtx_flag_model_luma[1][ctxId] : &ctx->ctx.cu_gtx_flag_model_chroma[1][ctxId];
int32_t* cb = &rate_estimator->m_gtxFracBits[ctxId];
int32_t par0 = (1 << SCALE_BITS) + (int32_t)CTX_ENTROPY_BITS(par_ctx, 0);
@ -423,7 +437,8 @@ static void init_rate_esimator(rate_estimator * rate_estimator, const cabac_data
static void xSetLastCoeffOffset(
const encoder_state_t* const state,
const cu_info_t* const cur_tu,
const cu_loc_t* const cu_loc,
const int width,
const int height,
rate_estimator* rate_estimator,
const bool cb_cbf,
const color_t compID)
@ -438,7 +453,7 @@ static void init_rate_esimator(rate_estimator * rate_estimator, const cabac_data
if (useIntraSubPartitions) {
bool rootCbfSoFar = false;
bool isLastSubPartition = false; //TODO: isp check
uint32_t nTus = uvg_get_isp_split_num(cu_loc->width, cu_loc->height, cur_tu->intra.isp_mode, true);
uint32_t nTus = uvg_get_isp_split_num(width, height, cur_tu->intra.isp_mode, true);
if (isLastSubPartition) {
//TransformUnit* tuPointer = tu.cu->firstTU;
//for (int tuIdx = 0; tuIdx < nTus - 1; tuIdx++) {
@ -477,7 +492,7 @@ static const unsigned prefixCtx[] = {0, 0, 0, 3, 6, 10, 15, 21};
for (unsigned xy = 0; xy < 2; xy++) {
int32_t bitOffset = (xy ? cbfDeltaBits : 0);
int32_t* lastBits = (xy ? rate_estimator->m_lastBitsY : rate_estimator->m_lastBitsX);
const unsigned size = (xy ? (compID == COLOR_Y ? cu_loc->height : cu_loc->chroma_height) : (compID == COLOR_Y ? cu_loc->width : cu_loc->chroma_width));
const unsigned size = (xy ? (height) : (width));
const unsigned log2Size = uvg_math_ceil_log2(size);
const bool useYCtx = (xy != 0);
const cabac_ctx_t* const ctxSetLast = useYCtx ?
@ -504,15 +519,18 @@ static const unsigned prefixCtx[] = {0, 0, 0, 3, 6, 10, 15, 21};
static void depquant_state_init(depquant_state* state, uint32_t sig_frac_bits[2], uint32_t gtx_frac_bits[6])
{
state->m_rdCost = INT64_MAX;
state->m_rdCost = INT64_MAX >> 1;
state->m_numSigSbb = 0;
state->m_remRegBins = 4; // just large enough for last scan pos
state->m_refSbbCtxId = -1;
state->m_sigFracBits[0] = sig_frac_bits[0];
state->m_sigFracBits[1] = sig_frac_bits[1];
memcpy(state->m_coeffFracBits, gtx_frac_bits, sizeof(gtx_frac_bits));
memcpy(state->m_coeffFracBits, gtx_frac_bits, sizeof(state->m_coeffFracBits));
state->m_goRicePar = 0;
state->m_goRiceZero = 0;
state->m_sbbFracBits[0] = 0;
state->m_sbbFracBits[1] = 0;
}
static INLINE void checkRdCostSkipSbbZeroOut(Decision *decision, const depquant_state * const state)
@ -841,7 +859,7 @@ static INLINE void updateStateEOS(
state->m_numSigSbb = 1;
memset(state->m_absLevelsAndCtxInit, 0, 16 * sizeof(uint8_t));
}
uint8_t* temp = (uint8_t*)(state->m_absLevelsAndCtxInit[scan_pos & 15]);
uint8_t* temp = (uint8_t*)(&state->m_absLevelsAndCtxInit[scan_pos & 15]);
*temp = (uint8_t)MIN(255, decision->absLevel);
update_common_context(state->m_commonCtx, scan_pos, width_in_sbb, height_in_sbb, sigNSbb, prvState, state);
@ -1099,7 +1117,8 @@ static void xDecideAndUpdate(
int uvg_dep_quant(
const encoder_state_t* const state,
const cu_info_t* const cur_tu,
const cu_loc_t* const cu_loc,
const int width,
const int height,
const coeff_t* srcCoeff,
coeff_t* coeff_out,
const color_t compID,
@ -1115,8 +1134,6 @@ int uvg_dep_quant(
dep_quant_context.m_prevStates = &dep_quant_context.m_allStates[4];
dep_quant_context.m_skipStates = &dep_quant_context.m_allStates[8];
const uint32_t width = compID == COLOR_Y ? cu_loc->width : cu_loc->chroma_width;
const uint32_t height = compID == COLOR_Y ? cu_loc->height : cu_loc->chroma_height;
const uint32_t lfnstIdx = tree_type != UVG_CHROMA_T || compID == COLOR_Y ?
cur_tu->lfnst_idx :
cur_tu->cr_lfnst_idx;
@ -1173,8 +1190,8 @@ int uvg_dep_quant(
height >= 4) {
firstTestPos =((width == 4 && height == 4) || (width == 8 && height == 8)) ? 7 : 15;
}
const int32_t default_quant_coeff = uvg_g_quant_scales[needs_block_size_trafo_scale][qp_scaled % 6];
const coeff_t thres = 4 << q_bits;
const int32_t default_quant_coeff = dep_quant_context.m_quant.m_QScale;
const int32_t thres = dep_quant_context.m_quant.m_thresLast;
for (; firstTestPos >= 0; firstTestPos--) {
coeff_t thresTmp = (enableScalingLists) ? (thres / (4 * q_coeff[scan[firstTestPos]])) : (thres / (4 * default_quant_coeff));
if (abs(srcCoeff[scan[firstTestPos]]) > thresTmp) {
@ -1188,7 +1205,7 @@ int uvg_dep_quant(
//===== real init =====
rate_estimator rate_estimator;
init_rate_esimator(&rate_estimator, &state->search_cabac, compID);
xSetLastCoeffOffset(state, cur_tu, cu_loc, &rate_estimator, cbf_is_set(cur_tu->cbf, COLOR_U), compID);
xSetLastCoeffOffset(state, cur_tu, width, height, &rate_estimator, cbf_is_set(cur_tu->cbf, COLOR_U), compID);
reset_common_context(&dep_quant_context.m_common_context, &rate_estimator, (width * height) >> 4, numCoeff);
dep_quant_context.m_common_context.m_nbInfo = encoder->m_scanId2NbInfoOutArray[log2_tr_width][log2_tr_height];
@ -1200,10 +1217,27 @@ int uvg_dep_quant(
depquant_state_init(&dep_quant_context.m_allStates[k], rate_estimator.m_sigFracBits[0][0], rate_estimator.m_gtxFracBits[0]);
dep_quant_context.m_allStates[k].effHeight = effectHeight;
dep_quant_context.m_allStates[k].effWidth = effectWidth;
dep_quant_context.m_allStates[k].m_commonCtx = &dep_quant_context.m_common_context;
int i1 = (k & 3) ? (k & 3) - 1 : 0;
dep_quant_context.m_allStates[k].m_stateId = i1;
for (int i = 0; i < (compID == COLOR_Y ? 12 : 8); ++i) {
dep_quant_context.m_allStates[k].m_sigFracBitsArray[i] = rate_estimator.m_sigFracBits[i1][i];
}
for (int i = 0; i < (compID == COLOR_Y ? 21 : 11); ++i) {
dep_quant_context.m_allStates[k].m_gtxFracBitsArray[i] = rate_estimator.m_gtxFracBits[i];
}
}
depquant_state_init(&dep_quant_context.m_startState, rate_estimator.m_sigFracBits[0][0], rate_estimator.m_gtxFracBits[0]);
dep_quant_context.m_startState.effHeight = effectHeight;
dep_quant_context.m_startState.effWidth = effectWidth;
dep_quant_context.m_startState.m_stateId = 0;
dep_quant_context.m_startState.m_commonCtx = &dep_quant_context.m_common_context;
for (int i = 0; i < (compID == COLOR_Y ? 12 : 8); ++i) {
dep_quant_context.m_startState.m_sigFracBitsArray[i] = rate_estimator.m_sigFracBits[0][i];
}
for (int i = 0; i < (compID == COLOR_Y ? 21 : 11); ++i) {
dep_quant_context.m_startState.m_gtxFracBitsArray[i] = rate_estimator.m_gtxFracBits[i];
}
const uint32_t height_in_sbb = MAX(height >> 2, 1);
@ -1292,16 +1326,15 @@ int uvg_dep_quant(
void uvg_dep_quant_dequant(
const encoder_state_t* const state,
const cu_info_t* const cur_tu,
const cu_loc_t* const cu_loc,
const int block_type,
const int width,
const int height,
const color_t compID,
coeff_t* quant_coeff,
coeff_t * coeff,
bool enableScalingLists)
{
const encoder_control_t* const encoder = state->encoder_control;
const uint32_t width = compID == COLOR_Y ? cu_loc->width : cu_loc->chroma_width;
const uint32_t height = compID == COLOR_Y ? cu_loc->height : cu_loc->chroma_height;
const int numCoeff = width * height;
@ -1339,7 +1372,7 @@ void uvg_dep_quant_dequant(
int shift = IQUANT_SHIFT + 1 - qpPer - transformShift + (enableScalingLists ? 4 : 0);
int invQScale = uvg_g_inv_quant_scales[needs_block_size_trafo_scale ? 1 : 0][qpRem];
int add = (shift < 0) ? 0 : ((1 << shift) >> 1);
int32_t scalinglist_type = (cur_tu->type == CU_INTRA ? 0 : 3) + (int8_t)(compID);
int32_t scalinglist_type = (block_type == CU_INTRA ? 0 : 3) + (int8_t)(compID);
const int32_t* dequant_coef = encoder->scaling_list.de_quant_coeff[log2_tr_width][log2_tr_height][scalinglist_type][qpDQ % 6];
//----- dequant coefficients -----

View file

@ -57,8 +57,9 @@ void uvg_dealloc_nb_info(encoder_control_t* encoder);
void uvg_dep_quant_dequant(
const encoder_state_t* const state,
const cu_info_t* const cur_tu,
const cu_loc_t* const cu_loc,
const int block_type,
const int width,
const int height,
const color_t compID,
coeff_t* quant_coeff,
coeff_t* coeff,
@ -67,7 +68,8 @@ void uvg_dep_quant_dequant(
int uvg_dep_quant(
const encoder_state_t* const state,
const cu_info_t* const cur_tu,
const cu_loc_t* const cu_loc,
const int width,
const int height,
const coeff_t* srcCoeff,
coeff_t* coeff_out,
const color_t compID,

View file

@ -795,6 +795,9 @@ static double qp_to_lambda(encoder_state_t* const state, int qp)
state->frame->QP + 2 + frame_allocation,
est_qp);
}
if(state->encoder_control->cfg.dep_quant) {
est_lambda *= pow(2, 0.25 / 3.0);
}
state->lambda = est_lambda;
state->lambda_sqrt = sqrt(est_lambda);
@ -820,7 +823,11 @@ static double qp_to_lambda(encoder_state_t* const state, int qp)
// Since this value will be later combined with qp_pred, clip to half of that instead to be safe
state->qp = CLIP(state->frame->QP + UVG_QP_DELTA_MIN / 2, state->frame->QP + UVG_QP_DELTA_MAX / 2, state->qp);
state->qp = CLIP_TO_QP(state->qp);
state->lambda = qp_to_lambda(state, state->qp);
double to_lambda = qp_to_lambda(state, state->qp);
if (state->encoder_control->cfg.dep_quant) {
to_lambda *= pow(2, 0.25 / 3.0);
}
state->lambda = to_lambda;
state->lambda_sqrt = sqrt(state->lambda);
ctu->adjust_lambda = state->lambda;
@ -1103,7 +1110,12 @@ void uvg_set_lcu_lambda_and_qp(encoder_state_t * const state,
pos.x = 0;
}
state->qp = CLIP_TO_QP(state->frame->QP + dqp);
state->lambda = qp_to_lambda(state, state->qp);
double to_lambda = qp_to_lambda(state, state->qp);
if (state->encoder_control->cfg.dep_quant) {
to_lambda *= pow(2, 0.25 / 3.0);
}
state->lambda = to_lambda;
state->lambda_sqrt = sqrt(state->lambda);
}
else if (ctrl->cfg.target_bitrate > 0) {
@ -1138,6 +1150,9 @@ void uvg_set_lcu_lambda_and_qp(encoder_state_t * const state,
state->frame->lambda * 1.5874010519681994,
lambda);
lambda = clip_lambda(lambda);
if (state->encoder_control->cfg.dep_quant) {
lambda *= pow(2, 0.25 / 3.0);
}
state->lambda = lambda;
state->lambda_sqrt = sqrt(lambda);
@ -1145,8 +1160,13 @@ void uvg_set_lcu_lambda_and_qp(encoder_state_t * const state,
} else {
state->qp = state->frame->QP;
state->lambda = state->frame->lambda;
state->lambda_sqrt = sqrt(state->frame->lambda);
double lambda = state->frame->lambda;
if (state->encoder_control->cfg.dep_quant) {
lambda *= pow(2, 0.25 / 3.0);
}
state->lambda = lambda;
state->lambda_sqrt = sqrt(lambda);
}
lcu->lambda = state->lambda;
@ -1170,7 +1190,11 @@ void uvg_set_lcu_lambda_and_qp(encoder_state_t * const state,
// Since this value will be later combined with qp_pred, clip to half of that instead to be safe
state->qp = CLIP(state->frame->QP + UVG_QP_DELTA_MIN / 2, state->frame->QP + UVG_QP_DELTA_MAX / 2, state->qp);
state->qp = CLIP_TO_QP(state->qp);
state->lambda = qp_to_lambda(state, state->qp);
double to_lambda = qp_to_lambda(state, state->qp);
if (state->encoder_control->cfg.dep_quant) {
to_lambda *= pow(2, 0.25 / 3.0);
}
state->lambda = to_lambda;
state->lambda_sqrt = sqrt(state->lambda);
lcu->adjust_lambda = state->lambda;

View file

@ -707,8 +707,21 @@ int uvg_quantize_residual_avx2(encoder_state_t *const state,
}
// Quantize coeffs. (coeff -> coeff_out)
if (state->encoder_control->cfg.rdoq_enable &&
int abs_sum = 0;
if(!use_trskip && state->encoder_control->cfg.dep_quant) {
uvg_dep_quant(
state,
cur_cu,
width,
height,
coeff,
coeff_out,
color,
tree_type,
&abs_sum,
state->encoder_control->cfg.scaling_list);
}
else if (state->encoder_control->cfg.rdoq_enable &&
(width > 4 || !state->encoder_control->cfg.rdoq_skip) && !use_trskip)
{
uvg_rdoq(state, coeff, coeff_out, width, height, color,
@ -792,6 +805,10 @@ int uvg_quantize_residual_avx2(encoder_state_t *const state,
void uvg_dequant_avx2(const encoder_state_t * const state, coeff_t *q_coef, coeff_t *coef, int32_t width, int32_t height,color_t color, int8_t block_type, int8_t transform_skip)
{
const encoder_control_t * const encoder = state->encoder_control;
if (encoder->cfg.dep_quant) {
uvg_dep_quant_dequant(state, block_type, width, height, color, q_coef, coef, encoder->cfg.scaling_list);
return;
}
int32_t shift,add,coeff_q;
int32_t n;
const uint32_t log2_tr_width = uvg_g_convert_to_log2[width];

View file

@ -316,8 +316,21 @@ int uvg_quant_cbcr_residual_generic(
if(lfnst_idx) {
uvg_fwd_lfnst(cur_cu, width, height, COLOR_UV, lfnst_idx, coeff, tree_type, state->collocated_luma_mode);
}
if (state->encoder_control->cfg.rdoq_enable &&
int abs_sum = 0;
if (!false && state->encoder_control->cfg.dep_quant) {
uvg_dep_quant(
state,
cur_cu,
width,
height,
coeff,
coeff_out,
COLOR_U,
tree_type,
&abs_sum,
state->encoder_control->cfg.scaling_list);
}
else if (state->encoder_control->cfg.rdoq_enable &&
(width > 4 || !state->encoder_control->cfg.rdoq_skip))
{
uvg_rdoq(state, coeff, coeff_out, width, height, cur_cu->joint_cb_cr == 1 ? COLOR_V : COLOR_U,
@ -497,7 +510,21 @@ int uvg_quantize_residual_generic(encoder_state_t *const state,
// Quantize coeffs. (coeff -> coeff_out)
if (state->encoder_control->cfg.rdoq_enable &&
int abs_sum = 0;
if (!false && state->encoder_control->cfg.dep_quant) {
uvg_dep_quant(
state,
cur_cu,
width,
height,
coeff,
coeff_out,
COLOR_U,
tree_type,
&abs_sum,
state->encoder_control->cfg.scaling_list);
}
else if (state->encoder_control->cfg.rdoq_enable &&
(width > 4 || !state->encoder_control->cfg.rdoq_skip) && !use_trskip)
{
uvg_rdoq(state, coeff, coeff_out, width, height, color,
@ -591,6 +618,10 @@ int uvg_quantize_residual_generic(encoder_state_t *const state,
void uvg_dequant_generic(const encoder_state_t * const state, coeff_t *q_coef, coeff_t *coef, int32_t width, int32_t height,color_t color, int8_t block_type, int8_t transform_skip)
{
const encoder_control_t * const encoder = state->encoder_control;
if(encoder->cfg.dep_quant) {
uvg_dep_quant_dequant(state, block_type, width, height, color, q_coef, coef, encoder->cfg.scaling_list);
return;
}
int32_t shift,add,coeff_q;
int32_t n;
const uint32_t log2_tr_width = uvg_g_convert_to_log2[width];