go_rice_param calculation fix

This commit is contained in:
Joose Sainio 2021-03-19 10:28:42 +02:00
parent 8049ebb597
commit 1fd583eae0

View file

@ -45,6 +45,10 @@
const uint32_t kvz_g_go_rice_range[5] = { 7, 14, 26, 46, 78 }; const uint32_t kvz_g_go_rice_range[5] = { 7, 14, 26, 46, 78 };
const uint32_t kvz_g_go_rice_prefix_len[5] = { 8, 7, 6, 5, 4 }; const uint32_t kvz_g_go_rice_prefix_len[5] = { 8, 7, 6, 5, 4 };
const uint32_t g_auiGoRiceParsCoeff[32] =
{
0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3
};
/** /**
* Entropy bits to estimate coded bits in RDO / RDOQ (From HM 12.0) * Entropy bits to estimate coded bits in RDO / RDOQ (From HM 12.0)
@ -646,6 +650,33 @@ void kvz_rdoq_sign_hiding(
} }
} }
unsigned templateAbsSum(const coeff_t* coeff, int baseLevel, uint32_t posX, uint32_t posY, uint32_t width, uint32_t height)
{
const coeff_t* pData = coeff + posX + posY * width;
coeff_t sum = 0;
if (posX < width - 1)
{
sum += abs(pData[1]);
if (posX < width - 2)
{
sum += abs(pData[2]);
}
if (posY < height - 1)
{
sum += abs(pData[width + 1]);
}
}
if (posY < height - 1)
{
sum += abs(pData[width]);
if (posY < height - 2)
{
sum += abs(pData[width << 1]);
}
}
return MAX(MIN(sum - 5 * baseLevel, 31), 0);
}
/** RDOQ with CABAC /** RDOQ with CABAC
* \returns void * \returns void
@ -786,8 +817,14 @@ void kvz_rdoq(encoder_state_t * const state, coeff_t *coef, coeff_t *dest_coeff,
double err = (double)level_double; double err = (double)level_double;
cost_coeff0[scanpos] = err * err * temp; cost_coeff0[scanpos] = err * err * temp;
block_uncoded_cost += cost_coeff0[ scanpos ]; block_uncoded_cost += cost_coeff0[ scanpos ];
uint32_t pos_y = blkpos >> log2_block_size;
uint32_t pos_x = blkpos - (pos_y << log2_block_size);
//===== coefficient level estimation ===== //===== coefficient level estimation =====
int32_t level; int32_t level;
if (reg_bins < 4) {
int sumAll = templateAbsSum(coef, 4, pos_x, pos_y, width, height);
go_rice_param = g_auiGoRiceParsCoeff[sumAll];
}
uint16_t gt1_ctx = ctx_set; uint16_t gt1_ctx = ctx_set;
uint16_t gt2_ctx = ctx_set; uint16_t gt2_ctx = ctx_set;
@ -798,8 +835,6 @@ void kvz_rdoq(encoder_state_t * const state, coeff_t *coef, coeff_t *dest_coeff,
level_double, max_abs_level, 0, gt1_ctx, gt2_ctx, par_ctx, go_rice_param, level_double, max_abs_level, 0, gt1_ctx, gt2_ctx, par_ctx, go_rice_param,
reg_bins, q_bits, temp, 1, type ); reg_bins, q_bits, temp, 1, type );
} else { } else {
uint32_t pos_y = blkpos >> log2_block_size;
uint32_t pos_x = blkpos - ( pos_y << log2_block_size );
uint16_t ctx_sig = kvz_context_get_sig_ctx_idx_abs(coef, pos_x, pos_y, width, height, type, &temp_diag, &temp_sum); uint16_t ctx_sig = kvz_context_get_sig_ctx_idx_abs(coef, pos_x, pos_y, width, height, type, &temp_diag, &temp_sum);
if (temp_diag != -1) { if (temp_diag != -1) {
ctx_set = (MIN(temp_sum, 4) + 1) + (!temp_diag ? ((type == 0) ? 15 : 5) : (type == 0) ? temp_diag < 3 ? 10 : (temp_diag < 10 ? 5 : 0) : 0); ctx_set = (MIN(temp_sum, 4) + 1) + (!temp_diag ? ((type == 0) ? 15 : 5) : (type == 0) ? temp_diag < 3 ? 10 : (temp_diag < 10 ? 5 : 0) : 0);
@ -836,12 +871,12 @@ void kvz_rdoq(encoder_state_t * const state, coeff_t *coef, coeff_t *dest_coeff,
dest_coeff[blkpos] = (coeff_t)level; dest_coeff[blkpos] = (coeff_t)level;
base_cost += cost_coeff[scanpos]; base_cost += cost_coeff[scanpos];
base_level = 4; //base_level = 4;
if (level >= base_level) { //if (level >= base_level) {
if(level > 3*(1<<go_rice_param)) { // if(level > 3*(1<<go_rice_param)) {
go_rice_param = MIN(go_rice_param + 1, 4); // go_rice_param = MIN(go_rice_param + 1, 4);
} // }
} //}
//===== context set update ===== //===== context set update =====
if ((scanpos % SCAN_SET_SIZE == 0) && scanpos > 0) { if ((scanpos % SCAN_SET_SIZE == 0) && scanpos > 0) {
@ -851,6 +886,8 @@ void kvz_rdoq(encoder_state_t * const state, coeff_t *coef, coeff_t *dest_coeff,
} }
else if (reg_bins >= 4) { else if (reg_bins >= 4) {
reg_bins -= (level < 2 ? level : 3) + (scanpos != last_scanpos); reg_bins -= (level < 2 ? level : 3) + (scanpos != last_scanpos);
int sumAll = templateAbsSum(coef, 4, pos_x, pos_y, width, height);
go_rice_param = g_auiGoRiceParsCoeff[sumAll];
} }
rd_stats.sig_cost += cost_sig[scanpos]; rd_stats.sig_cost += cost_sig[scanpos];