From 6b3395797862784d693988a03c5786cbb45f050a Mon Sep 17 00:00:00 2001 From: siivonek Date: Mon, 24 Jan 2022 13:16:28 +0200 Subject: [PATCH] [mip] Implement MIP bit cost calculation. --- src/cu.h | 4 +- src/search.c | 3 +- src/search_intra.c | 106 +++++++++++++++++++++++++++------------------ src/search_intra.h | 4 +- 4 files changed, 70 insertions(+), 47 deletions(-) diff --git a/src/cu.h b/src/cu.h index 56ece914..4be18926 100644 --- a/src/cu.h +++ b/src/cu.h @@ -169,8 +169,8 @@ typedef struct int8_t mode; int8_t mode_chroma; uint8_t multi_ref_idx; - uint8_t mip_flag; - uint8_t mip_is_transposed; + int8_t mip_flag; + int8_t mip_is_transposed; } intra; struct { mv_t mv[2][2]; // \brief Motion vectors for L0 and L1 diff --git a/src/search.c b/src/search.c index 2c91a5d8..2cc14eb3 100644 --- a/src/search.c +++ b/src/search.c @@ -502,7 +502,8 @@ static double calc_mode_bits(const encoder_state_t *state, kvz_intra_get_dir_luma_predictor(x, y, candidate_modes, cur_cu, left_cu, above_cu); } - double mode_bits = kvz_luma_mode_bits(state, cur_cu->intra.mode, candidate_modes, cur_cu->intra.multi_ref_idx); + // MIP_TODO: calculation of MIP mode cost if this CU has MIP enabled. + double mode_bits = kvz_luma_mode_bits(state, cur_cu->intra.mode, candidate_modes, cur_cu->intra.multi_ref_idx, 0); if (((depth == 4 && x % 8 && y % 8) || (depth != 4)) && state->encoder_control->chroma_format != KVZ_CSP_400) { mode_bits += kvz_chroma_mode_bits(state, cur_cu->intra.mode_chroma, cur_cu->intra.mode); diff --git a/src/search_intra.c b/src/search_intra.c index 48e70914..46a34e2f 100644 --- a/src/search_intra.c +++ b/src/search_intra.c @@ -679,7 +679,7 @@ static int8_t search_intra_rough(encoder_state_t * const state, // affecting the halving search. int lambda_cost = (int)(state->lambda_sqrt + 0.5); for (int mode_i = 0; mode_i < modes_selected; ++mode_i) { - costs[mode_i] += lambda_cost * kvz_luma_mode_bits(state, modes[mode_i], intra_preds, 0); + costs[mode_i] += lambda_cost * kvz_luma_mode_bits(state, modes[mode_i], intra_preds, 0, 0); } #undef PARALLEL_BLKS @@ -759,7 +759,7 @@ static int8_t search_intra_rdo(encoder_state_t * const state, // MIP search const int transp_off = num_mip_modes >> 1; for (uint8_t mip_mode = 0; mip_mode < num_mip_modes; ++mip_mode) { - int rdo_bitcost = kvz_mip_mode_bits(state, mip_mode, num_mip_modes); + int rdo_bitcost = kvz_luma_mode_bits(state, mip_modes[mip_mode], intra_preds, 0, num_mip_modes); mip_costs[mip_mode] = rdo_bitcost * (int)(state->lambda + 0.5); // MIP_TODO: check if this is also correct in the case when MIP is used. @@ -797,7 +797,7 @@ static int8_t search_intra_rdo(encoder_state_t * const state, } for(int rdo_mode = 0; rdo_mode < modes_to_check; rdo_mode ++) { - int rdo_bitcost = kvz_luma_mode_bits(state, modes[rdo_mode], intra_preds, multi_ref_idx); + int rdo_bitcost = kvz_luma_mode_bits(state, modes[rdo_mode], intra_preds, multi_ref_idx, 0); costs[rdo_mode] = rdo_bitcost * (int)(state->lambda + 0.5); @@ -853,56 +853,80 @@ static int8_t search_intra_rdo(encoder_state_t * const state, } -double kvz_mip_mode_bits(const encoder_state_t *state, int mip_mode, int num_mip_modes) +double kvz_luma_mode_bits(const encoder_state_t *state, int8_t luma_mode, const int8_t *intra_preds, const uint8_t multi_ref_idx, const uint8_t num_mip_modes) { double mode_bits = 0.0; - // MIP_TODO: calculate bit costs of writing the following: mip_flag, mip_transpose_flag & mip_mode + bool enable_mip = state->encoder_control->cfg.mip ? (num_mip_modes > 0 ? true : false) : false; - return mode_bits; -} + if (enable_mip) { + // Make a copy of state->cabac for bit cost estimation. + cabac_data_t state_cabac_copy; + cabac_data_t* cabac; + memcpy(&state_cabac_copy, &state->cabac, sizeof(cabac_data_t)); + // Clear data and set mode to count only + state_cabac_copy.only_count = 1; + state_cabac_copy.num_buffered_bytes = 0; + state_cabac_copy.bits_left = 23; + cabac = &state_cabac_copy; -double kvz_luma_mode_bits(const encoder_state_t *state, int8_t luma_mode, const int8_t *intra_preds, const uint8_t multi_ref_idx) -{ - double mode_bits = 0.0; - - int8_t mode_in_preds = -1; - for (int i = 0; i < INTRA_MPM_COUNT; ++i) { - if (luma_mode == intra_preds[i]) { - mode_in_preds = i; - break; + // Do cabac writes as normal + const int transp_off = num_mip_modes >> 1; + bool mip_flag = enable_mip; + const bool is_transposed = luma_mode >= transp_off ? true : false; + int8_t mip_mode = is_transposed ? luma_mode - transp_off : luma_mode; + // Write MIP flag + cabac->cur_ctx = &(cabac->ctx.mip_flag); + CABAC_BIN(cabac, mip_flag, "mip_flag"); + if (mip_flag) { + // Write MIP transpose flag & mode + CABAC_BIN_EP(cabac, is_transposed, "mip_transposed"); + kvz_cabac_encode_trunc_bin(cabac, mip_mode, transp_off); } + + // Writes done. Get bit cost out of cabac + mode_bits += (23 - state_cabac_copy.bits_left) + (state_cabac_copy.num_buffered_bytes << 3); // MIP_TODO: check what this bit shifting means. } - - bool enable_mrl = state->encoder_control->cfg.mrl; - uint8_t multi_ref_index = enable_mrl ? multi_ref_idx : 0; - - const cabac_ctx_t* ctx = &(state->cabac.ctx.intra_luma_mpm_flag_model); - - if (multi_ref_index == 0) { - mode_bits += CTX_ENTROPY_FBITS(ctx, mode_in_preds != -1); - } - - // Add MRL bits. - if (enable_mrl && MAX_REF_LINE_IDX > 1) { - ctx = &(state->cabac.ctx.multi_ref_line[0]); - mode_bits += CTX_ENTROPY_FBITS(ctx, multi_ref_index != 0); - - if (multi_ref_index != 0 && MAX_REF_LINE_IDX > 2) { - ctx = &(state->cabac.ctx.multi_ref_line[1]); - mode_bits += CTX_ENTROPY_FBITS(ctx, multi_ref_index != 1); + else { + int8_t mode_in_preds = -1; + for (int i = 0; i < INTRA_MPM_COUNT; ++i) { + if (luma_mode == intra_preds[i]) { + mode_in_preds = i; + break; + } } - } - if (mode_in_preds != -1 || multi_ref_index != 0) { - ctx = &(state->cabac.ctx.luma_planar_model[0]); + bool enable_mrl = state->encoder_control->cfg.mrl; + uint8_t multi_ref_index = enable_mrl ? multi_ref_idx : 0; + + const cabac_ctx_t* ctx = &(state->cabac.ctx.intra_luma_mpm_flag_model); + if (multi_ref_index == 0) { - mode_bits += CTX_ENTROPY_FBITS(ctx, mode_in_preds>0); + mode_bits += CTX_ENTROPY_FBITS(ctx, mode_in_preds != -1); + } + + // Add MRL bits. + if (enable_mrl && MAX_REF_LINE_IDX > 1) { + ctx = &(state->cabac.ctx.multi_ref_line[0]); + mode_bits += CTX_ENTROPY_FBITS(ctx, multi_ref_index != 0); + + if (multi_ref_index != 0 && MAX_REF_LINE_IDX > 2) { + ctx = &(state->cabac.ctx.multi_ref_line[1]); + mode_bits += CTX_ENTROPY_FBITS(ctx, multi_ref_index != 1); + } + } + + if (mode_in_preds != -1 || multi_ref_index != 0) { + ctx = &(state->cabac.ctx.luma_planar_model[0]); + if (multi_ref_index == 0) { + mode_bits += CTX_ENTROPY_FBITS(ctx, mode_in_preds > 0); + } + mode_bits += MIN(4.0, mode_in_preds); + } + else { + mode_bits += 6.0; } - mode_bits += MIN(4.0,mode_in_preds); - } else { - mode_bits += 6.0; } return mode_bits; diff --git a/src/search_intra.h b/src/search_intra.h index 7d25e09d..06ef77ec 100644 --- a/src/search_intra.h +++ b/src/search_intra.h @@ -43,10 +43,8 @@ #include "global.h" // IWYU pragma: keep #include "intra.h" -double kvz_mip_mode_bits(const encoder_state_t *state, int mip_mode, int num_mip_modes); - double kvz_luma_mode_bits(const encoder_state_t *state, - int8_t luma_mode, const int8_t *intra_preds, uint8_t multi_ref_idx); + int8_t luma_mode, const int8_t *intra_preds, uint8_t multi_ref_idx, const uint8_t num_mip_modes); double kvz_chroma_mode_bits(const encoder_state_t *state, int8_t chroma_mode, int8_t luma_mode);