[mrl] Implement MRL bitcost calculation.

This commit is contained in:
siivonek 2021-12-09 16:43:25 +02:00
parent dea3ca12aa
commit 236265a1f4
3 changed files with 30 additions and 10 deletions

View file

@ -500,7 +500,7 @@ static double calc_mode_bits(const encoder_state_t *state,
kvz_intra_get_dir_luma_predictor(x, y, candidate_modes, cur_cu, left_cu, above_cu); kvz_intra_get_dir_luma_predictor(x, y, candidate_modes, cur_cu, left_cu, above_cu);
} }
double mode_bits = kvz_luma_mode_bits(state, cur_cu->intra.mode, candidate_modes); double mode_bits = kvz_luma_mode_bits(state, cur_cu->intra.mode, candidate_modes, cur_cu->intra.multi_ref_idx);
if (((depth == 4 && x % 8 && y % 8) || (depth != 4)) && state->encoder_control->chroma_format != KVZ_CSP_400) { if (((depth == 4 && x % 8 && y % 8) || (depth != 4)) && state->encoder_control->chroma_format != KVZ_CSP_400) {
mode_bits += kvz_chroma_mode_bits(state, cur_cu->intra.mode_chroma, cur_cu->intra.mode); mode_bits += kvz_chroma_mode_bits(state, cur_cu->intra.mode_chroma, cur_cu->intra.mode);

View file

@ -675,7 +675,7 @@ static int8_t search_intra_rough(encoder_state_t * const state,
// affecting the halving search. // affecting the halving search.
int lambda_cost = (int)(state->lambda_sqrt + 0.5); int lambda_cost = (int)(state->lambda_sqrt + 0.5);
for (int mode_i = 0; mode_i < modes_selected; ++mode_i) { for (int mode_i = 0; mode_i < modes_selected; ++mode_i) {
costs[mode_i] += lambda_cost * kvz_luma_mode_bits(state, modes[mode_i], intra_preds); costs[mode_i] += lambda_cost * kvz_luma_mode_bits(state, modes[mode_i], intra_preds, 0);
} }
#undef PARALLEL_BLKS #undef PARALLEL_BLKS
@ -750,7 +750,8 @@ static int8_t search_intra_rdo(encoder_state_t * const state,
} }
for(int rdo_mode = 0; rdo_mode < modes_to_check; rdo_mode ++) { for(int rdo_mode = 0; rdo_mode < modes_to_check; rdo_mode ++) {
int rdo_bitcost = kvz_luma_mode_bits(state, modes[rdo_mode], intra_preds); int rdo_bitcost = kvz_luma_mode_bits(state, modes[rdo_mode], intra_preds, multi_ref_idx);
costs[rdo_mode] = rdo_bitcost * (int)(state->lambda + 0.5); costs[rdo_mode] = rdo_bitcost * (int)(state->lambda + 0.5);
// Perform transform split search and save mode RD cost for the best one. // Perform transform split search and save mode RD cost for the best one.
@ -800,9 +801,9 @@ static int8_t search_intra_rdo(encoder_state_t * const state,
} }
double kvz_luma_mode_bits(const encoder_state_t *state, int8_t luma_mode, const int8_t *intra_preds) double kvz_luma_mode_bits(const encoder_state_t *state, int8_t luma_mode, const int8_t *intra_preds, const uint8_t multi_ref_idx)
{ {
double mode_bits; double mode_bits = 0.0;
int8_t mode_in_preds = -1; int8_t mode_in_preds = -1;
for (int i = 0; i < INTRA_MPM_COUNT; ++i) { for (int i = 0; i < INTRA_MPM_COUNT; ++i) {
@ -812,12 +813,31 @@ double kvz_luma_mode_bits(const encoder_state_t *state, int8_t luma_mode, const
} }
} }
const cabac_ctx_t *ctx = &(state->cabac.ctx.intra_luma_mpm_flag_model); bool enable_mrl = state->encoder_control->cfg.mrl;
mode_bits = CTX_ENTROPY_FBITS(ctx, mode_in_preds!=-1); uint8_t multi_ref_index = enable_mrl ? multi_ref_idx : 0;
if (mode_in_preds != -1) { const cabac_ctx_t* ctx = &(state->cabac.ctx.intra_luma_mpm_flag_model);
if (multi_ref_index == 0) {
mode_bits += CTX_ENTROPY_FBITS(ctx, mode_in_preds != -1);
}
// Add MRL bits.
if (enable_mrl && MAX_REF_LINE_IDX > 1) {
ctx = &(state->cabac.ctx.multi_ref_line[0]);
mode_bits += CTX_ENTROPY_FBITS(ctx, multi_ref_index != 0);
if (multi_ref_index != 0 && MAX_REF_LINE_IDX > 2) {
ctx = &(state->cabac.ctx.multi_ref_line[1]);
mode_bits += CTX_ENTROPY_FBITS(ctx, multi_ref_index != 1);
}
}
if (mode_in_preds != -1 || multi_ref_index != 0) {
ctx = &(state->cabac.ctx.luma_planar_model[0]); ctx = &(state->cabac.ctx.luma_planar_model[0]);
mode_bits += CTX_ENTROPY_FBITS(ctx, mode_in_preds>0); if (multi_ref_index == 0) {
mode_bits += CTX_ENTROPY_FBITS(ctx, mode_in_preds>0);
}
mode_bits += MIN(4.0,mode_in_preds); mode_bits += MIN(4.0,mode_in_preds);
} else { } else {
mode_bits += 6.0; mode_bits += 6.0;

View file

@ -45,7 +45,7 @@
double kvz_luma_mode_bits(const encoder_state_t *state, double kvz_luma_mode_bits(const encoder_state_t *state,
int8_t luma_mode, const int8_t *intra_preds); int8_t luma_mode, const int8_t *intra_preds, uint8_t multi_ref_idx);
double kvz_chroma_mode_bits(const encoder_state_t *state, double kvz_chroma_mode_bits(const encoder_state_t *state,
int8_t chroma_mode, int8_t luma_mode); int8_t chroma_mode, int8_t luma_mode);