diff --git a/src/intra.c b/src/intra.c index bd8396a3..1925b8d7 100644 --- a/src/intra.c +++ b/src/intra.c @@ -725,10 +725,8 @@ void kvz_mip_pred_upsampling_1D(int* const dst, const int* const src, const int* /** \brief Matrix weighted intra prediction. */ -// MIP_TODO: remove color parameter if it is not used void kvz_mip_predict(encoder_state_t const* const state, kvz_intra_references* const refs, const uint16_t pred_block_width, const uint16_t pred_block_height, - const color_t color, kvz_pixel* dst, const int mip_mode, const bool mip_transp) { diff --git a/src/intra.h b/src/intra.h index 666044c5..44ab404d 100644 --- a/src/intra.h +++ b/src/intra.h @@ -160,7 +160,6 @@ void kvz_mip_predict( kvz_intra_references * refs, const uint16_t width, const uint16_t height, - const color_t color, kvz_pixel* dst, const int mip_mode, const bool mip_transp diff --git a/src/search.c b/src/search.c index 99a9df27..1bdc67d5 100644 --- a/src/search.c +++ b/src/search.c @@ -504,8 +504,11 @@ static double calc_mode_bits(const encoder_state_t *state, kvz_intra_get_dir_luma_predictor(x, y, candidate_modes, cur_cu, left_cu, above_cu); } - // MIP_TODO: calculation of MIP mode cost if this CU has MIP enabled. - double mode_bits = kvz_luma_mode_bits(state, cur_cu->intra.mode, candidate_modes, cur_cu->intra.multi_ref_idx, 0, 0); + int width = LCU_WIDTH >> depth; + int height = width; // TODO: height for non-square blocks + int num_mip_modes_half = NUM_MIP_MODES_HALF(width, height); + int mip_flag_ctx_id = kvz_get_mip_flag_context(x, y, width, height, lcu, NULL); + double mode_bits = kvz_luma_mode_bits(state, cur_cu->intra.mode, candidate_modes, cur_cu->intra.multi_ref_idx, num_mip_modes_half, mip_flag_ctx_id); if (((depth == 4 && x % 8 && y % 8) || (depth != 4)) && state->encoder_control->chroma_format != KVZ_CSP_400) { mode_bits += kvz_chroma_mode_bits(state, cur_cu->intra.mode_chroma, cur_cu->intra.mode); diff --git a/src/search.h b/src/search.h index 3694a2ff..044aa8a2 100644 --- a/src/search.h +++ b/src/search.h @@ -44,6 +44,9 @@ #include "image.h" #include "constraint.h" +#define NUM_MIP_MODES_FULL(width, height) (width == 4 && height == 4) ? 32 : (width == 4 || height == 4 || (width == 8 && height == 8) ? 16 : 12) +#define NUM_MIP_MODES_HALF(width, height) NUM_MIP_MODES_FULL(width, height) >> 1 + void kvz_sort_modes(int8_t *__restrict modes, double *__restrict costs, uint8_t length); void kvz_sort_modes_intra_luma(int8_t *__restrict modes, int8_t *__restrict trafo, double *__restrict costs, uint8_t length); diff --git a/src/search_intra.c b/src/search_intra.c index fa60eeb9..3b597c11 100644 --- a/src/search_intra.c +++ b/src/search_intra.c @@ -719,7 +719,7 @@ static int8_t search_intra_rdo(encoder_state_t * const state, int8_t *intra_preds, int modes_to_check, int8_t modes[67], int8_t trafo[67], double costs[67], - int num_mip_modes, + int num_mip_modes_full, int8_t mip_modes[32], int8_t mip_trafo[32], double mip_costs[32], lcu_t *lcu, uint8_t multi_ref_idx) @@ -756,14 +756,15 @@ static int8_t search_intra_rdo(encoder_state_t * const state, } // MIP_TODO: implement this inside the standard intra for loop. Code duplication is bad. + // MIP_TODO: loop through normal intra modes first // MIP search - const int transp_off = num_mip_modes >> 1; - for (uint8_t mip_mode = 0; mip_mode < num_mip_modes; ++mip_mode) { - // Derive mip flag context id - uint8_t ctx_id = kvz_get_mip_flag_context(x_px, y_px, width, height, lcu, NULL); - int rdo_bitcost = kvz_luma_mode_bits(state, mip_modes[mip_mode], intra_preds, 0, num_mip_modes, ctx_id); + const int transp_off = num_mip_modes_full >> 1; + // Derive mip flag context id + uint8_t ctx_id = kvz_get_mip_flag_context(x_px, y_px, width, height, lcu, NULL); + for (uint8_t mip_mode = 0; mip_mode < num_mip_modes_full; ++mip_mode) { + int rdo_bitcost = kvz_luma_mode_bits(state, mip_modes[mip_mode], intra_preds, 0, transp_off, ctx_id); - mip_costs[mip_mode] = rdo_bitcost * (int)(state->lambda + 0.5); // MIP_TODO: check if this is also correct in the case when MIP is used. + mip_costs[mip_mode] = rdo_bitcost * (int)(state->lambda + 0.5); const bool is_transposed = (mip_modes[mip_mode] >= transp_off ? true : false); // There can be 32 MIP modes, but only mode numbers [0, 15] are ever written to bitstream. @@ -791,7 +792,6 @@ static int8_t search_intra_rdo(encoder_state_t * const state, mip_costs[mip_mode] += mode_cost; mip_trafo[mip_mode] = pred_cu.tr_idx; - // MIP_TODO: check if ET is viable when MIP is used // Early termination if no coefficients has to be coded if (state->encoder_control->cfg.intra_rdo_et && !cbf_is_set_any(pred_cu.cbf, depth)) { modes_to_check = mip_mode + 1; @@ -834,8 +834,8 @@ static int8_t search_intra_rdo(encoder_state_t * const state, // Update order according to new costs kvz_sort_modes_intra_luma(modes, trafo, costs, modes_to_check); bool use_mip = false; - if (num_mip_modes) { - kvz_sort_modes_intra_luma(mip_modes, mip_trafo, mip_costs, num_mip_modes); + if (num_mip_modes_full) { + kvz_sort_modes_intra_luma(mip_modes, mip_trafo, mip_costs, num_mip_modes_full); if (costs[0] > mip_costs[0]) { use_mip = true; } @@ -854,7 +854,7 @@ static int8_t search_intra_rdo(encoder_state_t * const state, pred_cu.intra.mode = mip_modes[0]; pred_cu.intra.mode_chroma = 0; pred_cu.intra.multi_ref_idx = 0; - int transp_off = num_mip_modes >> 1; + int transp_off = num_mip_modes_full >> 1; bool is_transposed = (mip_modes[0] >= transp_off ? true : false); int8_t pred_mode = (is_transposed ? mip_modes[0] - transp_off : mip_modes[0]); pred_cu.intra.mode = pred_mode; @@ -877,12 +877,14 @@ static int8_t search_intra_rdo(encoder_state_t * const state, } -double kvz_luma_mode_bits(const encoder_state_t *state, int8_t luma_mode, const int8_t *intra_preds, const uint8_t multi_ref_idx, const uint8_t num_mip_modes, int mip_flag_ctx_id) +double kvz_luma_mode_bits(const encoder_state_t *state, int8_t luma_mode, const int8_t *intra_preds, const uint8_t multi_ref_idx, const uint8_t num_mip_modes_half, int mip_flag_ctx_id) { double mode_bits = 0.0; - bool enable_mip = state->encoder_control->cfg.mip ? (num_mip_modes > 0 ? true : false) : false; + bool enable_mip = state->encoder_control->cfg.mip; + bool mip_flag = enable_mip ? (num_mip_modes_half > 0 ? true : false) : false; + // Mip flag cost must be calculated even if mip is not used in this block if (enable_mip) { // Make a copy of state->cabac for bit cost estimation. cabac_data_t state_cabac_copy; @@ -896,24 +898,25 @@ double kvz_luma_mode_bits(const encoder_state_t *state, int8_t luma_mode, const cabac = &state_cabac_copy; // Do cabac writes as normal - const int transp_off = num_mip_modes >> 1; - bool mip_flag = enable_mip; + const int transp_off = num_mip_modes_half; const bool is_transposed = luma_mode >= transp_off ? true : false; int8_t mip_mode = is_transposed ? luma_mode - transp_off : luma_mode; // Write MIP flag cabac->cur_ctx = &(cabac->ctx.mip_flag[mip_flag_ctx_id]); CABAC_BIN(cabac, mip_flag, "mip_flag"); + if (mip_flag) { // Write MIP transpose flag & mode CABAC_BIN_EP(cabac, is_transposed, "mip_transposed"); kvz_cabac_encode_trunc_bin(cabac, mip_mode, transp_off); } - - // Writes done. Get bit cost out of cabac - mode_bits += (23 - state_cabac_copy.bits_left) + (state_cabac_copy.num_buffered_bytes << 3); // MIP_TODO: check what this bit shifting means. + + // Write is done. Get bit cost out of cabac + mode_bits += (23 - state_cabac_copy.bits_left) + (state_cabac_copy.num_buffered_bytes << 3); } - else { + + if (!mip_flag) { int8_t mode_in_preds = -1; for (int i = 0; i < INTRA_MPM_COUNT; ++i) { if (luma_mode == intra_preds[i]) { @@ -1211,18 +1214,9 @@ void kvz_search_cu_intra(encoder_state_t * const state, mip_modes[i] = i; mip_costs[i] = MAX_INT; } - // MIP_TODO: check for illegal block sizes. - if (width == 4 && height == 4) { - // Mip size_id = 0. Num modes = 32 - num_mip_modes = 32; - } - else if (width == 4 || height == 4 || (width == 8 && height == 8)) { - // Mip size_id = 1. Num modes = 16 - num_mip_modes = 16; - } - else { - // Mip size_id = 2. Num modes = 12 - num_mip_modes = 12; + // MIP is not allowed for 64 x 4 or 4 x 64 blocks + if (!((width == 64 && height == 4) || (width == 4 && height == 64))) { + num_mip_modes = NUM_MIP_MODES_FULL(width, height); } }