From b8a8bce55a1b9f54cf5afbffd833d2300b118212 Mon Sep 17 00:00:00 2001 From: siivonek Date: Wed, 26 Jan 2022 13:34:39 +0200 Subject: [PATCH] [mip] Fix MIP bit cost calculation. --- src/search.c | 2 +- src/search_intra.c | 41 ++++++++++++++++++++++++++++++++++++----- src/search_intra.h | 2 +- 3 files changed, 38 insertions(+), 7 deletions(-) diff --git a/src/search.c b/src/search.c index faa8b900..7db7d5b9 100644 --- a/src/search.c +++ b/src/search.c @@ -505,7 +505,7 @@ static double calc_mode_bits(const encoder_state_t *state, } // MIP_TODO: calculation of MIP mode cost if this CU has MIP enabled. - double mode_bits = kvz_luma_mode_bits(state, cur_cu->intra.mode, candidate_modes, cur_cu->intra.multi_ref_idx, 0); + double mode_bits = kvz_luma_mode_bits(state, cur_cu->intra.mode, candidate_modes, cur_cu->intra.multi_ref_idx, 0, 0); if (((depth == 4 && x % 8 && y % 8) || (depth != 4)) && state->encoder_control->chroma_format != KVZ_CSP_400) { mode_bits += kvz_chroma_mode_bits(state, cur_cu->intra.mode_chroma, cur_cu->intra.mode); diff --git a/src/search_intra.c b/src/search_intra.c index dc0d8d6e..6800bfef 100644 --- a/src/search_intra.c +++ b/src/search_intra.c @@ -679,7 +679,7 @@ static int8_t search_intra_rough(encoder_state_t * const state, // affecting the halving search. int lambda_cost = (int)(state->lambda_sqrt + 0.5); for (int mode_i = 0; mode_i < modes_selected; ++mode_i) { - costs[mode_i] += lambda_cost * kvz_luma_mode_bits(state, modes[mode_i], intra_preds, 0, 0); + costs[mode_i] += lambda_cost * kvz_luma_mode_bits(state, modes[mode_i], intra_preds, 0, 0, 0); } #undef PARALLEL_BLKS @@ -756,10 +756,40 @@ static int8_t search_intra_rdo(encoder_state_t * const state, } // MIP_TODO: implement this inside the standard intra for loop. Code duplication is bad. + // MIP_TODO: deriving mip flag context id could be done in it's own function since the exact same code is used in encode_coding_tree.c // MIP search const int transp_off = num_mip_modes >> 1; for (uint8_t mip_mode = 0; mip_mode < num_mip_modes; ++mip_mode) { - int rdo_bitcost = kvz_luma_mode_bits(state, mip_modes[mip_mode], intra_preds, 0, num_mip_modes); + // Derive mip flag context id + uint8_t ctx_id = 0; + const videoframe_t* const frame = state->tile->frame; + const vector2d_t lcu_px = { SUB_SCU(x_px), SUB_SCU(y_px) }; + cu_info_t* cur_cu; + cur_cu = LCU_GET_CU_AT_PX(lcu, lcu_px.x, lcu_px.y); + const int cu_width = width; + const int cu_height = cu_width; // TODO: height for non-square blocks + const int pu_x = PU_GET_X(cur_cu->part_size, cu_width, x_px, 0); + const int pu_y = PU_GET_Y(cur_cu->part_size, cu_width, y_px, 0); + const cu_info_t* left_pu = NULL; + const cu_info_t* above_pu = NULL; + + if (pu_x > 0) { + assert(pu_x >> 2 > 0); + left_pu = kvz_cu_array_at_const(frame->cu_array, pu_x - 1, pu_y + cu_width - 1); + } + if (left_pu != NULL) { + ctx_id = left_pu->intra.mip_flag ? 1 : 0; + } + // Don't take the above PU across the LCU boundary. + if (pu_y % LCU_WIDTH > 0 && pu_y > 0) { + assert(pu_y >> 2 > 0); + above_pu = kvz_cu_array_at_const(frame->cu_array, pu_x + cu_width - 1, pu_y - 1); + } + if (above_pu != NULL) { + ctx_id += above_pu->intra.mip_flag ? 1 : 0; + } + ctx_id = (cu_width > 2 * cu_height || cu_height > 2 * cu_width) ? 3 : ctx_id; + int rdo_bitcost = kvz_luma_mode_bits(state, mip_modes[mip_mode], intra_preds, 0, num_mip_modes, ctx_id); mip_costs[mip_mode] = rdo_bitcost * (int)(state->lambda + 0.5); // MIP_TODO: check if this is also correct in the case when MIP is used. @@ -797,7 +827,7 @@ static int8_t search_intra_rdo(encoder_state_t * const state, } for(int rdo_mode = 0; rdo_mode < modes_to_check; rdo_mode ++) { - int rdo_bitcost = kvz_luma_mode_bits(state, modes[rdo_mode], intra_preds, multi_ref_idx, 0); + int rdo_bitcost = kvz_luma_mode_bits(state, modes[rdo_mode], intra_preds, multi_ref_idx, 0, 0); costs[rdo_mode] = rdo_bitcost * (int)(state->lambda + 0.5); @@ -874,7 +904,7 @@ static int8_t search_intra_rdo(encoder_state_t * const state, } -double kvz_luma_mode_bits(const encoder_state_t *state, int8_t luma_mode, const int8_t *intra_preds, const uint8_t multi_ref_idx, const uint8_t num_mip_modes) +double kvz_luma_mode_bits(const encoder_state_t *state, int8_t luma_mode, const int8_t *intra_preds, const uint8_t multi_ref_idx, const uint8_t num_mip_modes, int mip_flag_ctx_id) { double mode_bits = 0.0; @@ -897,8 +927,9 @@ double kvz_luma_mode_bits(const encoder_state_t *state, int8_t luma_mode, const bool mip_flag = enable_mip; const bool is_transposed = luma_mode >= transp_off ? true : false; int8_t mip_mode = is_transposed ? luma_mode - transp_off : luma_mode; + // Write MIP flag - cabac->cur_ctx = &(cabac->ctx.mip_flag); + cabac->cur_ctx = &(cabac->ctx.mip_flag[mip_flag_ctx_id]); CABAC_BIN(cabac, mip_flag, "mip_flag"); if (mip_flag) { // Write MIP transpose flag & mode diff --git a/src/search_intra.h b/src/search_intra.h index 06ef77ec..659695b3 100644 --- a/src/search_intra.h +++ b/src/search_intra.h @@ -44,7 +44,7 @@ #include "intra.h" double kvz_luma_mode_bits(const encoder_state_t *state, - int8_t luma_mode, const int8_t *intra_preds, uint8_t multi_ref_idx, const uint8_t num_mip_modes); + int8_t luma_mode, const int8_t *intra_preds, uint8_t multi_ref_idx, const uint8_t num_mip_modes, int mip_flag_ctx_id); double kvz_chroma_mode_bits(const encoder_state_t *state, int8_t chroma_mode, int8_t luma_mode);