Probably correct RD cost calculation for all inter modes

This commit is contained in:
Joose Sainio 2022-01-28 12:26:12 +02:00
parent 1a9e54601f
commit 6d73db5a2a
4 changed files with 130 additions and 38 deletions

View file

@ -60,14 +60,6 @@
// Cost threshold for doing intra search in inter frames with --rd=0. // Cost threshold for doing intra search in inter frames with --rd=0.
static const int INTRA_THRESHOLD = 8; static const int INTRA_THRESHOLD = 8;
// Modify weight of luma SSD.
#ifndef LUMA_MULT
# define LUMA_MULT 0.8
#endif
// Modify weight of chroma SSD.
#ifndef CHROMA_MULT
# define CHROMA_MULT 1.5
#endif
static INLINE void copy_cu_info(int x_local, int y_local, int width, lcu_t *from, lcu_t *to) static INLINE void copy_cu_info(int x_local, int y_local, int width, lcu_t *from, lcu_t *to)
{ {
@ -216,16 +208,16 @@ static double cu_zero_coeff_cost(const encoder_state_t *state, lcu_t *work_tree,
const int chroma_index = (y_local / 2) * LCU_WIDTH_C + (x_local / 2); const int chroma_index = (y_local / 2) * LCU_WIDTH_C + (x_local / 2);
double ssd = 0.0; double ssd = 0.0;
ssd += LUMA_MULT * kvz_pixels_calc_ssd( ssd += KVZ_LUMA_MULT * kvz_pixels_calc_ssd(
&lcu->ref.y[luma_index], &lcu->rec.y[luma_index], &lcu->ref.y[luma_index], &lcu->rec.y[luma_index],
LCU_WIDTH, LCU_WIDTH, cu_width LCU_WIDTH, LCU_WIDTH, cu_width
); );
if (x % 8 == 0 && y % 8 == 0 && state->encoder_control->chroma_format != KVZ_CSP_400) { if (x % 8 == 0 && y % 8 == 0 && state->encoder_control->chroma_format != KVZ_CSP_400) {
ssd += CHROMA_MULT * kvz_pixels_calc_ssd( ssd += KVZ_CHROMA_MULT * kvz_pixels_calc_ssd(
&lcu->ref.u[chroma_index], &lcu->rec.u[chroma_index], &lcu->ref.u[chroma_index], &lcu->rec.u[chroma_index],
LCU_WIDTH_C, LCU_WIDTH_C, cu_width / 2 LCU_WIDTH_C, LCU_WIDTH_C, cu_width / 2
); );
ssd += CHROMA_MULT * kvz_pixels_calc_ssd( ssd += KVZ_CHROMA_MULT * kvz_pixels_calc_ssd(
&lcu->ref.v[chroma_index], &lcu->rec.v[chroma_index], &lcu->ref.v[chroma_index], &lcu->rec.v[chroma_index],
LCU_WIDTH_C, LCU_WIDTH_C, cu_width / 2 LCU_WIDTH_C, LCU_WIDTH_C, cu_width / 2
); );
@ -253,6 +245,7 @@ double kvz_cu_rd_cost_luma(const encoder_state_t *const state,
double *bit_cost) double *bit_cost)
{ {
const int width = LCU_WIDTH >> depth; const int width = LCU_WIDTH >> depth;
const int skip_residual_coding = pred_cu->skipped || (pred_cu->type == CU_INTER && pred_cu->cbf == 0);
// cur_cu is used for TU parameters. // cur_cu is used for TU parameters.
cu_info_t *const tr_cu = LCU_GET_CU_AT_PX(lcu, x_px, y_px); cu_info_t *const tr_cu = LCU_GET_CU_AT_PX(lcu, x_px, y_px);
@ -280,7 +273,8 @@ double kvz_cu_rd_cost_luma(const encoder_state_t *const state,
if (width <= TR_MAX_WIDTH if (width <= TR_MAX_WIDTH
&& width > TR_MIN_WIDTH && width > TR_MIN_WIDTH
&& !intra_split_flag && !intra_split_flag
&& MIN(tr_cu->tr_depth, depth) - tr_cu->depth < max_tr_depth) && MIN(tr_cu->tr_depth, depth) - tr_cu->depth < max_tr_depth
&& !skip_residual_coding)
{ {
cabac_ctx_t *ctx = &(cabac->ctx.trans_subdiv_model[5 - (6 - depth)]); cabac_ctx_t *ctx = &(cabac->ctx.trans_subdiv_model[5 - (6 - depth)]);
CABAC_FBITS_UPDATE(cabac, ctx, tr_depth > 0, tr_tree_bits, "tr_split_search"); CABAC_FBITS_UPDATE(cabac, ctx, tr_depth > 0, tr_tree_bits, "tr_split_search");
@ -300,7 +294,7 @@ double kvz_cu_rd_cost_luma(const encoder_state_t *const state,
} }
if (cabac->update && tr_cu->tr_depth == tr_cu->depth) { if (cabac->update && tr_cu->tr_depth == tr_cu->depth && !skip_residual_coding) {
// Because these need to be coded before the luma cbf they also need to be counted // Because these need to be coded before the luma cbf they also need to be counted
// before the cabac state changes. However, since this branch is only executed when // before the cabac state changes. However, since this branch is only executed when
// calculating the last RD cost it is not problem to include the chroma cbf costs in // calculating the last RD cost it is not problem to include the chroma cbf costs in
@ -340,7 +334,8 @@ double kvz_cu_rd_cost_luma(const encoder_state_t *const state,
width); width);
} }
{
if (!skip_residual_coding) {
int8_t luma_scan_mode = kvz_get_scan_order(pred_cu->type, pred_cu->intra.mode, depth); int8_t luma_scan_mode = kvz_get_scan_order(pred_cu->type, pred_cu->intra.mode, depth);
const coeff_t *coeffs = &lcu->coeff.y[xy_to_zorder(LCU_WIDTH, x_px, y_px)]; const coeff_t *coeffs = &lcu->coeff.y[xy_to_zorder(LCU_WIDTH, x_px, y_px)];
@ -349,7 +344,7 @@ double kvz_cu_rd_cost_luma(const encoder_state_t *const state,
} }
double bits = tr_tree_bits + coeff_bits; double bits = tr_tree_bits + coeff_bits;
return (double)ssd * LUMA_MULT + bits * state->lambda; return (double)ssd * KVZ_LUMA_MULT + bits * state->lambda;
} }
@ -362,6 +357,7 @@ double kvz_cu_rd_cost_chroma(const encoder_state_t *const state,
const vector2d_t lcu_px = { x_px / 2, y_px / 2 }; const vector2d_t lcu_px = { x_px / 2, y_px / 2 };
const int width = (depth <= MAX_DEPTH) ? LCU_WIDTH >> (depth + 1) : LCU_WIDTH >> depth; const int width = (depth <= MAX_DEPTH) ? LCU_WIDTH >> (depth + 1) : LCU_WIDTH >> depth;
cu_info_t *const tr_cu = LCU_GET_CU_AT_PX(lcu, x_px, y_px); cu_info_t *const tr_cu = LCU_GET_CU_AT_PX(lcu, x_px, y_px);
const int skip_residual_coding = pred_cu->skipped || (pred_cu->type == CU_INTER && pred_cu->cbf == 0);
double tr_tree_bits = 0; double tr_tree_bits = 0;
double coeff_bits = 0; double coeff_bits = 0;
@ -376,7 +372,7 @@ double kvz_cu_rd_cost_chroma(const encoder_state_t *const state,
} }
// See luma for why the second condition // See luma for why the second condition
if (depth < MAX_PU_DEPTH && (!state->search_cabac.update || tr_cu->tr_depth != tr_cu->depth)) { if (depth < MAX_PU_DEPTH && (!state->search_cabac.update || tr_cu->tr_depth != tr_cu->depth) && !skip_residual_coding) {
const int tr_depth = depth - pred_cu->depth; const int tr_depth = depth - pred_cu->depth;
cabac_data_t* cabac = (cabac_data_t*)&state->search_cabac; cabac_data_t* cabac = (cabac_data_t*)&state->search_cabac;
cabac_ctx_t *ctx = &(cabac->ctx.qt_cbf_model_chroma[tr_depth]); cabac_ctx_t *ctx = &(cabac->ctx.qt_cbf_model_chroma[tr_depth]);
@ -417,6 +413,7 @@ double kvz_cu_rd_cost_chroma(const encoder_state_t *const state,
ssd = ssd_u + ssd_v; ssd = ssd_u + ssd_v;
} }
if (!skip_residual_coding)
{ {
int8_t scan_order = kvz_get_scan_order(pred_cu->type, pred_cu->intra.mode_chroma, depth); int8_t scan_order = kvz_get_scan_order(pred_cu->type, pred_cu->intra.mode_chroma, depth);
const int index = xy_to_zorder(LCU_WIDTH_C, lcu_px.x, lcu_px.y); const int index = xy_to_zorder(LCU_WIDTH_C, lcu_px.x, lcu_px.y);
@ -427,7 +424,7 @@ double kvz_cu_rd_cost_chroma(const encoder_state_t *const state,
} }
double bits = tr_tree_bits + coeff_bits; double bits = tr_tree_bits + coeff_bits;
return (double)ssd * CHROMA_MULT + bits * state->lambda; return (double)ssd * KVZ_CHROMA_MULT + bits * state->lambda;
} }
static double cu_rd_cost_tr_split_accurate(const encoder_state_t* const state, static double cu_rd_cost_tr_split_accurate(const encoder_state_t* const state,
@ -553,7 +550,7 @@ static double cu_rd_cost_tr_split_accurate(const encoder_state_t* const state,
} }
*bit_cost += coeff_bits; *bit_cost += coeff_bits;
double bits = tr_tree_bits + coeff_bits; double bits = tr_tree_bits + coeff_bits;
return luma_ssd * LUMA_MULT + chroma_ssd * CHROMA_MULT + bits * state->lambda; return luma_ssd * KVZ_LUMA_MULT + chroma_ssd * KVZ_CHROMA_MULT + bits * state->lambda;
} }

View file

@ -46,6 +46,15 @@
#define MAX_UNIT_STATS_MAP_SIZE MAX(MAX_REF_PIC_COUNT, MRG_MAX_NUM_CANDS) #define MAX_UNIT_STATS_MAP_SIZE MAX(MAX_REF_PIC_COUNT, MRG_MAX_NUM_CANDS)
// Modify weight of luma SSD.
#ifndef KVZ_LUMA_MULT
# define KVZ_LUMA_MULT 0.8
#endif
// Modify weight of chroma SSD.
#ifndef KVZ_CHROMA_MULT
# define KVZ_CHROMA_MULT 1.5
#endif
/** /**
* \brief Data collected during search processes. * \brief Data collected during search processes.
* *

View file

@ -1160,6 +1160,30 @@ static void search_frac(inter_search_info_t *info,
*best_bits = bitcost; *best_bits = bitcost;
} }
int kvz_get_skip_context(int x, int y, lcu_t* const lcu, cu_array_t* const cu_a) {
assert(!(lcu && cu_a));
int context = 0;
if(lcu) {
int x_local = SUB_SCU(x);
int y_local = SUB_SCU(y);
if (x) {
context += LCU_GET_CU_AT_PX(lcu, x_local - 1, y_local)->skipped;
}
if (y) {
context += LCU_GET_CU_AT_PX(lcu, x_local, y_local - 1)->skipped;
}
}
else {
if (x > 0) {
context += kvz_cu_array_at_const(cu_a, x - 1, y)->skipped;
}
if (y > 0) {
context += kvz_cu_array_at_const(cu_a, x, y - 1)->skipped;
}
}
return context;
}
/** /**
* \brief Calculate the scaled MV * \brief Calculate the scaled MV
*/ */
@ -1676,7 +1700,7 @@ static void search_pu_inter(encoder_state_t * const state,
double bits = merge_flag_cost + merge_idx + CTX_ENTROPY_FBITS(&(state->search_cabac.ctx.cu_merge_idx_ext_model), merge_idx != 0); double bits = merge_flag_cost + merge_idx + CTX_ENTROPY_FBITS(&(state->search_cabac.ctx.cu_merge_idx_ext_model), merge_idx != 0);
if(state->encoder_control->cfg.rdo >= 2) { if(state->encoder_control->cfg.rdo >= 2) {
kvz_cu_cost_inter_rd2(state, x, y, depth, lcu, &merge->cost[merge->size], &bits); kvz_cu_cost_inter_rd2(state, x, y, depth, &merge->unit[merge->size], lcu, &merge->cost[merge->size], &bits);
} }
else { else {
merge->cost[merge->size] = kvz_satd_any_size(width, height, merge->cost[merge->size] = kvz_satd_any_size(width, height,
@ -1773,10 +1797,6 @@ static void search_pu_inter(encoder_state_t * const state,
amvp[0].size > 0 ? amvp[0].keys[0] : 0, amvp[0].size > 0 ? amvp[0].keys[0] : 0,
amvp[1].size > 0 ? amvp[1].keys[0] : 0 amvp[1].size > 0 ? amvp[1].keys[0] : 0
}; };
if (state->encoder_control->cfg.rdo >= 2) {
kvz_cu_cost_inter_rd2(state, x, y, depth, lcu, &amvp[0].cost[best_keys[0]], &amvp[0].bits[best_keys[0]]);
kvz_cu_cost_inter_rd2(state, x, y, depth, lcu, &amvp[1].cost[best_keys[1]], &amvp[1].bits[best_keys[1]]);
}
cu_info_t *best_unipred[2] = { cu_info_t *best_unipred[2] = {
&amvp[0].unit[best_keys[0]], &amvp[0].unit[best_keys[0]],
@ -1808,6 +1828,11 @@ static void search_pu_inter(encoder_state_t * const state,
} }
} }
if (state->encoder_control->cfg.rdo >= 2) {
kvz_cu_cost_inter_rd2(state, x, y, depth, &amvp[0].unit[best_keys[0]], lcu, &amvp[0].cost[best_keys[0]], &amvp[0].bits[best_keys[0]]);
kvz_cu_cost_inter_rd2(state, x, y, depth, &amvp[1].unit[best_keys[1]], lcu, &amvp[1].cost[best_keys[1]], &amvp[1].bits[best_keys[1]]);
}
// Fractional-pixel motion estimation. // Fractional-pixel motion estimation.
// Refine the best PUs so far from both lists, if available. // Refine the best PUs so far from both lists, if available.
for (int list = 0; list < 2; ++list) { for (int list = 0; list < 2; ++list) {
@ -1859,7 +1884,7 @@ static void search_pu_inter(encoder_state_t * const state,
CU_SET_MV_CAND(unipred_pu, list, cu_mv_cand); CU_SET_MV_CAND(unipred_pu, list, cu_mv_cand);
if (state->encoder_control->cfg.rdo >= 2) { if (state->encoder_control->cfg.rdo >= 2) {
kvz_cu_cost_inter_rd2(state, x, y, depth, lcu, &frac_cost, &frac_bits); kvz_cu_cost_inter_rd2(state, x, y, depth, unipred_pu, lcu, &frac_cost, &frac_bits);
} }
amvp[list].cost[key] = frac_cost; amvp[list].cost[key] = frac_cost;
@ -1985,7 +2010,7 @@ static void search_pu_inter(encoder_state_t * const state,
assert(amvp[2].size <= MAX_UNIT_STATS_MAP_SIZE); assert(amvp[2].size <= MAX_UNIT_STATS_MAP_SIZE);
kvz_sort_keys_by_cost(&amvp[2]); kvz_sort_keys_by_cost(&amvp[2]);
if (state->encoder_control->cfg.rdo >= 2) { if (state->encoder_control->cfg.rdo >= 2) {
kvz_cu_cost_inter_rd2(state, x, y, depth, lcu, &amvp[2].cost[amvp[2].keys[0]], &amvp[2].bits[amvp[2].keys[0]]); kvz_cu_cost_inter_rd2(state, x, y, depth, &amvp[2].unit[amvp[2].keys[0]], lcu, &amvp[2].cost[amvp[2].keys[0]], &amvp[2].bits[amvp[2].keys[0]]);
} }
} }
@ -2012,39 +2037,96 @@ static void search_pu_inter(encoder_state_t * const state,
*/ */
void kvz_cu_cost_inter_rd2(encoder_state_t * const state, void kvz_cu_cost_inter_rd2(encoder_state_t * const state,
int x, int y, int depth, int x, int y, int depth,
cu_info_t* cur_cu,
lcu_t *lcu, lcu_t *lcu,
double *inter_cost, double *inter_cost,
double* inter_bitcost){ double* inter_bitcost){
cu_info_t *cur_cu = LCU_GET_CU_AT_PX(lcu, SUB_SCU(x), SUB_SCU(y));
int tr_depth = MAX(1, depth); int tr_depth = MAX(1, depth);
if (cur_cu->part_size != SIZE_2Nx2N) { if (cur_cu->part_size != SIZE_2Nx2N) {
tr_depth = depth + 1; tr_depth = depth + 1;
} }
kvz_lcu_fill_trdepth(lcu, x, y, depth, tr_depth); kvz_lcu_fill_trdepth(lcu, x, y, depth, tr_depth);
const int x_px = SUB_SCU(x);
const int y_px = SUB_SCU(y);
const int width = LCU_WIDTH >> depth;
const bool reconstruct_chroma = state->encoder_control->chroma_format != KVZ_CSP_400; const bool reconstruct_chroma = state->encoder_control->chroma_format != KVZ_CSP_400;
kvz_inter_recon_cu(state, lcu, x, y, CU_WIDTH_FROM_DEPTH(depth), true, reconstruct_chroma); kvz_inter_recon_cu(state, lcu, x, y, CU_WIDTH_FROM_DEPTH(depth), true, reconstruct_chroma);
kvz_quantize_lcu_residual(state, true, reconstruct_chroma,
x, y, depth,
NULL,
lcu,
false);
int index = y_px * LCU_WIDTH + x_px;
double ssd = kvz_pixels_calc_ssd(&lcu->ref.y[index], &lcu->rec.y[index],
LCU_WIDTH, LCU_WIDTH,
width) * KVZ_LUMA_MULT;
if (reconstruct_chroma) {
int index = y_px / 2 * LCU_WIDTH_C + x_px / 2;
double ssd_u = kvz_pixels_calc_ssd(&lcu->ref.u[index], &lcu->rec.u[index],
LCU_WIDTH_C, LCU_WIDTH_C,
width);
double ssd_v = kvz_pixels_calc_ssd(&lcu->ref.v[index], &lcu->rec.v[index],
LCU_WIDTH_C, LCU_WIDTH_C,
width);
ssd += ssd_u + ssd_v;
ssd *= KVZ_CHROMA_MULT;
}
double no_cbf_bits;
double bits = 0; double bits = 0;
int cbf = cbf_is_set_any(cur_cu->cbf, depth); int skip_context = kvz_get_skip_context(x, y, lcu, NULL);
*inter_bitcost += CTX_ENTROPY_FBITS(&state->cabac.ctx.cu_qt_root_cbf_model, !!cbf); if (cur_cu->merged) {
no_cbf_bits = CTX_ENTROPY_FBITS(&state->cabac.ctx.cu_skip_flag_model[skip_context], 1);
bits += CTX_ENTROPY_FBITS(&state->cabac.ctx.cu_skip_flag_model[skip_context], 0);
}
else {
no_cbf_bits = CTX_ENTROPY_FBITS(&state->cabac.ctx.cu_qt_root_cbf_model, 0);
bits += CTX_ENTROPY_FBITS(&state->cabac.ctx.cu_qt_root_cbf_model, 1);
}
double no_cbf_cost = ssd + (no_cbf_bits + *inter_bitcost) * state->lambda;
kvz_quantize_lcu_residual(state, true, reconstruct_chroma,
x, y, depth,
NULL,
lcu,
false);
int cbf = cbf_is_set_any(cur_cu->cbf, depth);
double temp_bits = 0;
if(cbf) { if(cbf) {
*inter_cost = kvz_cu_rd_cost_luma(state, SUB_SCU(x), SUB_SCU(y), depth, cur_cu, lcu, &bits); *inter_cost = kvz_cu_rd_cost_luma(state, x_px, y_px, depth, cur_cu, lcu, &temp_bits);
if (reconstruct_chroma) { if (reconstruct_chroma) {
*inter_cost += kvz_cu_rd_cost_chroma(state, SUB_SCU(x), SUB_SCU(y), depth, cur_cu, lcu, &bits); *inter_cost += kvz_cu_rd_cost_chroma(state, x_px, y_px, depth, cur_cu, lcu, &temp_bits);
} }
} }
else {
// If we have no coeffs after quant we already have the cost calculated
*inter_cost = no_cbf_cost;
if(cur_cu->merged) {
*inter_bitcost += no_cbf_bits;
}
return;
}
FILE_BITS(bits, x, y, depth, "inter rd 2 bits"); FILE_BITS(bits, x, y, depth, "inter rd 2 bits");
*inter_cost += *inter_bitcost * state->lambda; *inter_cost += (*inter_bitcost +bits )* state->lambda;
if(no_cbf_cost < *inter_cost && 0) {
cur_cu->cbf = 0;
if (cur_cu->merged) {
cur_cu->skipped = 1;
}
kvz_inter_recon_cu(state, lcu, x, y, CU_WIDTH_FROM_DEPTH(depth), true, reconstruct_chroma);
*inter_cost = no_cbf_cost;
if (cur_cu->merged) {
*inter_bitcost += no_cbf_bits;
}
}
else if(cur_cu->merged) {
if (cur_cu->merged) {
*inter_bitcost += bits;
}
}
} }
@ -2268,6 +2350,7 @@ void kvz_search_cu_smp(encoder_state_t * const state,
if (state->encoder_control->cfg.rdo >= 2) { if (state->encoder_control->cfg.rdo >= 2) {
kvz_cu_cost_inter_rd2(state, kvz_cu_cost_inter_rd2(state,
x, y, depth, x, y, depth,
LCU_GET_CU_AT_PX(lcu, x_local, y_local),
lcu, lcu,
inter_cost, inter_cost,
inter_bitcost); inter_bitcost);

View file

@ -94,8 +94,11 @@ unsigned kvz_inter_satd_cost(const encoder_state_t* state,
int y); int y);
void kvz_cu_cost_inter_rd2(encoder_state_t* const state, void kvz_cu_cost_inter_rd2(encoder_state_t* const state,
int x, int y, int depth, int x, int y, int depth,
cu_info_t* cur_cu,
lcu_t* lcu, lcu_t* lcu,
double* inter_cost, double* inter_cost,
double* inter_bitcost); double* inter_bitcost);
int kvz_get_skip_context(int x, int y, lcu_t* const lcu, cu_array_t* const cu_a);
#endif // SEARCH_INTER_H_ #endif // SEARCH_INTER_H_