[mip] Combine mip mode search loop into the original intra mode search loop. Some code clean up.

This commit is contained in:
siivonek 2022-02-15 11:24:01 +02:00
parent ac45a5299c
commit 9b04a6f302
2 changed files with 44 additions and 69 deletions

View file

@ -1422,7 +1422,7 @@ static void intra_recon_tb_leaf(
if(intra_mode < 68) { if(intra_mode < 68) {
if (use_mip) { if (use_mip) {
assert(intra_mode >= 0 && intra_mode < 16 && "MIP mode must be between [0, 15]"); assert(intra_mode >= 0 && intra_mode < 16 && "MIP mode must be between [0, 15]");
kvz_mip_predict(state, &refs, width, height, color, pred, intra_mode, mip_transp); kvz_mip_predict(state, &refs, width, height, pred, intra_mode, mip_transp);
} }
else { else {
kvz_intra_predict(state, &refs, log2width, intra_mode, color, pred, filter_boundary, multi_ref_index); kvz_intra_predict(state, &refs, log2width, intra_mode, color, pred, filter_boundary, multi_ref_index);

View file

@ -757,20 +757,27 @@ static int8_t search_intra_rdo(encoder_state_t * const state,
// MIP_TODO: implement this inside the standard intra for loop. Code duplication is bad. // MIP_TODO: implement this inside the standard intra for loop. Code duplication is bad.
// MIP_TODO: loop through normal intra modes first // MIP_TODO: loop through normal intra modes first
// MIP search
const int transp_off = num_mip_modes_full >> 1;
// Derive mip flag context id
uint8_t ctx_id = kvz_get_mip_flag_context(x_px, y_px, width, height, lcu, NULL);
for (uint8_t mip_mode = 0; mip_mode < num_mip_modes_full; ++mip_mode) {
int rdo_bitcost = kvz_luma_mode_bits(state, mip_modes[mip_mode], intra_preds, 0, transp_off, ctx_id);
mip_costs[mip_mode] = rdo_bitcost * (int)(state->lambda + 0.5); for (int mip = 0; mip <= 1; mip++) {
const int transp_off = mip ? num_mip_modes_full >> 1 : 0;
uint8_t ctx_id = mip ? kvz_get_mip_flag_context(x_px, y_px, width, height, lcu, NULL) : 0;
uint8_t multi_ref_index = mip ? 0 : multi_ref_idx;
int *num_modes = mip ? &num_mip_modes_full : &modes_to_check;
const bool is_transposed = (mip_modes[mip_mode] >= transp_off ? true : false); for (uint8_t i = 0; i < *num_modes; i++) {
int8_t mode = mip ? mip_modes[i] : modes[i];
double *mode_cost_p = mip ? &mip_costs[i] : &costs[i];
int8_t *mode_trafo_p = mip ? &mip_trafo[i] : &trafo[i];
int rdo_bitcost = kvz_luma_mode_bits(state, mode, intra_preds, multi_ref_index, transp_off, ctx_id);
*mode_cost_p = rdo_bitcost * (int)(state->lambda + 0.5);
// Mip related stuff
// There can be 32 MIP modes, but only mode numbers [0, 15] are ever written to bitstream. // There can be 32 MIP modes, but only mode numbers [0, 15] are ever written to bitstream.
// Half of the modes [16, 31] are indicated with the separate transpose flag. // Half of the modes [16, 31] are indicated with the separate transpose flag.
// Number of possible modes is less for larger blocks. // Number of possible modes is less for larger blocks.
int8_t pred_mode = (is_transposed ? mip_modes[mip_mode] - transp_off : mip_modes[mip_mode]); const bool is_transposed = mip ? (mode >= transp_off ? true : false) : 0;
int8_t pred_mode = (is_transposed ? mode - transp_off : mode);
// Perform transform split search and save mode RD cost for the best one. // Perform transform split search and save mode RD cost for the best one.
cu_info_t pred_cu; cu_info_t pred_cu;
@ -779,9 +786,9 @@ static int8_t search_intra_rdo(encoder_state_t * const state,
pred_cu.part_size = ((depth == MAX_PU_DEPTH) ? SIZE_NxN : SIZE_2Nx2N); // TODO: non-square blocks pred_cu.part_size = ((depth == MAX_PU_DEPTH) ? SIZE_NxN : SIZE_2Nx2N); // TODO: non-square blocks
pred_cu.intra.mode = pred_mode; pred_cu.intra.mode = pred_mode;
pred_cu.intra.mode_chroma = pred_mode; pred_cu.intra.mode_chroma = pred_mode;
pred_cu.intra.multi_ref_idx = 0; pred_cu.intra.multi_ref_idx = multi_ref_index;
pred_cu.intra.mip_is_transposed = is_transposed; pred_cu.intra.mip_is_transposed = is_transposed;
pred_cu.intra.mip_flag = true; pred_cu.intra.mip_flag = mip ? true : false;
pred_cu.joint_cb_cr = 0; pred_cu.joint_cb_cr = 0;
FILL(pred_cu.cbf, 0); FILL(pred_cu.cbf, 0);
@ -789,46 +796,15 @@ static int8_t search_intra_rdo(encoder_state_t * const state,
kvz_lcu_fill_trdepth(lcu, x_px, y_px, depth, depth); kvz_lcu_fill_trdepth(lcu, x_px, y_px, depth, depth);
double mode_cost = search_intra_trdepth(state, x_px, y_px, depth, tr_depth, pred_mode, MAX_INT, &pred_cu, lcu, NULL, -1); double mode_cost = search_intra_trdepth(state, x_px, y_px, depth, tr_depth, pred_mode, MAX_INT, &pred_cu, lcu, NULL, -1);
mip_costs[mip_mode] += mode_cost; *mode_cost_p += mode_cost;
mip_trafo[mip_mode] = pred_cu.tr_idx; *mode_trafo_p = pred_cu.tr_idx;
// Early termination if no coefficients has to be coded // Early termination if no coefficients has to be coded
if (state->encoder_control->cfg.intra_rdo_et && !cbf_is_set_any(pred_cu.cbf, depth)) { if (state->encoder_control->cfg.intra_rdo_et && !cbf_is_set_any(pred_cu.cbf, depth)) {
modes_to_check = mip_mode + 1; *num_modes = i + 1;
break; break;
} }
} }
for(int rdo_mode = 0; rdo_mode < modes_to_check; rdo_mode ++) {
int rdo_bitcost = kvz_luma_mode_bits(state, modes[rdo_mode], intra_preds, multi_ref_idx, 0, 0);
costs[rdo_mode] = rdo_bitcost * (int)(state->lambda + 0.5);
// Perform transform split search and save mode RD cost for the best one.
cu_info_t pred_cu;
pred_cu.depth = depth;
pred_cu.type = CU_INTRA;
pred_cu.part_size = ((depth == MAX_PU_DEPTH) ? SIZE_NxN : SIZE_2Nx2N);
pred_cu.intra.mode = modes[rdo_mode];
pred_cu.intra.mode_chroma = modes[rdo_mode];
pred_cu.intra.multi_ref_idx = multi_ref_idx;
pred_cu.intra.mip_is_transposed = false;
pred_cu.intra.mip_flag = false;
pred_cu.joint_cb_cr = 0;
FILL(pred_cu.cbf, 0);
// Reset transform split data in lcu.cu for this area.
kvz_lcu_fill_trdepth(lcu, x_px, y_px, depth, depth);
double mode_cost = search_intra_trdepth(state, x_px, y_px, depth, tr_depth, modes[rdo_mode], MAX_INT, &pred_cu, lcu, NULL, -1);
costs[rdo_mode] += mode_cost;
trafo[rdo_mode] = pred_cu.tr_idx;
// Early termination if no coefficients has to be coded
if (state->encoder_control->cfg.intra_rdo_et && !cbf_is_set_any(pred_cu.cbf, depth)) {
modes_to_check = rdo_mode + 1;
break;
}
} }
// Update order according to new costs // Update order according to new costs
@ -851,14 +827,12 @@ static int8_t search_intra_rdo(encoder_state_t * const state,
pred_cu.type = CU_INTRA; pred_cu.type = CU_INTRA;
pred_cu.part_size = ((depth == MAX_PU_DEPTH) ? SIZE_NxN : SIZE_2Nx2N); pred_cu.part_size = ((depth == MAX_PU_DEPTH) ? SIZE_NxN : SIZE_2Nx2N);
if (use_mip) { if (use_mip) {
pred_cu.intra.mode = mip_modes[0];
pred_cu.intra.mode_chroma = 0;
pred_cu.intra.multi_ref_idx = 0;
int transp_off = num_mip_modes_full >> 1; int transp_off = num_mip_modes_full >> 1;
bool is_transposed = (mip_modes[0] >= transp_off ? true : false); bool is_transposed = (mip_modes[0] >= transp_off ? true : false);
int8_t pred_mode = (is_transposed ? mip_modes[0] - transp_off : mip_modes[0]); int8_t pred_mode = (is_transposed ? mip_modes[0] - transp_off : mip_modes[0]);
pred_cu.intra.mode = pred_mode; pred_cu.intra.mode = pred_mode;
pred_cu.intra.mode_chroma = 0; pred_cu.intra.mode_chroma = pred_mode;
pred_cu.intra.multi_ref_idx = 0;
pred_cu.intra.mip_flag = true; pred_cu.intra.mip_flag = true;
pred_cu.intra.mip_is_transposed = is_transposed; pred_cu.intra.mip_is_transposed = is_transposed;
} }
@ -873,6 +847,7 @@ static int8_t search_intra_rdo(encoder_state_t * const state,
search_intra_trdepth(state, x_px, y_px, depth, tr_depth, pred_cu.intra.mode, MAX_INT, &pred_cu, lcu, NULL, trafo[0]); search_intra_trdepth(state, x_px, y_px, depth, tr_depth, pred_cu.intra.mode, MAX_INT, &pred_cu, lcu, NULL, trafo[0]);
} }
// TODO: modes to check does not consider mip modes. Maybe replace with array when mip search is optimized?
return modes_to_check; return modes_to_check;
} }