diff --git a/src/search.c b/src/search.c index 13fa66a7..93a72e35 100644 --- a/src/search.c +++ b/src/search.c @@ -500,7 +500,7 @@ static double calc_mode_bits(const encoder_state_t *state, kvz_intra_get_dir_luma_predictor(x, y, candidate_modes, cur_cu, left_cu, above_cu); } - double mode_bits = kvz_luma_mode_bits(state, cur_cu->intra.mode, candidate_modes); + double mode_bits = kvz_luma_mode_bits(state, cur_cu->intra.mode, candidate_modes, cur_cu->intra.multi_ref_idx); if (((depth == 4 && x % 8 && y % 8) || (depth != 4)) && state->encoder_control->chroma_format != KVZ_CSP_400) { mode_bits += kvz_chroma_mode_bits(state, cur_cu->intra.mode_chroma, cur_cu->intra.mode); diff --git a/src/search_intra.c b/src/search_intra.c index f5bd6ed2..8615565a 100644 --- a/src/search_intra.c +++ b/src/search_intra.c @@ -675,7 +675,7 @@ static int8_t search_intra_rough(encoder_state_t * const state, // affecting the halving search. int lambda_cost = (int)(state->lambda_sqrt + 0.5); for (int mode_i = 0; mode_i < modes_selected; ++mode_i) { - costs[mode_i] += lambda_cost * kvz_luma_mode_bits(state, modes[mode_i], intra_preds); + costs[mode_i] += lambda_cost * kvz_luma_mode_bits(state, modes[mode_i], intra_preds, 0); } #undef PARALLEL_BLKS @@ -750,7 +750,8 @@ static int8_t search_intra_rdo(encoder_state_t * const state, } for(int rdo_mode = 0; rdo_mode < modes_to_check; rdo_mode ++) { - int rdo_bitcost = kvz_luma_mode_bits(state, modes[rdo_mode], intra_preds); + int rdo_bitcost = kvz_luma_mode_bits(state, modes[rdo_mode], intra_preds, multi_ref_idx); + costs[rdo_mode] = rdo_bitcost * (int)(state->lambda + 0.5); // Perform transform split search and save mode RD cost for the best one. @@ -800,9 +801,9 @@ static int8_t search_intra_rdo(encoder_state_t * const state, } -double kvz_luma_mode_bits(const encoder_state_t *state, int8_t luma_mode, const int8_t *intra_preds) +double kvz_luma_mode_bits(const encoder_state_t *state, int8_t luma_mode, const int8_t *intra_preds, const uint8_t multi_ref_idx) { - double mode_bits; + double mode_bits = 0.0; int8_t mode_in_preds = -1; for (int i = 0; i < INTRA_MPM_COUNT; ++i) { @@ -812,12 +813,31 @@ double kvz_luma_mode_bits(const encoder_state_t *state, int8_t luma_mode, const } } - const cabac_ctx_t *ctx = &(state->cabac.ctx.intra_luma_mpm_flag_model); - mode_bits = CTX_ENTROPY_FBITS(ctx, mode_in_preds!=-1); + bool enable_mrl = state->encoder_control->cfg.mrl; + uint8_t multi_ref_index = enable_mrl ? multi_ref_idx : 0; - if (mode_in_preds != -1) { + const cabac_ctx_t* ctx = &(state->cabac.ctx.intra_luma_mpm_flag_model); + + if (multi_ref_index == 0) { + mode_bits += CTX_ENTROPY_FBITS(ctx, mode_in_preds != -1); + } + + // Add MRL bits. + if (enable_mrl && MAX_REF_LINE_IDX > 1) { + ctx = &(state->cabac.ctx.multi_ref_line[0]); + mode_bits += CTX_ENTROPY_FBITS(ctx, multi_ref_index != 0); + + if (multi_ref_index != 0 && MAX_REF_LINE_IDX > 2) { + ctx = &(state->cabac.ctx.multi_ref_line[1]); + mode_bits += CTX_ENTROPY_FBITS(ctx, multi_ref_index != 1); + } + } + + if (mode_in_preds != -1 || multi_ref_index != 0) { ctx = &(state->cabac.ctx.luma_planar_model[0]); - mode_bits += CTX_ENTROPY_FBITS(ctx, mode_in_preds>0); + if (multi_ref_index == 0) { + mode_bits += CTX_ENTROPY_FBITS(ctx, mode_in_preds>0); + } mode_bits += MIN(4.0,mode_in_preds); } else { mode_bits += 6.0; diff --git a/src/search_intra.h b/src/search_intra.h index f7ce24c0..4fc7210d 100644 --- a/src/search_intra.h +++ b/src/search_intra.h @@ -45,7 +45,7 @@ double kvz_luma_mode_bits(const encoder_state_t *state, - int8_t luma_mode, const int8_t *intra_preds); + int8_t luma_mode, const int8_t *intra_preds, uint8_t multi_ref_idx); double kvz_chroma_mode_bits(const encoder_state_t *state, int8_t chroma_mode, int8_t luma_mode);